test
This commit is contained in:
parent
7194f8046c
commit
7448528eec
3
main.py
3
main.py
@ -273,13 +273,14 @@ class Main(object):
|
||||
labels: The label for each triple
|
||||
"""
|
||||
if split == 'train':
|
||||
print(triple.shape)
|
||||
if self.p.train_strategy == 'one_to_x':
|
||||
triple, label, neg_ent, sub_samp = [
|
||||
_.to(self.device) for _ in batch]
|
||||
print(triple.shape)
|
||||
return triple[:, 0], triple[:, 1], triple[:, 2], triple[:, 3], label, neg_ent, sub_samp
|
||||
else:
|
||||
triple, label = [_.to(self.device) for _ in batch]
|
||||
print(triple.shape)
|
||||
return triple[:, 0], triple[:, 1], triple[:, 2], triple[:, 3], label, None, None
|
||||
else:
|
||||
triple, label = [_.to(self.device) for _ in batch]
|
||||
|
Loading…
Reference in New Issue
Block a user