Relphormer baseline

This commit is contained in:
2022-12-26 04:54:46 +00:00
commit c0d0be076f
117 changed files with 1574716 additions and 0 deletions

View File

@ -0,0 +1,2 @@
from .data_module import KGC
from .processor import convert_examples_to_features, KGProcessor

View File

@ -0,0 +1,71 @@
"""Base DataModule class."""
from pathlib import Path
from typing import Dict
import argparse
import os
import pytorch_lightning as pl
from torch.utils.data import DataLoader
class Config(dict):
def __getattr__(self, name):
return self.get(name)
def __setattr__(self, name, val):
self[name] = val
BATCH_SIZE = 8
NUM_WORKERS = 8
class BaseDataModule(pl.LightningDataModule):
"""
Base DataModule.
Learn more at https://pytorch-lightning.readthedocs.io/en/stable/datamodules.html
"""
def __init__(self, args: argparse.Namespace = None) -> None:
super().__init__()
self.args = Config(vars(args)) if args is not None else {}
self.batch_size = self.args.get("batch_size", BATCH_SIZE)
self.num_workers = self.args.get("num_workers", NUM_WORKERS)
@staticmethod
def add_to_argparse(parser):
parser.add_argument(
"--batch_size", type=int, default=BATCH_SIZE, help="Number of examples to operate on per forward step."
)
parser.add_argument(
"--num_workers", type=int, default=0, help="Number of additional processes to load data."
)
parser.add_argument(
"--dataset", type=str, default="./dataset/NELL", help="Number of additional processes to load data."
)
return parser
def prepare_data(self):
"""
Use this method to do things that might write to disk or that need to be done only from a single GPU in distributed settings (so don't set state `self.x = y`).
"""
pass
def setup(self, stage=None):
"""
Split into train, val, test, and set dims.
Should assign `torch Dataset` objects to self.data_train, self.data_val, and optionally self.data_test.
"""
self.data_train = None
self.data_val = None
self.data_test = None
def train_dataloader(self):
return DataLoader(self.data_train, shuffle=True, batch_size=self.batch_size, num_workers=self.num_workers, pin_memory=True)
def val_dataloader(self):
return DataLoader(self.data_val, shuffle=False, batch_size=self.batch_size, num_workers=self.num_workers, pin_memory=True)
def test_dataloader(self):
return DataLoader(self.data_test, shuffle=False, batch_size=self.batch_size, num_workers=self.num_workers, pin_memory=True)

View File

@ -0,0 +1,196 @@
from dataclasses import dataclass
from typing import Any, Callable, Dict, List, NewType, Optional, Tuple, Union
from enum import Enum
import torch
from torch.utils.data import DataLoader
from transformers import AutoTokenizer, BertTokenizer
# from transformers.configuration_bert import BertTokenizer, BertTokenizerFast
from transformers.tokenization_utils_base import (BatchEncoding,
PreTrainedTokenizerBase)
from .base_data_module import BaseDataModule
from .processor import KGProcessor, get_dataset
import transformers
transformers.logging.set_verbosity_error()
class ExplicitEnum(Enum):
"""
Enum with more explicit error message for missing values.
"""
@classmethod
def _missing_(cls, value):
raise ValueError(
f"{value} is not a valid {cls.__name__}, please select one of {list(cls._value2member_map_.keys())}"
)
class PaddingStrategy(ExplicitEnum):
"""
Possible values for the ``padding`` argument in :meth:`PreTrainedTokenizerBase.__call__`. Useful for tab-completion
in an IDE.
"""
LONGEST = "longest"
MAX_LENGTH = "max_length"
DO_NOT_PAD = "do_not_pad"
import numpy as np
@dataclass
class DataCollatorForSeq2Seq:
"""
Data collator that will dynamically pad the inputs received, as well as the labels.
Args:
tokenizer (:class:`~transformers.PreTrainedTokenizer` or :class:`~transformers.PreTrainedTokenizerFast`):
The tokenizer used for encoding the data.
model (:class:`~transformers.PreTrainedModel`):
The model that is being trained. If set and has the `prepare_decoder_input_ids_from_labels`, use it to
prepare the `decoder_input_ids`
This is useful when using `label_smoothing` to avoid calculating loss twice.
padding (:obj:`bool`, :obj:`str` or :class:`~transformers.file_utils.PaddingStrategy`, `optional`, defaults to :obj:`True`):
Select a strategy to pad the returned sequences (according to the model's padding side and padding index)
among:
* :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
sequence is provided).
* :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
maximum acceptable input length for the model if that argument is not provided.
* :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
different lengths).
max_length (:obj:`int`, `optional`):
Maximum length of the returned list and optionally padding length (see above).
pad_to_multiple_of (:obj:`int`, `optional`):
If set will pad the sequence to a multiple of the provided value.
This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >=
7.5 (Volta).
label_pad_token_id (:obj:`int`, `optional`, defaults to -100):
The id to use when padding the labels (-100 will be automatically ignored by PyTorch loss functions).
"""
tokenizer: PreTrainedTokenizerBase
model: Optional[Any] = None
padding: Union[bool, str, PaddingStrategy] = True
max_length: Optional[int] = None
pad_to_multiple_of: Optional[int] = None
label_pad_token_id: int = -100
return_tensors: str = "pt"
num_labels: int = 0
def __call__(self, features, return_tensors=None):
if return_tensors is None:
return_tensors = self.return_tensors
labels = [feature.pop("labels") for feature in features] if "labels" in features[0].keys() else None
label = [feature.pop("label") for feature in features]
features_keys = {}
name_keys = list(features[0].keys())
for k in name_keys:
# ignore the padding arguments
if k in ["input_ids", "attention_mask", "token_type_ids"]: continue
try:
features_keys[k] = [feature.pop(k) for feature in features]
except KeyError:
continue
# features_keys[k] = [feature.pop(k) for feature in features]
# We have to pad the labels before calling `tokenizer.pad` as this method won't pad them and needs them of the
# same length to return tensors.
bsz = len(labels)
with torch.no_grad():
new_labels = torch.zeros(bsz, self.num_labels)
for i,l in enumerate(labels):
if isinstance(l, int):
new_labels[i][l] = 1
else:
for j in l:
new_labels[i][j] = 1
labels = new_labels
features = self.tokenizer.pad(
features,
padding=self.padding,
max_length=self.max_length,
pad_to_multiple_of=self.pad_to_multiple_of,
return_tensors=return_tensors,
)
features['labels'] = labels
features['label'] = torch.tensor(label)
features.update(features_keys)
return features
class KGC(BaseDataModule):
def __init__(self, args, model) -> None:
super().__init__(args)
self.tokenizer = AutoTokenizer.from_pretrained(self.args.model_name_or_path, use_fast=False)
self.processor = KGProcessor(self.tokenizer, args)
self.label_list = self.processor.get_labels(args.data_dir)
entity_list = self.processor.get_entities(args.data_dir)
num_added_tokens = self.tokenizer.add_special_tokens({'additional_special_tokens': entity_list})
self.sampler = DataCollatorForSeq2Seq(self.tokenizer,
model=model,
label_pad_token_id=self.tokenizer.pad_token_id,
pad_to_multiple_of=8 if self.args.precision == 16 else None,
padding="longest",
max_length=self.args.max_seq_length,
num_labels = len(entity_list),
)
relations_tokens = self.processor.get_relations(args.data_dir)
self.num_relations = len(relations_tokens)
num_added_tokens = self.tokenizer.add_special_tokens({'additional_special_tokens': relations_tokens})
vocab = self.tokenizer.get_added_vocab()
self.relation_id_st = vocab[relations_tokens[0]]
self.relation_id_ed = vocab[relations_tokens[-1]] + 1
self.entity_id_st = vocab[entity_list[0]]
self.entity_id_ed = vocab[entity_list[-1]] + 1
def setup(self, stage=None):
self.data_train = get_dataset(self.args, self.processor, self.label_list, self.tokenizer, "train")
self.data_val = get_dataset(self.args, self.processor, self.label_list, self.tokenizer, "dev")
self.data_test = get_dataset(self.args, self.processor, self.label_list, self.tokenizer, "test")
def prepare_data(self):
pass
def get_config(self):
d = {}
for k, v in self.__dict__.items():
if "st" in k or "ed" in k:
d.update({k:v})
return d
@staticmethod
def add_to_argparse(parser):
BaseDataModule.add_to_argparse(parser)
parser.add_argument("--model_name_or_path", type=str, default="roberta-base", help="the name or the path to the pretrained model")
parser.add_argument("--data_dir", type=str, default="roberta-base", help="the name or the path to the pretrained model")
parser.add_argument("--max_seq_length", type=int, default=256, help="Number of examples to operate on per forward step.")
parser.add_argument("--warm_up_radio", type=float, default=0.1, help="Number of examples to operate on per forward step.")
parser.add_argument("--eval_batch_size", type=int, default=8)
parser.add_argument("--overwrite_cache", action="store_true", default=False)
return parser
def get_tokenizer(self):
return self.tokenizer
def train_dataloader(self):
return DataLoader(self.data_train, num_workers=self.num_workers, pin_memory=True, collate_fn=self.sampler, batch_size=self.args.batch_size, shuffle=not self.args.faiss_init)
def val_dataloader(self):
return DataLoader(self.data_val, num_workers=self.num_workers, pin_memory=True, collate_fn=self.sampler, batch_size=self.args.eval_batch_size)
def test_dataloader(self):
return DataLoader(self.data_test, num_workers=self.num_workers, pin_memory=True, collate_fn=self.sampler, batch_size=self.args.eval_batch_size)

