diff --git a/.gitignore b/.gitignore index 37dd143..88244ef 100644 --- a/.gitignore +++ b/.gitignore @@ -7,3 +7,4 @@ dataset/FB15k-237/masked_*.txt dataset/FB15k-237/cached_*.pkl **/__pycache__/ **/.DS_Store +nohup.out diff --git a/data/data_module.py b/data/data_module.py index 953a5e7..420b73e 100644 --- a/data/data_module.py +++ b/data/data_module.py @@ -79,6 +79,7 @@ class DataCollatorForSeq2Seq: label_pad_token_id: int = -100 return_tensors: str = "pt" num_labels: int = 0 + args: Any = None def __call__(self, features, return_tensors=None): @@ -105,7 +106,7 @@ class DataCollatorForSeq2Seq: if isinstance(l, int): new_labels[i][l] = 1 else: - if (l[0] != getNegativeEntityId()): + if (l[0] != getNegativeEntityId(self.args)): for j in l: new_labels[i][j] = 1 labels = new_labels @@ -142,6 +143,7 @@ class KGC(BaseDataModule): padding="longest", max_length=self.args.max_seq_length, num_labels = len(entity_list), + args=args ) relations_tokens = self.processor.get_relations(args.data_dir) self.num_relations = len(relations_tokens) diff --git a/data/processor.py b/data/processor.py index 971cecb..5c8cae3 100644 --- a/data/processor.py +++ b/data/processor.py @@ -110,7 +110,7 @@ def cache_results(_cache_fp, _refresh=False, _verbose=1): model_name = my_args.model_name_or_path.split("/")[-1] is_pretrain = my_args.pretrain - cache_filepath = os.path.join(my_args.data_dir, f"cached_{mode}_features{model_name}_pretrain{is_pretrain}_faiss{my_args.faiss_init}_seqlength{my_args.max_seq_length}_{my_args.litmodel_class}.pkl") + cache_filepath = os.path.join(my_args.data_dir, f"cached_{func.__name__}_{mode}_features{model_name}_pretrain{is_pretrain}_faiss{my_args.faiss_init}_seqlength{my_args.max_seq_length}_{my_args.litmodel_class}.pkl") refresh = my_args.overwrite_cache if cache_filepath is not None and refresh is False: @@ -137,6 +137,116 @@ def cache_results(_cache_fp, _refresh=False, _verbose=1): return wrapper_ +def cache_results_load_once(_cache_fp, _refresh=False, _verbose=1, _global_var=None): + r""" + === USE IN TRAINING MODE ONLY === + cache_results是fastNLP中用于cache数据的装饰器。通过下面的例子看一下如何使用:: + + import time + import numpy as np + from fastNLP import cache_results + + @cache_results('cache.pkl') + def process_data(): + # 一些比较耗时的工作,比如读取数据,预处理数据等,这里用time.sleep()代替耗时 + time.sleep(1) + return np.random.randint(10, size=(5,)) + + start_time = time.time() + print("res =",process_data()) + print(time.time() - start_time) + + start_time = time.time() + print("res =",process_data()) + print(time.time() - start_time) + + # 输出内容如下,可以看到两次结果相同,且第二次几乎没有花费时间 + # Save cache to cache.pkl. + # res = [5 4 9 1 8] + # 1.0042750835418701 + # Read cache from cache.pkl. + # res = [5 4 9 1 8] + # 0.0040721893310546875 + + 可以看到第二次运行的时候,只用了0.0001s左右,是由于第二次运行将直接从cache.pkl这个文件读取数据,而不会经过再次预处理:: + + # 还是以上面的例子为例,如果需要重新生成另一个cache,比如另一个数据集的内容,通过如下的方式调用即可 + process_data(_cache_fp='cache2.pkl') # 完全不影响之前的‘cache.pkl' + + 上面的_cache_fp是cache_results会识别的参数,它将从'cache2.pkl'这里缓存/读取数据,即这里的'cache2.pkl'覆盖默认的 + 'cache.pkl'。如果在你的函数前面加上了@cache_results()则你的函数会增加三个参数[_cache_fp, _refresh, _verbose]。 + 上面的例子即为使用_cache_fp的情况,这三个参数不会传入到你的函数中,当然你写的函数参数名也不可能包含这三个名称:: + + process_data(_cache_fp='cache2.pkl', _refresh=True) # 这里强制重新生成一份对预处理的cache。 + # _verbose是用于控制输出信息的,如果为0,则不输出任何内容;如果为1,则会提醒当前步骤是读取的cache还是生成了新的cache + + :param str _cache_fp: 将返回结果缓存到什么位置;或从什么位置读取缓存。如果为None,cache_results没有任何效用,除非在 + 函数调用的时候传入_cache_fp这个参数。 + :param bool _refresh: 是否重新生成cache。 + :param int _verbose: 是否打印cache的信息。 + :return: + """ + + def wrapper_(func): + signature = inspect.signature(func) + for key, _ in signature.parameters.items(): + if key in ('_cache_fp', '_refresh', '_verbose', '_global_var'): + raise RuntimeError("The function decorated by cache_results cannot have keyword `{}`.".format(key)) + v = globals().get(_global_var, None) + if (v is not None): + return v + + def wrapper(*args, **kwargs): + + my_args = args[0] + mode = "train" + if '_cache_fp' in kwargs: + cache_filepath = kwargs.pop('_cache_fp') + assert isinstance(cache_filepath, str), "_cache_fp can only be str." + else: + cache_filepath = _cache_fp + if '_refresh' in kwargs: + refresh = kwargs.pop('_refresh') + assert isinstance(refresh, bool), "_refresh can only be bool." + else: + refresh = _refresh + if '_verbose' in kwargs: + verbose = kwargs.pop('_verbose') + assert isinstance(verbose, int), "_verbose can only be integer." + else: + verbose = _verbose + refresh_flag = True + + model_name = my_args.model_name_or_path.split("/")[-1] + is_pretrain = my_args.pretrain + cache_filepath = os.path.join(my_args.data_dir, f"cached_{func.__name__}_{mode}_features{model_name}_pretrain{is_pretrain}.pkl") + refresh = my_args.overwrite_cache + + if cache_filepath is not None and refresh is False: + # load data + if os.path.exists(cache_filepath): + with open(cache_filepath, 'rb') as f: + results = pickle.load(f) + globals()[_global_var] = results + if verbose == 1: + logger.info("Read cache from {}.".format(cache_filepath)) + refresh_flag = False + + if refresh_flag: + results = func(*args, **kwargs) + if cache_filepath is not None: + if results is None: + raise RuntimeError("The return value is None. Delete the decorator.") + with open(cache_filepath, 'wb') as f: + pickle.dump(results, f) + logger.info("Save cache to {}.".format(cache_filepath)) + + return results + + return wrapper + + return wrapper_ + import argparse import csv @@ -254,7 +364,7 @@ class _LiveState(type): class LiveState(metaclass=_LiveState): def __init__(self): - self._pool_size = 16 + self._pool_size = 4 self._deq = deque(maxlen=self._pool_size) def put(self, item): self._deq.append(item) @@ -397,6 +507,9 @@ def solve(line, set_type="train", pretrain=1, max_triplet=32): for prev_ent in _prev: + if (prev_ent == line[0] or prev_ent == line[2]): + continue + z = head_filter_entities["\t".join([prev_ent,line[1]])] if (len(z) == 0): z.append('[NEG]') @@ -407,6 +520,8 @@ def solve(line, set_type="train", pretrain=1, max_triplet=32): masked_neg_graph_list = masked_tail_neighbor.get("\t".join([prev_ent, line[1]]), []) if len(masked_tail_neighbor.get("\t".join([prev_ent, line[1]]), [])) < max_triplet else \ random.sample(masked_tail_neighbor["\t".join([prev_ent, line[1]])], max_triplet) + if (len(masked_head_graph_list) == 0): + masked_head_graph_list.append(['[NEG]', line[1], '[NEG]']) for item in masked_neg_graph_list: masked_neg_seq.add(item[0]) @@ -419,11 +534,11 @@ def solve(line, set_type="train", pretrain=1, max_triplet=32): masked_neg_seq = masked_neg_seq.difference({line[0]}) masked_neg_seq = masked_neg_seq.difference({line[2]}) masked_neg_seq = masked_neg_seq.difference({line[1]}) - masked_neg_seq = masked_neg_seq.difference(prev_ent) + masked_neg_seq = masked_neg_seq.difference({prev_ent}) masked_neg_seq_id = masked_neg_seq_id.difference({ent2id[line[0]]}) masked_neg_seq_id = masked_neg_seq_id.difference({rel2id[line[1]]}) masked_neg_seq_id = masked_neg_seq_id.difference({ent2id[line[2]]}) - masked_neg_seq_id = masked_neg_seq_id.difference(prev_ent) + masked_neg_seq_id = masked_neg_seq_id.difference({ent2id[prev_ent]}) # examples.append( # InputExample(guid=guid, text_a="[MASK]", text_b=' '.join(text_b.split(' ')[:16]) + " [PAD]", text_c = "[PAD]" + " " + ' '.join(text_c.split(' ')[:16]), text_d = masked_head_seq, label=lmap(lambda x: ent2id[x], b), real_label=ent2id[line[0]], en=[rel2id[line[1]], ent2id[line[2]]], rel=rel2id[line[1]])) # examples.append( @@ -449,7 +564,7 @@ def filter_init(head, tail, t1,t2, ent2id_, ent2token_, rel2id_, masked_head_nei global masked_head_neighbor global masked_tail_neighbor global rel2token - global negativeEntity + # global negativeEntity head_filter_entities = head tail_filter_entities = tail @@ -461,7 +576,8 @@ def filter_init(head, tail, t1,t2, ent2id_, ent2token_, rel2id_, masked_head_nei masked_head_neighbor = masked_head_neighbor_ masked_tail_neighbor = masked_tail_neighbor_ rel2token = rel2token_ - negativeEntity = ent2id['[NEG]'] + # negativeEntity = ent2id['[NEG]'] + print("Initialized negative entity ID") def delete_init(ent2text_): global ent2text @@ -471,8 +587,10 @@ def getEntityIdByName(name): global ent2id return ent2id[name] -def getNegativeEntityId(): +@cache_results_load_once(_cache_fp="./dataset", _global_var='negativeEntity') +def getNegativeEntityId(args): global negativeEntity + negativeEntity = ent2id['[NEG]'] return negativeEntity @@ -654,6 +772,8 @@ class KGProcessor(DataProcessor): threads = min(1, cpu_count()) filter_init(head_filter_entities, tail_filter_entities,ent2text, rel2text, ent2id, ent2token, rel2id, masked_head_neighbor, masked_tail_neighbor, rel2token ) + # cache this + getNegativeEntityId(args) if hasattr(args, "faiss_init") and args.faiss_init: annotate_ = partial( diff --git a/lit_models/transformer.py b/lit_models/transformer.py index 3cb610c..34717fc 100644 --- a/lit_models/transformer.py +++ b/lit_models/transformer.py @@ -81,13 +81,13 @@ class TransformerLitModel(BaseLitModel): pos = batch.pop("pos") try: en = batch.pop("en") - self.print("__DEBUG__: en", en) + # self.print("__DEBUG__: en", en) rel = batch.pop("rel") - self.print("__DEBUG__: rel", rel) + # self.print("__DEBUG__: rel", rel) except KeyError: pass input_ids = batch['input_ids'] - self.print("__DEBUG__: input_ids", input_ids) + # self.print("__DEBUG__: input_ids", input_ids) distance_attention = torch.stack([pad_distance(len(input_ids[i]) - len(distance) - 1, distance) for i, distance in enumerate(batch['distance_attention'])]) distance = batch.pop("distance_attention") diff --git a/scripts/fb15k-237/fb15k-237.sh b/scripts/fb15k-237/fb15k-237.sh index ac9bc99..a04dec5 100755 --- a/scripts/fb15k-237/fb15k-237.sh +++ b/scripts/fb15k-237/fb15k-237.sh @@ -1,13 +1,12 @@ -nohup python -u main.py --gpus "1" --max_epochs=16 --num_workers=32 \ +nohup python -u main.py --gpus "2," --max_epochs=16 --num_workers=32 \ --model_name_or_path bert-base-uncased \ --accumulate_grad_batches 1 \ --model_class BertKGC \ --batch_size 64 \ - --checkpoint /kg_374/Relphormer/pretrain/output/FB15k-237/epoch=15-step=19299-Eval/hits10=0.96.ckpt \ + --checkpoint /kg_374/Relphormer/output/FB15k-237/epoch=1-Eval/hits10=Eval/hits1=0.47-Eval/hits1=0.22.ckpt \ --pretrain 0 \ --bce 0 \ --check_val_every_n_epoch 1 \ - --overwrite_cache \ --data_dir dataset/FB15k-237 \ --eval_batch_size 128 \ --max_seq_length 128 \