runnable v1

This commit is contained in:
2023-01-14 10:40:58 +00:00
parent fcfeae2bd3
commit 45cd8e1396
5 changed files with 136 additions and 14 deletions

View File

@ -81,13 +81,13 @@ class TransformerLitModel(BaseLitModel):
pos = batch.pop("pos")
try:
en = batch.pop("en")
self.print("__DEBUG__: en", en)
# self.print("__DEBUG__: en", en)
rel = batch.pop("rel")
self.print("__DEBUG__: rel", rel)
# self.print("__DEBUG__: rel", rel)
except KeyError:
pass
input_ids = batch['input_ids']
self.print("__DEBUG__: input_ids", input_ids)
# self.print("__DEBUG__: input_ids", input_ids)
distance_attention = torch.stack([pad_distance(len(input_ids[i]) - len(distance) - 1, distance) for i, distance in enumerate(batch['distance_attention'])])
distance = batch.pop("distance_attention")