936
pretrain/data/processor.py Normal file
View File

@ -0,0 +1,936 @@
from re import DEBUG
import contextlib
import sys
from collections import Counter
from multiprocessing import Pool
from torch._C import HOIST_CONV_PACKED_PARAMS
from torch.utils.data import Dataset, Sampler, IterableDataset
from collections import defaultdict
from functools import partial
from multiprocessing import Pool
import os
import random
import json
import torch
import copy
import numpy as np
import pickle
from tqdm import tqdm
from dataclasses import dataclass, asdict, replace
import inspect
from transformers.models.auto.tokenization_auto import AutoTokenizer
from models.utils import get_entity_spans_pre_processing
def lmap(a, b):
return list(map(a,b))
def cache_results(_cache_fp, _refresh=False, _verbose=1):
r"""
cache_results是fastNLP中用于cache数据的装饰器。通过下面的例子看一下如何使用::
import time
import numpy as np
from fastNLP import cache_results
@cache_results('cache.pkl')
def process_data():
# 一些比较耗时的工作比如读取数据预处理数据等这里用time.sleep()代替耗时
time.sleep(1)
return np.random.randint(10, size=(5,))
start_time = time.time()
print("res =",process_data())
print(time.time() - start_time)
start_time = time.time()
print("res =",process_data())
print(time.time() - start_time)
# 输出内容如下,可以看到两次结果相同,且第二次几乎没有花费时间
# Save cache to cache.pkl.
# res = [5 4 9 1 8]
# 1.0042750835418701
# Read cache from cache.pkl.
# res = [5 4 9 1 8]
# 0.0040721893310546875
可以看到第二次运行的时候只用了0.0001s左右是由于第二次运行将直接从cache.pkl这个文件读取数据而不会经过再次预处理::
# 还是以上面的例子为例如果需要重新生成另一个cache比如另一个数据集的内容通过如下的方式调用即可
process_data(_cache_fp='cache2.pkl') # 完全不影响之前的cache.pkl'
上面的_cache_fp是cache_results会识别的参数它将从'cache2.pkl'这里缓存/读取数据,即这里的'cache2.pkl'覆盖默认的
'cache.pkl'。如果在你的函数前面加上了@cache_results()则你的函数会增加三个参数[_cache_fp, _refresh, _verbose]。
上面的例子即为使用_cache_fp的情况这三个参数不会传入到你的函数中当然你写的函数参数名也不可能包含这三个名称::
process_data(_cache_fp='cache2.pkl', _refresh=True) # 这里强制重新生成一份对预处理的cache。
# _verbose是用于控制输出信息的如果为0,则不输出任何内容;如果为1,则会提醒当前步骤是读取的cache还是生成了新的cache
:param str _cache_fp: 将返回结果缓存到什么位置;或从什么位置读取缓存。如果为Nonecache_results没有任何效用除非在
函数调用的时候传入_cache_fp这个参数。
:param bool _refresh: 是否重新生成cache。
:param int _verbose: 是否打印cache的信息。
:return:
"""
def wrapper_(func):
signature = inspect.signature(func)
for key, _ in signature.parameters.items():
if key in ('_cache_fp', '_refresh', '_verbose'):
raise RuntimeError("The function decorated by cache_results cannot have keyword `{}`.".format(key))
def wrapper(*args, **kwargs):
my_args = args[0]
mode = args[-1]
if '_cache_fp' in kwargs:
cache_filepath = kwargs.pop('_cache_fp')
assert isinstance(cache_filepath, str), "_cache_fp can only be str."
else:
cache_filepath = _cache_fp
if '_refresh' in kwargs:
refresh = kwargs.pop('_refresh')
assert isinstance(refresh, bool), "_refresh can only be bool."
else:
refresh = _refresh
if '_verbose' in kwargs:
verbose = kwargs.pop('_verbose')
assert isinstance(verbose, int), "_verbose can only be integer."
else:
verbose = _verbose
refresh_flag = True
model_name = my_args.model_name_or_path.split("/")[-1]
is_pretrain = my_args.pretrain
cache_filepath = os.path.join(my_args.data_dir, f"cached_{mode}_features{model_name}_pretrain{is_pretrain}_faiss{my_args.faiss_init}_seqlength{my_args.max_seq_length}_{my_args.litmodel_class}.pkl")
refresh = my_args.overwrite_cache
if cache_filepath is not None and refresh is False:
# load data
if os.path.exists(cache_filepath):
with open(cache_filepath, 'rb') as f:
results = pickle.load(f)
if verbose == 1:
logger.info("Read cache from {}.".format(cache_filepath))
refresh_flag = False
if refresh_flag:
results = func(*args, **kwargs)
if cache_filepath is not None:
if results is None:
raise RuntimeError("The return value is None. Delete the decorator.")
with open(cache_filepath, 'wb') as f:
pickle.dump(results, f)
logger.info("Save cache to {}.".format(cache_filepath))
return results
return wrapper
return wrapper_
import argparse
import csv
import logging
import os
import random
import sys
import numpy as np
import torch
from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler,
TensorDataset)
from torch.utils.data.distributed import DistributedSampler
from tqdm import tqdm, trange
# from torch.nn import CrossEntropyLoss, MSELoss
# from scipy.stats import pearsonr, spearmanr
# from sklearn.metrics import matthews_corrcoef, f1_scoreclass
logger = logging.getLogger(__name__)
class InputExample(object):
"""A single training/test example for simple sequence classification."""
def __init__(self, guid, text_a, text_b=None, text_c=None, label=None, real_label=None, en=None, rel=None):
"""Constructs a InputExample.
Args:
guid: Unique id for the example.
text_a: string. The untokenized text of the first sequence. For single
sequence tasks, only this sequence must be specified.
text_b: (Optional) string. The untokenized text of the second sequence.
Only must be specified for sequence pair tasks.
text_c: (Optional) string. The untokenized text of the third sequence.
Only must be specified for sequence triple tasks.
label: (Optional) string. list of entities
"""
self.guid = guid
self.text_a = text_a
self.text_b = text_b
self.text_c = text_c
self.label = label
self.real_label = real_label
self.en = en
self.rel = rel # rel id
@dataclass
class InputFeatures:
"""A single set of features of data."""
input_ids: torch.Tensor
attention_mask: torch.Tensor
labels: torch.Tensor = None
label: torch.Tensor = None
en: torch.Tensor = 0
rel: torch.Tensor = 0
pos: torch.Tensor = 0
class DataProcessor(object):
"""Base class for data converters for sequence classification data sets."""
def get_train_examples(self, data_dir):
"""Gets a collection of `InputExample`s for the train set."""
raise NotImplementedError()
def get_dev_examples(self, data_dir):
"""Gets a collection of `InputExample`s for the dev set."""
raise NotImplementedError()
def get_labels(self, data_dir):
"""Gets the list of labels for this data set."""
raise NotImplementedError()
@classmethod
def _read_tsv(cls, input_file, quotechar=None):
"""Reads a tab separated value file."""
with open(input_file, "r", encoding="utf-8") as f:
reader = csv.reader(f, delimiter="\t", quotechar=quotechar)
lines = []
for line in reader:
if sys.version_info[0] == 2:
line = list(unicode(cell, 'utf-8') for cell in line)
lines.append(line)
return lines
import copy
def solve_get_knowledge_store(line, set_type="train", pretrain=1):
"""
use the LM to get the entity embedding.
Transductive: triples + text description
Inductive: text description
"""
examples = []
head_ent_text = ent2text[line[0]]
tail_ent_text = ent2text[line[2]]
relation_text = rel2text[line[1]]
i=0
a = tail_filter_entities["\t".join([line[0],line[1]])]
b = head_filter_entities["\t".join([line[2],line[1]])]
guid = "%s-%s" % (set_type, i)
text_a = head_ent_text
text_b = relation_text
text_c = tail_ent_text
# use the description of c to predict A
examples.append(
InputExample(guid=guid, text_a="[PAD]", text_b=text_b + "[PAD]", text_c = "[PAD]" + " " + text_c, label=lmap(lambda x: ent2id[x], b), real_label=ent2id[line[0]], en=[ent2id[line[0]], rel2id[line[1]], ent2id[line[2]]], rel=0)
)
examples.append(
InputExample(guid=guid, text_a="[PAD]", text_b=text_b + "[PAD]", text_c = "[PAD]" + " " + text_a, label=lmap(lambda x: ent2id[x], b), real_label=ent2id[line[2]], en=[ent2id[line[0]], rel2id[line[1]], ent2id[line[2]]], rel=0)
)
return examples
def solve(line, set_type="train", pretrain=1):
examples = []
head_ent_text = ent2text[line[0]]
tail_ent_text = ent2text[line[2]]
relation_text = rel2text[line[1]]
i=0
a = tail_filter_entities["\t".join([line[0],line[1]])]
b = head_filter_entities["\t".join([line[2],line[1]])]
guid = "%s-%s" % (set_type, i)
text_a = head_ent_text
text_b = relation_text
text_c = tail_ent_text
if pretrain:
text_a_tokens = text_a.split()
for i in range(10):
st = random.randint(0, len(text_a_tokens))
examples.append(
InputExample(guid=guid, text_a="[MASK]", text_b=" ".join(text_a_tokens[st:min(st+64, len(text_a_tokens))]), text_c = "", label=ent2id[line[0]], real_label=ent2id[line[0]], en=0, rel=0)
)
examples.append(
InputExample(guid=guid, text_a="[MASK]", text_b=text_a, text_c = "", label=ent2id[line[0]], real_label=ent2id[line[0]], en=0, rel=0)
)
# examples.append(
# InputExample(guid=guid, text_a="[MASK]", text_b=text_c, text_c = "", label=ent2id[line[2]], real_label=ent2id[line[2]], en=0, rel=0)
# )
else:
# examples.append(
# InputExample(guid=guid, text_a="[MASK]", text_b=text_b + "[PAD]", text_c = "[UNK]" , label=lmap(lambda x: ent2id[x], b), real_label=ent2id[line[0]], en=ent2id[line[2]], rel=rel2id[line[1]]))
# examples.append(
# InputExample(guid=guid, text_a="[UNK] ", text_b=text_b + "[PAD]", text_c = "[MASK]", label=lmap(lambda x: ent2id[x], a), real_label=ent2id[line[2]], en=ent2id[line[0]], rel=rel2id[line[1]]))
# examples.append(
# InputExample(guid=guid, text_a="[UNK]" + " " + text_c, text_b=text_b + "[PAD]", text_c = "[MASK]", label=lmap(lambda x: ent2id[x], b), real_label=ent2id[line[0]], en=ent2id[line[2]], rel=rel2id[line[1]]))
# examples.append(
# InputExample(guid=guid, text_a="[MASK]", text_b=text_b + "[PAD]", text_c = "[UNK]" + text_a, label=lmap(lambda x: ent2id[x], a), real_label=ent2id[line[2]], en=ent2id[line[0]], rel=rel2id[line[1]]))
examples.append(
InputExample(guid=guid, text_a="[MASK]", text_b=text_b + "[PAD]", text_c = "[PAD]" + " " + text_c, label=lmap(lambda x: ent2id[x], b), real_label=ent2id[line[0]], en=[rel2id[line[1]], ent2id[line[2]]], rel=rel2id[line[1]]))
examples.append(
InputExample(guid=guid, text_a="[PAD] ", text_b=text_b + "[PAD]", text_c = "[MASK]" +" " + text_a, label=lmap(lambda x: ent2id[x], a), real_label=ent2id[line[2]], en=[ent2id[line[0]], rel2id[line[1]]], rel=rel2id[line[1]]))
return examples
def filter_init(head, tail, t1,t2, ent2id_, ent2token_, rel2id_):
global head_filter_entities
global tail_filter_entities
global ent2text
global rel2text
global ent2id
global ent2token
global rel2id
head_filter_entities = head
tail_filter_entities = tail
ent2text =t1
rel2text =t2
ent2id = ent2id_
ent2token = ent2token_
rel2id = rel2id_
def delete_init(ent2text_):
global ent2text
ent2text = ent2text_
class KGProcessor(DataProcessor):
"""Processor for knowledge graph data set."""
def __init__(self, tokenizer, args):
self.labels = set()
self.tokenizer = tokenizer
self.args = args
self.entity_path = os.path.join(args.data_dir, "entity2textlong.txt") if os.path.exists(os.path.join(args.data_dir, 'entity2textlong.txt')) \
else os.path.join(args.data_dir, "entity2text.txt")
def get_train_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "train.tsv")), "train", data_dir, self.args)
def get_dev_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev", data_dir, self.args)
def get_test_examples(self, data_dir, chunk=""):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, f"test{chunk}.tsv")), "test", data_dir, self.args)
def get_relations(self, data_dir):
"""Gets all labels (relations) in the knowledge graph."""
# return list(self.labels)
with open(os.path.join(data_dir, "relations.txt"), 'r') as f:
lines = f.readlines()
relations = []
for line in lines:
relations.append(line.strip().split('\t')[0])
rel2token = {ent : f"[RELATION_{i}]" for i, ent in enumerate(relations)}
return list(rel2token.values())
def get_labels(self, data_dir):
"""Gets all labels (0, 1) for triples in the knowledge graph."""
relation = []
with open(os.path.join(data_dir, "relation2text.txt"), 'r') as f:
lines = f.readlines()
entities = []
for line in lines:
relation.append(line.strip().split("\t")[-1])
return relation
def get_entities(self, data_dir):
"""Gets all entities in the knowledge graph."""
with open(self.entity_path, 'r') as f:
lines = f.readlines()
entities = []
for line in lines:
entities.append(line.strip().split("\t")[0])
ent2token = {ent : f"[ENTITY_{i}]" for i, ent in enumerate(entities)}
return list(ent2token.values())
def get_train_triples(self, data_dir):
"""Gets training triples."""
return self._read_tsv(os.path.join(data_dir, "train.tsv"))
def get_dev_triples(self, data_dir):
"""Gets validation triples."""
return self._read_tsv(os.path.join(data_dir, "dev.tsv"))
def get_test_triples(self, data_dir, chunk=""):
"""Gets test triples."""
return self._read_tsv(os.path.join(data_dir, f"test{chunk}.tsv"))
def _create_examples(self, lines, set_type, data_dir, args):
"""Creates examples for the training and dev sets."""
# entity to text
ent2text = {}
ent2text_with_type = {}
with open(self.entity_path, 'r') as f:
ent_lines = f.readlines()
for line in ent_lines:
temp = line.strip().split('\t')
try:
end = temp[1]#.find(',')
if "wiki" in data_dir:
assert "Q" in temp[0]
ent2text[temp[0]] = temp[1].replace("\\n", " ").replace("\\", "") #[:end]
except IndexError:
# continue
end = " "#.find(',')
if "wiki" in data_dir:
assert "Q" in temp[0]
ent2text[temp[0]] = end #[:end]
entities = list(ent2text.keys())
ent2token = {ent : f"[ENTITY_{i}]" for i, ent in enumerate(entities)}
ent2id = {ent : i for i, ent in enumerate(entities)}
rel2text = {}
with open(os.path.join(data_dir, "relation2text.txt"), 'r') as f:
rel_lines = f.readlines()
for line in rel_lines:
temp = line.strip().split('\t')
rel2text[temp[0]] = temp[1]
relation_names = {}
with open(os.path.join(data_dir, "relations.txt"), "r") as file:
for line in file.readlines():
t = line.strip()
relation_names[t] = rel2text[t]
tmp_lines = []
not_in_text = 0
for line in tqdm(lines, desc="delete entities without text name."):
if (line[0] not in ent2text) or (line[2] not in ent2text) or (line[1] not in rel2text):
not_in_text += 1
continue
tmp_lines.append(line)
lines = tmp_lines
print(f"total entity not in text : {not_in_text} ")
# rel id -> relation token id
num_entities = len(self.get_entities(args.data_dir))
rel2id = {w:i+num_entities for i,w in enumerate(relation_names.keys())}
# add reverse relation
# tmp_rel2id = {}
# num_relations = len(rel2id)
# cnt = 0
# for k, v in rel2id.items():
# tmp_rel2id[k + " (reverse)"] = num_relations + cnt
# cnt += 1
# rel2id.update(tmp_rel2id)
examples = []
# head filter head entity
head_filter_entities = defaultdict(list)
tail_filter_entities = defaultdict(list)
dataset_list = ["train.tsv", "dev.tsv", "test.tsv"]
# in training, only use the train triples
if set_type == "train" and not args.pretrain: dataset_list = dataset_list[0:1]
for m in dataset_list:
with open(os.path.join(data_dir, m), 'r') as file:
train_lines = file.readlines()
for idx in range(len(train_lines)):
train_lines[idx] = train_lines[idx].strip().split("\t")
for line in train_lines:
tail_filter_entities["\t".join([line[0], line[1]])].append(line[2])
head_filter_entities["\t".join([line[2], line[1]])].append(line[0])
max_head_entities = max(len(_) for _ in head_filter_entities.values())
max_tail_entities = max(len(_) for _ in tail_filter_entities.values())
# use bce loss, ignore the mlm
if set_type == "train" and args.bce:
lines = []
for k, v in tail_filter_entities.items():
h, r = k.split('\t')
t = v[0]
lines.append([h, r, t])
for k, v in head_filter_entities.items():
t, r = k.split('\t')
h = v[0]
lines.append([h, r, t])
# for training , select each entity as for get mask embedding.
if args.pretrain:
rel = list(rel2text.keys())[0]
lines = []
for k in ent2text.keys():
lines.append([k, rel, k])
print(f"max number of filter entities : {max_head_entities} {max_tail_entities}")
from os import cpu_count
threads = min(1, cpu_count())
filter_init(head_filter_entities, tail_filter_entities,ent2text, rel2text, ent2id, ent2token, rel2id
)
if hasattr(args, "faiss_init") and args.faiss_init:
annotate_ = partial(
solve_get_knowledge_store,
pretrain=self.args.pretrain
)
else:
annotate_ = partial(
solve,
pretrain=self.args.pretrain
)
examples = list(
tqdm(
map(annotate_, lines),
total=len(lines),
desc="convert text to examples"
)
)
# with Pool(threads, initializer=filter_init, initargs=(head_filter_entities, tail_filter_entities,ent2text, rel2text,
# ent2text_with_type, rel2id,)) as pool:
# annotate_ = partial(
# solve,
# )
# examples = list(
# tqdm(
# map(annotate_, lines, chunksize= 128),
# total=len(lines),
# desc="convert text to examples"
# )
# )
tmp_examples = []
for e in examples:
for ee in e:
tmp_examples.append(ee)
examples = tmp_examples
# delete vars
del head_filter_entities, tail_filter_entities, ent2text, rel2text, ent2id, ent2token, rel2id
return examples
class Verbalizer(object):
def __init__(self, args):
if "WN18RR" in args.data_dir:
self.mode = "WN18RR"
elif "FB15k" in args.data_dir:
self.mode = "FB15k"
elif "umls" in args.data_dir:
self.mode = "umls"
elif "codexs" in args.data_dir:
self.mode = "codexs"
elif "codexl" in args.data_dir:
self.mode = "codexl"
elif "FB13" in args.data_dir:
self.mode = "FB13"
elif "WN11" in args.data_dir:
self.mode = "WN11"
def _convert(self, head, relation, tail):
if self.mode == "umls":
return f"The {relation} {head} is "
return f"{head} {relation}"
class KGCDataset(Dataset):
def __init__(self, features):
self.features = features
def __getitem__(self, index):
return self.features[index]
def __len__(self):
return len(self.features)
def convert_examples_to_features_init(tokenizer_for_convert):
global tokenizer
tokenizer = tokenizer_for_convert
def convert_examples_to_features(example, max_seq_length, mode, pretrain=1):
"""Loads a data file into a list of `InputBatch`s."""
# tokens_a = tokenizer.tokenize(example.text_a)
# tokens_b = tokenizer.tokenize(example.text_b)
# tokens_c = tokenizer.tokenize(example.text_c)
# _truncate_seq_triple(tokens_a, tokens_b, tokens_c, max_length= max_seq_length)
text_a = " ".join(example.text_a.split()[:128])
text_b = " ".join(example.text_b.split()[:128])
text_c = " ".join(example.text_c.split()[:128])
if pretrain:
input_text_a = text_a
input_text_b = text_b
else:
input_text_a = tokenizer.sep_token.join([text_a, text_b])
input_text_b = text_c
inputs = tokenizer(
input_text_a,
input_text_b,
truncation="longest_first",
max_length=max_seq_length,
padding="longest",
add_special_tokens=True,
)
# assert tokenizer.mask_token_id in inputs.input_ids, "mask token must in input"
features = asdict(InputFeatures(input_ids=inputs["input_ids"],
attention_mask=inputs['attention_mask'],
labels=torch.tensor(example.label),
label=torch.tensor(example.real_label)
)
)
return features
def _truncate_seq_pair(tokens_a, tokens_b, max_length):
"""Truncates a sequence pair in place to the maximum length."""
# This is a simple heuristic which will always truncate the longer sequence
# one token at a time. This makes more sense than truncating an equal percent
# of tokens from each, since if one sequence is very short then each token
# that's truncated likely contains more information than a longer sequence.
while True:
total_length = len(tokens_a) + len(tokens_b)
if total_length <= max_length:
break
if len(tokens_a) > len(tokens_b):
tokens_a.pop()
else:
tokens_b.pop()
def _truncate_seq_triple(tokens_a, tokens_b, tokens_c, max_length):
"""Truncates a sequence triple in place to the maximum length."""
# This is a simple heuristic which will always truncate the longer sequence
# one token at a time. This makes more sense than truncating an equal percent
# of tokens from each, since if one sequence is very short then each token
# that's truncated likely contains more information than a longer sequence.
while True:
total_length = len(tokens_a) + len(tokens_b) + len(tokens_c)
if total_length <= max_length:
break
if len(tokens_a) > len(tokens_b) and len(tokens_a) > len(tokens_c):
tokens_a.pop()
elif len(tokens_b) > len(tokens_a) and len(tokens_b) > len(tokens_c):
tokens_b.pop()
elif len(tokens_c) > len(tokens_a) and len(tokens_c) > len(tokens_b):
tokens_c.pop()
else:
tokens_c.pop()
@cache_results(_cache_fp="./dataset")
def get_dataset(args, processor, label_list, tokenizer, mode):
assert mode in ["train", "dev", "test"], "mode must be in train dev test!"
# use training data to construct the entity embedding
combine_train_and_test = False
if args.faiss_init and mode == "test" and not args.pretrain:
mode = "train"
if "ind" in args.data_dir: combine_train_and_test = True
else:
pass
if mode == "train":
train_examples = processor.get_train_examples(args.data_dir)
elif mode == "dev":
train_examples = processor.get_dev_examples(args.data_dir)
else:
train_examples = processor.get_test_examples(args.data_dir)
if combine_train_and_test:
logger.info("use all the dataset for getting the entity mask embedding in pretraining pretraining")
logger.info("use all the dataset for getting the entity mask embedding in pretraining pretraining")
train_examples = processor.get_test_examples(args.data_dir) + processor.get_train_examples(args.data_dir) + processor.get_dev_examples(args.data_dir)
from os import cpu_count
with open(os.path.join(args.data_dir, f"examples_{mode}.txt"), 'w') as file:
for line in train_examples:
d = {}
d.update(line.__dict__)
file.write(json.dumps(d) + '\n')
tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, use_fast=False)
features = []
# # with open(os.path.join(args.data_dir, "cached_relation_pattern.pkl"), "rb") as file:
# # pattern = pickle.load(file)
# pattern = None
# convert_examples_to_features_init(tokenizer)
# annotate_ = partial(
# convert_examples_to_features,
# max_seq_length=args.max_seq_length,
# mode = mode,
# pretrain = args.pretrain
# )
# features = list(
# tqdm(
# map(annotate_, train_examples),
# total=len(train_examples)
# )
# )
# encoder = MultiprocessingEncoder(tokenizer, args)
# encoder.initializer()
# for t in tqdm(train_examples):
# features.append(encoder.encode_lines([json.dumps(t.__dict__)]))
# for example in tqdm(train_examples):
# text_a = example.text_a
# text_b = example.text_b
# text_c = example.text_c
# bpe = tokenizer
# if 0:
# input_text_a = text_a
# input_text_b = text_b
# else:
# if text_a == "[MASK]":
# input_text_a = bpe.sep_token.join([text_a, text_b])
# input_text_b = text_c
# else:
# input_text_a = text_a
# input_text_b = bpe.sep_token.join([text_b, text_c])
# inputs = tokenizer(
# input_text_a,
# # input_text_b,
# truncation="longest_first",
# max_length=128,
# padding="longest",
# add_special_tokens=True,
# )
# # assert tokenizer.mask_token_id in inputs.input_ids, "mask token must in input"
# # features.append(asdict(InputFeatures(input_ids=inputs["input_ids"],
# # attention_mask=inputs['attention_mask'],
# # labels=example.label,
# # label=example.real_label
# # )
# # ))
file_inputs = [os.path.join(args.data_dir, f"examples_{mode}.txt")]
file_outputs = [os.path.join(args.data_dir, f"features_{mode}.txt")]
with contextlib.ExitStack() as stack:
inputs = [
stack.enter_context(open(input, "r", encoding="utf-8"))
if input != "-" else sys.stdin
for input in file_inputs
]
outputs = [
stack.enter_context(open(output, "w", encoding="utf-8"))
if output != "-" else sys.stdout
for output in file_outputs
]
encoder = MultiprocessingEncoder(tokenizer, args)
pool = Pool(16, initializer=encoder.initializer)
encoder.initializer()
encoded_lines = pool.imap(encoder.encode_lines, zip(*inputs), 1000)
# encoded_lines = map(encoder.encode_lines, zip(*inputs))
stats = Counter()
for i, (filt, enc_lines) in tqdm(enumerate(encoded_lines, start=1), total=len(train_examples)):
if filt == "PASS":
for enc_line, output_h in zip(enc_lines, outputs):
features.append(eval(enc_line))
# features.append(enc_line)
# print(enc_line, file=output_h)
else:
stats["num_filtered_" + filt] += 1
# if i % 10000 == 0:
# print("processed {} lines".format(i), file=sys.stderr)
for k, v in stats.most_common():
print("[{}] filtered {} lines".format(k, v), file=sys.stderr)
# threads = min(16, cpu_count())
# with Pool(threads, initializer=convert_examples_to_features_init, initargs=(tokenizer,)) as pool:
# annotate_ = partial(
# convert_examples_to_features,
# max_seq_length=args.max_seq_length,
# mode = mode,
# pretrain = args.pretrain
# )
# features = list(
# tqdm(
# pool.imap_unordered(annotate_, train_examples),
# total=len(train_examples),
# desc="convert examples to features",
# )
# )
# num_entities = len(processor.get_entities(args.data_dir))
for f_id, f in enumerate(features):
en = features[f_id].pop("en")
rel = features[f_id].pop("rel")
real_label = f['label']
cnt = 0
if not isinstance(en, list): break
pos = 0
for i,t in enumerate(f['input_ids']):
if t == tokenizer.pad_token_id:
features[f_id]['input_ids'][i] = en[cnt] + len(tokenizer)
cnt += 1
if features[f_id]['input_ids'][i] == real_label + len(tokenizer):
pos = i
if cnt == len(en): break
assert not (args.faiss_init and pos == 0)
features[f_id]['pos'] = pos
# for i,t in enumerate(f['input_ids']):
# if t == tokenizer.pad_token_id:
# features[f_id]['input_ids'][i] = rel + len(tokenizer) + num_entities
# break
features = KGCDataset(features)
return features
class MultiprocessingEncoder(object):
def __init__(self, tokenizer, args):
self.tokenizer = tokenizer
self.pretrain = args.pretrain
self.max_seq_length = args.max_seq_length
def initializer(self):
global bpe
bpe = self.tokenizer
def encode(self, line):
global bpe
ids = bpe.encode(line)
return list(map(str, ids))
def decode(self, tokens):
global bpe
return bpe.decode(tokens)
def encode_lines(self, lines):
"""
Encode a set of lines. All lines will be encoded together.
"""
enc_lines = []
for line in lines:
line = line.strip()
if len(line) == 0:
return ["EMPTY", None]
# enc_lines.append(" ".join(tokens))
enc_lines.append(json.dumps(self.convert_examples_to_features(example=eval(line))))
# enc_lines.append(" ")
# enc_lines.append("123")
return ["PASS", enc_lines]
def decode_lines(self, lines):
dec_lines = []
for line in lines:
tokens = map(int, line.strip().split())
dec_lines.append(self.decode(tokens))
return ["PASS", dec_lines]
def convert_examples_to_features(self, example):
pretrain = self.pretrain
max_seq_length = self.max_seq_length
global bpe
"""Loads a data file into a list of `InputBatch`s."""
# tokens_a = tokenizer.tokenize(example.text_a)
# tokens_b = tokenizer.tokenize(example.text_b)
# tokens_c = tokenizer.tokenize(example.text_c)
# _truncate_seq_triple(tokens_a, tokens_b, tokens_c, max_length= max_seq_length)
# text_a = " ".join(example['text_a'].split()[:128])
# text_b = " ".join(example['text_b'].split()[:128])
# text_c = " ".join(example['text_c'].split()[:128])
text_a = example['text_a']
text_b = example['text_b']
text_c = example['text_c']
if pretrain:
# the des of xxx is [MASK] .
input_text = f"The description of {text_a} is that {text_b} ."
inputs = bpe(
input_text,
truncation="longest_first",
max_length=max_seq_length,
padding="longest",
add_special_tokens=True,
)
else:
if text_a == "[MASK]":
input_text_a = bpe.sep_token.join([text_a, text_b])
input_text_b = text_c
else:
input_text_a = text_a
input_text_b = bpe.sep_token.join([text_b, text_c])
inputs = bpe(
input_text_a,
input_text_b,
truncation="longest_first",
max_length=max_seq_length,
padding="longest",
add_special_tokens=True,
)
# assert bpe.mask_token_id in inputs.input_ids, "mask token must in input"
features = asdict(InputFeatures(input_ids=inputs["input_ids"],
attention_mask=inputs['attention_mask'],
labels=example['label'],
label=example['real_label'],
en=example['en'],
rel=example['rel']
)
)
return features
if __name__ == "__main__":
dataset = KGCDataset('./dataset')

