From fcfeae2bd3006c4ef15e76c28cceb90b3ace1801 Mon Sep 17 00:00:00 2001 From: Cong Thanh Vu Date: Sun, 8 Jan 2023 14:25:31 +0000 Subject: [PATCH] first version of negative sampling --- .vscode/launch.json | 38 +++++++++++ data/data_module.py | 7 ++- data/processor.py | 87 +++++++++++++++++++++++++- lit_models/transformer.py | 11 +++- main.py | 6 +- pretrain/data/data_module.py | 7 ++- pretrain/data/processor.py | 11 ++++ pretrain/lit_models/transformer.py | 8 ++- pretrain/scripts/pretrain_fb15k-237.sh | 0 scripts/fb15k-237/fb15k-237.sh | 0 10 files changed, 163 insertions(+), 12 deletions(-) create mode 100644 .vscode/launch.json mode change 100644 => 100755 pretrain/scripts/pretrain_fb15k-237.sh mode change 100644 => 100755 scripts/fb15k-237/fb15k-237.sh diff --git a/.vscode/launch.json b/.vscode/launch.json new file mode 100644 index 0000000..34bb231 --- /dev/null +++ b/.vscode/launch.json @@ -0,0 +1,38 @@ +{ + // Use IntelliSense to learn about possible attributes. + // Hover to view descriptions of existing attributes. + // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 + "version": "0.2.0", + "configurations": [ + { + "name": "Python: Current File", + "type": "python", + "request": "launch", + "program": "${file}", + "console": "integratedTerminal", + "justMyCode": true, + "args": [ + "--gpus", "1", + "--max_epochs=16", + "--num_workers=32", + "--model_name_or_path", "bert-base-uncased", + "--accumulate_grad_batches", "1", + "--model_class", "BertKGC", + "--batch_size", "64", + "--checkpoint", "/kg_374/Relphormer/pretrain/output/FB15k-237/epoch=15-step=19299-Eval/hits10=0.96.ckpt", + "--pretrain", "0", + "--bce", "0", + "--check_val_every_n_epoch", "1", + "--overwrite_cache", + "--data_dir", "dataset/FB15k-237", + "--eval_batch_size", "128", + "--max_seq_length", "128", + "--lr", "3e-5", + "--max_triplet", "64", + "--add_attn_bias", "True", + "--use_global_node", "True", + "--fast_dev_run", "True", + ] + } + ] +} \ No newline at end of file diff --git a/data/data_module.py b/data/data_module.py index 1520c5f..953a5e7 100644 --- a/data/data_module.py +++ b/data/data_module.py @@ -10,7 +10,7 @@ from transformers.tokenization_utils_base import (BatchEncoding, PreTrainedTokenizerBase) from .base_data_module import BaseDataModule -from .processor import KGProcessor, get_dataset +from .processor import KGProcessor, get_dataset, getNegativeEntityId import transformers transformers.logging.set_verbosity_error() @@ -105,8 +105,9 @@ class DataCollatorForSeq2Seq: if isinstance(l, int): new_labels[i][l] = 1 else: - for j in l: - new_labels[i][j] = 1 + if (l[0] != getNegativeEntityId()): + for j in l: + new_labels[i][j] = 1 labels = new_labels features = self.tokenizer.pad( diff --git a/data/processor.py b/data/processor.py index 1b1f972..971cecb 100644 --- a/data/processor.py +++ b/data/processor.py @@ -5,7 +5,7 @@ import contextlib import sys from collections import Counter -from multiprocessing import Pool +from multiprocessing import Pool, synchronize from torch._C import HOIST_CONV_PACKED_PARAMS from torch.utils.data import Dataset, Sampler, IterableDataset from collections import defaultdict @@ -235,6 +235,31 @@ class DataProcessor(object): import copy +from collections import deque +import threading + + +class _LiveState(type): + _instances = {} + _lock = threading.Lock() + + def __call__(cls, *args, **kwargs): + if cls not in cls._instances: + with cls._lock: + if cls not in cls._instances: + cls._instances[cls] = super(_LiveState, cls).__call__(*args, **kwargs) + + return cls._instances[cls] + +class LiveState(metaclass=_LiveState): + + def __init__(self): + self._pool_size = 16 + self._deq = deque(maxlen=self._pool_size) + def put(self, item): + self._deq.append(item) + def get(self): + return list(self._deq) def solve_get_knowledge_store(line, set_type="train", pretrain=1): """ @@ -364,6 +389,53 @@ def solve(line, set_type="train", pretrain=1, max_triplet=32): InputExample(guid=guid, text_a="[MASK]", text_b="[PAD]", text_c = "[PAD]", text_d = list(masked_head_seq), label=lmap(lambda x: ent2id[x], b), real_label=ent2id[line[0]], en=[line[1], line[2]], en_id = [rel2id[line[1]], ent2id[line[2]]], rel=rel2id[line[1]], text_d_id = list(masked_head_seq_id), graph_inf = masked_head_graph_list)) examples.append( InputExample(guid=guid, text_a="[PAD]", text_b="[PAD]", text_c = "[MASK]", text_d = list(masked_tail_seq), label=lmap(lambda x: ent2id[x], a), real_label=ent2id[line[2]], en=[line[0], line[1]], en_id = [ent2id[line[0]], rel2id[line[1]]], rel=rel2id[line[1]], text_d_id = list(masked_tail_seq_id), graph_inf = masked_tail_graph_list)) + + liveState = LiveState() + _prev = liveState.get() + + if (set_type == "train" and len(_prev) > 0): + + for prev_ent in _prev: + + z = head_filter_entities["\t".join([prev_ent,line[1]])] + if (len(z) == 0): + z.append('[NEG]') + z.append(line[2]) + z.append(line[0]) + masked_neg_seq = set() + masked_neg_seq_id = set() + + masked_neg_graph_list = masked_tail_neighbor.get("\t".join([prev_ent, line[1]]), []) if len(masked_tail_neighbor.get("\t".join([prev_ent, line[1]]), [])) < max_triplet else \ + random.sample(masked_tail_neighbor["\t".join([prev_ent, line[1]])], max_triplet) + + for item in masked_neg_graph_list: + masked_neg_seq.add(item[0]) + masked_neg_seq.add(item[1]) + masked_neg_seq.add(item[2]) + masked_neg_seq_id.add(ent2id[item[0]]) + masked_neg_seq_id.add(rel2id[item[1]]) + masked_neg_seq_id.add(ent2id[item[2]]) + + masked_neg_seq = masked_neg_seq.difference({line[0]}) + masked_neg_seq = masked_neg_seq.difference({line[2]}) + masked_neg_seq = masked_neg_seq.difference({line[1]}) + masked_neg_seq = masked_neg_seq.difference(prev_ent) + masked_neg_seq_id = masked_neg_seq_id.difference({ent2id[line[0]]}) + masked_neg_seq_id = masked_neg_seq_id.difference({rel2id[line[1]]}) + masked_neg_seq_id = masked_neg_seq_id.difference({ent2id[line[2]]}) + masked_neg_seq_id = masked_neg_seq_id.difference(prev_ent) + # examples.append( + # InputExample(guid=guid, text_a="[MASK]", text_b=' '.join(text_b.split(' ')[:16]) + " [PAD]", text_c = "[PAD]" + " " + ' '.join(text_c.split(' ')[:16]), text_d = masked_head_seq, label=lmap(lambda x: ent2id[x], b), real_label=ent2id[line[0]], en=[rel2id[line[1]], ent2id[line[2]]], rel=rel2id[line[1]])) + # examples.append( + # InputExample(guid=guid, text_a="[PAD] ", text_b=' '.join(text_b.split(' ')[:16]) + " [PAD]", text_c = "[MASK]" +" " + ' '.join(text_a.split(' ')[:16]), text_d = masked_tail_seq, label=lmap(lambda x: ent2id[x], a), real_label=ent2id[line[2]], en=[ent2id[line[0]], rel2id[line[1]]], rel=rel2id[line[1]])) + examples.append( + InputExample(guid=guid, text_a="[MASK]", text_b="[PAD]", text_c = "[PAD]", text_d = list(masked_neg_seq), label=lmap(lambda x: ent2id[x], z), real_label=ent2id[line[0]], en=[line[1], prev_ent], en_id = [rel2id[line[1]], ent2id[prev_ent]], rel=rel2id[line[1]], text_d_id = list(masked_neg_seq_id), graph_inf = masked_neg_graph_list)) + examples.append( + InputExample(guid=guid, text_a="[MASK]", text_b="[PAD]", text_c = "[PAD]", text_d = list(masked_neg_seq), label=lmap(lambda x: ent2id[x], z), real_label=ent2id[line[2]], en=[line[1], prev_ent], en_id = [rel2id[line[1]], ent2id[prev_ent]], rel=rel2id[line[1]], text_d_id = list(masked_neg_seq_id), graph_inf = masked_neg_graph_list)) + + + liveState.put(line[0]) + liveState.put(line[2]) return examples def filter_init(head, tail, t1,t2, ent2id_, ent2token_, rel2id_, masked_head_neighbor_, masked_tail_neighbor_, rel2token_): @@ -377,6 +449,7 @@ def filter_init(head, tail, t1,t2, ent2id_, ent2token_, rel2id_, masked_head_nei global masked_head_neighbor global masked_tail_neighbor global rel2token + global negativeEntity head_filter_entities = head tail_filter_entities = tail @@ -388,11 +461,20 @@ def filter_init(head, tail, t1,t2, ent2id_, ent2token_, rel2id_, masked_head_nei masked_head_neighbor = masked_head_neighbor_ masked_tail_neighbor = masked_tail_neighbor_ rel2token = rel2token_ + negativeEntity = ent2id['[NEG]'] def delete_init(ent2text_): global ent2text ent2text = ent2text_ +def getEntityIdByName(name): + global ent2id + return ent2id[name] + +def getNegativeEntityId(): + global negativeEntity + return negativeEntity + class KGProcessor(DataProcessor): """Processor for knowledge graph data set.""" @@ -443,6 +525,7 @@ class KGProcessor(DataProcessor): """Gets all entities in the knowledge graph.""" with open(self.entity_path, 'r') as f: lines = f.readlines() + lines.append('[NEG]\t') entities = [] for line in lines: entities.append(line.strip().split("\t")[0]) @@ -469,6 +552,7 @@ class KGProcessor(DataProcessor): ent2text_with_type = {} with open(self.entity_path, 'r') as f: ent_lines = f.readlines() + ent_lines.append('[NEG]\t') for line in ent_lines: temp = line.strip().split('\t') try: @@ -579,6 +663,7 @@ class KGProcessor(DataProcessor): else: annotate_ = partial( solve, + set_type=set_type, pretrain=self.args.pretrain, max_triplet=self.args.max_triplet ) diff --git a/lit_models/transformer.py b/lit_models/transformer.py index 6075a28..3cb610c 100644 --- a/lit_models/transformer.py +++ b/lit_models/transformer.py @@ -81,10 +81,13 @@ class TransformerLitModel(BaseLitModel): pos = batch.pop("pos") try: en = batch.pop("en") + self.print("__DEBUG__: en", en) rel = batch.pop("rel") + self.print("__DEBUG__: rel", rel) except KeyError: pass input_ids = batch['input_ids'] + self.print("__DEBUG__: input_ids", input_ids) distance_attention = torch.stack([pad_distance(len(input_ids[i]) - len(distance) - 1, distance) for i, distance in enumerate(batch['distance_attention'])]) distance = batch.pop("distance_attention") @@ -382,13 +385,17 @@ class CombineEntityEmbeddingLitModel(UseEntityEmbeddingLitModel): self.id2entity = {} with open("./dataset/FB15k-237/entity2textlong.txt", 'r') as file: cnt = 0 - for line in file.readlines(): + lines = file.readlines() + lines.append('[NEG]\t') + for line in lines: e, d = line.strip().split("\t") self.id2entity[cnt] = e cnt += 1 self.id2entity_t = {} with open("./dataset/FB15k-237/entity2text.txt", 'r') as file: - for line in file.readlines(): + lines = file.readlines() + lines.append('[NEG]\t') + for line in lines: e, d = line.strip().split("\t") self.id2entity_t[e] = d for k, v in self.id2entity.items(): diff --git a/main.py b/main.py index b9c8a71..66c8c3a 100644 --- a/main.py +++ b/main.py @@ -98,6 +98,7 @@ def main(): tokenizer = data.tokenizer lit_model = litmodel_class(args=args, model=model, tokenizer=tokenizer, data_config=data.get_config()) + print("__DEBUG__: Initialized") if args.checkpoint: lit_model.load_state_dict(torch.load(args.checkpoint, map_location="cpu")["state_dict"], strict=False) @@ -120,9 +121,12 @@ def main(): callbacks = [early_callback, model_checkpoint] # args.weights_summary = "full" # Print full summary of the model - trainer = pl.Trainer.from_argparse_args(args, callbacks=callbacks, logger=logger, default_root_dir="training/logs", accelerator="ddp") + trainer = pl.Trainer.from_argparse_args(args, callbacks=callbacks, logger=logger, default_root_dir="training/logs") + print('__DEBUG__: Init trainer') + if "EntityEmbedding" not in lit_model.__class__.__name__: + print('__DEBUG__: Fit trainer') trainer.fit(lit_model, datamodule=data) path = model_checkpoint.best_model_path lit_model.load_state_dict(torch.load(path)["state_dict"], strict=False) diff --git a/pretrain/data/data_module.py b/pretrain/data/data_module.py index eeae904..a9b8452 100644 --- a/pretrain/data/data_module.py +++ b/pretrain/data/data_module.py @@ -10,7 +10,7 @@ from transformers.tokenization_utils_base import (BatchEncoding, PreTrainedTokenizerBase) from .base_data_module import BaseDataModule -from .processor import KGProcessor, get_dataset +from .processor import KGProcessor, get_dataset, getNegativeEntityId import transformers transformers.logging.set_verbosity_error() @@ -106,8 +106,9 @@ class DataCollatorForSeq2Seq: if isinstance(l, int): new_labels[i][l] = 1 else: - for j in l: - new_labels[i][j] = 1 + if (l[0] != getNegativeEntityId()): + for j in l: + new_labels[i][j] = 1 labels = new_labels features = self.tokenizer.pad( diff --git a/pretrain/data/processor.py b/pretrain/data/processor.py index 8915e38..d272a10 100644 --- a/pretrain/data/processor.py +++ b/pretrain/data/processor.py @@ -314,6 +314,7 @@ def filter_init(head, tail, t1,t2, ent2id_, ent2token_, rel2id_): global ent2id global ent2token global rel2id + global negativeEntity head_filter_entities = head tail_filter_entities = tail @@ -322,11 +323,19 @@ def filter_init(head, tail, t1,t2, ent2id_, ent2token_, rel2id_): ent2id = ent2id_ ent2token = ent2token_ rel2id = rel2id_ + negativeEntity = ent2id['[NEG]'] def delete_init(ent2text_): global ent2text ent2text = ent2text_ +def getEntityIdByName(name): + global ent2id + return ent2id[name] + +def getNegativeEntityId(): + global negativeEntity + return negativeEntity class KGProcessor(DataProcessor): """Processor for knowledge graph data set.""" @@ -377,6 +386,7 @@ class KGProcessor(DataProcessor): """Gets all entities in the knowledge graph.""" with open(self.entity_path, 'r') as f: lines = f.readlines() + lines.append('[NEG]\t') entities = [] for line in lines: entities.append(line.strip().split("\t")[0]) @@ -403,6 +413,7 @@ class KGProcessor(DataProcessor): ent2text_with_type = {} with open(self.entity_path, 'r') as f: ent_lines = f.readlines() + ent_lines.append('[NEG]\t') for line in ent_lines: temp = line.strip().split('\t') try: diff --git a/pretrain/lit_models/transformer.py b/pretrain/lit_models/transformer.py index caff13f..4aa6dfc 100644 --- a/pretrain/lit_models/transformer.py +++ b/pretrain/lit_models/transformer.py @@ -364,13 +364,17 @@ class CombineEntityEmbeddingLitModel(UseEntityEmbeddingLitModel): self.id2entity = {} with open("./dataset/FB15k-237/entity2textlong.txt", 'r') as file: cnt = 0 - for line in file.readlines(): + lines = file.readlines() + lines.append('[NEG]\t') + for line in lines: e, d = line.strip().split("\t") self.id2entity[cnt] = e cnt += 1 self.id2entity_t = {} with open("./dataset/FB15k-237/entity2text.txt", 'r') as file: - for line in file.readlines(): + lines = file.readlines() + lines.append('[NEG]\t') + for line in lines: e, d = line.strip().split("\t") self.id2entity_t[e] = d for k, v in self.id2entity.items(): diff --git a/pretrain/scripts/pretrain_fb15k-237.sh b/pretrain/scripts/pretrain_fb15k-237.sh old mode 100644 new mode 100755 diff --git a/scripts/fb15k-237/fb15k-237.sh b/scripts/fb15k-237/fb15k-237.sh old mode 100644 new mode 100755