test
This commit is contained in:
parent
39734013c4
commit
936c37d0f6
4
main.py
4
main.py
@ -138,13 +138,13 @@ class Main(object):
|
|||||||
for sub, rel, obj, nt_rel in self.data['train']:
|
for sub, rel, obj, nt_rel in self.data['train']:
|
||||||
rel_inv = rel + self.p.num_rel
|
rel_inv = rel + self.p.num_rel
|
||||||
sub_samp = len(self.sr2o[(sub, rel, nt_rel)]) + \
|
sub_samp = len(self.sr2o[(sub, rel, nt_rel)]) + \
|
||||||
len(self.sr2o[(obj, rel_inv)])
|
len(self.sr2o[(obj, rel_inv, nt_rel + self.p.num_rel)])
|
||||||
sub_samp = np.sqrt(1/sub_samp)
|
sub_samp = np.sqrt(1/sub_samp)
|
||||||
|
|
||||||
self.triples['train'].append({'triple': (
|
self.triples['train'].append({'triple': (
|
||||||
sub, rel, obj, nt_rel), 'label': self.sr2o[(sub, rel, nt_rel)], 'sub_samp': sub_samp})
|
sub, rel, obj, nt_rel), 'label': self.sr2o[(sub, rel, nt_rel)], 'sub_samp': sub_samp})
|
||||||
self.triples['train'].append({'triple': (
|
self.triples['train'].append({'triple': (
|
||||||
obj, rel_inv, sub, nt_rel + self.p.num_rel), 'label': self.sr2o[(obj, rel_inv, nt_rel)], 'sub_samp': sub_samp})
|
obj, rel_inv, sub, nt_rel + self.p.num_rel), 'label': self.sr2o[(obj, rel_inv, nt_rel + self.p.num_rel)], 'sub_samp': sub_samp})
|
||||||
|
|
||||||
for split in ['test', 'valid']:
|
for split in ['test', 'valid']:
|
||||||
for sub, rel, obj, nt_rel in self.data[split]:
|
for sub, rel, obj, nt_rel in self.data[split]:
|
||||||
|
Loading…
Reference in New Issue
Block a user