View File

@ -0,0 +1,2 @@
from .transformer import *
from .base import *

View File

@ -0,0 +1,97 @@
import argparse
import pytorch_lightning as pl
import torch
from typing import Dict, Any
OPTIMIZER = "AdamW"
LR = 5e-5
LOSS = "cross_entropy"
ONE_CYCLE_TOTAL_STEPS = 100
class Config(dict):
def __getattr__(self, name):
return self.get(name)
def __setattr__(self, name, val):
self[name] = val
class BaseLitModel(pl.LightningModule):
"""
Generic PyTorch-Lightning class that must be initialized with a PyTorch module.
"""
def __init__(self, model, args: argparse.Namespace = None):
super().__init__()
self.model = model
self.args = Config(vars(args)) if args is not None else {}
optimizer = self.args.get("optimizer", OPTIMIZER)
self.optimizer_class = getattr(torch.optim, optimizer)
self.lr = self.args.get("lr", LR)
@staticmethod
def add_to_argparse(parser):
parser.add_argument("--optimizer", type=str, default=OPTIMIZER, help="optimizer class from torch.optim")
parser.add_argument("--lr", type=float, default=LR)
parser.add_argument("--weight_decay", type=float, default=0.01)
return parser
def configure_optimizers(self):
optimizer = self.optimizer_class(self.parameters(), lr=self.lr)
if self.one_cycle_max_lr is None:
return optimizer
scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer=optimizer, max_lr=self.one_cycle_max_lr, total_steps=self.one_cycle_total_steps)
return {"optimizer": optimizer, "lr_scheduler": scheduler, "monitor": "val_loss"}
def forward(self, x):
return self.model(x)
def training_step(self, batch, batch_idx): # pylint: disable=unused-argument
x, y = batch
logits = self(x)
loss = self.loss_fn(logits, y)
self.log("train_loss", loss)
self.train_acc(logits, y)
self.log("train_acc", self.train_acc, on_step=False, on_epoch=True)
return loss
def validation_step(self, batch, batch_idx): # pylint: disable=unused-argument
x, y = batch
logits = self(x)
loss = self.loss_fn(logits, y)
self.log("val_loss", loss, prog_bar=True)
self.val_acc(logits, y)
self.log("val_acc", self.val_acc, on_step=False, on_epoch=True, prog_bar=True)
def test_step(self, batch, batch_idx): # pylint: disable=unused-argument
x, y = batch
logits = self(x)
self.test_acc(logits, y)
self.log("test_acc", self.test_acc, on_step=False, on_epoch=True)
@property
def num_training_steps(self) -> int:
"""Total training steps inferred from datamodule and devices."""
if isinstance(self.trainer.limit_train_batches, int) and self.trainer.limit_train_batches != 0:
dataset_size = self.trainer.limit_train_batches
elif isinstance(self.trainer.limit_train_batches, float):
# limit_train_batches is a percentage of batches
dataset_size = len(self.trainer.datamodule.train_dataloader())
dataset_size = int(dataset_size * self.trainer.limit_train_batches)
else:
dataset_size = len(self.trainer.datamodule.train_dataloader())
num_devices = max(1, self.trainer.num_gpus, self.trainer.num_processes)
if self.trainer.tpu_cores:
num_devices = max(num_devices, self.trainer.tpu_cores)
effective_batch_size = self.trainer.accumulate_grad_batches * num_devices
max_estimated_steps = (dataset_size // effective_batch_size) * self.trainer.max_epochs
if self.trainer.max_steps and self.trainer.max_steps < max_estimated_steps:
return self.trainer.max_steps
return max_estimated_steps

View File

@ -0,0 +1,503 @@
from logging import debug
import random
import pytorch_lightning as pl
import torch
import pickle
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import json
# from transformers.utils.dummy_pt_objects import PrefixConstrainedLogitsProcessor
from .base import BaseLitModel
from transformers.optimization import get_linear_schedule_with_warmup, get_cosine_schedule_with_warmup
from functools import partial
from .utils import rank_score, acc, LabelSmoothSoftmaxCEV1
from typing import Callable, Iterable, List
def lmap(f: Callable, x: Iterable) -> List:
"""list(map(f, x))"""
return list(map(f, x))
def multilabel_categorical_crossentropy(y_pred, y_true):
y_pred = (1 - 2 * y_true) * y_pred
y_pred_neg = y_pred - y_true * 1e12
y_pred_pos = y_pred - (1 - y_true) * 1e12
zeros = torch.zeros_like(y_pred[..., :1])
y_pred_neg = torch.cat([y_pred_neg, zeros], dim=-1)
y_pred_pos = torch.cat([y_pred_pos, zeros], dim=-1)
neg_loss = torch.logsumexp(y_pred_neg, dim=-1)
pos_loss = torch.logsumexp(y_pred_pos, dim=-1)
return (neg_loss + pos_loss).mean()
def decode(output_ids, tokenizer):
return lmap(str.strip, tokenizer.batch_decode(output_ids, skip_special_tokens=False, clean_up_tokenization_spaces=True))
class TransformerLitModel(BaseLitModel):
def __init__(self, model, args, tokenizer=None, data_config={}):
super().__init__(model, args)
self.save_hyperparameters(args)
if args.bce:
self.loss_fn = torch.nn.BCEWithLogitsLoss()
elif args.label_smoothing != 0.0:
self.loss_fn = LabelSmoothSoftmaxCEV1(lb_smooth=args.label_smoothing)
else:
self.loss_fn = nn.CrossEntropyLoss()
self.best_acc = 0
self.first = True
self.tokenizer = tokenizer
self.__dict__.update(data_config)
# resize the word embedding layer
self.model.resize_token_embeddings(len(self.tokenizer))
self.decode = partial(decode, tokenizer=self.tokenizer)
if args.pretrain:
self._freaze_attention()
elif "ind" in args.data_dir:
# for inductive setting, use feeaze the word embedding
self._freaze_word_embedding()
def forward(self, x):
return self.model(x)
def training_step(self, batch, batch_idx): # pylint: disable=unused-argument
# embed();exit()
# print(self.optimizers().param_groups[1]['lr'])
labels = batch.pop("labels")
label = batch.pop("label")
pos = batch.pop("pos")
try:
en = batch.pop("en")
rel = batch.pop("rel")
except KeyError:
pass
input_ids = batch['input_ids']
logits = self.model(**batch, return_dict=True).logits
_, mask_idx = (input_ids == self.tokenizer.mask_token_id).nonzero(as_tuple=True)
bs = input_ids.shape[0]
mask_logits = logits[torch.arange(bs), mask_idx][:, self.entity_id_st:self.entity_id_ed]
assert mask_idx.shape[0] == bs, "only one mask in sequence!"
if self.args.bce:
loss = self.loss_fn(mask_logits, labels)
else:
loss = self.loss_fn(mask_logits, label)
if batch_idx == 0:
print('\n'.join(self.decode(batch['input_ids'][:4])))
return loss
def _eval(self, batch, batch_idx, ):
labels = batch.pop("labels")
input_ids = batch['input_ids']
# single label
label = batch.pop('label')
pos = batch.pop('pos')
my_keys = list(batch.keys())
for k in my_keys:
if k not in ["input_ids", "attention_mask", "token_type_ids"]:
batch.pop(k)
logits = self.model(**batch, return_dict=True).logits[:, :, self.entity_id_st:self.entity_id_ed]
_, mask_idx = (input_ids == self.tokenizer.mask_token_id).nonzero(as_tuple=True)
bsz = input_ids.shape[0]
logits = logits[torch.arange(bsz), mask_idx]
# get the entity ranks
# filter the entity
assert labels[0][label[0]], "correct ids must in filiter!"
labels[torch.arange(bsz), label] = 0
assert logits.shape == labels.shape
logits += labels * -100 # mask entityj
# for i in range(bsz):
# logits[i][labels]
_, outputs = torch.sort(logits, dim=1, descending=True)
_, outputs = torch.sort(outputs, dim=1)
ranks = outputs[torch.arange(bsz), label].detach().cpu() + 1
return dict(ranks = np.array(ranks))
def validation_step(self, batch, batch_idx):
result = self._eval(batch, batch_idx)
return result
def validation_epoch_end(self, outputs) -> None:
ranks = np.concatenate([_['ranks'] for _ in outputs])
total_ranks = ranks.shape[0]
if not self.args.pretrain:
l_ranks = ranks[np.array(list(np.arange(0, total_ranks, 2)))]
r_ranks = ranks[np.array(list(np.arange(0, total_ranks, 2))) + 1]
self.log("Eval/lhits10", (l_ranks<=10).mean())
self.log("Eval/rhits10", (r_ranks<=10).mean())
hits20 = (ranks<=20).mean()
hits10 = (ranks<=10).mean()
hits3 = (ranks<=3).mean()
hits1 = (ranks<=1).mean()
self.log("Eval/hits10", hits10)
self.log("Eval/hits20", hits20)
self.log("Eval/hits3", hits3)
self.log("Eval/hits1", hits1)
self.log("Eval/mean_rank", ranks.mean())
self.log("Eval/mrr", (1. / ranks).mean())
self.log("hits10", hits10, prog_bar=True)
self.log("hits1", hits1, prog_bar=True)
def test_step(self, batch, batch_idx): # pylint: disable=unused-argument
# ranks = self._eval(batch, batch_idx)
result = self._eval(batch, batch_idx)
# self.log("Test/ranks", np.mean(ranks))
return result
def test_epoch_end(self, outputs) -> None:
ranks = np.concatenate([_['ranks'] for _ in outputs])
hits20 = (ranks<=20).mean()
hits10 = (ranks<=10).mean()
hits3 = (ranks<=3).mean()
hits1 = (ranks<=1).mean()
self.log("Test/hits10", hits10)
self.log("Test/hits20", hits20)
self.log("Test/hits3", hits3)
self.log("Test/hits1", hits1)
self.log("Test/mean_rank", ranks.mean())
self.log("Test/mrr", (1. / ranks).mean())
def configure_optimizers(self):
no_decay_param = ["bias", "LayerNorm.weight"]
optimizer_group_parameters = [
{"params": [p for n, p in self.model.named_parameters() if p.requires_grad and not any(nd in n for nd in no_decay_param)], "weight_decay": self.args.weight_decay},
{"params": [p for n, p in self.model.named_parameters() if p.requires_grad and any(nd in n for nd in no_decay_param)], "weight_decay": 0}
]
optimizer = self.optimizer_class(optimizer_group_parameters, lr=self.lr, eps=1e-8)
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=self.num_training_steps * self.args.warm_up_radio, num_training_steps=self.num_training_steps)
return {
"optimizer": optimizer,
"lr_scheduler":{
'scheduler': scheduler,
'interval': 'step', # or 'epoch'
'frequency': 1,
}
}
def _freaze_attention(self):
for k, v in self.model.named_parameters():
if "word" not in k:
v.requires_grad = False
else:
print(k)
def _freaze_word_embedding(self):
for k, v in self.model.named_parameters():
if "word" in k:
print(k)
v.requires_grad = False
@staticmethod
def add_to_argparse(parser):
parser = BaseLitModel.add_to_argparse(parser)
parser.add_argument("--label_smoothing", type=float, default=0.1, help="")
parser.add_argument("--bce", type=int, default=0, help="")
return parser
import faiss
import os
class GetEntityEmbeddingLitModel(TransformerLitModel):
def __init__(self, model, args, tokenizer, data_config={}):
super().__init__(model, args, tokenizer, data_config)
self.faissid2entityid = {}
# self.index = faiss.IndexFlatL2(d) # build the index
d, measure = self.model.config.hidden_size, faiss.METRIC_L2
# param = 'HNSW64'
# self.index = faiss.index_factory(d, param, measure)
self.index = faiss.IndexFlatL2(d) # build the index
# print(self.index.is_trained) # 此时输出为True
# index.add(xb)
self.cnt_batch = 0
self.total_embedding = []
def test_step(self, batch, batch_idx):
labels = batch.pop("labels")
mask_idx = batch.pop("pos")
input_ids = batch['input_ids']
# single label
label = batch.pop('label')
# last layer
hidden_states = self.model(**batch, return_dict=True, output_hidden_states=True).hidden_states[-1]
# _, mask_idx = (input_ids == self.tokenizer.mask_token_id).nonzero(as_tuple=True)
bsz = input_ids.shape[0]
entity_embedding = hidden_states[torch.arange(bsz), mask_idx].cpu()
# use normalize or not ?
# entity_embedding = F.normalize(entity_embedding, dim=-1, p = 2)
self.total_embedding.append(entity_embedding)
# self.index.add(np.array(entity_embedding, dtype=np.float32))
for i, l in zip(range(bsz), label):
self.faissid2entityid[i+self.cnt_batch] = l.cpu()
self.cnt_batch += bsz
def test_epoch_end(self, outputs) -> None:
self.total_embedding = np.concatenate(self.total_embedding, axis=0)
# self.index.train(self.total_embedding)
print(faiss.MatrixStats(self.total_embedding).comments)
self.index.add(self.total_embedding)
faiss.write_index(self.index, os.path.join(self.args.data_dir, "faiss_dump.index"))
with open(os.path.join(self.args.data_dir, "faissid2entityid.pkl") ,'wb') as file:
pickle.dump(self.faissid2entityid, file)
with open(os.path.join(self.args.data_dir, "total_embedding.pkl") ,'wb') as file:
pickle.dump(self.total_embedding, file)
# print(f"number of entity embedding : {len(self.faissid2entityid)}")
@staticmethod
def add_to_argparse(parser):
parser = TransformerLitModel.add_to_argparse(parser)
parser.add_argument("--faiss_init", type=int, default=1, help="get the embedding and save it the file.")
return parser
class UseEntityEmbeddingLitModel(TransformerLitModel):
def __init__(self, model, args, tokenizer, data_config={}):
super().__init__(model, args, tokenizer, data_config)
self.faissid2entityid = pickle.load(open(os.path.join(self.args.data_dir, "faissid2entityid.pkl") ,'rb'))
self.index = faiss.read_index(os.path.join(self.args.data_dir, "faiss_dump.index"))
self.dis2logits = distance2logits_2
def _eval(self, batch, batch_idx, ):
labels = batch.pop("labels")
pos = batch.pop("pos")
input_ids = batch['input_ids']
# single label
label = batch.pop('label')
hidden_states = self.model(**batch, return_dict=True, output_hidden_states=True).hidden_states[-1]
_, mask_idx = (input_ids == self.tokenizer.mask_token_id).nonzero(as_tuple=True)
bsz = input_ids.shape[0]
mask_embedding = np.array(hidden_states[torch.arange(bsz), mask_idx].cpu(), dtype=np.float32)
topk = 200
D, I = self.index.search(mask_embedding, topk)
labels[torch.arange(bsz), label] = 0
# logits = torch.zeros_like(labels)
# D = torch.softmax(torch.exp(-1. * torch.tensor(D)), dim=-1)
# for i in range(bsz):
# for j in range(topk):
# logits[i][self.faissid2entityid[I[i][j]]] += D[i][j]
# # get the entity ranks
# # filter the entity
# assert labels[0][label[0]], "correct ids must in filiter!"
# labels[torch.arange(bsz), label] = 0
# assert logits.shape == labels.shape
# logits += labels * -100 # mask entityj
# _, outputs = torch.sort(logits, dim=1, descending=True)
# _, outputs = torch.sort(outputs, dim=1)
# ranks = outputs[torch.arange(bsz), label].detach().cpu() + 1
entity_logits = torch.full(labels.shape, -100.).to(self.device)
D = self.dis2logits(D)
for i in range(bsz):
for j in range(topk):
# filter entity in labels
if I[i][j] not in self.faissid2entityid:
print(I[i][j])
break
# assert I[i][j] in self.faissid2entityid, print(I[i][j])
if labels[i][self.faissid2entityid[I[i][j]]]: continue
if entity_logits[i][self.faissid2entityid[I[i][j]]] == -100.:
entity_logits[i][self.faissid2entityid[I[i][j]]] = D[i][j]
# no added together
# else:
# entity_logits[i][self.faissid2entityid[I[i][j]]] += D[i][j]
# get the entity ranks
# filter the entity
assert entity_logits.shape == labels.shape
_, outputs = torch.sort(entity_logits, dim=1, descending=True)
_, outputs = torch.sort(outputs, dim=1)
ranks = outputs[torch.arange(bsz), label].detach().cpu() + 1
return dict(ranks = np.array(ranks))
@staticmethod
def add_to_argparse(parser):
parser = TransformerLitModel.add_to_argparse(parser)
parser.add_argument("--faiss_init", type=int, default=0, help="get the embedding and save it the file.")
parser.add_argument("--faiss_use", type=int, default=1, help="get the embedding and save it the file.")
return parser
class CombineEntityEmbeddingLitModel(UseEntityEmbeddingLitModel):
def __init__(self, model, args, tokenizer, data_config={}):
super().__init__(model, args, tokenizer, data_config=data_config)
self.dis2logits = distance2logits_2
self.id2entity = {}
with open("./dataset/FB15k-237/entity2textlong.txt", 'r') as file:
cnt = 0
for line in file.readlines():
e, d = line.strip().split("\t")
self.id2entity[cnt] = e
cnt += 1
self.id2entity_t = {}
with open("./dataset/FB15k-237/entity2text.txt", 'r') as file:
for line in file.readlines():
e, d = line.strip().split("\t")
self.id2entity_t[e] = d
for k, v in self.id2entity.items():
self.id2entity[k] = self.id2entity_t[v]
def _eval(self, batch, batch_idx, ):
labels = batch.pop("labels")
input_ids = batch['input_ids']
# single label
label = batch.pop('label')
pos = batch.pop("pos")
result = self.model(**batch, return_dict=True, output_hidden_states=True)
hidden_states = result.hidden_states[-1]
_, mask_idx = (input_ids == self.tokenizer.mask_token_id).nonzero(as_tuple=True)
bsz = input_ids.shape[0]
mask_embedding = np.array(hidden_states[torch.arange(bsz), mask_idx].cpu(), dtype=np.float32)
# mask_embedding = np.array(hidden_states[torch.arange(bsz), mask_idx].cpu(), dtype=np.float32)
topk = self.args.knn_topk
D, I = self.index.search(mask_embedding, topk)
D = torch.from_numpy(D).to(self.device)
assert labels[0][label[0]], "correct ids must in filiter!"
labels[torch.arange(bsz), label] = 0
mask_logits = result.logits[:, :, self.entity_id_st:self.entity_id_ed]
mask_logits = mask_logits[torch.arange(bsz), mask_idx]
entity_logits = torch.full(labels.shape, 1000.).to(self.device)
# D = self.dis2logits(D)
for i in range(bsz):
for j in range(topk):
# filter entity in labels
if labels[i][self.faissid2entityid[I[i][j]]]: continue
if entity_logits[i][self.faissid2entityid[I[i][j]]] == 1000.:
entity_logits[i][self.faissid2entityid[I[i][j]]] = D[i][j]
# else:
# entity_logits[i][self.faissid2entityid[I[i][j]]] += D[i][j]
entity_logits = self.dis2logits(entity_logits)
# get the entity ranks
# filter the entity
assert entity_logits.shape == labels.shape
assert mask_logits.shape == labels.shape
# entity_logits = torch.softmax(entity_logits + labels * -100, dim=-1) # mask entityj
entity_logits = entity_logits + labels* -100.
mask_logits = torch.softmax(mask_logits + labels* -100, dim=-1)
# logits = mask_logits
logits = combine_knn_and_vocab_probs(entity_logits, mask_logits, self.args.knn_lambda)
# logits = entity_logits + mask_logits
knn_topk_logits, knn_topk_id = entity_logits.topk(20)
mask_topk_logits, mask_topk_id = mask_logits.topk(20)
union_topk = []
for i in range(bsz):
num_same = len(list(set(knn_topk_id[i].cpu().tolist()) & set(mask_topk_id[i].cpu().tolist())))
union_topk.append(num_same/ 20.)
knn_topk_id = knn_topk_id.to("cpu")
mask_topk_id = mask_topk_id.to("cpu")
mask_topk_logits = mask_topk_logits.to("cpu")
knn_topk_logits = knn_topk_logits.to("cpu")
label = label.to("cpu")
for t in range(bsz):
if knn_topk_id[t][0] == label[t] and knn_topk_logits[t][0] > mask_topk_logits[t][0] and mask_topk_logits[t][0] <= 0.4:
print(knn_topk_logits[t], knn_topk_id[t])
print(lmap(lambda x: self.id2entity[x.item()], knn_topk_id[t]))
print(mask_topk_logits[t], mask_topk_id[t])
print(lmap(lambda x: self.id2entity[x.item()], mask_topk_id[t]))
print(label[t])
print()
_, outputs = torch.sort(logits, dim=1, descending=True)
_, outputs = torch.sort(outputs, dim=1)
ranks = outputs[torch.arange(bsz), label].detach().cpu() + 1
return dict(ranks = np.array(ranks), knn_topk_id=knn_topk_id, knn_topk_logits=knn_topk_logits,
mask_topk_id=mask_topk_id, mask_topk_logits=mask_topk_logits, num_same = np.array(union_topk))
def test_epoch_end(self, outputs) -> None:
ranks = np.concatenate([_['ranks'] for _ in outputs])
num_same = np.concatenate([_['num_same'] for _ in outputs])
results_keys = list(outputs[0].keys())
results = {}
# for k in results_keys:
# results.
self.log("Test/num_same", num_same.mean())
hits20 = (ranks<=20).mean()
hits10 = (ranks<=10).mean()
hits3 = (ranks<=3).mean()
hits1 = (ranks<=1).mean()
self.log("Test/hits10", hits10)
self.log("Test/hits20", hits20)
self.log("Test/hits3", hits3)
self.log("Test/hits1", hits1)
self.log("Test/mean_rank", ranks.mean())
self.log("Test/mrr", (1. / ranks).mean())
def add_to_argparse(parser):
parser = TransformerLitModel.add_to_argparse(parser)
parser.add_argument("--knn_lambda", type=float, default=0.5, help="lambda * knn + (1-lambda) * mask logits , lambda of knn logits and mask logits.")
parser.add_argument("--knn_topk", type=int, default=100, help="")
return parser
def combine_knn_and_vocab_probs(knn_p, vocab_p, coeff=0.5):
combine_probs = torch.stack([vocab_p, knn_p], dim=0)
coeffs = torch.ones_like(combine_probs)
coeffs[0] = np.log(1 - coeff)
coeffs[1] = np.log(coeff)
curr_prob = torch.logsumexp(combine_probs + coeffs, dim=0)
return curr_prob
def distance2logits(D):
return torch.softmax( -1. * torch.tensor(D) / 30., dim=-1)
def distance2logits_2(D, n=10):
if not isinstance(D, torch.Tensor):
D = torch.tensor(D)
if torch.sum(D) != 0.0:
distances = torch.exp(-D/n) / torch.sum(torch.exp(-D/n), dim=-1, keepdim=True)
return distances

