first version of negative sampling

This commit is contained in:
2023-01-08 14:25:31 +00:00
parent c0d0be076f
commit fcfeae2bd3
10 changed files with 163 additions and 12 deletions

View File

@ -10,7 +10,7 @@ from transformers.tokenization_utils_base import (BatchEncoding,
PreTrainedTokenizerBase)
from .base_data_module import BaseDataModule
from .processor import KGProcessor, get_dataset
from .processor import KGProcessor, get_dataset, getNegativeEntityId
import transformers
transformers.logging.set_verbosity_error()
@ -106,8 +106,9 @@ class DataCollatorForSeq2Seq:
if isinstance(l, int):
new_labels[i][l] = 1
else:
for j in l:
new_labels[i][j] = 1
if (l[0] != getNegativeEntityId()):
for j in l:
new_labels[i][j] = 1
labels = new_labels
features = self.tokenizer.pad(

View File

@ -314,6 +314,7 @@ def filter_init(head, tail, t1,t2, ent2id_, ent2token_, rel2id_):
global ent2id
global ent2token
global rel2id
global negativeEntity
head_filter_entities = head
tail_filter_entities = tail
@ -322,11 +323,19 @@ def filter_init(head, tail, t1,t2, ent2id_, ent2token_, rel2id_):
ent2id = ent2id_
ent2token = ent2token_
rel2id = rel2id_
negativeEntity = ent2id['[NEG]']
def delete_init(ent2text_):
global ent2text
ent2text = ent2text_
def getEntityIdByName(name):
global ent2id
return ent2id[name]
def getNegativeEntityId():
global negativeEntity
return negativeEntity
class KGProcessor(DataProcessor):
"""Processor for knowledge graph data set."""
@ -377,6 +386,7 @@ class KGProcessor(DataProcessor):
"""Gets all entities in the knowledge graph."""
with open(self.entity_path, 'r') as f:
lines = f.readlines()
lines.append('[NEG]\t')
entities = []
for line in lines:
entities.append(line.strip().split("\t")[0])
@ -403,6 +413,7 @@ class KGProcessor(DataProcessor):
ent2text_with_type = {}
with open(self.entity_path, 'r') as f:
ent_lines = f.readlines()
ent_lines.append('[NEG]\t')
for line in ent_lines:
temp = line.strip().split('\t')
try:

View File

@ -364,13 +364,17 @@ class CombineEntityEmbeddingLitModel(UseEntityEmbeddingLitModel):
self.id2entity = {}
with open("./dataset/FB15k-237/entity2textlong.txt", 'r') as file:
cnt = 0
for line in file.readlines():
lines = file.readlines()
lines.append('[NEG]\t')
for line in lines:
e, d = line.strip().split("\t")
self.id2entity[cnt] = e
cnt += 1
self.id2entity_t = {}
with open("./dataset/FB15k-237/entity2text.txt", 'r') as file:
for line in file.readlines():
lines = file.readlines()
lines.append('[NEG]\t')
for line in lines:
e, d = line.strip().split("\t")
self.id2entity_t[e] = d
for k, v in self.id2entity.items():

0
pretrain/scripts/pretrain_fb15k-237.sh Normal file → Executable file
View File