diff --git a/main.py b/main.py index d5bc996..3a23e65 100644 --- a/main.py +++ b/main.py @@ -138,13 +138,13 @@ class Main(object): for sub, rel, obj, nt_rel in self.data['train']: rel_inv = rel + self.p.num_rel sub_samp = len(self.sr2o[(sub, rel, nt_rel)]) + \ - len(self.sr2o[(obj, rel_inv)]) + len(self.sr2o[(obj, rel_inv, nt_rel + self.p.num_rel)]) sub_samp = np.sqrt(1/sub_samp) self.triples['train'].append({'triple': ( sub, rel, obj, nt_rel), 'label': self.sr2o[(sub, rel, nt_rel)], 'sub_samp': sub_samp}) self.triples['train'].append({'triple': ( - obj, rel_inv, sub, nt_rel + self.p.num_rel), 'label': self.sr2o[(obj, rel_inv, nt_rel)], 'sub_samp': sub_samp}) + obj, rel_inv, sub, nt_rel + self.p.num_rel), 'label': self.sr2o[(obj, rel_inv, nt_rel + self.p.num_rel)], 'sub_samp': sub_samp}) for split in ['test', 'valid']: for sub, rel, obj, nt_rel in self.data[split]: