first version of negative sampling

This commit is contained in:
2023-01-08 14:25:31 +00:00
parent c0d0be076f
commit fcfeae2bd3
10 changed files with 163 additions and 12 deletions

View File

@ -81,10 +81,13 @@ class TransformerLitModel(BaseLitModel):
pos = batch.pop("pos")
try:
en = batch.pop("en")
self.print("__DEBUG__: en", en)
rel = batch.pop("rel")
self.print("__DEBUG__: rel", rel)
except KeyError:
pass
input_ids = batch['input_ids']
self.print("__DEBUG__: input_ids", input_ids)
distance_attention = torch.stack([pad_distance(len(input_ids[i]) - len(distance) - 1, distance) for i, distance in enumerate(batch['distance_attention'])])
distance = batch.pop("distance_attention")
@ -382,13 +385,17 @@ class CombineEntityEmbeddingLitModel(UseEntityEmbeddingLitModel):
self.id2entity = {}
with open("./dataset/FB15k-237/entity2textlong.txt", 'r') as file:
cnt = 0
for line in file.readlines():
lines = file.readlines()
lines.append('[NEG]\t')
for line in lines:
e, d = line.strip().split("\t")
self.id2entity[cnt] = e
cnt += 1
self.id2entity_t = {}
with open("./dataset/FB15k-237/entity2text.txt", 'r') as file:
for line in file.readlines():
lines = file.readlines()
lines.append('[NEG]\t')
for line in lines:
e, d = line.strip().split("\t")
self.id2entity_t[e] = d
for k, v in self.id2entity.items():