runnable v1
This commit is contained in:
		@@ -79,6 +79,7 @@ class DataCollatorForSeq2Seq:
 | 
			
		||||
    label_pad_token_id: int = -100
 | 
			
		||||
    return_tensors: str = "pt"
 | 
			
		||||
    num_labels: int = 0
 | 
			
		||||
    args: Any = None
 | 
			
		||||
 | 
			
		||||
    def __call__(self, features, return_tensors=None):
 | 
			
		||||
 | 
			
		||||
@@ -105,7 +106,7 @@ class DataCollatorForSeq2Seq:
 | 
			
		||||
                if isinstance(l, int): 
 | 
			
		||||
                    new_labels[i][l] = 1
 | 
			
		||||
                else:
 | 
			
		||||
                    if (l[0] != getNegativeEntityId()):
 | 
			
		||||
                    if (l[0] != getNegativeEntityId(self.args)):
 | 
			
		||||
                        for j in l:
 | 
			
		||||
                            new_labels[i][j] = 1
 | 
			
		||||
            labels = new_labels
 | 
			
		||||
@@ -142,6 +143,7 @@ class KGC(BaseDataModule):
 | 
			
		||||
            padding="longest",
 | 
			
		||||
            max_length=self.args.max_seq_length,
 | 
			
		||||
            num_labels = len(entity_list),
 | 
			
		||||
            args=args
 | 
			
		||||
        )
 | 
			
		||||
        relations_tokens = self.processor.get_relations(args.data_dir)
 | 
			
		||||
        self.num_relations = len(relations_tokens)
 | 
			
		||||
 
 | 
			
		||||
@@ -110,7 +110,7 @@ def cache_results(_cache_fp, _refresh=False, _verbose=1):
 | 
			
		||||
            
 | 
			
		||||
            model_name = my_args.model_name_or_path.split("/")[-1]
 | 
			
		||||
            is_pretrain = my_args.pretrain
 | 
			
		||||
            cache_filepath = os.path.join(my_args.data_dir, f"cached_{mode}_features{model_name}_pretrain{is_pretrain}_faiss{my_args.faiss_init}_seqlength{my_args.max_seq_length}_{my_args.litmodel_class}.pkl")
 | 
			
		||||
            cache_filepath = os.path.join(my_args.data_dir, f"cached_{func.__name__}_{mode}_features{model_name}_pretrain{is_pretrain}_faiss{my_args.faiss_init}_seqlength{my_args.max_seq_length}_{my_args.litmodel_class}.pkl")
 | 
			
		||||
            refresh = my_args.overwrite_cache
 | 
			
		||||
 | 
			
		||||
            if cache_filepath is not None and refresh is False:
 | 
			
		||||
@@ -137,6 +137,116 @@ def cache_results(_cache_fp, _refresh=False, _verbose=1):
 | 
			
		||||
 | 
			
		||||
    return wrapper_
 | 
			
		||||
 | 
			
		||||
