test
This commit is contained in:
parent
7194f8046c
commit
7448528eec
3
main.py
3
main.py
@ -273,13 +273,14 @@ class Main(object):
|
|||||||
labels: The label for each triple
|
labels: The label for each triple
|
||||||
"""
|
"""
|
||||||
if split == 'train':
|
if split == 'train':
|
||||||
print(triple.shape)
|
|
||||||
if self.p.train_strategy == 'one_to_x':
|
if self.p.train_strategy == 'one_to_x':
|
||||||
triple, label, neg_ent, sub_samp = [
|
triple, label, neg_ent, sub_samp = [
|
||||||
_.to(self.device) for _ in batch]
|
_.to(self.device) for _ in batch]
|
||||||
|
print(triple.shape)
|
||||||
return triple[:, 0], triple[:, 1], triple[:, 2], triple[:, 3], label, neg_ent, sub_samp
|
return triple[:, 0], triple[:, 1], triple[:, 2], triple[:, 3], label, neg_ent, sub_samp
|
||||||
else:
|
else:
|
||||||
triple, label = [_.to(self.device) for _ in batch]
|
triple, label = [_.to(self.device) for _ in batch]
|
||||||
|
print(triple.shape)
|
||||||
return triple[:, 0], triple[:, 1], triple[:, 2], triple[:, 3], label, None, None
|
return triple[:, 0], triple[:, 1], triple[:, 2], triple[:, 3], label, None, None
|
||||||
else:
|
else:
|
||||||
triple, label = [_.to(self.device) for _ in batch]
|
triple, label = [_.to(self.device) for _ in batch]
|
||||||
|
Loading…
Reference in New Issue
Block a user