from re import DEBUG import contextlib import sys from collections import Counter from multiprocessing import Pool from torch._C import HOIST_CONV_PACKED_PARAMS from torch.utils.data import Dataset, Sampler, IterableDataset from collections import defaultdict from functools import partial from multiprocessing import Pool import os import random import json import torch import copy import numpy as np import pickle from tqdm import tqdm from dataclasses import dataclass, asdict, replace import inspect from transformers.models.auto.tokenization_auto import AutoTokenizer from models.utils import get_entity_spans_pre_processing def lmap(a, b): return list(map(a,b)) def cache_results(_cache_fp, _refresh=False, _verbose=1): r""" cache_results是fastNLP中用于cache数据的装饰器。通过下面的例子看一下如何使用:: import time import numpy as np from fastNLP import cache_results @cache_results('cache.pkl') def process_data(): # 一些比较耗时的工作,比如读取数据,预处理数据等,这里用time.sleep()代替耗时 time.sleep(1) return np.random.randint(10, size=(5,)) start_time = time.time() print("res =",process_data()) print(time.time() - start_time) start_time = time.time() print("res =",process_data()) print(time.time() - start_time) # 输出内容如下,可以看到两次结果相同,且第二次几乎没有花费时间 # Save cache to cache.pkl. # res = [5 4 9 1 8] # 1.0042750835418701 # Read cache from cache.pkl. # res = [5 4 9 1 8] # 0.0040721893310546875 可以看到第二次运行的时候,只用了0.0001s左右,是由于第二次运行将直接从cache.pkl这个文件读取数据,而不会经过再次预处理:: # 还是以上面的例子为例,如果需要重新生成另一个cache,比如另一个数据集的内容,通过如下的方式调用即可 process_data(_cache_fp='cache2.pkl') # 完全不影响之前的‘cache.pkl' 上面的_cache_fp是cache_results会识别的参数,它将从'cache2.pkl'这里缓存/读取数据,即这里的'cache2.pkl'覆盖默认的 'cache.pkl'。如果在你的函数前面加上了@cache_results()则你的函数会增加三个参数[_cache_fp, _refresh, _verbose]。 上面的例子即为使用_cache_fp的情况,这三个参数不会传入到你的函数中,当然你写的函数参数名也不可能包含这三个名称:: process_data(_cache_fp='cache2.pkl', _refresh=True) # 这里强制重新生成一份对预处理的cache。 # _verbose是用于控制输出信息的,如果为0,则不输出任何内容;如果为1,则会提醒当前步骤是读取的cache还是生成了新的cache :param str _cache_fp: 将返回结果缓存到什么位置;或从什么位置读取缓存。如果为None,cache_results没有任何效用,除非在 函数调用的时候传入_cache_fp这个参数。 :param bool _refresh: 是否重新生成cache。 :param int _verbose: 是否打印cache的信息。 :return: """ def wrapper_(func): signature = inspect.signature(func) for key, _ in signature.parameters.items(): if key in ('_cache_fp', '_refresh', '_verbose'): raise RuntimeError("The function decorated by cache_results cannot have keyword `{}`.".format(key)) def wrapper(*args, **kwargs): my_args = args[0] mode = args[-1] if '_cache_fp' in kwargs: cache_filepath = kwargs.pop('_cache_fp') assert isinstance(cache_filepath, str), "_cache_fp can only be str." else: cache_filepath = _cache_fp if '_refresh' in kwargs: refresh = kwargs.pop('_refresh') assert isinstance(refresh, bool), "_refresh can only be bool." else: refresh = _refresh if '_verbose' in kwargs: verbose = kwargs.pop('_verbose') assert isinstance(verbose, int), "_verbose can only be integer." else: verbose = _verbose refresh_flag = True model_name = my_args.model_name_or_path.split("/")[-1] is_pretrain = my_args.pretrain cache_filepath = os.path.join(my_args.data_dir, f"cached_{mode}_features{model_name}_pretrain{is_pretrain}_faiss{my_args.faiss_init}_seqlength{my_args.max_seq_length}_{my_args.litmodel_class}.pkl") refresh = my_args.overwrite_cache if cache_filepath is not None and refresh is False: # load data if os.path.exists(cache_filepath): with open(cache_filepath, 'rb') as f: results = pickle.load(f) if verbose == 1: logger.info("Read cache from {}.".format(cache_filepath)) refresh_flag = False if refresh_flag: results = func(*args, **kwargs) if cache_filepath is not None: if results is None: raise RuntimeError("The return value is None. Delete the decorator.") with open(cache_filepath, 'wb') as f: pickle.dump(results, f) logger.info("Save cache to {}.".format(cache_filepath)) return results return wrapper return wrapper_ import argparse import csv import logging import os import random import sys import numpy as np import torch from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler, TensorDataset) from torch.utils.data.distributed import DistributedSampler from tqdm import tqdm, trange # from torch.nn import CrossEntropyLoss, MSELoss # from scipy.stats import pearsonr, spearmanr # from sklearn.metrics import matthews_corrcoef, f1_scoreclass logger = logging.getLogger(__name__) class InputExample(object): """A single training/test example for simple sequence classification.""" def __init__(self, guid, text_a, text_b=None, text_c=None, label=None, real_label=None, en=None, rel=None): """Constructs a InputExample. Args: guid: Unique id for the example. text_a: string. The untokenized text of the first sequence. For single sequence tasks, only this sequence must be specified. text_b: (Optional) string. The untokenized text of the second sequence. Only must be specified for sequence pair tasks. text_c: (Optional) string. The untokenized text of the third sequence. Only must be specified for sequence triple tasks. label: (Optional) string. list of entities """ self.guid = guid self.text_a = text_a self.text_b = text_b self.text_c = text_c self.label = label self.real_label = real_label self.en = en self.rel = rel # rel id @dataclass class InputFeatures: """A single set of features of data.""" input_ids: torch.Tensor attention_mask: torch.Tensor labels: torch.Tensor = None label: torch.Tensor = None en: torch.Tensor = 0 rel: torch.Tensor = 0 pos: torch.Tensor = 0 class DataProcessor(object): """Base class for data converters for sequence classification data sets.""" def get_train_examples(self, data_dir): """Gets a collection of `InputExample`s for the train set.""" raise NotImplementedError() def get_dev_examples(self, data_dir): """Gets a collection of `InputExample`s for the dev set.""" raise NotImplementedError() def get_labels(self, data_dir): """Gets the list of labels for this data set.""" raise NotImplementedError() @classmethod def _read_tsv(cls, input_file, quotechar=None): """Reads a tab separated value file.""" with open(input_file, "r", encoding="utf-8") as f: reader = csv.reader(f, delimiter="\t", quotechar=quotechar) lines = [] for line in reader: if sys.version_info[0] == 2: line = list(unicode(cell, 'utf-8') for cell in line) lines.append(line) return lines import copy def solve_get_knowledge_store(line, set_type="train", pretrain=1): """ use the LM to get the entity embedding. Transductive: triples + text description Inductive: text description """ examples = [] head_ent_text = ent2text[line[0]] tail_ent_text = ent2text[line[2]] relation_text = rel2text[line[1]] i=0 a = tail_filter_entities["\t".join([line[0],line[1]])] b = head_filter_entities["\t".join([line[2],line[1]])] guid = "%s-%s" % (set_type, i) text_a = head_ent_text text_b = relation_text text_c = tail_ent_text # use the description of c to predict A examples.append( InputExample(guid=guid, text_a="[PAD]", text_b=text_b + "[PAD]", text_c = "[PAD]" + " " + text_c, label=lmap(lambda x: ent2id[x], b), real_label=ent2id[line[0]], en=[ent2id[line[0]], rel2id[line[1]], ent2id[line[2]]], rel=0) ) examples.append( InputExample(guid=guid, text_a="[PAD]", text_b=text_b + "[PAD]", text_c = "[PAD]" + " " + text_a, label=lmap(lambda x: ent2id[x], b), real_label=ent2id[line[2]], en=[ent2id[line[0]], rel2id[line[1]], ent2id[line[2]]], rel=0) ) return examples def solve(line, set_type="train", pretrain=1): examples = [] head_ent_text = ent2text[line[0]] tail_ent_text = ent2text[line[2]] relation_text = rel2text[line[1]] i=0 a = tail_filter_entities["\t".join([line[0],line[1]])] b = head_filter_entities["\t".join([line[2],line[1]])] guid = "%s-%s" % (set_type, i) text_a = head_ent_text text_b = relation_text text_c = tail_ent_text if pretrain: text_a_tokens = text_a.split() for i in range(10): st = random.randint(0, len(text_a_tokens)) examples.append( InputExample(guid=guid, text_a="[MASK]", text_b=" ".join(text_a_tokens[st:min(st+64, len(text_a_tokens))]), text_c = "", label=ent2id[line[0]], real_label=ent2id[line[0]], en=0, rel=0) ) examples.append( InputExample(guid=guid, text_a="[MASK]", text_b=text_a, text_c = "", label=ent2id[line[0]], real_label=ent2id[line[0]], en=0, rel=0) ) # examples.append( # InputExample(guid=guid, text_a="[MASK]", text_b=text_c, text_c = "", label=ent2id[line[2]], real_label=ent2id[line[2]], en=0, rel=0) # ) else: # examples.append( # InputExample(guid=guid, text_a="[MASK]", text_b=text_b + "[PAD]", text_c = "[UNK]" , label=lmap(lambda x: ent2id[x], b), real_label=ent2id[line[0]], en=ent2id[line[2]], rel=rel2id[line[1]])) # examples.append( # InputExample(guid=guid, text_a="[UNK] ", text_b=text_b + "[PAD]", text_c = "[MASK]", label=lmap(lambda x: ent2id[x], a), real_label=ent2id[line[2]], en=ent2id[line[0]], rel=rel2id[line[1]])) # examples.append( # InputExample(guid=guid, text_a="[UNK]" + " " + text_c, text_b=text_b + "[PAD]", text_c = "[MASK]", label=lmap(lambda x: ent2id[x], b), real_label=ent2id[line[0]], en=ent2id[line[2]], rel=rel2id[line[1]])) # examples.append( # InputExample(guid=guid, text_a="[MASK]", text_b=text_b + "[PAD]", text_c = "[UNK]" + text_a, label=lmap(lambda x: ent2id[x], a), real_label=ent2id[line[2]], en=ent2id[line[0]], rel=rel2id[line[1]])) examples.append( InputExample(guid=guid, text_a="[MASK]", text_b=text_b + "[PAD]", text_c = "[PAD]" + " " + text_c, label=lmap(lambda x: ent2id[x], b), real_label=ent2id[line[0]], en=[rel2id[line[1]], ent2id[line[2]]], rel=rel2id[line[1]])) examples.append( InputExample(guid=guid, text_a="[PAD] ", text_b=text_b + "[PAD]", text_c = "[MASK]" +" " + text_a, label=lmap(lambda x: ent2id[x], a), real_label=ent2id[line[2]], en=[ent2id[line[0]], rel2id[line[1]]], rel=rel2id[line[1]])) return examples def filter_init(head, tail, t1,t2, ent2id_, ent2token_, rel2id_): global head_filter_entities global tail_filter_entities global ent2text global rel2text global ent2id global ent2token global rel2id global negativeEntity head_filter_entities = head tail_filter_entities = tail ent2text =t1 rel2text =t2 ent2id = ent2id_ ent2token = ent2token_ rel2id = rel2id_ negativeEntity = ent2id['[NEG]'] def delete_init(ent2text_): global ent2text ent2text = ent2text_ def getEntityIdByName(name): global ent2id return ent2id[name] def getNegativeEntityId(): global negativeEntity return negativeEntity class KGProcessor(DataProcessor): """Processor for knowledge graph data set.""" def __init__(self, tokenizer, args): self.labels = set() self.tokenizer = tokenizer self.args = args self.entity_path = os.path.join(args.data_dir, "entity2textlong.txt") if os.path.exists(os.path.join(args.data_dir, 'entity2textlong.txt')) \ else os.path.join(args.data_dir, "entity2text.txt") def get_train_examples(self, data_dir): """See base class.""" return self._create_examples( self._read_tsv(os.path.join(data_dir, "train.tsv")), "train", data_dir, self.args) def get_dev_examples(self, data_dir): """See base class.""" return self._create_examples( self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev", data_dir, self.args) def get_test_examples(self, data_dir, chunk=""): """See base class.""" return self._create_examples( self._read_tsv(os.path.join(data_dir, f"test{chunk}.tsv")), "test", data_dir, self.args) def get_relations(self, data_dir): """Gets all labels (relations) in the knowledge graph.""" # return list(self.labels) with open(os.path.join(data_dir, "relations.txt"), 'r') as f: lines = f.readlines() relations = [] for line in lines: relations.append(line.strip().split('\t')[0]) rel2token = {ent : f"[RELATION_{i}]" for i, ent in enumerate(relations)} return list(rel2token.values()) def get_labels(self, data_dir): """Gets all labels (0, 1) for triples in the knowledge graph.""" relation = [] with open(os.path.join(data_dir, "relation2text.txt"), 'r') as f: lines = f.readlines() entities = [] for line in lines: relation.append(line.strip().split("\t")[-1]) return relation def get_entities(self, data_dir): """Gets all entities in the knowledge graph.""" with open(self.entity_path, 'r') as f: lines = f.readlines() lines.append('[NEG]\t') entities = [] for line in lines: entities.append(line.strip().split("\t")[0]) ent2token = {ent : f"[ENTITY_{i}]" for i, ent in enumerate(entities)} return list(ent2token.values()) def get_train_triples(self, data_dir): """Gets training triples.""" return self._read_tsv(os.path.join(data_dir, "train.tsv")) def get_dev_triples(self, data_dir): """Gets validation triples.""" return self._read_tsv(os.path.join(data_dir, "dev.tsv")) def get_test_triples(self, data_dir, chunk=""): """Gets test triples.""" return self._read_tsv(os.path.join(data_dir, f"test{chunk}.tsv")) def _create_examples(self, lines, set_type, data_dir, args): """Creates examples for the training and dev sets.""" # entity to text ent2text = {} ent2text_with_type = {} with open(self.entity_path, 'r') as f: ent_lines = f.readlines() ent_lines.append('[NEG]\t') for line in ent_lines: temp = line.strip().split('\t') try: end = temp[1]#.find(',') if "wiki" in data_dir: assert "Q" in temp[0] ent2text[temp[0]] = temp[1].replace("\\n", " ").replace("\\", "") #[:end] except IndexError: # continue end = " "#.find(',') if "wiki" in data_dir: assert "Q" in temp[0] ent2text[temp[0]] = end #[:end] entities = list(ent2text.keys()) ent2token = {ent : f"[ENTITY_{i}]" for i, ent in enumerate(entities)} ent2id = {ent : i for i, ent in enumerate(entities)} rel2text = {} with open(os.path.join(data_dir, "relation2text.txt"), 'r') as f: rel_lines = f.readlines() for line in rel_lines: temp = line.strip().split('\t') rel2text[temp[0]] = temp[1] relation_names = {} with open(os.path.join(data_dir, "relations.txt"), "r") as file: for line in file.readlines(): t = line.strip() relation_names[t] = rel2text[t] tmp_lines = [] not_in_text = 0 for line in tqdm(lines, desc="delete entities without text name."): if (line[0] not in ent2text) or (line[2] not in ent2text) or (line[1] not in rel2text): not_in_text += 1 continue tmp_lines.append(line) lines = tmp_lines print(f"total entity not in text : {not_in_text} ") # rel id -> relation token id num_entities = len(self.get_entities(args.data_dir)) rel2id = {w:i+num_entities for i,w in enumerate(relation_names.keys())} # add reverse relation # tmp_rel2id = {} # num_relations = len(rel2id) # cnt = 0 # for k, v in rel2id.items(): # tmp_rel2id[k + " (reverse)"] = num_relations + cnt # cnt += 1 # rel2id.update(tmp_rel2id) examples = [] # head filter head entity head_filter_entities = defaultdict(list) tail_filter_entities = defaultdict(list) dataset_list = ["train.tsv", "dev.tsv", "test.tsv"] # in training, only use the train triples if set_type == "train" and not args.pretrain: dataset_list = dataset_list[0:1] for m in dataset_list: with open(os.path.join(data_dir, m), 'r') as file: train_lines = file.readlines() for idx in range(len(train_lines)): train_lines[idx] = train_lines[idx].strip().split("\t") for line in train_lines: tail_filter_entities["\t".join([line[0], line[1]])].append(line[2]) head_filter_entities["\t".join([line[2], line[1]])].append(line[0]) max_head_entities = max(len(_) for _ in head_filter_entities.values()) max_tail_entities = max(len(_) for _ in tail_filter_entities.values()) # use bce loss, ignore the mlm if set_type == "train" and args.bce: lines = [] for k, v in tail_filter_entities.items(): h, r = k.split('\t') t = v[0] lines.append([h, r, t]) for k, v in head_filter_entities.items(): t, r = k.split('\t') h = v[0] lines.append([h, r, t]) # for training , select each entity as for get mask embedding. if args.pretrain: rel = list(rel2text.keys())[0] lines = [] for k in ent2text.keys(): lines.append([k, rel, k]) print(f"max number of filter entities : {max_head_entities} {max_tail_entities}") from os import cpu_count threads = min(1, cpu_count()) filter_init(head_filter_entities, tail_filter_entities,ent2text, rel2text, ent2id, ent2token, rel2id ) if hasattr(args, "faiss_init") and args.faiss_init: annotate_ = partial( solve_get_knowledge_store, pretrain=self.args.pretrain ) else: annotate_ = partial( solve, pretrain=self.args.pretrain ) examples = list( tqdm( map(annotate_, lines), total=len(lines), desc="convert text to examples" ) ) # with Pool(threads, initializer=filter_init, initargs=(head_filter_entities, tail_filter_entities,ent2text, rel2text, # ent2text_with_type, rel2id,)) as pool: # annotate_ = partial( # solve, # ) # examples = list( # tqdm( # map(annotate_, lines, chunksize= 128), # total=len(lines), # desc="convert text to examples" # ) # ) tmp_examples = [] for e in examples: for ee in e: tmp_examples.append(ee) examples = tmp_examples # delete vars del head_filter_entities, tail_filter_entities, ent2text, rel2text, ent2id, ent2token, rel2id return examples class Verbalizer(object): def __init__(self, args): if "WN18RR" in args.data_dir: self.mode = "WN18RR" elif "FB15k" in args.data_dir: self.mode = "FB15k" elif "umls" in args.data_dir: self.mode = "umls" elif "codexs" in args.data_dir: self.mode = "codexs" elif "codexl" in args.data_dir: self.mode = "codexl" elif "FB13" in args.data_dir: self.mode = "FB13" elif "WN11" in args.data_dir: self.mode = "WN11" def _convert(self, head, relation, tail): if self.mode == "umls": return f"The {relation} {head} is " return f"{head} {relation}" class KGCDataset(Dataset): def __init__(self, features): self.features = features def __getitem__(self, index): return self.features[index] def __len__(self): return len(self.features) def convert_examples_to_features_init(tokenizer_for_convert): global tokenizer tokenizer = tokenizer_for_convert def convert_examples_to_features(example, max_seq_length, mode, pretrain=1): """Loads a data file into a list of `InputBatch`s.""" # tokens_a = tokenizer.tokenize(example.text_a) # tokens_b = tokenizer.tokenize(example.text_b) # tokens_c = tokenizer.tokenize(example.text_c) # _truncate_seq_triple(tokens_a, tokens_b, tokens_c, max_length= max_seq_length) text_a = " ".join(example.text_a.split()[:128]) text_b = " ".join(example.text_b.split()[:128]) text_c = " ".join(example.text_c.split()[:128]) if pretrain: input_text_a = text_a input_text_b = text_b else: input_text_a = tokenizer.sep_token.join([text_a, text_b]) input_text_b = text_c inputs = tokenizer( input_text_a, input_text_b, truncation="longest_first", max_length=max_seq_length, padding="longest", add_special_tokens=True, ) # assert tokenizer.mask_token_id in inputs.input_ids, "mask token must in input" features = asdict(InputFeatures(input_ids=inputs["input_ids"], attention_mask=inputs['attention_mask'], labels=torch.tensor(example.label), label=torch.tensor(example.real_label) ) ) return features def _truncate_seq_pair(tokens_a, tokens_b, max_length): """Truncates a sequence pair in place to the maximum length.""" # This is a simple heuristic which will always truncate the longer sequence # one token at a time. This makes more sense than truncating an equal percent # of tokens from each, since if one sequence is very short then each token # that's truncated likely contains more information than a longer sequence. while True: total_length = len(tokens_a) + len(tokens_b) if total_length <= max_length: break if len(tokens_a) > len(tokens_b): tokens_a.pop() else: tokens_b.pop() def _truncate_seq_triple(tokens_a, tokens_b, tokens_c, max_length): """Truncates a sequence triple in place to the maximum length.""" # This is a simple heuristic which will always truncate the longer sequence # one token at a time. This makes more sense than truncating an equal percent # of tokens from each, since if one sequence is very short then each token # that's truncated likely contains more information than a longer sequence. while True: total_length = len(tokens_a) + len(tokens_b) + len(tokens_c) if total_length <= max_length: break if len(tokens_a) > len(tokens_b) and len(tokens_a) > len(tokens_c): tokens_a.pop() elif len(tokens_b) > len(tokens_a) and len(tokens_b) > len(tokens_c): tokens_b.pop() elif len(tokens_c) > len(tokens_a) and len(tokens_c) > len(tokens_b): tokens_c.pop() else: tokens_c.pop() @cache_results(_cache_fp="./dataset") def get_dataset(args, processor, label_list, tokenizer, mode): assert mode in ["train", "dev", "test"], "mode must be in train dev test!" # use training data to construct the entity embedding combine_train_and_test = False if args.faiss_init and mode == "test" and not args.pretrain: mode = "train" if "ind" in args.data_dir: combine_train_and_test = True else: pass if mode == "train": train_examples = processor.get_train_examples(args.data_dir) elif mode == "dev": train_examples = processor.get_dev_examples(args.data_dir) else: train_examples = processor.get_test_examples(args.data_dir) if combine_train_and_test: logger.info("use all the dataset for getting the entity mask embedding in pretraining pretraining") logger.info("use all the dataset for getting the entity mask embedding in pretraining pretraining") train_examples = processor.get_test_examples(args.data_dir) + processor.get_train_examples(args.data_dir) + processor.get_dev_examples(args.data_dir) from os import cpu_count with open(os.path.join(args.data_dir, f"examples_{mode}.txt"), 'w') as file: for line in train_examples: d = {} d.update(line.__dict__) file.write(json.dumps(d) + '\n') tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, use_fast=False) features = [] # # with open(os.path.join(args.data_dir, "cached_relation_pattern.pkl"), "rb") as file: # # pattern = pickle.load(file) # pattern = None # convert_examples_to_features_init(tokenizer) # annotate_ = partial( # convert_examples_to_features, # max_seq_length=args.max_seq_length, # mode = mode, # pretrain = args.pretrain # ) # features = list( # tqdm( # map(annotate_, train_examples), # total=len(train_examples) # ) # ) # encoder = MultiprocessingEncoder(tokenizer, args) # encoder.initializer() # for t in tqdm(train_examples): # features.append(encoder.encode_lines([json.dumps(t.__dict__)])) # for example in tqdm(train_examples): # text_a = example.text_a # text_b = example.text_b # text_c = example.text_c # bpe = tokenizer # if 0: # input_text_a = text_a # input_text_b = text_b # else: # if text_a == "[MASK]": # input_text_a = bpe.sep_token.join([text_a, text_b]) # input_text_b = text_c # else: # input_text_a = text_a # input_text_b = bpe.sep_token.join([text_b, text_c]) # inputs = tokenizer( # input_text_a, # # input_text_b, # truncation="longest_first", # max_length=128, # padding="longest", # add_special_tokens=True, # ) # # assert tokenizer.mask_token_id in inputs.input_ids, "mask token must in input" # # features.append(asdict(InputFeatures(input_ids=inputs["input_ids"], # # attention_mask=inputs['attention_mask'], # # labels=example.label, # # label=example.real_label # # ) # # )) file_inputs = [os.path.join(args.data_dir, f"examples_{mode}.txt")] file_outputs = [os.path.join(args.data_dir, f"features_{mode}.txt")] with contextlib.ExitStack() as stack: inputs = [ stack.enter_context(open(input, "r", encoding="utf-8")) if input != "-" else sys.stdin for input in file_inputs ] outputs = [ stack.enter_context(open(output, "w", encoding="utf-8")) if output != "-" else sys.stdout for output in file_outputs ] encoder = MultiprocessingEncoder(tokenizer, args) pool = Pool(16, initializer=encoder.initializer) encoder.initializer() encoded_lines = pool.imap(encoder.encode_lines, zip(*inputs), 1000) # encoded_lines = map(encoder.encode_lines, zip(*inputs)) stats = Counter() for i, (filt, enc_lines) in tqdm(enumerate(encoded_lines, start=1), total=len(train_examples)): if filt == "PASS": for enc_line, output_h in zip(enc_lines, outputs): features.append(eval(enc_line)) # features.append(enc_line) # print(enc_line, file=output_h) else: stats["num_filtered_" + filt] += 1 # if i % 10000 == 0: # print("processed {} lines".format(i), file=sys.stderr) for k, v in stats.most_common(): print("[{}] filtered {} lines".format(k, v), file=sys.stderr) # threads = min(16, cpu_count()) # with Pool(threads, initializer=convert_examples_to_features_init, initargs=(tokenizer,)) as pool: # annotate_ = partial( # convert_examples_to_features, # max_seq_length=args.max_seq_length, # mode = mode, # pretrain = args.pretrain # ) # features = list( # tqdm( # pool.imap_unordered(annotate_, train_examples), # total=len(train_examples), # desc="convert examples to features", # ) # ) # num_entities = len(processor.get_entities(args.data_dir)) for f_id, f in enumerate(features): en = features[f_id].pop("en") rel = features[f_id].pop("rel") real_label = f['label'] cnt = 0 if not isinstance(en, list): break pos = 0 for i,t in enumerate(f['input_ids']): if t == tokenizer.pad_token_id: features[f_id]['input_ids'][i] = en[cnt] + len(tokenizer) cnt += 1 if features[f_id]['input_ids'][i] == real_label + len(tokenizer): pos = i if cnt == len(en): break assert not (args.faiss_init and pos == 0) features[f_id]['pos'] = pos # for i,t in enumerate(f['input_ids']): # if t == tokenizer.pad_token_id: # features[f_id]['input_ids'][i] = rel + len(tokenizer) + num_entities # break features = KGCDataset(features) return features class MultiprocessingEncoder(object): def __init__(self, tokenizer, args): self.tokenizer = tokenizer self.pretrain = args.pretrain self.max_seq_length = args.max_seq_length def initializer(self): global bpe bpe = self.tokenizer def encode(self, line): global bpe ids = bpe.encode(line) return list(map(str, ids)) def decode(self, tokens): global bpe return bpe.decode(tokens) def encode_lines(self, lines): """ Encode a set of lines. All lines will be encoded together. """ enc_lines = [] for line in lines: line = line.strip() if len(line) == 0: return ["EMPTY", None] # enc_lines.append(" ".join(tokens)) enc_lines.append(json.dumps(self.convert_examples_to_features(example=eval(line)))) # enc_lines.append(" ") # enc_lines.append("123") return ["PASS", enc_lines] def decode_lines(self, lines): dec_lines = [] for line in lines: tokens = map(int, line.strip().split()) dec_lines.append(self.decode(tokens)) return ["PASS", dec_lines] def convert_examples_to_features(self, example): pretrain = self.pretrain max_seq_length = self.max_seq_length global bpe """Loads a data file into a list of `InputBatch`s.""" # tokens_a = tokenizer.tokenize(example.text_a) # tokens_b = tokenizer.tokenize(example.text_b) # tokens_c = tokenizer.tokenize(example.text_c) # _truncate_seq_triple(tokens_a, tokens_b, tokens_c, max_length= max_seq_length) # text_a = " ".join(example['text_a'].split()[:128]) # text_b = " ".join(example['text_b'].split()[:128]) # text_c = " ".join(example['text_c'].split()[:128]) text_a = example['text_a'] text_b = example['text_b'] text_c = example['text_c'] if pretrain: # the des of xxx is [MASK] . input_text = f"The description of {text_a} is that {text_b} ." inputs = bpe( input_text, truncation="longest_first", max_length=max_seq_length, padding="longest", add_special_tokens=True, ) else: if text_a == "[MASK]": input_text_a = bpe.sep_token.join([text_a, text_b]) input_text_b = text_c else: input_text_a = text_a input_text_b = bpe.sep_token.join([text_b, text_c]) inputs = bpe( input_text_a, input_text_b, truncation="longest_first", max_length=max_seq_length, padding="longest", add_special_tokens=True, ) # assert bpe.mask_token_id in inputs.input_ids, "mask token must in input" features = asdict(InputFeatures(input_ids=inputs["input_ids"], attention_mask=inputs['attention_mask'], labels=example['label'], label=example['real_label'], en=example['en'], rel=example['rel'] ) ) return features if __name__ == "__main__": dataset = KGCDataset('./dataset')