View File

@ -0,0 +1,66 @@
import json
import numpy as np
def rank_score(ranks):
# prepare the dataset
len_samples = len(ranks)
hits10 = [0] * len_samples
hits5 = [0] * len_samples
hits1 = [0] * len_samples
mrr = []
for idx, rank in enumerate(ranks):
if rank <= 10:
hits10[idx] = 1.
if rank <= 5:
hits5[idx] = 1.
if rank <= 1:
hits1[idx] = 1.
mrr.append(1./rank)
return np.mean(hits10), np.mean(hits5), np.mean(hits1), np.mean(mrr)
def acc(logits, labels):
preds = np.argmax(logits, axis=-1)
return (preds == labels).mean()
import torch.nn as nn
import torch
class LabelSmoothSoftmaxCEV1(nn.Module):
'''
This is the autograd version, you can also try the LabelSmoothSoftmaxCEV2 that uses derived gradients
'''
def __init__(self, lb_smooth=0.1, reduction='mean', ignore_index=-100):
super(LabelSmoothSoftmaxCEV1, self).__init__()
self.lb_smooth = lb_smooth
self.reduction = reduction
self.lb_ignore = ignore_index
self.log_softmax = nn.LogSoftmax(dim=1)
def forward(self, logits, label):
'''
args: logits: tensor of shape (N, C, H, W)
args: label: tensor of shape(N, H, W)
'''
# overcome ignored label
with torch.no_grad():
num_classes = logits.size(1)
label = label.clone().detach()
ignore = label == self.lb_ignore
n_valid = (ignore == 0).sum()
label[ignore] = 0
lb_pos, lb_neg = 1. - self.lb_smooth, self.lb_smooth / num_classes
label = torch.empty_like(logits).fill_(
lb_neg).scatter_(1, label.unsqueeze(1), lb_pos).detach()
logs = self.log_softmax(logits)
loss = -torch.sum(logs * label, dim=1)
loss[ignore] = 0
if self.reduction == 'mean':
loss = loss.sum() / n_valid
if self.reduction == 'sum':
loss = loss.sum()
return loss

