test
This commit is contained in:
parent
417a38d2e5
commit
7194f8046c
@ -28,8 +28,6 @@ class TrainDataset(Dataset):
|
|||||||
|
|
||||||
def __getitem__(self, idx):
|
def __getitem__(self, idx):
|
||||||
ele = self.triples[idx]
|
ele = self.triples[idx]
|
||||||
if (idx == 0):
|
|
||||||
print(ele)
|
|
||||||
triple, label, sub_samp = torch.LongTensor(ele['triple']), np.int32(
|
triple, label, sub_samp = torch.LongTensor(ele['triple']), np.int32(
|
||||||
ele['label']), np.float32(ele['sub_samp'])
|
ele['label']), np.float32(ele['sub_samp'])
|
||||||
trp_label = self.get_label(label)
|
trp_label = self.get_label(label)
|
||||||
@ -51,7 +49,6 @@ class TrainDataset(Dataset):
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def collate_fn(data):
|
def collate_fn(data):
|
||||||
triple = torch.stack([_[0] for _ in data], dim=0)
|
triple = torch.stack([_[0] for _ in data], dim=0)
|
||||||
print(triple)
|
|
||||||
trp_label = torch.stack([_[1] for _ in data], dim=0)
|
trp_label = torch.stack([_[1] for _ in data], dim=0)
|
||||||
|
|
||||||
if not data[0][2] is None: # one_to_x
|
if not data[0][2] is None: # one_to_x
|
||||||
|
1
main.py
1
main.py
@ -273,6 +273,7 @@ class Main(object):
|
|||||||
labels: The label for each triple
|
labels: The label for each triple
|
||||||
"""
|
"""
|
||||||
if split == 'train':
|
if split == 'train':
|
||||||
|
print(triple.shape)
|
||||||
if self.p.train_strategy == 'one_to_x':
|
if self.p.train_strategy == 'one_to_x':
|
||||||
triple, label, neg_ent, sub_samp = [
|
triple, label, neg_ent, sub_samp = [
|
||||||
_.to(self.device) for _ in batch]
|
_.to(self.device) for _ in batch]
|
||||||
|
Loading…
Reference in New Issue
Block a user