test
This commit is contained in:
parent
417a38d2e5
commit
7194f8046c
@ -28,8 +28,6 @@ class TrainDataset(Dataset):
|
||||
|
||||
def __getitem__(self, idx):
|
||||
ele = self.triples[idx]
|
||||
if (idx == 0):
|
||||
print(ele)
|
||||
triple, label, sub_samp = torch.LongTensor(ele['triple']), np.int32(
|
||||
ele['label']), np.float32(ele['sub_samp'])
|
||||
trp_label = self.get_label(label)
|
||||
@ -51,7 +49,6 @@ class TrainDataset(Dataset):
|
||||
@staticmethod
|
||||
def collate_fn(data):
|
||||
triple = torch.stack([_[0] for _ in data], dim=0)
|
||||
print(triple)
|
||||
trp_label = torch.stack([_[1] for _ in data], dim=0)
|
||||
|
||||
if not data[0][2] is None: # one_to_x
|
||||
|
Loading…
Reference in New Issue
Block a user