runnable v1

This commit is contained in:
Cong Thanh Vu 2023-01-14 10:40:58 +00:00
parent fcfeae2bd3
commit 45cd8e1396
5 changed files with 136 additions and 14 deletions

1
.gitignore vendored
View File

@ -7,3 +7,4 @@ dataset/FB15k-237/masked_*.txt
dataset/FB15k-237/cached_*.pkl dataset/FB15k-237/cached_*.pkl
**/__pycache__/ **/__pycache__/
**/.DS_Store **/.DS_Store
nohup.out

View File

@ -79,6 +79,7 @@ class DataCollatorForSeq2Seq:
label_pad_token_id: int = -100 label_pad_token_id: int = -100
return_tensors: str = "pt" return_tensors: str = "pt"
num_labels: int = 0 num_labels: int = 0
args: Any = None
def __call__(self, features, return_tensors=None): def __call__(self, features, return_tensors=None):
@ -105,7 +106,7 @@ class DataCollatorForSeq2Seq:
if isinstance(l, int): if isinstance(l, int):
new_labels[i][l] = 1 new_labels[i][l] = 1
else: else:
if (l[0] != getNegativeEntityId()): if (l[0] != getNegativeEntityId(self.args)):
for j in l: for j in l:
new_labels[i][j] = 1 new_labels[i][j] = 1
labels = new_labels labels = new_labels
@ -142,6 +143,7 @@ class KGC(BaseDataModule):
padding="longest", padding="longest",
max_length=self.args.max_seq_length, max_length=self.args.max_seq_length,
num_labels = len(entity_list), num_labels = len(entity_list),
args=args
) )
relations_tokens = self.processor.get_relations(args.data_dir) relations_tokens = self.processor.get_relations(args.data_dir)
self.num_relations = len(relations_tokens) self.num_relations = len(relations_tokens)

View File