141
pretrain/main.py Normal file
View File

@ -0,0 +1,141 @@
import argparse
import importlib
from logging import debug
import numpy as np
import torch
import pytorch_lightning as pl
import lit_models
import yaml
import time
from transformers import AutoConfig
import os
# from utils import get_decoder_input_ids
os.environ["TOKENIZERS_PARALLELISM"] = "false"
# In order to ensure reproducible experiments, we must set random seeds.
def _import_class(module_and_class_name: str) -> type:
"""Import class from a module, e.g. 'text_recognizer.models.MLP'"""
module_name, class_name = module_and_class_name.rsplit(".", 1)
module = importlib.import_module(module_name)
class_ = getattr(module, class_name)
return class_
def _setup_parser():
"""Set up Python's ArgumentParser with data, model, trainer, and other arguments."""
parser = argparse.ArgumentParser(add_help=False)
# Add Trainer specific arguments, such as --max_epochs, --gpus, --precision
trainer_parser = pl.Trainer.add_argparse_args(parser)
trainer_parser._action_groups[1].title = "Trainer Args" # pylint: disable=protected-access
parser = argparse.ArgumentParser(add_help=False, parents=[trainer_parser])
# Basic arguments
parser.add_argument("--wandb", action="store_true", default=False)
parser.add_argument("--litmodel_class", type=str, default="TransformerLitModel")
parser.add_argument("--seed", type=int, default=7)
parser.add_argument("--data_class", type=str, default="KGC")
parser.add_argument("--chunk", type=str, default="")
parser.add_argument("--model_class", type=str, default="RobertaUseLabelWord")
parser.add_argument("--checkpoint", type=str, default=None)
# Get the data and model classes, so that we can add their specific arguments
temp_args, _ = parser.parse_known_args()
data_class = _import_class(f"data.{temp_args.data_class}")
model_class = _import_class(f"models.{temp_args.model_class}")
lit_model_class = _import_class(f"lit_models.{temp_args.litmodel_class}")
# Get data, model, and LitModel specific arguments
data_group = parser.add_argument_group("Data Args")
data_class.add_to_argparse(data_group)
model_group = parser.add_argument_group("Model Args")
if hasattr(model_class, "add_to_argparse"):
model_class.add_to_argparse(model_group)
lit_model_group = parser.add_argument_group("LitModel Args")
lit_model_class.add_to_argparse(lit_model_group)
parser.add_argument("--help", "-h", action="help")
return parser
def _saved_pretrain(lit_model, tokenizer, path):
lit_model.model.save_pretrained(path)
tokenizer.save_pretrained(path)
def main():
parser = _setup_parser()
args = parser.parse_args()
np.random.seed(args.seed)
torch.manual_seed(args.seed)
pl.seed_everything(args.seed)
data_class = _import_class(f"data.{args.data_class}")
model_class = _import_class(f"models.{args.model_class}")
litmodel_class = _import_class(f"lit_models.{args.litmodel_class}")
config = AutoConfig.from_pretrained(args.model_name_or_path)
# update parameters
config.label_smoothing = args.label_smoothing
model = model_class.from_pretrained(args.model_name_or_path, config=config)
data = data_class(args, model)
tokenizer = data.tokenizer
lit_model = litmodel_class(args=args, model=model, tokenizer=tokenizer, data_config=data.get_config())
if args.checkpoint:
lit_model.load_state_dict(torch.load(args.checkpoint, map_location="cpu")["state_dict"])
logger = pl.loggers.TensorBoardLogger("training/logs")
if args.wandb:
logger = pl.loggers.WandbLogger(project="kgc_bert", name=args.data_dir.split("/")[-1])
logger.log_hyperparams(vars(args))
metric_name = "Eval/mrr" if not args.pretrain else "Eval/hits1"
early_callback = pl.callbacks.EarlyStopping(monitor="Eval/mrr", mode="max", patience=10)
model_checkpoint = pl.callbacks.ModelCheckpoint(monitor=metric_name, mode="max",
filename=args.data_dir.split("/")[-1] + '/{epoch}-{Eval/hits10:.2f}-{Eval/hits1:.2f}' if not args.pretrain else args.data_dir.split("/")[-1] + '/{epoch}-{step}-{Eval/hits10:.2f}',
dirpath="output",
save_weights_only=True, # to be modified
every_n_train_steps=100 if args.pretrain else None,
save_top_k=5 if args.pretrain else 1
)
callbacks = [early_callback, model_checkpoint]
# args.weights_summary = "full" # Print full summary of the model
trainer = pl.Trainer.from_argparse_args(args, callbacks=callbacks, logger=logger, default_root_dir="training/logs")
if "EntityEmbedding" not in lit_model.__class__.__name__:
trainer.fit(lit_model, datamodule=data)
path = model_checkpoint.best_model_path
lit_model.load_state_dict(torch.load(path)["state_dict"])
result = trainer.test(lit_model, datamodule=data)
print(result)
# _saved_pretrain(lit_model, tokenizer, path)
if "EntityEmbedding" not in lit_model.__class__.__name__:
print("*path"*30)
print(path)
if __name__ == "__main__":
main()

