Thesis/pretrain/data/processor.py

947 lines
35 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from re import DEBUG
import contextlib
import sys
from collections import Counter
from multiprocessing import Pool
from torch._C import HOIST_CONV_PACKED_PARAMS
from torch.utils.data import Dataset, Sampler, IterableDataset
from collections import defaultdict
from functools import partial
from multiprocessing import Pool
import os
import random
import json
import torch
import copy
import numpy as np
import pickle
from tqdm import tqdm
from dataclasses import dataclass, asdict, replace
import inspect
from transformers.models.auto.tokenization_auto import AutoTokenizer
from models.utils import get_entity_spans_pre_processing
def lmap(a, b):
return list(map(a,b))
def cache_results(_cache_fp, _refresh=False, _verbose=1):
r"""
cache_results是fastNLP中用于cache数据的装饰器。通过下面的例子看一下如何使用::
import time
import numpy as np
from fastNLP import cache_results
@cache_results('cache.pkl')
def process_data():
# 一些比较耗时的工作比如读取数据预处理数据等这里用time.sleep()代替耗时
time.sleep(1)
return np.random.randint(10, size=(5,))
start_time = time.time()
print("res =",process_data())
print(time.time() - start_time)
start_time = time.time()
print("res =",process_data())
print(time.time() - start_time)
# 输出内容如下,可以看到两次结果相同,且第二次几乎没有花费时间
# Save cache to cache.pkl.
# res = [5 4 9 1 8]
# 1.0042750835418701
# Read cache from cache.pkl.
# res = [5 4 9 1 8]
# 0.0040721893310546875
可以看到第二次运行的时候只用了0.0001s左右是由于第二次运行将直接从cache.pkl这个文件读取数据而不会经过再次预处理::
# 还是以上面的例子为例如果需要重新生成另一个cache比如另一个数据集的内容通过如下的方式调用即可
process_data(_cache_fp='cache2.pkl') # 完全不影响之前的cache.pkl'
上面的_cache_fp是cache_results会识别的参数它将从'cache2.pkl'这里缓存/读取数据,即这里的'cache2.pkl'覆盖默认的
'cache.pkl'。如果在你的函数前面加上了@cache_results()则你的函数会增加三个参数[_cache_fp, _refresh, _verbose]。
上面的例子即为使用_cache_fp的情况这三个参数不会传入到你的函数中当然你写的函数参数名也不可能包含这三个名称::
process_data(_cache_fp='cache2.pkl', _refresh=True) # 这里强制重新生成一份对预处理的cache。
# _verbose是用于控制输出信息的如果为0,则不输出任何内容;如果为1,则会提醒当前步骤是读取的cache还是生成了新的cache
:param str _cache_fp: 将返回结果缓存到什么位置;或从什么位置读取缓存。如果为Nonecache_results没有任何效用除非在
函数调用的时候传入_cache_fp这个参数。
:param bool _refresh: 是否重新生成cache。
:param int _verbose: 是否打印cache的信息。
:return:
"""
def wrapper_(func):
signature = inspect.signature(func)
for key, _ in signature.parameters.items():
if key in ('_cache_fp', '_refresh', '_verbose'):
raise RuntimeError("The function decorated by cache_results cannot have keyword `{}`.".format(key))
def wrapper(*args, **kwargs):
my_args = args[0]
mode = args[-1]
if '_cache_fp' in kwargs:
cache_filepath = kwargs.pop('_cache_fp')
assert isinstance(cache_filepath, str), "_cache_fp can only be str."
else:
cache_filepath = _cache_fp
if '_refresh' in kwargs:
refresh = kwargs.pop('_refresh')
assert isinstance(refresh, bool), "_refresh can only be bool."
else:
refresh = _refresh
if '_verbose' in kwargs:
verbose = kwargs.pop('_verbose')
assert isinstance(verbose, int), "_verbose can only be integer."
else:
verbose = _verbose
refresh_flag = True
model_name = my_args.model_name_or_path.split("/")[-1]
is_pretrain = my_args.pretrain
cache_filepath = os.path.join(my_args.data_dir, f"cached_{mode}_features{model_name}_pretrain{is_pretrain}_faiss{my_args.faiss_init}_seqlength{my_args.max_seq_length}_{my_args.litmodel_class}.pkl")
refresh = my_args.overwrite_cache
if cache_filepath is not None and refresh is False:
# load data
if os.path.exists(cache_filepath):
with open(cache_filepath, 'rb') as f:
results = pickle.load(f)
if verbose == 1:
logger.info("Read cache from {}.".format(cache_filepath))
refresh_flag = False
if refresh_flag:
results = func(*args, **kwargs)
if cache_filepath is not None:
if results is None:
raise RuntimeError("The return value is None. Delete the decorator.")
with open(cache_filepath, 'wb') as f:
pickle.dump(results, f)
logger.info("Save cache to {}.".format(cache_filepath))
return results
return wrapper
return wrapper_
import argparse
import csv
import logging
import os
import random
import sys
import numpy as np
import torch
from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler,
TensorDataset)
from torch.utils.data.distributed import DistributedSampler
from tqdm import tqdm, trange
# from torch.nn import CrossEntropyLoss, MSELoss
# from scipy.stats import pearsonr, spearmanr
# from sklearn.metrics import matthews_corrcoef, f1_scoreclass
logger = logging.getLogger(__name__)
class InputExample(object):
"""A single training/test example for simple sequence classification."""
def __init__(self, guid, text_a, text_b=None, text_c=None, label=None, real_label=None, en=None, rel=None):
"""Constructs a InputExample.
Args:
guid: Unique id for the example.
text_a: string. The untokenized text of the first sequence. For single
sequence tasks, only this sequence must be specified.
text_b: (Optional) string. The untokenized text of the second sequence.
Only must be specified for sequence pair tasks.
text_c: (Optional) string. The untokenized text of the third sequence.
Only must be specified for sequence triple tasks.
label: (Optional) string. list of entities
"""
self.guid = guid
self.text_a = text_a
self.text_b = text_b
self.text_c = text_c
self.label = label
self.real_label = real_label
self.en = en
self.rel = rel # rel id
@dataclass
class InputFeatures:
"""A single set of features of data."""
input_ids: torch.Tensor
attention_mask: torch.Tensor
labels: torch.Tensor = None
label: torch.Tensor = None
en: torch.Tensor = 0
rel: torch.Tensor = 0
pos: torch.Tensor = 0
class DataProcessor(object):
"""Base class for data converters for sequence classification data sets."""
def get_train_examples(self, data_dir):
"""Gets a collection of `InputExample`s for the train set."""
raise NotImplementedError()
def get_dev_examples(self, data_dir):
"""Gets a collection of `InputExample`s for the dev set."""
raise NotImplementedError()
def get_labels(self, data_dir):
"""Gets the list of labels for this data set."""
raise NotImplementedError()
@classmethod
def _read_tsv(cls, input_file, quotechar=None):
"""Reads a tab separated value file."""
with open(input_file, "r", encoding="utf-8") as f:
reader = csv.reader(f, delimiter="\t", quotechar=quotechar)
lines = []
for line in reader:
if sys.version_info[0] == 2:
line = list(unicode(cell, 'utf-8') for cell in line)
lines.append(line)
return lines
import copy
def solve_get_knowledge_store(line, set_type="train", pretrain=1):
"""
use the LM to get the entity embedding.
Transductive: triples + text description
Inductive: text description
"""
examples = []
head_ent_text = ent2text[line[0]]
tail_ent_text = ent2text[line[2]]
relation_text = rel2text[line[1]]
i=0
a = tail_filter_entities["\t".join([line[0],line[1]])]
b = head_filter_entities["\t".join([line[2],line[1]])]
guid = "%s-%s" % (set_type, i)
text_a = head_ent_text
text_b = relation_text
text_c = tail_ent_text
# use the description of c to predict A
examples.append(
InputExample(guid=guid, text_a="[PAD]", text_b=text_b + "[PAD]", text_c = "[PAD]" + " " + text_c, label=lmap(lambda x: ent2id[x], b), real_label=ent2id[line[0]], en=[ent2id[line[0]], rel2id[line[1]], ent2id[line[2]]], rel=0)
)
examples.append(
InputExample(guid=guid, text_a="[PAD]", text_b=text_b + "[PAD]", text_c = "[PAD]" + " " + text_a, label=lmap(lambda x: ent2id[x], b), real_label=ent2id[line[2]], en=[ent2id[line[0]], rel2id[line[1]], ent2id[line[2]]], rel=0)
)
return examples
def solve(line, set_type="train", pretrain=1):
examples = []
head_ent_text = ent2text[line[0]]
tail_ent_text = ent2text[line[2]]
relation_text = rel2text[line[1]]
i=0
a = tail_filter_entities["\t".join([line[0],line[1]])]
b = head_filter_entities["\t".join([line[2],line[1]])]
guid = "%s-%s" % (set_type, i)
text_a = head_ent_text
text_b = relation_text
text_c = tail_ent_text
if pretrain:
text_a_tokens = text_a.split()
for i in range(10):
st = random.randint(0, len(text_a_tokens))
examples.append(
InputExample(guid=guid, text_a="[MASK]", text_b=" ".join(text_a_tokens[st:min(st+64, len(text_a_tokens))]), text_c = "", label=ent2id[line[0]], real_label=ent2id[line[0]], en=0, rel=0)
)
examples.append(
InputExample(guid=guid, text_a="[MASK]", text_b=text_a, text_c = "", label=ent2id[line[0]], real_label=ent2id[line[0]], en=0, rel=0)
)
# examples.append(
# InputExample(guid=guid, text_a="[MASK]", text_b=text_c, text_c = "", label=ent2id[line[2]], real_label=ent2id[line[2]], en=0, rel=0)
# )
else:
# examples.append(
# InputExample(guid=guid, text_a="[MASK]", text_b=text_b + "[PAD]", text_c = "[UNK]" , label=lmap(lambda x: ent2id[x], b), real_label=ent2id[line[0]], en=ent2id[line[2]], rel=rel2id[line[1]]))
# examples.append(
# InputExample(guid=guid, text_a="[UNK] ", text_b=text_b + "[PAD]", text_c = "[MASK]", label=lmap(lambda x: ent2id[x], a), real_label=ent2id[line[2]], en=ent2id[line[0]], rel=rel2id[line[1]]))
# examples.append(
# InputExample(guid=guid, text_a="[UNK]" + " " + text_c, text_b=text_b + "[PAD]", text_c = "[MASK]", label=lmap(lambda x: ent2id[x], b), real_label=ent2id[line[0]], en=ent2id[line[2]], rel=rel2id[line[1]]))
# examples.append(
# InputExample(guid=guid, text_a="[MASK]", text_b=text_b + "[PAD]", text_c = "[UNK]" + text_a, label=lmap(lambda x: ent2id[x], a), real_label=ent2id[line[2]], en=ent2id[line[0]], rel=rel2id[line[1]]))
examples.append(
InputExample(guid=guid, text_a="[MASK]", text_b=text_b + "[PAD]", text_c = "[PAD]" + " " + text_c, label=lmap(lambda x: ent2id[x], b), real_label=ent2id[line[0]], en=[rel2id[line[1]], ent2id[line[2]]], rel=rel2id[line[1]]))
examples.append(
InputExample(guid=guid, text_a="[PAD] ", text_b=text_b + "[PAD]", text_c = "[MASK]" +" " + text_a, label=lmap(lambda x: ent2id[x], a), real_label=ent2id[line[2]], en=[ent2id[line[0]], rel2id[line[1]]], rel=rel2id[line[1]]))
return examples
def filter_init(head, tail, t1,t2, ent2id_, ent2token_, rel2id_):
global head_filter_entities
global tail_filter_entities
global ent2text
global rel2text
global ent2id
global ent2token
global rel2id
global negativeEntity
head_filter_entities = head
tail_filter_entities = tail
ent2text =t1
rel2text =t2
ent2id = ent2id_
ent2token = ent2token_
rel2id = rel2id_
negativeEntity = ent2id['[NEG]']
def delete_init(ent2text_):
global ent2text
ent2text = ent2text_
def getEntityIdByName(name):
global ent2id
return ent2id[name]
def getNegativeEntityId():
global negativeEntity
return negativeEntity
class KGProcessor(DataProcessor):
"""Processor for knowledge graph data set."""
def __init__(self, tokenizer, args):
self.labels = set()
self.tokenizer = tokenizer
self.args = args
self.entity_path = os.path.join(args.data_dir, "entity2textlong.txt") if os.path.exists(os.path.join(args.data_dir, 'entity2textlong.txt')) \
else os.path.join(args.data_dir, "entity2text.txt")
def get_train_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "train.tsv")), "train", data_dir, self.args)
def get_dev_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev", data_dir, self.args)
def get_test_examples(self, data_dir, chunk=""):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, f"test{chunk}.tsv")), "test", data_dir, self.args)
def get_relations(self, data_dir):
"""Gets all labels (relations) in the knowledge graph."""
# return list(self.labels)
with open(os.path.join(data_dir, "relations.txt"), 'r') as f:
lines = f.readlines()
relations = []
for line in lines:
relations.append(line.strip().split('\t')[0])
rel2token = {ent : f"[RELATION_{i}]" for i, ent in enumerate(relations)}
return list(rel2token.values())
def get_labels(self, data_dir):
"""Gets all labels (0, 1) for triples in the knowledge graph."""
relation = []
with open(os.path.join(data_dir, "relation2text.txt"), 'r') as f:
lines = f.readlines()
entities = []
for line in lines:
relation.append(line.strip().split("\t")[-1])
return relation
def get_entities(self, data_dir):
"""Gets all entities in the knowledge graph."""
with open(self.entity_path, 'r') as f:
lines = f.readlines()
lines.append('[NEG]\t')
entities = []
for line in lines:
entities.append(line.strip().split("\t")[0])
ent2token = {ent : f"[ENTITY_{i}]" for i, ent in enumerate(entities)}
return list(ent2token.values())
def get_train_triples(self, data_dir):
"""Gets training triples."""
return self._read_tsv(os.path.join(data_dir, "train.tsv"))
def get_dev_triples(self, data_dir):
"""Gets validation triples."""
return self._read_tsv(os.path.join(data_dir, "dev.tsv"))
def get_test_triples(self, data_dir, chunk=""):
"""Gets test triples."""
return self._read_tsv(os.path.join(data_dir, f"test{chunk}.tsv"))
def _create_examples(self, lines, set_type, data_dir, args):
"""Creates examples for the training and dev sets."""
# entity to text
ent2text = {}
ent2text_with_type = {}
with open(self.entity_path, 'r') as f:
ent_lines = f.readlines()
ent_lines.append('[NEG]\t')
for line in ent_lines:
temp = line.strip().split('\t')
try:
end = temp[1]#.find(',')
if "wiki" in data_dir:
assert "Q" in temp[0]
ent2text[temp[0]] = temp[1].replace("\\n", " ").replace("\\", "") #[:end]
except IndexError:
# continue
end = " "#.find(',')
if "wiki" in data_dir:
assert "Q" in temp[0]
ent2text[temp[0]] = end #[:end]
entities = list(ent2text.keys())
ent2token = {ent : f"[ENTITY_{i}]" for i, ent in enumerate(entities)}
ent2id = {ent : i for i, ent in enumerate(entities)}
rel2text = {}
with open(os.path.join(data_dir, "relation2text.txt"), 'r') as f:
rel_lines = f.readlines()
for line in rel_lines:
temp = line.strip().split('\t')
rel2text[temp[0]] = temp[1]
relation_names = {}
with open(os.path.join(data_dir, "relations.txt"), "r") as file:
for line in file.readlines():
t = line.strip()
relation_names[t] = rel2text[t]
tmp_lines = []
not_in_text = 0
for line in tqdm(lines, desc="delete entities without text name."):
if (line[0] not in ent2text) or (line[2] not in ent2text) or (line[1] not in rel2text):
not_in_text += 1
continue
tmp_lines.append(line)
lines = tmp_lines
print(f"total entity not in text : {not_in_text} ")
# rel id -> relation token id
num_entities = len(self.get_entities(args.data_dir))
rel2id = {w:i+num_entities for i,w in enumerate(relation_names.keys())}
# add reverse relation
# tmp_rel2id = {}
# num_relations = len(rel2id)
# cnt = 0
# for k, v in rel2id.items():
# tmp_rel2id[k + " (reverse)"] = num_relations + cnt
# cnt += 1
# rel2id.update(tmp_rel2id)
examples = []
# head filter head entity
head_filter_entities = defaultdict(list)
tail_filter_entities = defaultdict(list)
dataset_list = ["train.tsv", "dev.tsv", "test.tsv"]
# in training, only use the train triples
if set_type == "train" and not args.pretrain: dataset_list = dataset_list[0:1]
for m in dataset_list:
with open(os.path.join(data_dir, m), 'r') as file:
train_lines = file.readlines()
for idx in range(len(train_lines)):
train_lines[idx] = train_lines[idx].strip().split("\t")
for line in train_lines:
tail_filter_entities["\t".join([line[0], line[1]])].append(line[2])
head_filter_entities["\t".join([line[2], line[1]])].append(line[0])
max_head_entities = max(len(_) for _ in head_filter_entities.values())
max_tail_entities = max(len(_) for _ in tail_filter_entities.values())
# use bce loss, ignore the mlm
if set_type == "train" and args.bce:
lines = []
for k, v in tail_filter_entities.items():
h, r = k.split('\t')
t = v[0]
lines.append([h, r, t])
for k, v in head_filter_entities.items():
t, r = k.split('\t')
h = v[0]
lines.append([h, r, t])
# for training , select each entity as for get mask embedding.
if args.pretrain:
rel = list(rel2text.keys())[0]
lines = []
for k in ent2text.keys():
lines.append([k, rel, k])
print(f"max number of filter entities : {max_head_entities} {max_tail_entities}")
from os import cpu_count
threads = min(1, cpu_count())
filter_init(head_filter_entities, tail_filter_entities,ent2text, rel2text, ent2id, ent2token, rel2id
)
if hasattr(args, "faiss_init") and args.faiss_init:
annotate_ = partial(
solve_get_knowledge_store,
pretrain=self.args.pretrain
)
else:
annotate_ = partial(
solve,
pretrain=self.args.pretrain
)
examples = list(
tqdm(
map(annotate_, lines),
total=len(lines),
desc="convert text to examples"
)
)
# with Pool(threads, initializer=filter_init, initargs=(head_filter_entities, tail_filter_entities,ent2text, rel2text,
# ent2text_with_type, rel2id,)) as pool:
# annotate_ = partial(
# solve,
# )
# examples = list(
# tqdm(
# map(annotate_, lines, chunksize= 128),
# total=len(lines),
# desc="convert text to examples"
# )
# )
tmp_examples = []
for e in examples:
for ee in e:
tmp_examples.append(ee)
examples = tmp_examples
# delete vars
del head_filter_entities, tail_filter_entities, ent2text, rel2text, ent2id, ent2token, rel2id
return examples
class Verbalizer(object):
def __init__(self, args):
if "WN18RR" in args.data_dir:
self.mode = "WN18RR"
elif "FB15k" in args.data_dir:
self.mode = "FB15k"
elif "umls" in args.data_dir:
self.mode = "umls"
elif "codexs" in args.data_dir:
self.mode = "codexs"
elif "codexl" in args.data_dir:
self.mode = "codexl"
elif "FB13" in args.data_dir:
self.mode = "FB13"
elif "WN11" in args.data_dir:
self.mode = "WN11"
def _convert(self, head, relation, tail):
if self.mode == "umls":
return f"The {relation} {head} is "
return f"{head} {relation}"
class KGCDataset(Dataset):
def __init__(self, features):
self.features = features
def __getitem__(self, index):
return self.features[index]
def __len__(self):
return len(self.features)
def convert_examples_to_features_init(tokenizer_for_convert):
global tokenizer
tokenizer = tokenizer_for_convert
def convert_examples_to_features(example, max_seq_length, mode, pretrain=1):
"""Loads a data file into a list of `InputBatch`s."""
# tokens_a = tokenizer.tokenize(example.text_a)
# tokens_b = tokenizer.tokenize(example.text_b)
# tokens_c = tokenizer.tokenize(example.text_c)
# _truncate_seq_triple(tokens_a, tokens_b, tokens_c, max_length= max_seq_length)
text_a = " ".join(example.text_a.split()[:128])
text_b = " ".join(example.text_b.split()[:128])
text_c = " ".join(example.text_c.split()[:128])
if pretrain:
input_text_a = text_a
input_text_b = text_b
else:
input_text_a = tokenizer.sep_token.join([text_a, text_b])
input_text_b = text_c
inputs = tokenizer(
input_text_a,
input_text_b,
truncation="longest_first",
max_length=max_seq_length,
padding="longest",
add_special_tokens=True,
)
# assert tokenizer.mask_token_id in inputs.input_ids, "mask token must in input"
features = asdict(InputFeatures(input_ids=inputs["input_ids"],
attention_mask=inputs['attention_mask'],
labels=torch.tensor(example.label),
label=torch.tensor(example.real_label)
)
)
return features
def _truncate_seq_pair(tokens_a, tokens_b, max_length):
"""Truncates a sequence pair in place to the maximum length."""
# This is a simple heuristic which will always truncate the longer sequence
# one token at a time. This makes more sense than truncating an equal percent
# of tokens from each, since if one sequence is very short then each token
# that's truncated likely contains more information than a longer sequence.
while True:
total_length = len(tokens_a) + len(tokens_b)
if total_length <= max_length:
break
if len(tokens_a) > len(tokens_b):
tokens_a.pop()
else:
tokens_b.pop()
def _truncate_seq_triple(tokens_a, tokens_b, tokens_c, max_length):
"""Truncates a sequence triple in place to the maximum length."""
# This is a simple heuristic which will always truncate the longer sequence
# one token at a time. This makes more sense than truncating an equal percent
# of tokens from each, since if one sequence is very short then each token
# that's truncated likely contains more information than a longer sequence.
while True:
total_length = len(tokens_a) + len(tokens_b) + len(tokens_c)
if total_length <= max_length:
break
if len(tokens_a) > len(tokens_b) and len(tokens_a) > len(tokens_c):
tokens_a.pop()
elif len(tokens_b) > len(tokens_a) and len(tokens_b) > len(tokens_c):
tokens_b.pop()
elif len(tokens_c) > len(tokens_a) and len(tokens_c) > len(tokens_b):
tokens_c.pop()
else:
tokens_c.pop()
@cache_results(_cache_fp="./dataset")
def get_dataset(args, processor, label_list, tokenizer, mode):
assert mode in ["train", "dev", "test"], "mode must be in train dev test!"
# use training data to construct the entity embedding
combine_train_and_test = False
if args.faiss_init and mode == "test" and not args.pretrain:
mode = "train"
if "ind" in args.data_dir: combine_train_and_test = True
else:
pass
if mode == "train":
train_examples = processor.get_train_examples(args.data_dir)
elif mode == "dev":
train_examples = processor.get_dev_examples(args.data_dir)
else:
train_examples = processor.get_test_examples(args.data_dir)
if combine_train_and_test:
logger.info("use all the dataset for getting the entity mask embedding in pretraining pretraining")
logger.info("use all the dataset for getting the entity mask embedding in pretraining pretraining")
train_examples = processor.get_test_examples(args.data_dir) + processor.get_train_examples(args.data_dir) + processor.get_dev_examples(args.data_dir)
from os import cpu_count
with open(os.path.join(args.data_dir, f"examples_{mode}.txt"), 'w') as file:
for line in train_examples:
d = {}
d.update(line.__dict__)
file.write(json.dumps(d) + '\n')
tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, use_fast=False)
features = []
# # with open(os.path.join(args.data_dir, "cached_relation_pattern.pkl"), "rb") as file:
# # pattern = pickle.load(file)
# pattern = None
# convert_examples_to_features_init(tokenizer)
# annotate_ = partial(
# convert_examples_to_features,
# max_seq_length=args.max_seq_length,
# mode = mode,
# pretrain = args.pretrain
# )
# features = list(
# tqdm(
# map(annotate_, train_examples),
# total=len(train_examples)
# )
# )
# encoder = MultiprocessingEncoder(tokenizer, args)
# encoder.initializer()
# for t in tqdm(train_examples):
# features.append(encoder.encode_lines([json.dumps(t.__dict__)]))
# for example in tqdm(train_examples):
# text_a = example.text_a
# text_b = example.text_b
# text_c = example.text_c
# bpe = tokenizer
# if 0:
# input_text_a = text_a
# input_text_b = text_b
# else:
# if text_a == "[MASK]":
# input_text_a = bpe.sep_token.join([text_a, text_b])
# input_text_b = text_c
# else:
# input_text_a = text_a
# input_text_b = bpe.sep_token.join([text_b, text_c])
# inputs = tokenizer(
# input_text_a,
# # input_text_b,
# truncation="longest_first",
# max_length=128,
# padding="longest",
# add_special_tokens=True,
# )
# # assert tokenizer.mask_token_id in inputs.input_ids, "mask token must in input"
# # features.append(asdict(InputFeatures(input_ids=inputs["input_ids"],
# # attention_mask=inputs['attention_mask'],
# # labels=example.label,
# # label=example.real_label
# # )
# # ))
file_inputs = [os.path.join(args.data_dir, f"examples_{mode}.txt")]
file_outputs = [os.path.join(args.data_dir, f"features_{mode}.txt")]
with contextlib.ExitStack() as stack:
inputs = [
stack.enter_context(open(input, "r", encoding="utf-8"))
if input != "-" else sys.stdin
for input in file_inputs
]
outputs = [
stack.enter_context(open(output, "w", encoding="utf-8"))
if output != "-" else sys.stdout
for output in file_outputs
]
encoder = MultiprocessingEncoder(tokenizer, args)
pool = Pool(16, initializer=encoder.initializer)
encoder.initializer()
encoded_lines = pool.imap(encoder.encode_lines, zip(*inputs), 1000)
# encoded_lines = map(encoder.encode_lines, zip(*inputs))
stats = Counter()
for i, (filt, enc_lines) in tqdm(enumerate(encoded_lines, start=1), total=len(train_examples)):
if filt == "PASS":
for enc_line, output_h in zip(enc_lines, outputs):
features.append(eval(enc_line))
# features.append(enc_line)
# print(enc_line, file=output_h)
else:
stats["num_filtered_" + filt] += 1
# if i % 10000 == 0:
# print("processed {} lines".format(i), file=sys.stderr)
for k, v in stats.most_common():
print("[{}] filtered {} lines".format(k, v), file=sys.stderr)
# threads = min(16, cpu_count())
# with Pool(threads, initializer=convert_examples_to_features_init, initargs=(tokenizer,)) as pool:
# annotate_ = partial(
# convert_examples_to_features,
# max_seq_length=args.max_seq_length,
# mode = mode,
# pretrain = args.pretrain
# )
# features = list(
# tqdm(
# pool.imap_unordered(annotate_, train_examples),
# total=len(train_examples),
# desc="convert examples to features",
# )
# )
# num_entities = len(processor.get_entities(args.data_dir))
for f_id, f in enumerate(features):
en = features[f_id].pop("en")
rel = features[f_id].pop("rel")
real_label = f['label']
cnt = 0
if not isinstance(en, list): break
pos = 0
for i,t in enumerate(f['input_ids']):
if t == tokenizer.pad_token_id:
features[f_id]['input_ids'][i] = en[cnt] + len(tokenizer)
cnt += 1
if features[f_id]['input_ids'][i] == real_label + len(tokenizer):
pos = i
if cnt == len(en): break
assert not (args.faiss_init and pos == 0)
features[f_id]['pos'] = pos
# for i,t in enumerate(f['input_ids']):
# if t == tokenizer.pad_token_id:
# features[f_id]['input_ids'][i] = rel + len(tokenizer) + num_entities
# break
features = KGCDataset(features)
return features
class MultiprocessingEncoder(object):
def __init__(self, tokenizer, args):
self.tokenizer = tokenizer
self.pretrain = args.pretrain
self.max_seq_length = args.max_seq_length
def initializer(self):
global bpe
bpe = self.tokenizer
def encode(self, line):
global bpe
ids = bpe.encode(line)
return list(map(str, ids))
def decode(self, tokens):
global bpe
return bpe.decode(tokens)
def encode_lines(self, lines):
"""
Encode a set of lines. All lines will be encoded together.
"""
enc_lines = []
for line in lines:
line = line.strip()
if len(line) == 0:
return ["EMPTY", None]
# enc_lines.append(" ".join(tokens))
enc_lines.append(json.dumps(self.convert_examples_to_features(example=eval(line))))
# enc_lines.append(" ")
# enc_lines.append("123")
return ["PASS", enc_lines]
def decode_lines(self, lines):
dec_lines = []
for line in lines:
tokens = map(int, line.strip().split())
dec_lines.append(self.decode(tokens))
return ["PASS", dec_lines]
def convert_examples_to_features(self, example):
pretrain = self.pretrain
max_seq_length = self.max_seq_length
global bpe
"""Loads a data file into a list of `InputBatch`s."""
# tokens_a = tokenizer.tokenize(example.text_a)
# tokens_b = tokenizer.tokenize(example.text_b)
# tokens_c = tokenizer.tokenize(example.text_c)
# _truncate_seq_triple(tokens_a, tokens_b, tokens_c, max_length= max_seq_length)
# text_a = " ".join(example['text_a'].split()[:128])
# text_b = " ".join(example['text_b'].split()[:128])
# text_c = " ".join(example['text_c'].split()[:128])
text_a = example['text_a']
text_b = example['text_b']
text_c = example['text_c']
if pretrain:
# the des of xxx is [MASK] .
input_text = f"The description of {text_a} is that {text_b} ."
inputs = bpe(
input_text,
truncation="longest_first",
max_length=max_seq_length,
padding="longest",
add_special_tokens=True,
)
else:
if text_a == "[MASK]":
input_text_a = bpe.sep_token.join([text_a, text_b])
input_text_b = text_c
else:
input_text_a = text_a
input_text_b = bpe.sep_token.join([text_b, text_c])
inputs = bpe(
input_text_a,
input_text_b,
truncation="longest_first",
max_length=max_seq_length,
padding="longest",
add_special_tokens=True,
)
# assert bpe.mask_token_id in inputs.input_ids, "mask token must in input"
features = asdict(InputFeatures(input_ids=inputs["input_ids"],
attention_mask=inputs['attention_mask'],
labels=example['label'],
label=example['real_label'],
en=example['en'],
rel=example['rel']
)
)
return features
if __name__ == "__main__":
dataset = KGCDataset('./dataset')