This commit is contained in:
thanhvc3 2024-06-19 00:11:37 +07:00
parent 7194f8046c
commit 7448528eec

View File

@ -273,13 +273,14 @@ class Main(object):
labels: The label for each triple labels: The label for each triple
""" """
if split == 'train': if split == 'train':
print(triple.shape)
if self.p.train_strategy == 'one_to_x': if self.p.train_strategy == 'one_to_x':
triple, label, neg_ent, sub_samp = [ triple, label, neg_ent, sub_samp = [
_.to(self.device) for _ in batch] _.to(self.device) for _ in batch]
print(triple.shape)
return triple[:, 0], triple[:, 1], triple[:, 2], triple[:, 3], label, neg_ent, sub_samp return triple[:, 0], triple[:, 1], triple[:, 2], triple[:, 3], label, neg_ent, sub_samp
else: else:
triple, label = [_.to(self.device) for _ in batch] triple, label = [_.to(self.device) for _ in batch]
print(triple.shape)
return triple[:, 0], triple[:, 1], triple[:, 2], triple[:, 3], label, None, None return triple[:, 0], triple[:, 1], triple[:, 2], triple[:, 3], label, None, None
else: else:
triple, label = [_.to(self.device) for _ in batch] triple, label = [_.to(self.device) for _ in batch]