From 7194f8046cf3e124a25cdd86bf81c047a079f9da Mon Sep 17 00:00:00 2001 From: thanhvc3 Date: Wed, 19 Jun 2024 00:10:51 +0700 Subject: [PATCH] test --- data_loader.py | 3 --- main.py | 1 + 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/data_loader.py b/data_loader.py index 5eb7d73..03e0af5 100644 --- a/data_loader.py +++ b/data_loader.py @@ -28,8 +28,6 @@ class TrainDataset(Dataset): def __getitem__(self, idx): ele = self.triples[idx] - if (idx == 0): - print(ele) triple, label, sub_samp = torch.LongTensor(ele['triple']), np.int32( ele['label']), np.float32(ele['sub_samp']) trp_label = self.get_label(label) @@ -51,7 +49,6 @@ class TrainDataset(Dataset): @staticmethod def collate_fn(data): triple = torch.stack([_[0] for _ in data], dim=0) - print(triple) trp_label = torch.stack([_[1] for _ in data], dim=0) if not data[0][2] is None: # one_to_x diff --git a/main.py b/main.py index 3a23e65..9e8a992 100644 --- a/main.py +++ b/main.py @@ -273,6 +273,7 @@ class Main(object): labels: The label for each triple """ if split == 'train': + print(triple.shape) if self.p.train_strategy == 'one_to_x': triple, label, neg_ent, sub_samp = [ _.to(self.device) for _ in batch]