@ -110,7 +110,7 @@ def cache_results(_cache_fp, _refresh=False, _verbose=1):
model_name = my_args.model_name_or_path.split("/")[-1] model_name = my_args.model_name_or_path.split("/")[-1]
is_pretrain = my_args.pretrain is_pretrain = my_args.pretrain
cache_filepath = os.path.join(my_args.data_dir, f"cached_{mode}_features{model_name}_pretrain{is_pretrain}_faiss{my_args.faiss_init}_seqlength{my_args.max_seq_length}_{my_args.litmodel_class}.pkl") cache_filepath = os.path.join(my_args.data_dir, f"cached_{func.__name__}_{mode}_features{model_name}_pretrain{is_pretrain}_faiss{my_args.faiss_init}_seqlength{my_args.max_seq_length}_{my_args.litmodel_class}.pkl")
refresh = my_args.overwrite_cache refresh = my_args.overwrite_cache
if cache_filepath is not None and refresh is False: if cache_filepath is not None and refresh is False:
@ -137,6 +137,116 @@ def cache_results(_cache_fp, _refresh=False, _verbose=1):
return wrapper_ return wrapper_
def cache_results_load_once(_cache_fp, _refresh=False, _verbose=1, _global_var=None):
r"""
=== USE IN TRAINING MODE ONLY ===
cache_results是fastNLP中用于cache数据的装饰器通过下面的例子看一下如何使用::
import time
import numpy as np
from fastNLP import cache_results
@cache_results('cache.pkl')
def process_data():
# 一些比较耗时的工作比如读取数据预处理数据等这里用time.sleep()代替耗时
time.sleep(1)
return np.random.randint(10, size=(5,))
start_time = time.time()
print("res =",process_data())
print(time.time() - start_time)
start_time = time.time()
print("res =",process_data())
print(time.time() - start_time)
# 输出内容如下,可以看到两次结果相同,且第二次几乎没有花费时间
# Save cache to cache.pkl.
# res = [5 4 9 1 8]
# 1.0042750835418701
# Read cache from cache.pkl.
# res = [5 4 9 1 8]
# 0.0040721893310546875
可以看到第二次运行的时候只用了0.0001s左右是由于第二次运行将直接从cache.pkl这个文件读取数据而不会经过再次预处理::
# 还是以上面的例子为例如果需要重新生成另一个cache比如另一个数据集的内容通过如下的方式调用即可
process_data(_cache_fp='cache2.pkl') # 完全不影响之前的cache.pkl'
上面的_cache_fp是cache_results会识别的参数它将从'cache2.pkl'这里缓存/读取数据即这里的'cache2.pkl'覆盖默认的
'cache.pkl'如果在你的函数前面加上了@cache_results()则你的函数会增加三个参数[_cache_fp, _refresh, _verbose]
上面的例子即为使用_cache_fp的情况这三个参数不会传入到你的函数中当然你写的函数参数名也不可能包含这三个名称::
process_data(_cache_fp='cache2.pkl', _refresh=True) # 这里强制重新生成一份对预处理的cache。
# _verbose是用于控制输出信息的如果为0,则不输出任何内容;如果为1,则会提醒当前步骤是读取的cache还是生成了新的cache
:param str _cache_fp: 将返回结果缓存到什么位置;或从什么位置读取缓存如果为Nonecache_results没有任何效用除非在
函数调用的时候传入_cache_fp这个参数
:param bool _refresh: 是否重新生成cache
:param int _verbose: 是否打印cache的信息
:return:
"""
def wrapper_(func):
signature = inspect.signature(func)
for key, _ in signature.parameters.items():
if key in ('_cache_fp', '_refresh', '_verbose', '_global_var'):
raise RuntimeError("The function decorated by cache_results cannot have keyword `{}`.".format(key))
v = globals().get(_global_var, None)
if (v is not None):
return v
def wrapper(*args, **kwargs):
my_args = args[0]
mode = "train"
if '_cache_fp' in kwargs:
cache_filepath = kwargs.pop('_cache_fp')
assert isinstance(cache_filepath, str), "_cache_fp can only be str."
else:
cache_filepath = _cache_fp
if '_refresh' in kwargs:
refresh = kwargs.pop('_refresh')
assert isinstance(refresh, bool), "_refresh can only be bool."
else:
refresh = _refresh
if '_verbose' in kwargs:
verbose = kwargs.pop('_verbose')
assert isinstance(verbose, int), "_verbose can only be integer."
else:
verbose = _verbose
refresh_flag = True
model_name = my_args.model_name_or_path.split("/")[-1]
is_pretrain = my_args.pretrain
cache_filepath = os.path.join(my_args.data_dir, f"cached_{func.__name__}_{mode}_features{model_name}_pretrain{is_pretrain}.pkl")
refresh = my_args.overwrite_cache
if cache_filepath is not None and refresh is False:
# load data
if os.path.exists(cache_filepath):
with open(cache_filepath, 'rb') as f:
results = pickle.load(f)
globals()[_global_var] = results
if verbose == 1:
logger.info("Read cache from {}.".format(cache_filepath))
refresh_flag = False
if refresh_flag:
results = func(*args, **kwargs)
if cache_filepath is not None:
if results is None:
raise RuntimeError("The return value is None. Delete the decorator.")
with open(cache_filepath, 'wb') as f:
pickle.dump(results, f)
logger.info("Save cache to {}.".format(cache_filepath))
return results
return wrapper
return wrapper_
import argparse import argparse
import csv import csv
@ -254,7 +364,7 @@ class _LiveState(type):
class LiveState(metaclass=_LiveState): class LiveState(metaclass=_LiveState):
def __init__(self): def __init__(self):
self._pool_size = 16 self._pool_size = 4
self._deq = deque(maxlen=self._pool_size) self._deq = deque(maxlen=self._pool_size)
def put(self, item): def put(self, item):
self._deq.append(item) self._deq.append(item)
@ -397,6 +507,9 @@ def solve(line, set_type="train", pretrain=1, max_triplet=32):
for prev_ent in _prev: for prev_ent in _prev:
if (prev_ent == line[0] or prev_ent == line[2]):
continue
z = head_filter_entities["\t".join([prev_ent,line[1]])] z = head_filter_entities["\t".join([prev_ent,line[1]])]
if (len(z) == 0): if (len(z) == 0):
z.append('[NEG]') z.append('[NEG]')
@ -407,6 +520,8 @@ def solve(line, set_type="train", pretrain=1, max_triplet=32):
masked_neg_graph_list = masked_tail_neighbor.get("\t".join([prev_ent, line[1]]), []) if len(masked_tail_neighbor.get("\t".join([prev_ent, line[1]]), [])) < max_triplet else \ masked_neg_graph_list = masked_tail_neighbor.get("\t".join([prev_ent, line[1]]), []) if len(masked_tail_neighbor.get("\t".join([prev_ent, line[1]]), [])) < max_triplet else \
random.sample(masked_tail_neighbor["\t".join([prev_ent, line[1]])], max_triplet) random.sample(masked_tail_neighbor["\t".join([prev_ent, line[1]])], max_triplet)
if (len(masked_head_graph_list) == 0):
masked_head_graph_list.append(['[NEG]', line[1], '[NEG]'])
for item in masked_neg_graph_list: for item in masked_neg_graph_list:
masked_neg_seq.add(item[0]) masked_neg_seq.add(item[0])
@ -419,11 +534,11 @@ def solve(line, set_type="train", pretrain=1, max_triplet=32):
masked_neg_seq = masked_neg_seq.difference({line[0]}) masked_neg_seq = masked_neg_seq.difference({line[0]})
masked_neg_seq = masked_neg_seq.difference({line[2]}) masked_neg_seq = masked_neg_seq.difference({line[2]})
masked_neg_seq = masked_neg_seq.difference({line[1]}) masked_neg_seq = masked_neg_seq.difference({line[1]})
masked_neg_seq = masked_neg_seq.difference(prev_ent) masked_neg_seq = masked_neg_seq.difference({prev_ent})
masked_neg_seq_id = masked_neg_seq_id.difference({ent2id[line[0]]}) masked_neg_seq_id = masked_neg_seq_id.difference({ent2id[line[0]]})
masked_neg_seq_id = masked_neg_seq_id.difference({rel2id[line[1]]}) masked_neg_seq_id = masked_neg_seq_id.difference({rel2id[line[1]]})
masked_neg_seq_id = masked_neg_seq_id.difference({ent2id[line[2]]}) masked_neg_seq_id = masked_neg_seq_id.difference({ent2id[line[2]]})
masked_neg_seq_id = masked_neg_seq_id.difference(prev_ent) masked_neg_seq_id = masked_neg_seq_id.difference({ent2id[prev_ent]})
# examples.append( # examples.append(
# InputExample(guid=guid, text_a="[MASK]", text_b=' '.join(text_b.split(' ')[:16]) + " [PAD]", text_c = "[PAD]" + " " + ' '.join(text_c.split(' ')[:16]), text_d = masked_head_seq, label=lmap(lambda x: ent2id[x], b), real_label=ent2id[line[0]], en=[rel2id[line[1]], ent2id[line[2]]], rel=rel2id[line[1]])) # InputExample(guid=guid, text_a="[MASK]", text_b=' '.join(text_b.split(' ')[:16]) + " [PAD]", text_c = "[PAD]" + " " + ' '.join(text_c.split(' ')[:16]), text_d = masked_head_seq, label=lmap(lambda x: ent2id[x], b), real_label=ent2id[line[0]], en=[rel2id[line[1]], ent2id[line[2]]], rel=rel2id[line[1]]))
# examples.append( # examples.append(
@ -449,7 +564,7 @@ def filter_init(head, tail, t1,t2, ent2id_, ent2token_, rel2id_, masked_head_nei
global masked_head_neighbor global masked_head_neighbor
global masked_tail_neighbor global masked_tail_neighbor
global rel2token global rel2token
global negativeEntity # global negativeEntity
head_filter_entities = head head_filter_entities = head
tail_filter_entities = tail tail_filter_entities = tail
@ -461,7 +576,8 @@ def filter_init(head, tail, t1,t2, ent2id_, ent2token_, rel2id_, masked_head_nei
masked_head_neighbor = masked_head_neighbor_ masked_head_neighbor = masked_head_neighbor_
masked_tail_neighbor = masked_tail_neighbor_ masked_tail_neighbor = masked_tail_neighbor_
rel2token = rel2token_ rel2token = rel2token_
negativeEntity = ent2id['[NEG]'] # negativeEntity = ent2id['[NEG]']
print("Initialized negative entity ID")
def delete_init(ent2text_): def delete_init(ent2text_):
global ent2text global ent2text
@ -471,8 +587,10 @@ def getEntityIdByName(name):
global ent2id global ent2id
return ent2id[name] return ent2id[name]
def getNegativeEntityId(): @cache_results_load_once(_cache_fp="./dataset", _global_var='negativeEntity')
def getNegativeEntityId(args):
global negativeEntity global negativeEntity
negativeEntity = ent2id['[NEG]']
return negativeEntity return negativeEntity
@ -654,6 +772,8 @@ class KGProcessor(DataProcessor):
threads = min(1, cpu_count()) threads = min(1, cpu_count())
filter_init(head_filter_entities, tail_filter_entities,ent2text, rel2text, ent2id, ent2token, rel2id, masked_head_neighbor, masked_tail_neighbor, rel2token filter_init(head_filter_entities, tail_filter_entities,ent2text, rel2text, ent2id, ent2token, rel2id, masked_head_neighbor, masked_tail_neighbor, rel2token
) )
# cache this
getNegativeEntityId(args)
if hasattr(args, "faiss_init") and args.faiss_init: if hasattr(args, "faiss_init") and args.faiss_init:
annotate_ = partial( annotate_ = partial(

View File

@ -81,13 +81,13 @@ class TransformerLitModel(BaseLitModel):
pos = batch.pop("pos") pos = batch.pop("pos")
try: try:
en = batch.pop("en") en = batch.pop("en")
self.print("__DEBUG__: en", en) # self.print("__DEBUG__: en", en)
rel = batch.pop("rel") rel = batch.pop("rel")
self.print("__DEBUG__: rel", rel) # self.print("__DEBUG__: rel", rel)
except KeyError: except KeyError:
pass pass
input_ids = batch['input_ids'] input_ids = batch['input_ids']
self.print("__DEBUG__: input_ids", input_ids) # self.print("__DEBUG__: input_ids", input_ids)
distance_attention = torch.stack([pad_distance(len(input_ids[i]) - len(distance) - 1, distance) for i, distance in enumerate(batch['distance_attention'])]) distance_attention = torch.stack([pad_distance(len(input_ids[i]) - len(distance) - 1, distance) for i, distance in enumerate(batch['distance_attention'])])
distance = batch.pop("distance_attention") distance = batch.pop("distance_attention")

View File

@ -1,13 +1,12 @@
nohup python -u main.py --gpus "1" --max_epochs=16 --num_workers=32 \ nohup python -u main.py --gpus "2," --max_epochs=16 --num_workers=32 \
--model_name_or_path bert-base-uncased \ --model_name_or_path bert-base-uncased \
--accumulate_grad_batches 1 \ --accumulate_grad_batches 1 \
--model_class BertKGC \ --model_class BertKGC \
--batch_size 64 \ --batch_size 64 \
--checkpoint /kg_374/Relphormer/pretrain/output/FB15k-237/epoch=15-step=19299-Eval/hits10=0.96.ckpt \ --checkpoint /kg_374/Relphormer/output/FB15k-237/epoch=1-Eval/hits10=Eval/hits1=0.47-Eval/hits1=0.22.ckpt \
--pretrain 0 \ --pretrain 0 \
--bce 0 \ --bce 0 \
--check_val_every_n_epoch 1 \ --check_val_every_n_epoch 1 \
--overwrite_cache \
--data_dir dataset/FB15k-237 \ --data_dir dataset/FB15k-237 \
--eval_batch_size 128 \ --eval_batch_size 128 \
--max_seq_length 128 \ --max_seq_length 128 \