def cache_results_load_once(_cache_fp, _refresh=False, _verbose=1, _global_var=None):
 | 
			
		||||
    r"""
 | 
			
		||||
    === USE IN TRAINING MODE ONLY ===
 | 
			
		||||
    cache_results是fastNLP中用于cache数据的装饰器。通过下面的例子看一下如何使用::
 | 
			
		||||
 | 
			
		||||
        import time
 | 
			
		||||
        import numpy as np
 | 
			
		||||
        from fastNLP import cache_results
 | 
			
		||||
        
 | 
			
		||||
        @cache_results('cache.pkl')
 | 
			
		||||
        def process_data():
 | 
			
		||||
            # 一些比较耗时的工作,比如读取数据,预处理数据等,这里用time.sleep()代替耗时
 | 
			
		||||
            time.sleep(1)
 | 
			
		||||
            return np.random.randint(10, size=(5,))
 | 
			
		||||
        
 | 
			
		||||
        start_time = time.time()
 | 
			
		||||
        print("res =",process_data())
 | 
			
		||||
        print(time.time() - start_time)
 | 
			
		||||
        
 | 
			
		||||
        start_time = time.time()
 | 
			
		||||
        print("res =",process_data())
 | 
			
		||||
        print(time.time() - start_time)
 | 
			
		||||
        
 | 
			
		||||
        # 输出内容如下,可以看到两次结果相同,且第二次几乎没有花费时间
 | 
			
		||||
        # Save cache to cache.pkl.
 | 
			
		||||
        # res = [5 4 9 1 8]
 | 
			
		||||
        # 1.0042750835418701
 | 
			
		||||
        # Read cache from cache.pkl.
 | 
			
		||||
        # res = [5 4 9 1 8]
 | 
			
		||||
        # 0.0040721893310546875
 | 
			
		||||
 | 
			
		||||
    可以看到第二次运行的时候,只用了0.0001s左右,是由于第二次运行将直接从cache.pkl这个文件读取数据,而不会经过再次预处理::
 | 
			
		||||
 | 
			
		||||
        # 还是以上面的例子为例,如果需要重新生成另一个cache,比如另一个数据集的内容,通过如下的方式调用即可
 | 
			
		||||
        process_data(_cache_fp='cache2.pkl')  # 完全不影响之前的‘cache.pkl'
 | 
			
		||||
 | 
			
		||||
    上面的_cache_fp是cache_results会识别的参数,它将从'cache2.pkl'这里缓存/读取数据,即这里的'cache2.pkl'覆盖默认的
 | 
			
		||||
    'cache.pkl'。如果在你的函数前面加上了@cache_results()则你的函数会增加三个参数[_cache_fp, _refresh, _verbose]。
 | 
			
		||||
    上面的例子即为使用_cache_fp的情况,这三个参数不会传入到你的函数中,当然你写的函数参数名也不可能包含这三个名称::
 | 
			
		||||
 | 
			
		||||
        process_data(_cache_fp='cache2.pkl', _refresh=True)  # 这里强制重新生成一份对预处理的cache。
 | 
			
		||||
        #  _verbose是用于控制输出信息的,如果为0,则不输出任何内容;如果为1,则会提醒当前步骤是读取的cache还是生成了新的cache
 | 
			
		||||
 | 
			
		||||
    :param str _cache_fp: 将返回结果缓存到什么位置;或从什么位置读取缓存。如果为None,cache_results没有任何效用,除非在
 | 
			
		||||
        函数调用的时候传入_cache_fp这个参数。
 | 
			
		||||
    :param bool _refresh: 是否重新生成cache。
 | 
			
		||||
    :param int _verbose: 是否打印cache的信息。
 | 
			
		||||
    :return:
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    def wrapper_(func):
 | 
			
		||||
        signature = inspect.signature(func)
 | 
			
		||||
        for key, _ in signature.parameters.items():
 | 
			
		||||
            if key in ('_cache_fp', '_refresh', '_verbose', '_global_var'):
 | 
			
		||||
                raise RuntimeError("The function decorated by cache_results cannot have keyword `{}`.".format(key))
 | 
			
		||||
        v = globals().get(_global_var, None)
 | 
			
		||||
        if (v is not None):
 | 
			
		||||
            return v
 | 
			
		||||
 | 
			
		||||
        def wrapper(*args, **kwargs):
 | 
			
		||||
            
 | 
			
		||||
            my_args = args[0]
 | 
			
		||||
            mode = "train"
 | 
			
		||||
            if '_cache_fp' in kwargs:
 | 
			
		||||
                cache_filepath = kwargs.pop('_cache_fp')
 | 
			
		||||
                assert isinstance(cache_filepath, str), "_cache_fp can only be str."
 | 
			
		||||
            else:
 | 
			
		||||
                cache_filepath = _cache_fp
 | 
			
		||||
            if '_refresh' in kwargs:
 | 
			
		||||
                refresh = kwargs.pop('_refresh')
 | 
			
		||||
                assert isinstance(refresh, bool), "_refresh can only be bool."
 | 
			
		||||
            else:
 | 
			
		||||
                refresh = _refresh
 | 
			
		||||
            if '_verbose' in kwargs:
 | 
			
		||||
                verbose = kwargs.pop('_verbose')
 | 
			
		||||
                assert isinstance(verbose, int), "_verbose can only be integer."
 | 
			
		||||
            else:
 | 
			
		||||
                verbose = _verbose
 | 
			
		||||
            refresh_flag = True
 | 
			
		||||
            
 | 
			
		||||
            model_name = my_args.model_name_or_path.split("/")[-1]
 | 
			
		||||
            is_pretrain = my_args.pretrain
 | 
			
		||||
            cache_filepath = os.path.join(my_args.data_dir, f"cached_{func.__name__}_{mode}_features{model_name}_pretrain{is_pretrain}.pkl")
 | 
			
		||||
            refresh = my_args.overwrite_cache
 | 
			
		||||
 | 
			
		||||
            if cache_filepath is not None and refresh is False:
 | 
			
		||||
                # load data
 | 
			
		||||
                if os.path.exists(cache_filepath):
 | 
			
		||||
                    with open(cache_filepath, 'rb') as f:
 | 
			
		||||
                        results = pickle.load(f)
 | 
			
		||||
                        globals()[_global_var] = results
 | 
			
		||||
                    if verbose == 1:
 | 
			
		||||
                        logger.info("Read cache from {}.".format(cache_filepath))
 | 
			
		||||
                    refresh_flag = False
 | 
			
		||||
 | 
			
		||||
            if refresh_flag:
 | 
			
		||||
                results = func(*args, **kwargs)
 | 
			
		||||
                if cache_filepath is not None:
 | 
			
		||||
                    if results is None:
 | 
			
		||||
                        raise RuntimeError("The return value is None. Delete the decorator.")
 | 
			
		||||
                    with open(cache_filepath, 'wb') as f:
 | 
			
		||||
                        pickle.dump(results, f)
 | 
			
		||||
                    logger.info("Save cache to {}.".format(cache_filepath))
 | 
			
		||||
 | 
			
		||||
            return results
 | 
			
		||||
 | 
			
		||||
        return wrapper
 | 
			
		||||
 | 
			
		||||
    return wrapper_
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
import argparse
 | 
			
		||||
import csv
 | 
			
		||||
@@ -254,7 +364,7 @@ class _LiveState(type):
 | 
			
		||||
class LiveState(metaclass=_LiveState):
 | 
			
		||||
 | 
			
		||||
    def __init__(self):
 | 
			
		||||
        self._pool_size = 16
 | 
			
		||||
        self._pool_size = 4
 | 
			
		||||
        self._deq = deque(maxlen=self._pool_size)
 | 
			
		||||
    def put(self, item):
 | 
			
		||||
        self._deq.append(item)
 | 
			
		||||
@@ -397,6 +507,9 @@ def solve(line,  set_type="train", pretrain=1, max_triplet=32):
 | 
			
		||||
 | 
			
		||||
            for prev_ent in _prev:
 | 
			
		||||
 | 
			
		||||
                if (prev_ent == line[0] or prev_ent == line[2]):
 | 
			
		||||
                    continue
 | 
			
		||||
 | 
			
		||||
                z = head_filter_entities["\t".join([prev_ent,line[1]])]
 | 
			
		||||
                if (len(z) == 0):
 | 
			
		||||
                    z.append('[NEG]')
 | 
			
		||||
@@ -407,6 +520,8 @@ def solve(line,  set_type="train", pretrain=1, max_triplet=32):
 | 
			
		||||
 | 
			
		||||
                masked_neg_graph_list = masked_tail_neighbor.get("\t".join([prev_ent, line[1]]), []) if len(masked_tail_neighbor.get("\t".join([prev_ent, line[1]]), [])) < max_triplet else \
 | 
			
		||||
                    random.sample(masked_tail_neighbor["\t".join([prev_ent, line[1]])], max_triplet)
 | 
			
		||||
                if (len(masked_head_graph_list) == 0):
 | 
			
		||||
                    masked_head_graph_list.append(['[NEG]', line[1], '[NEG]'])
 | 
			
		||||
 | 
			
		||||
                for item in masked_neg_graph_list:
 | 
			
		||||
                    masked_neg_seq.add(item[0])
 | 
			
		||||
@@ -419,11 +534,11 @@ def solve(line,  set_type="train", pretrain=1, max_triplet=32):
 | 
			
		||||
                masked_neg_seq = masked_neg_seq.difference({line[0]})
 | 
			
		||||
                masked_neg_seq = masked_neg_seq.difference({line[2]})
 | 
			
		||||
                masked_neg_seq = masked_neg_seq.difference({line[1]})
 | 
			
		||||
                masked_neg_seq = masked_neg_seq.difference(prev_ent)
 | 
			
		||||
                masked_neg_seq = masked_neg_seq.difference({prev_ent})
 | 
			
		||||
                masked_neg_seq_id = masked_neg_seq_id.difference({ent2id[line[0]]})
 | 
			
		||||
                masked_neg_seq_id = masked_neg_seq_id.difference({rel2id[line[1]]})
 | 
			
		||||
                masked_neg_seq_id = masked_neg_seq_id.difference({ent2id[line[2]]})
 | 
			
		||||
                masked_neg_seq_id = masked_neg_seq_id.difference(prev_ent)
 | 
			
		||||
                masked_neg_seq_id = masked_neg_seq_id.difference({ent2id[prev_ent]})
 | 
			
		||||
                # examples.append(
 | 
			
		||||
                #     InputExample(guid=guid, text_a="[MASK]", text_b=' '.join(text_b.split(' ')[:16]) + " [PAD]", text_c = "[PAD]" + " " + ' '.join(text_c.split(' ')[:16]), text_d = masked_head_seq, label=lmap(lambda x: ent2id[x], b), real_label=ent2id[line[0]], en=[rel2id[line[1]], ent2id[line[2]]], rel=rel2id[line[1]]))
 | 
			
		||||
                # examples.append(
 | 
			
		||||
@@ -449,7 +564,7 @@ def filter_init(head, tail, t1,t2, ent2id_, ent2token_, rel2id_, masked_head_nei
 | 
			
		||||
    global masked_head_neighbor
 | 
			
		||||
    global masked_tail_neighbor
 | 
			
		||||
    global rel2token
 | 
			
		||||
    global negativeEntity
 | 
			
		||||
    # global negativeEntity
 | 
			
		||||
 | 
			
		||||
    head_filter_entities = head
 | 
			
		||||
    tail_filter_entities = tail
 | 
			
		||||
@@ -461,7 +576,8 @@ def filter_init(head, tail, t1,t2, ent2id_, ent2token_, rel2id_, masked_head_nei
 | 
			
		||||
    masked_head_neighbor = masked_head_neighbor_
 | 
			
		||||
    masked_tail_neighbor = masked_tail_neighbor_
 | 
			
		||||
    rel2token = rel2token_
 | 
			
		||||
    negativeEntity = ent2id['[NEG]']
 | 
			
		||||
    # negativeEntity = ent2id['[NEG]']
 | 
			
		||||
    print("Initialized negative entity ID")
 | 
			
		||||
 | 
			
		||||
def delete_init(ent2text_):
 | 
			
		||||
    global ent2text
 | 
			
		||||
@@ -471,8 +587,10 @@ def getEntityIdByName(name):
 | 
			
		||||
    global ent2id
 | 
			
		||||
    return ent2id[name]
 | 
			
		||||
 | 
			
		||||
def getNegativeEntityId():
 | 
			
		||||
@cache_results_load_once(_cache_fp="./dataset", _global_var='negativeEntity')
 | 
			
		||||
def getNegativeEntityId(args):
 | 
			
		||||
    global negativeEntity
 | 
			
		||||
    negativeEntity = ent2id['[NEG]']
 | 
			
		||||
    return negativeEntity
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@@ -654,6 +772,8 @@ class KGProcessor(DataProcessor):
 | 
			
		||||
        threads = min(1, cpu_count())
 | 
			
		||||
        filter_init(head_filter_entities, tail_filter_entities,ent2text, rel2text, ent2id, ent2token, rel2id, masked_head_neighbor, masked_tail_neighbor, rel2token
 | 
			
		||||
            )
 | 
			
		||||
        # cache this
 | 
			
		||||
        getNegativeEntityId(args)
 | 
			
		||||
        
 | 
			
		||||
        if hasattr(args, "faiss_init") and args.faiss_init:
 | 
			
		||||
            annotate_ = partial(
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user