This commit is contained in:
thanhvc3 2024-06-19 00:03:14 +07:00
parent 39734013c4
commit 936c37d0f6

View File

@ -138,13 +138,13 @@ class Main(object):
for sub, rel, obj, nt_rel in self.data['train']: for sub, rel, obj, nt_rel in self.data['train']:
rel_inv = rel + self.p.num_rel rel_inv = rel + self.p.num_rel
sub_samp = len(self.sr2o[(sub, rel, nt_rel)]) + \ sub_samp = len(self.sr2o[(sub, rel, nt_rel)]) + \
len(self.sr2o[(obj, rel_inv)]) len(self.sr2o[(obj, rel_inv, nt_rel + self.p.num_rel)])
sub_samp = np.sqrt(1/sub_samp) sub_samp = np.sqrt(1/sub_samp)
self.triples['train'].append({'triple': ( self.triples['train'].append({'triple': (
sub, rel, obj, nt_rel), 'label': self.sr2o[(sub, rel, nt_rel)], 'sub_samp': sub_samp}) sub, rel, obj, nt_rel), 'label': self.sr2o[(sub, rel, nt_rel)], 'sub_samp': sub_samp})
self.triples['train'].append({'triple': ( self.triples['train'].append({'triple': (
obj, rel_inv, sub, nt_rel + self.p.num_rel), 'label': self.sr2o[(obj, rel_inv, nt_rel)], 'sub_samp': sub_samp}) obj, rel_inv, sub, nt_rel + self.p.num_rel), 'label': self.sr2o[(obj, rel_inv, nt_rel + self.p.num_rel)], 'sub_samp': sub_samp})
for split in ['test', 'valid']: for split in ['test', 'valid']:
for sub, rel, obj, nt_rel in self.data[split]: for sub, rel, obj, nt_rel in self.data[split]: