test
This commit is contained in:
parent
bb9856ecd1
commit
39734013c4
6
main.py
6
main.py
@ -417,8 +417,8 @@ class Main(object):
|
|||||||
obj_pred = []
|
obj_pred = []
|
||||||
obj_pred_score = []
|
obj_pred_score = []
|
||||||
for step, batch in enumerate(train_iter):
|
for step, batch in enumerate(train_iter):
|
||||||
sub, rel, obj, label = self.read_batch(batch, split)
|
sub, rel, obj, nt_rel, label = self.read_batch(batch, split)
|
||||||
pred = self.model.forward(sub, rel, None, 'one_to_n')
|
pred = self.model.forward(sub, rel, nt_rel, None, 'one_to_n')
|
||||||
b_range = torch.arange(pred.size()[0], device=self.device)
|
b_range = torch.arange(pred.size()[0], device=self.device)
|
||||||
target_pred = pred[b_range, obj]
|
target_pred = pred[b_range, obj]
|
||||||
pred = torch.where(label.byte(), torch.zeros_like(pred), pred)
|
pred = torch.where(label.byte(), torch.zeros_like(pred), pred)
|
||||||
@ -691,7 +691,7 @@ if __name__ == "__main__":
|
|||||||
collate_fn=TrainDataset.collate_fn
|
collate_fn=TrainDataset.collate_fn
|
||||||
))
|
))
|
||||||
for step, batch in enumerate(dataloader):
|
for step, batch in enumerate(dataloader):
|
||||||
sub, rel, obj, label, neg_ent, sub_samp = model.read_batch(
|
sub, rel, obj, nt_rel, label, neg_ent, sub_samp = model.read_batch(
|
||||||
batch, 'train')
|
batch, 'train')
|
||||||
|
|
||||||
if (neg_ent is None):
|
if (neg_ent is None):
|
||||||
|
Loading…
Reference in New Issue
Block a user