diff --git a/data_loader.py b/data_loader.py index 03e0af5..c41a3ed 100644 --- a/data_loader.py +++ b/data_loader.py @@ -28,6 +28,8 @@ class TrainDataset(Dataset): def __getitem__(self, idx): ele = self.triples[idx] + if (idx == 0): + print(ele) triple, label, sub_samp = torch.LongTensor(ele['triple']), np.int32( ele['label']), np.float32(ele['sub_samp']) trp_label = self.get_label(label)