diff --git a/data_loader.py b/data_loader.py index c41a3ed..5eb7d73 100644 --- a/data_loader.py +++ b/data_loader.py @@ -51,6 +51,7 @@ class TrainDataset(Dataset): @staticmethod def collate_fn(data): triple = torch.stack([_[0] for _ in data], dim=0) + print(triple) trp_label = torch.stack([_[1] for _ in data], dim=0) if not data[0][2] is None: # one_to_x