This commit is contained in:
thanhvc3 2024-06-19 00:10:51 +07:00
parent 417a38d2e5
commit 7194f8046c
2 changed files with 1 additions and 3 deletions

View File

@ -28,8 +28,6 @@ class TrainDataset(Dataset):
def __getitem__(self, idx): def __getitem__(self, idx):
ele = self.triples[idx] ele = self.triples[idx]
if (idx == 0):
print(ele)
triple, label, sub_samp = torch.LongTensor(ele['triple']), np.int32( triple, label, sub_samp = torch.LongTensor(ele['triple']), np.int32(
ele['label']), np.float32(ele['sub_samp']) ele['label']), np.float32(ele['sub_samp'])
trp_label = self.get_label(label) trp_label = self.get_label(label)
@ -51,7 +49,6 @@ class TrainDataset(Dataset):
@staticmethod @staticmethod
def collate_fn(data): def collate_fn(data):
triple = torch.stack([_[0] for _ in data], dim=0) triple = torch.stack([_[0] for _ in data], dim=0)
print(triple)
trp_label = torch.stack([_[1] for _ in data], dim=0) trp_label = torch.stack([_[1] for _ in data], dim=0)
if not data[0][2] is None: # one_to_x if not data[0][2] is None: # one_to_x

View File

@ -273,6 +273,7 @@ class Main(object):
labels: The label for each triple labels: The label for each triple
""" """
if split == 'train': if split == 'train':
print(triple.shape)
if self.p.train_strategy == 'one_to_x': if self.p.train_strategy == 'one_to_x':
triple, label, neg_ent, sub_samp = [ triple, label, neg_ent, sub_samp = [
_.to(self.device) for _ in batch] _.to(self.device) for _ in batch]