diff --git a/main.py b/main.py index 0cc5d42..d5bc996 100644 --- a/main.py +++ b/main.py @@ -417,8 +417,8 @@ class Main(object): obj_pred = [] obj_pred_score = [] for step, batch in enumerate(train_iter): - sub, rel, obj, label = self.read_batch(batch, split) - pred = self.model.forward(sub, rel, None, 'one_to_n') + sub, rel, obj, nt_rel, label = self.read_batch(batch, split) + pred = self.model.forward(sub, rel, nt_rel, None, 'one_to_n') b_range = torch.arange(pred.size()[0], device=self.device) target_pred = pred[b_range, obj] pred = torch.where(label.byte(), torch.zeros_like(pred), pred) @@ -691,7 +691,7 @@ if __name__ == "__main__": collate_fn=TrainDataset.collate_fn )) for step, batch in enumerate(dataloader): - sub, rel, obj, label, neg_ent, sub_samp = model.read_batch( + sub, rel, obj, nt_rel, label, neg_ent, sub_samp = model.read_batch( batch, 'train') if (neg_ent is None):