first version of negative sampling
This commit is contained in:
@ -81,10 +81,13 @@ class TransformerLitModel(BaseLitModel):
|
||||
pos = batch.pop("pos")
|
||||
try:
|
||||
en = batch.pop("en")
|
||||
self.print("__DEBUG__: en", en)
|
||||
rel = batch.pop("rel")
|
||||
self.print("__DEBUG__: rel", rel)
|
||||
except KeyError:
|
||||
pass
|
||||
input_ids = batch['input_ids']
|
||||
self.print("__DEBUG__: input_ids", input_ids)
|
||||
|
||||
distance_attention = torch.stack([pad_distance(len(input_ids[i]) - len(distance) - 1, distance) for i, distance in enumerate(batch['distance_attention'])])
|
||||
distance = batch.pop("distance_attention")
|
||||
@ -382,13 +385,17 @@ class CombineEntityEmbeddingLitModel(UseEntityEmbeddingLitModel):
|
||||
self.id2entity = {}
|
||||
with open("./dataset/FB15k-237/entity2textlong.txt", 'r') as file:
|
||||
cnt = 0
|
||||
for line in file.readlines():
|
||||
lines = file.readlines()
|
||||
lines.append('[NEG]\t')
|
||||
for line in lines:
|
||||
e, d = line.strip().split("\t")
|
||||
self.id2entity[cnt] = e
|
||||
cnt += 1
|
||||
self.id2entity_t = {}
|
||||
with open("./dataset/FB15k-237/entity2text.txt", 'r') as file:
|
||||
for line in file.readlines():
|
||||
lines = file.readlines()
|
||||
lines.append('[NEG]\t')
|
||||
for line in lines:
|
||||
e, d = line.strip().split("\t")
|
||||
self.id2entity_t[e] = d
|
||||
for k, v in self.id2entity.items():
|
||||
|
Reference in New Issue
Block a user