runnable v1
This commit is contained in:
@ -81,13 +81,13 @@ class TransformerLitModel(BaseLitModel):
|
||||
pos = batch.pop("pos")
|
||||
try:
|
||||
en = batch.pop("en")
|
||||
self.print("__DEBUG__: en", en)
|
||||
# self.print("__DEBUG__: en", en)
|
||||
rel = batch.pop("rel")
|
||||
self.print("__DEBUG__: rel", rel)
|
||||
# self.print("__DEBUG__: rel", rel)
|
||||
except KeyError:
|
||||
pass
|
||||
input_ids = batch['input_ids']
|
||||
self.print("__DEBUG__: input_ids", input_ids)
|
||||
# self.print("__DEBUG__: input_ids", input_ids)
|
||||
|
||||
distance_attention = torch.stack([pad_distance(len(input_ids[i]) - len(distance) - 1, distance) for i, distance in enumerate(batch['distance_attention'])])
|
||||
distance = batch.pop("distance_attention")
|
||||
|
Reference in New Issue
Block a user