first version of negative sampling
This commit is contained in:
@ -10,7 +10,7 @@ from transformers.tokenization_utils_base import (BatchEncoding,
|
||||
PreTrainedTokenizerBase)
|
||||
|
||||
from .base_data_module import BaseDataModule
|
||||
from .processor import KGProcessor, get_dataset
|
||||
from .processor import KGProcessor, get_dataset, getNegativeEntityId
|
||||
import transformers
|
||||
transformers.logging.set_verbosity_error()
|
||||
|
||||
@ -106,8 +106,9 @@ class DataCollatorForSeq2Seq:
|
||||
if isinstance(l, int):
|
||||
new_labels[i][l] = 1
|
||||
else:
|
||||
for j in l:
|
||||
new_labels[i][j] = 1
|
||||
if (l[0] != getNegativeEntityId()):
|
||||
for j in l:
|
||||
new_labels[i][j] = 1
|
||||
labels = new_labels
|
||||
|
||||
features = self.tokenizer.pad(
|
||||
|
@ -314,6 +314,7 @@ def filter_init(head, tail, t1,t2, ent2id_, ent2token_, rel2id_):
|
||||
global ent2id
|
||||
global ent2token
|
||||
global rel2id
|
||||
global negativeEntity
|
||||
|
||||
head_filter_entities = head
|
||||
tail_filter_entities = tail
|
||||
@ -322,11 +323,19 @@ def filter_init(head, tail, t1,t2, ent2id_, ent2token_, rel2id_):
|
||||
ent2id = ent2id_
|
||||
ent2token = ent2token_
|
||||
rel2id = rel2id_
|
||||
negativeEntity = ent2id['[NEG]']
|
||||
|
||||
def delete_init(ent2text_):
|
||||
global ent2text
|
||||
ent2text = ent2text_
|
||||
|
||||
def getEntityIdByName(name):
|
||||
global ent2id
|
||||
return ent2id[name]
|
||||
|
||||
def getNegativeEntityId():
|
||||
global negativeEntity
|
||||
return negativeEntity
|
||||
|
||||
class KGProcessor(DataProcessor):
|
||||
"""Processor for knowledge graph data set."""
|
||||
@ -377,6 +386,7 @@ class KGProcessor(DataProcessor):
|
||||
"""Gets all entities in the knowledge graph."""
|
||||
with open(self.entity_path, 'r') as f:
|
||||
lines = f.readlines()
|
||||
lines.append('[NEG]\t')
|
||||
entities = []
|
||||
for line in lines:
|
||||
entities.append(line.strip().split("\t")[0])
|
||||
@ -403,6 +413,7 @@ class KGProcessor(DataProcessor):
|
||||
ent2text_with_type = {}
|
||||
with open(self.entity_path, 'r') as f:
|
||||
ent_lines = f.readlines()
|
||||
ent_lines.append('[NEG]\t')
|
||||
for line in ent_lines:
|
||||
temp = line.strip().split('\t')
|
||||
try:
|
||||
|
@ -364,13 +364,17 @@ class CombineEntityEmbeddingLitModel(UseEntityEmbeddingLitModel):
|
||||
self.id2entity = {}
|
||||
with open("./dataset/FB15k-237/entity2textlong.txt", 'r') as file:
|
||||
cnt = 0
|
||||
for line in file.readlines():
|
||||
lines = file.readlines()
|
||||
lines.append('[NEG]\t')
|
||||
for line in lines:
|
||||
e, d = line.strip().split("\t")
|
||||
self.id2entity[cnt] = e
|
||||
cnt += 1
|
||||
self.id2entity_t = {}
|
||||
with open("./dataset/FB15k-237/entity2text.txt", 'r') as file:
|
||||
for line in file.readlines():
|
||||
lines = file.readlines()
|
||||
lines.append('[NEG]\t')
|
||||
for line in lines:
|
||||
e, d = line.strip().split("\t")
|
||||
self.id2entity_t[e] = d
|
||||
for k, v in self.id2entity.items():
|
||||
|
0
pretrain/scripts/pretrain_fb15k-237.sh
Normal file → Executable file
0
pretrain/scripts/pretrain_fb15k-237.sh
Normal file → Executable file
Reference in New Issue
Block a user