diff --git a/main.py b/main.py index 9e8a992..9909869 100644 --- a/main.py +++ b/main.py @@ -273,13 +273,14 @@ class Main(object): labels: The label for each triple """ if split == 'train': - print(triple.shape) if self.p.train_strategy == 'one_to_x': triple, label, neg_ent, sub_samp = [ _.to(self.device) for _ in batch] + print(triple.shape) return triple[:, 0], triple[:, 1], triple[:, 2], triple[:, 3], label, neg_ent, sub_samp else: triple, label = [_.to(self.device) for _ in batch] + print(triple.shape) return triple[:, 0], triple[:, 1], triple[:, 2], triple[:, 3], label, None, None else: triple, label = [_.to(self.device) for _ in batch]