6
pretrain/models/__init__.py Executable file
View File

@ -0,0 +1,6 @@
from transformers import BartForConditionalGeneration, T5ForConditionalGeneration, GPT2LMHeadModel
from .model import *

10
pretrain/models/model.py Normal file
View File

@ -0,0 +1,10 @@
from transformers.models.bert.modeling_bert import BertForMaskedLM
class BertKGC(BertForMaskedLM):
@staticmethod
def add_to_argparse(parser):
parser.add_argument("--pretrain", type=int, default=0, help="")
return parser

1159
pretrain/models/utils.py Normal file

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,14 @@
nohup python -u main.py --gpus "1" --max_epochs=16 --num_workers=32 \
--model_name_or_path bert-base-uncased \
--accumulate_grad_batches 1 \
--model_class BertKGC \
--batch_size 128 \
--pretrain 1 \
--bce 0 \
--check_val_every_n_epoch 1 \
--overwrite_cache \
--data_dir /kg_374/Relphormer/dataset/FB15k-237 \
--eval_batch_size 256 \
--max_seq_length 64 \
--lr 1e-4 \
>logs/pretrain_fb15k-237.log 2>&1 &

View File

@ -0,0 +1,14 @@
nohup python -u main.py --gpus "0," --max_epochs=20 --num_workers=32 \
--model_name_or_path bert-base-uncased \
--accumulate_grad_batches 1 \
--model_class BertKGC \
--batch_size 128 \
--pretrain 1 \
--bce 0 \
--check_val_every_n_epoch 1 \
--overwrite_cache \
--data_dir xxx/Relphormer/dataset/umls \
--eval_batch_size 256 \
--max_seq_length 64 \
--lr 1e-4 \
>logs/pretrain_umls.log 2>&1 &

View File

@ -0,0 +1,16 @@
nohup python -u main.py --gpus "0," --max_epochs=15 --num_workers=32 \
--model_name_or_path bert-base-uncased \
--accumulate_grad_batches 1 \
--bce 0 \
--model_class BertKGC \
--batch_size 128 \
--pretrain 1 \
--check_val_every_n_epoch 1 \
--data_dir xxx/Relphormer/dataset/WN18RR \
--overwrite_cache \
--eval_batch_size 256 \
--precision 16 \
--max_seq_length 32 \
--lr 1e-4 \
>logs/pretrain_wn18rr.log 2>&1 &