137 lines
4.2 KiB
Python
137 lines
4.2 KiB
Python
import torch
|
|
import numpy as np
|
|
|
|
from torch.utils.data import Dataset
|
|
|
|
|
|
class TrainDataset(Dataset):
|
|
"""
|
|
Training Dataset class.
|
|
Parameters
|
|
----------
|
|
triples: The triples used for training the model
|
|
params: Parameters for the experiments
|
|
|
|
Returns
|
|
-------
|
|
A training Dataset class instance used by DataLoader
|
|
"""
|
|
|
|
def __init__(self, triples, params):
|
|
self.triples = triples
|
|
self.p = params
|
|
self.strategy = self.p.train_strategy
|
|
self.entities = np.arange(self.p.num_ent, dtype=np.int32)
|
|
|
|
def __len__(self):
|
|
return len(self.triples)
|
|
|
|
def __getitem__(self, idx):
|
|
ele = self.triples[idx]
|
|
triple, label, sub_samp = torch.LongTensor(ele['triple']), np.int32(
|
|
ele['label']), np.float32(ele['sub_samp'])
|
|
trp_label = self.get_label(label)
|
|
|
|
if self.p.lbl_smooth != 0.0:
|
|
trp_label = (1.0 - self.p.lbl_smooth) * \
|
|
trp_label + (1.0/self.p.num_ent)
|
|
|
|
if self.strategy == 'one_to_n':
|
|
return triple, trp_label, None, None
|
|
|
|
elif self.strategy == 'one_to_x':
|
|
sub_samp = torch.FloatTensor([sub_samp])
|
|
neg_ent = torch.LongTensor(self.get_neg_ent(triple, label))
|
|
return triple, trp_label, neg_ent, sub_samp
|
|
else:
|
|
raise NotImplementedError
|
|
|
|
@staticmethod
|
|
def collate_fn(data):
|
|
triple = torch.stack([_[0] for _ in data], dim=0)
|
|
trp_label = torch.stack([_[1] for _ in data], dim=0)
|
|
|
|
if not data[0][2] is None: # one_to_x
|
|
neg_ent = torch.stack([_[2] for _ in data], dim=0)
|
|
sub_samp = torch.cat([_[3] for _ in data], dim=0)
|
|
return triple, trp_label, neg_ent, sub_samp
|
|
else:
|
|
return triple, trp_label
|
|
|
|
def get_neg_ent(self, triple, label):
|
|
def get(triple, label):
|
|
if self.strategy == 'one_to_x':
|
|
pos_obj = triple[2]
|
|
mask = np.ones([self.p.num_ent], dtype=np.bool)
|
|
mask[label] = 0
|
|
neg_ent = np.int32(np.random.choice(
|
|
self.entities[mask], self.p.neg_num, replace=False)).reshape([-1])
|
|
neg_ent = np.concatenate((pos_obj.reshape([-1]), neg_ent))
|
|
else:
|
|
pos_obj = label
|
|
mask = np.ones([self.p.num_ent], dtype=np.bool)
|
|
mask[label] = 0
|
|
neg_ent = np.int32(np.random.choice(
|
|
self.entities[mask], self.p.neg_num - len(label), replace=False)).reshape([-1])
|
|
neg_ent = np.concatenate((pos_obj.reshape([-1]), neg_ent))
|
|
|
|
if len(neg_ent) > self.p.neg_num:
|
|
import pdb
|
|
pdb.set_trace()
|
|
|
|
return neg_ent
|
|
|
|
neg_ent = get(triple, label)
|
|
return neg_ent
|
|
|
|
def get_label(self, label):
|
|
if self.strategy == 'one_to_n':
|
|
y = np.zeros([self.p.num_ent], dtype=np.float32)
|
|
for e2 in label:
|
|
y[e2] = 1.0
|
|
elif self.strategy == 'one_to_x':
|
|
y = [1] + [0] * self.p.neg_num
|
|
else:
|
|
raise NotImplementedError
|
|
return torch.FloatTensor(y)
|
|
|
|
|
|
class TestDataset(Dataset):
|
|
"""
|
|
Evaluation Dataset class.
|
|
Parameters
|
|
----------
|
|
triples: The triples used for evaluating the model
|
|
params: Parameters for the experiments
|
|
|
|
Returns
|
|
-------
|
|
An evaluation Dataset class instance used by DataLoader for model evaluation
|
|
"""
|
|
|
|
def __init__(self, triples, params):
|
|
self.triples = triples
|
|
self.p = params
|
|
|
|
def __len__(self):
|
|
return len(self.triples)
|
|
|
|
def __getitem__(self, idx):
|
|
ele = self.triples[idx]
|
|
triple, label = torch.LongTensor(ele['triple']), np.int32(ele['label'])
|
|
label = self.get_label(label)
|
|
|
|
return triple, label
|
|
|
|
@staticmethod
|
|
def collate_fn(data):
|
|
triple = torch.stack([_[0] for _ in data], dim=0)
|
|
label = torch.stack([_[1] for _ in data], dim=0)
|
|
return triple, label
|
|
|
|
def get_label(self, label):
|
|
y = np.zeros([self.p.num_ent], dtype=np.float32)
|
|
for e2 in label:
|
|
y[e2] = 1.0
|
|
return torch.FloatTensor(y)
|