This commit is contained in:
thanhvc3 2024-06-19 00:08:27 +07:00
parent 03f42561c6
commit 417a38d2e5

View File

@ -51,6 +51,7 @@ class TrainDataset(Dataset):
@staticmethod @staticmethod
def collate_fn(data): def collate_fn(data):
triple = torch.stack([_[0] for _ in data], dim=0) triple = torch.stack([_[0] for _ in data], dim=0)
print(triple)
trp_label = torch.stack([_[1] for _ in data], dim=0) trp_label = torch.stack([_[1] for _ in data], dim=0)
if not data[0][2] is None: # one_to_x if not data[0][2] is None: # one_to_x