first version of negative sampling
This commit is contained in:
parent
c0d0be076f
commit
fcfeae2bd3
38
.vscode/launch.json
vendored
Normal file
38
.vscode/launch.json
vendored
Normal file
@ -0,0 +1,38 @@
|
||||
{
|
||||
// Use IntelliSense to learn about possible attributes.
|
||||
// Hover to view descriptions of existing attributes.
|
||||
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
|
||||
"version": "0.2.0",
|
||||
"configurations": [
|
||||
{
|
||||
"name": "Python: Current File",
|
||||
"type": "python",
|
||||
"request": "launch",
|
||||
"program": "${file}",
|
||||
"console": "integratedTerminal",
|
||||
"justMyCode": true,
|
||||
"args": [
|
||||
"--gpus", "1",
|
||||
"--max_epochs=16",
|
||||
"--num_workers=32",
|
||||
"--model_name_or_path", "bert-base-uncased",
|
||||
"--accumulate_grad_batches", "1",
|
||||
"--model_class", "BertKGC",
|
||||
"--batch_size", "64",
|
||||
"--checkpoint", "/kg_374/Relphormer/pretrain/output/FB15k-237/epoch=15-step=19299-Eval/hits10=0.96.ckpt",
|
||||
"--pretrain", "0",
|
||||
"--bce", "0",
|
||||
"--check_val_every_n_epoch", "1",
|
||||
"--overwrite_cache",
|
||||
"--data_dir", "dataset/FB15k-237",
|
||||
"--eval_batch_size", "128",
|
||||
"--max_seq_length", "128",
|
||||
"--lr", "3e-5",
|
||||
"--max_triplet", "64",
|
||||
"--add_attn_bias", "True",
|
||||
"--use_global_node", "True",
|
||||
"--fast_dev_run", "True",
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
@ -10,7 +10,7 @@ from transformers.tokenization_utils_base import (BatchEncoding,
|
||||
PreTrainedTokenizerBase)
|
||||
|
||||
from .base_data_module import BaseDataModule
|
||||
from .processor import KGProcessor, get_dataset
|
||||
from .processor import KGProcessor, get_dataset, getNegativeEntityId
|
||||
import transformers
|
||||
transformers.logging.set_verbosity_error()
|
||||
|
||||
@ -105,6 +105,7 @@ class DataCollatorForSeq2Seq:
|
||||
if isinstance(l, int):
|
||||
new_labels[i][l] = 1
|
||||
else:
|
||||
if (l[0] != getNegativeEntityId()):
|
||||
for j in l:
|
||||
new_labels[i][j] = 1
|
||||
labels = new_labels
|
||||
|
@ -5,7 +5,7 @@ import contextlib
|
||||
import sys
|
||||
|
||||
from collections import Counter
|
||||
from multiprocessing import Pool
|
||||
from multiprocessing import Pool, synchronize
|
||||
from torch._C import HOIST_CONV_PACKED_PARAMS
|
||||
from torch.utils.data import Dataset, Sampler, IterableDataset
|
||||
from collections import defaultdict
|
||||
@ -235,6 +235,31 @@ class DataProcessor(object):
|
||||
|
||||
import copy
|
||||
|
||||
from collections import deque
|
||||
import threading
|
||||
|
||||
|
||||
class _LiveState(type):
|
||||
_instances = {}
|
||||
_lock = threading.Lock()
|
||||
|
||||
def __call__(cls, *args, **kwargs):
|
||||
if cls not in cls._instances:
|
||||
with cls._lock:
|
||||
if cls not in cls._instances:
|
||||
cls._instances[cls] = super(_LiveState, cls).__call__(*args, **kwargs)
|
||||
|
||||
return cls._instances[cls]
|
||||
|
||||
class LiveState(metaclass=_LiveState):
|
||||
|
||||
def __init__(self):
|
||||
self._pool_size = 16
|
||||
self._deq = deque(maxlen=self._pool_size)
|
||||
def put(self, item):
|
||||
self._deq.append(item)
|
||||
def get(self):
|
||||
return list(self._deq)
|
||||
|
||||
def solve_get_knowledge_store(line, set_type="train", pretrain=1):
|
||||
"""
|
||||
@ -364,6 +389,53 @@ def solve(line, set_type="train", pretrain=1, max_triplet=32):
|
||||
InputExample(guid=guid, text_a="[MASK]", text_b="[PAD]", text_c = "[PAD]", text_d = list(masked_head_seq), label=lmap(lambda x: ent2id[x], b), real_label=ent2id[line[0]], en=[line[1], line[2]], en_id = [rel2id[line[1]], ent2id[line[2]]], rel=rel2id[line[1]], text_d_id = list(masked_head_seq_id), graph_inf = masked_head_graph_list))
|
||||
examples.append(
|
||||
InputExample(guid=guid, text_a="[PAD]", text_b="[PAD]", text_c = "[MASK]", text_d = list(masked_tail_seq), label=lmap(lambda x: ent2id[x], a), real_label=ent2id[line[2]], en=[line[0], line[1]], en_id = [ent2id[line[0]], rel2id[line[1]]], rel=rel2id[line[1]], text_d_id = list(masked_tail_seq_id), graph_inf = masked_tail_graph_list))
|
||||
|
||||
liveState = LiveState()
|
||||
_prev = liveState.get()
|
||||
|
||||
if (set_type == "train" and len(_prev) > 0):
|
||||
|
||||
for prev_ent in _prev:
|
||||
|
||||
z = head_filter_entities["\t".join([prev_ent,line[1]])]
|
||||
if (len(z) == 0):
|
||||
z.append('[NEG]')
|
||||
z.append(line[2])
|
||||
z.append(line[0])
|
||||
masked_neg_seq = set()
|
||||
masked_neg_seq_id = set()
|
||||
|
||||
masked_neg_graph_list = masked_tail_neighbor.get("\t".join([prev_ent, line[1]]), []) if len(masked_tail_neighbor.get("\t".join([prev_ent, line[1]]), [])) < max_triplet else \
|
||||
random.sample(masked_tail_neighbor["\t".join([prev_ent, line[1]])], max_triplet)
|
||||
|
||||
for item in masked_neg_graph_list:
|
||||
masked_neg_seq.add(item[0])
|
||||
masked_neg_seq.add(item[1])
|
||||
masked_neg_seq.add(item[2])
|
||||
masked_neg_seq_id.add(ent2id[item[0]])
|
||||
masked_neg_seq_id.add(rel2id[item[1]])
|
||||
masked_neg_seq_id.add(ent2id[item[2]])
|
||||
|
||||
masked_neg_seq = masked_neg_seq.difference({line[0]})
|
||||
masked_neg_seq = masked_neg_seq.difference({line[2]})
|
||||
masked_neg_seq = masked_neg_seq.difference({line[1]})
|
||||
masked_neg_seq = masked_neg_seq.difference(prev_ent)
|
||||
masked_neg_seq_id = masked_neg_seq_id.difference({ent2id[line[0]]})
|
||||
masked_neg_seq_id = masked_neg_seq_id.difference({rel2id[line[1]]})
|
||||
masked_neg_seq_id = masked_neg_seq_id.difference({ent2id[line[2]]})
|
||||
masked_neg_seq_id = masked_neg_seq_id.difference(prev_ent)
|
||||
# examples.append(
|
||||
# InputExample(guid=guid, text_a="[MASK]", text_b=' '.join(text_b.split(' ')[:16]) + " [PAD]", text_c = "[PAD]" + " " + ' '.join(text_c.split(' ')[:16]), text_d = masked_head_seq, label=lmap(lambda x: ent2id[x], b), real_label=ent2id[line[0]], en=[rel2id[line[1]], ent2id[line[2]]], rel=rel2id[line[1]]))
|
||||
# examples.append(
|
||||
# InputExample(guid=guid, text_a="[PAD] ", text_b=' '.join(text_b.split(' ')[:16]) + " [PAD]", text_c = "[MASK]" +" " + ' '.join(text_a.split(' ')[:16]), text_d = masked_tail_seq, label=lmap(lambda x: ent2id[x], a), real_label=ent2id[line[2]], en=[ent2id[line[0]], rel2id[line[1]]], rel=rel2id[line[1]]))
|
||||
examples.append(
|
||||
InputExample(guid=guid, text_a="[MASK]", text_b="[PAD]", text_c = "[PAD]", text_d = list(masked_neg_seq), label=lmap(lambda x: ent2id[x], z), real_label=ent2id[line[0]], en=[line[1], prev_ent], en_id = [rel2id[line[1]], ent2id[prev_ent]], rel=rel2id[line[1]], text_d_id = list(masked_neg_seq_id), graph_inf = masked_neg_graph_list))
|
||||
examples.append(
|
||||
InputExample(guid=guid, text_a="[MASK]", text_b="[PAD]", text_c = "[PAD]", text_d = list(masked_neg_seq), label=lmap(lambda x: ent2id[x], z), real_label=ent2id[line[2]], en=[line[1], prev_ent], en_id = [rel2id[line[1]], ent2id[prev_ent]], rel=rel2id[line[1]], text_d_id = list(masked_neg_seq_id), graph_inf = masked_neg_graph_list))
|
||||
|
||||
|
||||
liveState.put(line[0])
|
||||
liveState.put(line[2])
|
||||
return examples
|
||||
|
||||
def filter_init(head, tail, t1,t2, ent2id_, ent2token_, rel2id_, masked_head_neighbor_, masked_tail_neighbor_, rel2token_):
|
||||
@ -377,6 +449,7 @@ def filter_init(head, tail, t1,t2, ent2id_, ent2token_, rel2id_, masked_head_nei
|
||||
global masked_head_neighbor
|
||||
global masked_tail_neighbor
|
||||
global rel2token
|
||||
global negativeEntity
|
||||
|
||||
head_filter_entities = head
|
||||
tail_filter_entities = tail
|
||||
@ -388,11 +461,20 @@ def filter_init(head, tail, t1,t2, ent2id_, ent2token_, rel2id_, masked_head_nei
|
||||
masked_head_neighbor = masked_head_neighbor_
|
||||
masked_tail_neighbor = masked_tail_neighbor_
|
||||
rel2token = rel2token_
|
||||
negativeEntity = ent2id['[NEG]']
|
||||
|
||||
def delete_init(ent2text_):
|
||||
global ent2text
|
||||
ent2text = ent2text_
|
||||
|
||||
def getEntityIdByName(name):
|
||||
global ent2id
|
||||
return ent2id[name]
|
||||
|
||||
def getNegativeEntityId():
|
||||
global negativeEntity
|
||||
return negativeEntity
|
||||
|
||||
|
||||
class KGProcessor(DataProcessor):
|
||||
"""Processor for knowledge graph data set."""
|
||||
@ -443,6 +525,7 @@ class KGProcessor(DataProcessor):
|
||||
"""Gets all entities in the knowledge graph."""
|
||||
with open(self.entity_path, 'r') as f:
|
||||
lines = f.readlines()
|
||||
lines.append('[NEG]\t')
|
||||
entities = []
|
||||
for line in lines:
|
||||
entities.append(line.strip().split("\t")[0])
|
||||
@ -469,6 +552,7 @@ class KGProcessor(DataProcessor):
|
||||
ent2text_with_type = {}
|
||||
with open(self.entity_path, 'r') as f:
|
||||
ent_lines = f.readlines()
|
||||
ent_lines.append('[NEG]\t')
|
||||
for line in ent_lines:
|
||||
temp = line.strip().split('\t')
|
||||
try:
|
||||
@ -579,6 +663,7 @@ class KGProcessor(DataProcessor):
|
||||
else:
|
||||
annotate_ = partial(
|
||||
solve,
|
||||
set_type=set_type,
|
||||
pretrain=self.args.pretrain,
|
||||
max_triplet=self.args.max_triplet
|
||||
)
|
||||
|
@ -81,10 +81,13 @@ class TransformerLitModel(BaseLitModel):
|
||||
pos = batch.pop("pos")
|
||||
try:
|
||||
en = batch.pop("en")
|
||||
self.print("__DEBUG__: en", en)
|
||||
rel = batch.pop("rel")
|
||||
self.print("__DEBUG__: rel", rel)
|
||||
except KeyError:
|
||||
pass
|
||||
input_ids = batch['input_ids']
|
||||
self.print("__DEBUG__: input_ids", input_ids)
|
||||
|
||||
distance_attention = torch.stack([pad_distance(len(input_ids[i]) - len(distance) - 1, distance) for i, distance in enumerate(batch['distance_attention'])])
|
||||
distance = batch.pop("distance_attention")
|
||||
@ -382,13 +385,17 @@ class CombineEntityEmbeddingLitModel(UseEntityEmbeddingLitModel):
|
||||
self.id2entity = {}
|
||||
with open("./dataset/FB15k-237/entity2textlong.txt", 'r') as file:
|
||||
cnt = 0
|
||||
for line in file.readlines():
|
||||
lines = file.readlines()
|
||||
lines.append('[NEG]\t')
|
||||
for line in lines:
|
||||
e, d = line.strip().split("\t")
|
||||
self.id2entity[cnt] = e
|
||||
cnt += 1
|
||||
self.id2entity_t = {}
|
||||
with open("./dataset/FB15k-237/entity2text.txt", 'r') as file:
|
||||
for line in file.readlines():
|
||||
lines = file.readlines()
|
||||
lines.append('[NEG]\t')
|
||||
for line in lines:
|
||||
e, d = line.strip().split("\t")
|
||||
self.id2entity_t[e] = d
|
||||
for k, v in self.id2entity.items():
|
||||
|
6
main.py
6
main.py
@ -98,6 +98,7 @@ def main():
|
||||
tokenizer = data.tokenizer
|
||||
|
||||
lit_model = litmodel_class(args=args, model=model, tokenizer=tokenizer, data_config=data.get_config())
|
||||
print("__DEBUG__: Initialized")
|
||||
if args.checkpoint:
|
||||
lit_model.load_state_dict(torch.load(args.checkpoint, map_location="cpu")["state_dict"], strict=False)
|
||||
|
||||
@ -120,9 +121,12 @@ def main():
|
||||
callbacks = [early_callback, model_checkpoint]
|
||||
|
||||
# args.weights_summary = "full" # Print full summary of the model
|
||||
trainer = pl.Trainer.from_argparse_args(args, callbacks=callbacks, logger=logger, default_root_dir="training/logs", accelerator="ddp")
|
||||
trainer = pl.Trainer.from_argparse_args(args, callbacks=callbacks, logger=logger, default_root_dir="training/logs")
|
||||
|
||||
print('__DEBUG__: Init trainer')
|
||||
|
||||
if "EntityEmbedding" not in lit_model.__class__.__name__:
|
||||
print('__DEBUG__: Fit trainer')
|
||||
trainer.fit(lit_model, datamodule=data)
|
||||
path = model_checkpoint.best_model_path
|
||||
lit_model.load_state_dict(torch.load(path)["state_dict"], strict=False)
|
||||
|
@ -10,7 +10,7 @@ from transformers.tokenization_utils_base import (BatchEncoding,
|
||||
PreTrainedTokenizerBase)
|
||||
|
||||
from .base_data_module import BaseDataModule
|
||||
from .processor import KGProcessor, get_dataset
|
||||
from .processor import KGProcessor, get_dataset, getNegativeEntityId
|
||||
import transformers
|
||||
transformers.logging.set_verbosity_error()
|
||||
|
||||
@ -106,6 +106,7 @@ class DataCollatorForSeq2Seq:
|
||||
if isinstance(l, int):
|
||||
new_labels[i][l] = 1
|
||||
else:
|
||||
if (l[0] != getNegativeEntityId()):
|
||||
for j in l:
|
||||
new_labels[i][j] = 1
|
||||
labels = new_labels
|
||||
|
@ -314,6 +314,7 @@ def filter_init(head, tail, t1,t2, ent2id_, ent2token_, rel2id_):
|
||||
global ent2id
|
||||
global ent2token
|
||||
global rel2id
|
||||
global negativeEntity
|
||||
|
||||
head_filter_entities = head
|
||||
tail_filter_entities = tail
|
||||
@ -322,11 +323,19 @@ def filter_init(head, tail, t1,t2, ent2id_, ent2token_, rel2id_):
|
||||
ent2id = ent2id_
|
||||
ent2token = ent2token_
|
||||
rel2id = rel2id_
|
||||
negativeEntity = ent2id['[NEG]']
|
||||
|
||||
def delete_init(ent2text_):
|
||||
global ent2text
|
||||
ent2text = ent2text_
|
||||
|
||||
def getEntityIdByName(name):
|
||||
global ent2id
|
||||
return ent2id[name]
|
||||
|
||||
def getNegativeEntityId():
|
||||
global negativeEntity
|
||||
return negativeEntity
|
||||
|
||||
class KGProcessor(DataProcessor):
|
||||
"""Processor for knowledge graph data set."""
|
||||
@ -377,6 +386,7 @@ class KGProcessor(DataProcessor):
|
||||
"""Gets all entities in the knowledge graph."""
|
||||
with open(self.entity_path, 'r') as f:
|
||||
lines = f.readlines()
|
||||
lines.append('[NEG]\t')
|
||||
entities = []
|
||||
for line in lines:
|
||||
entities.append(line.strip().split("\t")[0])
|
||||
@ -403,6 +413,7 @@ class KGProcessor(DataProcessor):
|
||||
ent2text_with_type = {}
|
||||
with open(self.entity_path, 'r') as f:
|
||||
ent_lines = f.readlines()
|
||||
ent_lines.append('[NEG]\t')
|
||||
for line in ent_lines:
|
||||
temp = line.strip().split('\t')
|
||||
try:
|
||||
|
@ -364,13 +364,17 @@ class CombineEntityEmbeddingLitModel(UseEntityEmbeddingLitModel):
|
||||
self.id2entity = {}
|
||||
with open("./dataset/FB15k-237/entity2textlong.txt", 'r') as file:
|
||||
cnt = 0
|
||||
for line in file.readlines():
|
||||
lines = file.readlines()
|
||||
lines.append('[NEG]\t')
|
||||
for line in lines:
|
||||
e, d = line.strip().split("\t")
|
||||
self.id2entity[cnt] = e
|
||||
cnt += 1
|
||||
self.id2entity_t = {}
|
||||
with open("./dataset/FB15k-237/entity2text.txt", 'r') as file:
|
||||
for line in file.readlines():
|
||||
lines = file.readlines()
|
||||
lines.append('[NEG]\t')
|
||||
for line in lines:
|
||||
e, d = line.strip().split("\t")
|
||||
self.id2entity_t[e] = d
|
||||
for k, v in self.id2entity.items():
|
||||
|
0
pretrain/scripts/pretrain_fb15k-237.sh
Normal file → Executable file
0
pretrain/scripts/pretrain_fb15k-237.sh
Normal file → Executable file
0
scripts/fb15k-237/fb15k-237.sh
Normal file → Executable file
0
scripts/fb15k-237/fb15k-237.sh
Normal file → Executable file
Loading…
Reference in New Issue
Block a user