From 7448528eec781259e4a22ebef54b73585c2bf58d Mon Sep 17 00:00:00 2001 From: thanhvc3 Date: Wed, 19 Jun 2024 00:11:37 +0700 Subject: [PATCH] test --- main.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/main.py b/main.py index 9e8a992..9909869 100644 --- a/main.py +++ b/main.py @@ -273,13 +273,14 @@ class Main(object): labels: The label for each triple """ if split == 'train': - print(triple.shape) if self.p.train_strategy == 'one_to_x': triple, label, neg_ent, sub_samp = [ _.to(self.device) for _ in batch] + print(triple.shape) return triple[:, 0], triple[:, 1], triple[:, 2], triple[:, 3], label, neg_ent, sub_samp else: triple, label = [_.to(self.device) for _ in batch] + print(triple.shape) return triple[:, 0], triple[:, 1], triple[:, 2], triple[:, 3], label, None, None else: triple, label = [_.to(self.device) for _ in batch]