Relphormer baseline
This commit is contained in:
2
pretrain/data/__init__.py
Normal file
2
pretrain/data/__init__.py
Normal file
@ -0,0 +1,2 @@
|
||||
from .data_module import KGC
|
||||
from .processor import convert_examples_to_features, KGProcessor
|
71
pretrain/data/base_data_module.py
Normal file
71
pretrain/data/base_data_module.py
Normal file
@ -0,0 +1,71 @@
|
||||
"""Base DataModule class."""
|
||||
from pathlib import Path
|
||||
from typing import Dict
|
||||
import argparse
|
||||
import os
|
||||
|
||||
import pytorch_lightning as pl
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
|
||||
class Config(dict):
|
||||
def __getattr__(self, name):
|
||||
return self.get(name)
|
||||
|
||||
def __setattr__(self, name, val):
|
||||
self[name] = val
|
||||
|
||||
|
||||
BATCH_SIZE = 8
|
||||
NUM_WORKERS = 8
|
||||
|
||||
|
||||
class BaseDataModule(pl.LightningDataModule):
|
||||
"""
|
||||
Base DataModule.
|
||||
Learn more at https://pytorch-lightning.readthedocs.io/en/stable/datamodules.html
|
||||
"""
|
||||
|
||||
def __init__(self, args: argparse.Namespace = None) -> None:
|
||||
super().__init__()
|
||||
self.args = Config(vars(args)) if args is not None else {}
|
||||
self.batch_size = self.args.get("batch_size", BATCH_SIZE)
|
||||
self.num_workers = self.args.get("num_workers", NUM_WORKERS)
|
||||
|
||||
|
||||
@staticmethod
|
||||
def add_to_argparse(parser):
|
||||
parser.add_argument(
|
||||
"--batch_size", type=int, default=BATCH_SIZE, help="Number of examples to operate on per forward step."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num_workers", type=int, default=0, help="Number of additional processes to load data."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataset", type=str, default="./dataset/NELL", help="Number of additional processes to load data."
|
||||
)
|
||||
return parser
|
||||
|
||||
def prepare_data(self):
|
||||
"""
|
||||
Use this method to do things that might write to disk or that need to be done only from a single GPU in distributed settings (so don't set state `self.x = y`).
|
||||
"""
|
||||
pass
|
||||
|
||||
def setup(self, stage=None):
|
||||
"""
|
||||
Split into train, val, test, and set dims.
|
||||
Should assign `torch Dataset` objects to self.data_train, self.data_val, and optionally self.data_test.
|
||||
"""
|
||||
self.data_train = None
|
||||
self.data_val = None
|
||||
self.data_test = None
|
||||
|
||||
def train_dataloader(self):
|
||||
return DataLoader(self.data_train, shuffle=True, batch_size=self.batch_size, num_workers=self.num_workers, pin_memory=True)
|
||||
|
||||
def val_dataloader(self):
|
||||
return DataLoader(self.data_val, shuffle=False, batch_size=self.batch_size, num_workers=self.num_workers, pin_memory=True)
|
||||
|
||||
def test_dataloader(self):
|
||||
return DataLoader(self.data_test, shuffle=False, batch_size=self.batch_size, num_workers=self.num_workers, pin_memory=True)
|
196
pretrain/data/data_module.py
Normal file
196
pretrain/data/data_module.py
Normal file
@ -0,0 +1,196 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable, Dict, List, NewType, Optional, Tuple, Union
|
||||
from enum import Enum
|
||||
import torch
|
||||
|
||||
from torch.utils.data import DataLoader
|
||||
from transformers import AutoTokenizer, BertTokenizer
|
||||
# from transformers.configuration_bert import BertTokenizer, BertTokenizerFast
|
||||
from transformers.tokenization_utils_base import (BatchEncoding,
|
||||
PreTrainedTokenizerBase)
|
||||
|
||||
from .base_data_module import BaseDataModule
|
||||
from .processor import KGProcessor, get_dataset
|
||||
import transformers
|
||||
transformers.logging.set_verbosity_error()
|
||||
|
||||
class ExplicitEnum(Enum):
|
||||
"""
|
||||
Enum with more explicit error message for missing values.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def _missing_(cls, value):
|
||||
raise ValueError(
|
||||
f"{value} is not a valid {cls.__name__}, please select one of {list(cls._value2member_map_.keys())}"
|
||||
)
|
||||
|
||||
class PaddingStrategy(ExplicitEnum):
|
||||
"""
|
||||
Possible values for the ``padding`` argument in :meth:`PreTrainedTokenizerBase.__call__`. Useful for tab-completion
|
||||
in an IDE.
|
||||
"""
|
||||
|
||||
LONGEST = "longest"
|
||||
MAX_LENGTH = "max_length"
|
||||
DO_NOT_PAD = "do_not_pad"
|
||||
|
||||
import numpy as np
|
||||
|
||||
@dataclass
|
||||
class DataCollatorForSeq2Seq:
|
||||
"""
|
||||
Data collator that will dynamically pad the inputs received, as well as the labels.
|
||||
|
||||
Args:
|
||||
tokenizer (:class:`~transformers.PreTrainedTokenizer` or :class:`~transformers.PreTrainedTokenizerFast`):
|
||||
The tokenizer used for encoding the data.
|
||||
model (:class:`~transformers.PreTrainedModel`):
|
||||
The model that is being trained. If set and has the `prepare_decoder_input_ids_from_labels`, use it to
|
||||
prepare the `decoder_input_ids`
|
||||
|
||||
This is useful when using `label_smoothing` to avoid calculating loss twice.
|
||||
padding (:obj:`bool`, :obj:`str` or :class:`~transformers.file_utils.PaddingStrategy`, `optional`, defaults to :obj:`True`):
|
||||
Select a strategy to pad the returned sequences (according to the model's padding side and padding index)
|
||||
among:
|
||||
|
||||
* :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
|
||||
sequence is provided).
|
||||
* :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
|
||||
maximum acceptable input length for the model if that argument is not provided.
|
||||
* :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
|
||||
different lengths).
|
||||
max_length (:obj:`int`, `optional`):
|
||||
Maximum length of the returned list and optionally padding length (see above).
|
||||
pad_to_multiple_of (:obj:`int`, `optional`):
|
||||
If set will pad the sequence to a multiple of the provided value.
|
||||
|
||||
This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >=
|
||||
7.5 (Volta).
|
||||
label_pad_token_id (:obj:`int`, `optional`, defaults to -100):
|
||||
The id to use when padding the labels (-100 will be automatically ignored by PyTorch loss functions).
|
||||
"""
|
||||
|
||||
tokenizer: PreTrainedTokenizerBase
|
||||
model: Optional[Any] = None
|
||||
padding: Union[bool, str, PaddingStrategy] = True
|
||||
max_length: Optional[int] = None
|
||||
pad_to_multiple_of: Optional[int] = None
|
||||
label_pad_token_id: int = -100
|
||||
return_tensors: str = "pt"
|
||||
num_labels: int = 0
|
||||
|
||||
def __call__(self, features, return_tensors=None):
|
||||
|
||||
if return_tensors is None:
|
||||
return_tensors = self.return_tensors
|
||||
labels = [feature.pop("labels") for feature in features] if "labels" in features[0].keys() else None
|
||||
label = [feature.pop("label") for feature in features]
|
||||
features_keys = {}
|
||||
name_keys = list(features[0].keys())
|
||||
for k in name_keys:
|
||||
# ignore the padding arguments
|
||||
if k in ["input_ids", "attention_mask", "token_type_ids"]: continue
|
||||
try:
|
||||
features_keys[k] = [feature.pop(k) for feature in features]
|
||||
except KeyError:
|
||||
continue
|
||||
# features_keys[k] = [feature.pop(k) for feature in features]
|
||||
|
||||
# We have to pad the labels before calling `tokenizer.pad` as this method won't pad them and needs them of the
|
||||
# same length to return tensors.
|
||||
bsz = len(labels)
|
||||
with torch.no_grad():
|
||||
new_labels = torch.zeros(bsz, self.num_labels)
|
||||
for i,l in enumerate(labels):
|
||||
if isinstance(l, int):
|
||||
new_labels[i][l] = 1
|
||||
else:
|
||||
for j in l:
|
||||
new_labels[i][j] = 1
|
||||
labels = new_labels
|
||||
|
||||
features = self.tokenizer.pad(
|
||||
features,
|
||||
padding=self.padding,
|
||||
max_length=self.max_length,
|
||||
pad_to_multiple_of=self.pad_to_multiple_of,
|
||||
return_tensors=return_tensors,
|
||||
)
|
||||
features['labels'] = labels
|
||||
features['label'] = torch.tensor(label)
|
||||
features.update(features_keys)
|
||||
|
||||
return features
|
||||
|
||||
|
||||
|
||||
class KGC(BaseDataModule):
|
||||
def __init__(self, args, model) -> None:
|
||||
super().__init__(args)
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(self.args.model_name_or_path, use_fast=False)
|
||||
self.processor = KGProcessor(self.tokenizer, args)
|
||||
self.label_list = self.processor.get_labels(args.data_dir)
|
||||
|
||||
entity_list = self.processor.get_entities(args.data_dir)
|
||||
|
||||
num_added_tokens = self.tokenizer.add_special_tokens({'additional_special_tokens': entity_list})
|
||||
self.sampler = DataCollatorForSeq2Seq(self.tokenizer,
|
||||
model=model,
|
||||
label_pad_token_id=self.tokenizer.pad_token_id,
|
||||
pad_to_multiple_of=8 if self.args.precision == 16 else None,
|
||||
padding="longest",
|
||||
max_length=self.args.max_seq_length,
|
||||
num_labels = len(entity_list),
|
||||
)
|
||||
relations_tokens = self.processor.get_relations(args.data_dir)
|
||||
self.num_relations = len(relations_tokens)
|
||||
num_added_tokens = self.tokenizer.add_special_tokens({'additional_special_tokens': relations_tokens})
|
||||
|
||||
vocab = self.tokenizer.get_added_vocab()
|
||||
self.relation_id_st = vocab[relations_tokens[0]]
|
||||
self.relation_id_ed = vocab[relations_tokens[-1]] + 1
|
||||
self.entity_id_st = vocab[entity_list[0]]
|
||||
self.entity_id_ed = vocab[entity_list[-1]] + 1
|
||||
|
||||
|
||||
def setup(self, stage=None):
|
||||
self.data_train = get_dataset(self.args, self.processor, self.label_list, self.tokenizer, "train")
|
||||
self.data_val = get_dataset(self.args, self.processor, self.label_list, self.tokenizer, "dev")
|
||||
self.data_test = get_dataset(self.args, self.processor, self.label_list, self.tokenizer, "test")
|
||||
|
||||
def prepare_data(self):
|
||||
pass
|
||||
|
||||
def get_config(self):
|
||||
d = {}
|
||||
for k, v in self.__dict__.items():
|
||||
if "st" in k or "ed" in k:
|
||||
d.update({k:v})
|
||||
|
||||
return d
|
||||
|
||||
|
||||
@staticmethod
|
||||
def add_to_argparse(parser):
|
||||
BaseDataModule.add_to_argparse(parser)
|
||||
parser.add_argument("--model_name_or_path", type=str, default="roberta-base", help="the name or the path to the pretrained model")
|
||||
parser.add_argument("--data_dir", type=str, default="roberta-base", help="the name or the path to the pretrained model")
|
||||
parser.add_argument("--max_seq_length", type=int, default=256, help="Number of examples to operate on per forward step.")
|
||||
parser.add_argument("--warm_up_radio", type=float, default=0.1, help="Number of examples to operate on per forward step.")
|
||||
parser.add_argument("--eval_batch_size", type=int, default=8)
|
||||
parser.add_argument("--overwrite_cache", action="store_true", default=False)
|
||||
return parser
|
||||
|
||||
def get_tokenizer(self):
|
||||
return self.tokenizer
|
||||
|
||||
def train_dataloader(self):
|
||||
return DataLoader(self.data_train, num_workers=self.num_workers, pin_memory=True, collate_fn=self.sampler, batch_size=self.args.batch_size, shuffle=not self.args.faiss_init)
|
||||
|
||||
def val_dataloader(self):
|
||||
return DataLoader(self.data_val, num_workers=self.num_workers, pin_memory=True, collate_fn=self.sampler, batch_size=self.args.eval_batch_size)
|
||||
|
||||
def test_dataloader(self):
|
||||
return DataLoader(self.data_test, num_workers=self.num_workers, pin_memory=True, collate_fn=self.sampler, batch_size=self.args.eval_batch_size)
|
||||
|
936
pretrain/data/processor.py
Normal file
936
pretrain/data/processor.py
Normal file
@ -0,0 +1,936 @@
|
||||
from re import DEBUG
|
||||
|
||||
import contextlib
|
||||
import sys
|
||||
|
||||
from collections import Counter
|
||||
from multiprocessing import Pool
|
||||
from torch._C import HOIST_CONV_PACKED_PARAMS
|
||||
from torch.utils.data import Dataset, Sampler, IterableDataset
|
||||
from collections import defaultdict
|
||||
from functools import partial
|
||||
from multiprocessing import Pool
|
||||
import os
|
||||
import random
|
||||
import json
|
||||
import torch
|
||||
import copy
|
||||
import numpy as np
|
||||
import pickle
|
||||
from tqdm import tqdm
|
||||
from dataclasses import dataclass, asdict, replace
|
||||
import inspect
|
||||
|
||||
from transformers.models.auto.tokenization_auto import AutoTokenizer
|
||||
|
||||
from models.utils import get_entity_spans_pre_processing
|
||||
|
||||
def lmap(a, b):
|
||||
return list(map(a,b))
|
||||
|
||||
def cache_results(_cache_fp, _refresh=False, _verbose=1):
|
||||
r"""
|
||||
cache_results是fastNLP中用于cache数据的装饰器。通过下面的例子看一下如何使用::
|
||||
|
||||
import time
|
||||
import numpy as np
|
||||
from fastNLP import cache_results
|
||||
|
||||
@cache_results('cache.pkl')
|
||||
def process_data():
|
||||
# 一些比较耗时的工作,比如读取数据,预处理数据等,这里用time.sleep()代替耗时
|
||||
time.sleep(1)
|
||||
return np.random.randint(10, size=(5,))
|
||||
|
||||
start_time = time.time()
|
||||
print("res =",process_data())
|
||||
print(time.time() - start_time)
|
||||
|
||||
start_time = time.time()
|
||||
print("res =",process_data())
|
||||
print(time.time() - start_time)
|
||||
|
||||
# 输出内容如下,可以看到两次结果相同,且第二次几乎没有花费时间
|
||||
# Save cache to cache.pkl.
|
||||
# res = [5 4 9 1 8]
|
||||
# 1.0042750835418701
|
||||
# Read cache from cache.pkl.
|
||||
# res = [5 4 9 1 8]
|
||||
# 0.0040721893310546875
|
||||
|
||||
可以看到第二次运行的时候,只用了0.0001s左右,是由于第二次运行将直接从cache.pkl这个文件读取数据,而不会经过再次预处理::
|
||||
|
||||
# 还是以上面的例子为例,如果需要重新生成另一个cache,比如另一个数据集的内容,通过如下的方式调用即可
|
||||
process_data(_cache_fp='cache2.pkl') # 完全不影响之前的‘cache.pkl'
|
||||
|
||||
上面的_cache_fp是cache_results会识别的参数,它将从'cache2.pkl'这里缓存/读取数据,即这里的'cache2.pkl'覆盖默认的
|
||||
'cache.pkl'。如果在你的函数前面加上了@cache_results()则你的函数会增加三个参数[_cache_fp, _refresh, _verbose]。
|
||||
上面的例子即为使用_cache_fp的情况,这三个参数不会传入到你的函数中,当然你写的函数参数名也不可能包含这三个名称::
|
||||
|
||||
process_data(_cache_fp='cache2.pkl', _refresh=True) # 这里强制重新生成一份对预处理的cache。
|
||||
# _verbose是用于控制输出信息的,如果为0,则不输出任何内容;如果为1,则会提醒当前步骤是读取的cache还是生成了新的cache
|
||||
|
||||
:param str _cache_fp: 将返回结果缓存到什么位置;或从什么位置读取缓存。如果为None,cache_results没有任何效用,除非在
|
||||
函数调用的时候传入_cache_fp这个参数。
|
||||
:param bool _refresh: 是否重新生成cache。
|
||||
:param int _verbose: 是否打印cache的信息。
|
||||
:return:
|
||||
"""
|
||||
|
||||
def wrapper_(func):
|
||||
signature = inspect.signature(func)
|
||||
for key, _ in signature.parameters.items():
|
||||
if key in ('_cache_fp', '_refresh', '_verbose'):
|
||||
raise RuntimeError("The function decorated by cache_results cannot have keyword `{}`.".format(key))
|
||||
|
||||
def wrapper(*args, **kwargs):
|
||||
my_args = args[0]
|
||||
mode = args[-1]
|
||||
if '_cache_fp' in kwargs:
|
||||
cache_filepath = kwargs.pop('_cache_fp')
|
||||
assert isinstance(cache_filepath, str), "_cache_fp can only be str."
|
||||
else:
|
||||
cache_filepath = _cache_fp
|
||||
if '_refresh' in kwargs:
|
||||
refresh = kwargs.pop('_refresh')
|
||||
assert isinstance(refresh, bool), "_refresh can only be bool."
|
||||
else:
|
||||
refresh = _refresh
|
||||
if '_verbose' in kwargs:
|
||||
verbose = kwargs.pop('_verbose')
|
||||
assert isinstance(verbose, int), "_verbose can only be integer."
|
||||
else:
|
||||
verbose = _verbose
|
||||
refresh_flag = True
|
||||
|
||||
model_name = my_args.model_name_or_path.split("/")[-1]
|
||||
is_pretrain = my_args.pretrain
|
||||
cache_filepath = os.path.join(my_args.data_dir, f"cached_{mode}_features{model_name}_pretrain{is_pretrain}_faiss{my_args.faiss_init}_seqlength{my_args.max_seq_length}_{my_args.litmodel_class}.pkl")
|
||||
refresh = my_args.overwrite_cache
|
||||
|
||||
if cache_filepath is not None and refresh is False:
|
||||
# load data
|
||||
if os.path.exists(cache_filepath):
|
||||
with open(cache_filepath, 'rb') as f:
|
||||
results = pickle.load(f)
|
||||
if verbose == 1:
|
||||
logger.info("Read cache from {}.".format(cache_filepath))
|
||||
refresh_flag = False
|
||||
|
||||
if refresh_flag:
|
||||
results = func(*args, **kwargs)
|
||||
if cache_filepath is not None:
|
||||
if results is None:
|
||||
raise RuntimeError("The return value is None. Delete the decorator.")
|
||||
with open(cache_filepath, 'wb') as f:
|
||||
pickle.dump(results, f)
|
||||
logger.info("Save cache to {}.".format(cache_filepath))
|
||||
|
||||
return results
|
||||
|
||||
return wrapper
|
||||
|
||||
return wrapper_
|
||||
|
||||
|
||||
import argparse
|
||||
import csv
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import sys
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler,
|
||||
TensorDataset)
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
from tqdm import tqdm, trange
|
||||
|
||||
# from torch.nn import CrossEntropyLoss, MSELoss
|
||||
# from scipy.stats import pearsonr, spearmanr
|
||||
# from sklearn.metrics import matthews_corrcoef, f1_scoreclass
|
||||
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class InputExample(object):
|
||||
"""A single training/test example for simple sequence classification."""
|
||||
|
||||
def __init__(self, guid, text_a, text_b=None, text_c=None, label=None, real_label=None, en=None, rel=None):
|
||||
"""Constructs a InputExample.
|
||||
|
||||
Args:
|
||||
guid: Unique id for the example.
|
||||
text_a: string. The untokenized text of the first sequence. For single
|
||||
sequence tasks, only this sequence must be specified.
|
||||
text_b: (Optional) string. The untokenized text of the second sequence.
|
||||
Only must be specified for sequence pair tasks.
|
||||
text_c: (Optional) string. The untokenized text of the third sequence.
|
||||
Only must be specified for sequence triple tasks.
|
||||
label: (Optional) string. list of entities
|
||||
"""
|
||||
self.guid = guid
|
||||
self.text_a = text_a
|
||||
self.text_b = text_b
|
||||
self.text_c = text_c
|
||||
self.label = label
|
||||
self.real_label = real_label
|
||||
self.en = en
|
||||
self.rel = rel # rel id
|
||||
|
||||
|
||||
@dataclass
|
||||
class InputFeatures:
|
||||
"""A single set of features of data."""
|
||||
|
||||
input_ids: torch.Tensor
|
||||
attention_mask: torch.Tensor
|
||||
labels: torch.Tensor = None
|
||||
label: torch.Tensor = None
|
||||
en: torch.Tensor = 0
|
||||
rel: torch.Tensor = 0
|
||||
pos: torch.Tensor = 0
|
||||
|
||||
|
||||
class DataProcessor(object):
|
||||
"""Base class for data converters for sequence classification data sets."""
|
||||
|
||||
def get_train_examples(self, data_dir):
|
||||
"""Gets a collection of `InputExample`s for the train set."""
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_dev_examples(self, data_dir):
|
||||
"""Gets a collection of `InputExample`s for the dev set."""
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_labels(self, data_dir):
|
||||
"""Gets the list of labels for this data set."""
|
||||
raise NotImplementedError()
|
||||
|
||||
@classmethod
|
||||
def _read_tsv(cls, input_file, quotechar=None):
|
||||
"""Reads a tab separated value file."""
|
||||
with open(input_file, "r", encoding="utf-8") as f:
|
||||
reader = csv.reader(f, delimiter="\t", quotechar=quotechar)
|
||||
lines = []
|
||||
for line in reader:
|
||||
if sys.version_info[0] == 2:
|
||||
line = list(unicode(cell, 'utf-8') for cell in line)
|
||||
lines.append(line)
|
||||
return lines
|
||||
|
||||
import copy
|
||||
|
||||
|
||||
def solve_get_knowledge_store(line, set_type="train", pretrain=1):
|
||||
"""
|
||||
use the LM to get the entity embedding.
|
||||
Transductive: triples + text description
|
||||
Inductive: text description
|
||||
|
||||
"""
|
||||
examples = []
|
||||
|
||||
head_ent_text = ent2text[line[0]]
|
||||
tail_ent_text = ent2text[line[2]]
|
||||
relation_text = rel2text[line[1]]
|
||||
|
||||
i=0
|
||||
|
||||
a = tail_filter_entities["\t".join([line[0],line[1]])]
|
||||
b = head_filter_entities["\t".join([line[2],line[1]])]
|
||||
|
||||
guid = "%s-%s" % (set_type, i)
|
||||
text_a = head_ent_text
|
||||
text_b = relation_text
|
||||
text_c = tail_ent_text
|
||||
|
||||
# use the description of c to predict A
|
||||
examples.append(
|
||||
InputExample(guid=guid, text_a="[PAD]", text_b=text_b + "[PAD]", text_c = "[PAD]" + " " + text_c, label=lmap(lambda x: ent2id[x], b), real_label=ent2id[line[0]], en=[ent2id[line[0]], rel2id[line[1]], ent2id[line[2]]], rel=0)
|
||||
)
|
||||
examples.append(
|
||||
InputExample(guid=guid, text_a="[PAD]", text_b=text_b + "[PAD]", text_c = "[PAD]" + " " + text_a, label=lmap(lambda x: ent2id[x], b), real_label=ent2id[line[2]], en=[ent2id[line[0]], rel2id[line[1]], ent2id[line[2]]], rel=0)
|
||||
)
|
||||
return examples
|
||||
|
||||
|
||||
def solve(line, set_type="train", pretrain=1):
|
||||
examples = []
|
||||
|
||||
head_ent_text = ent2text[line[0]]
|
||||
tail_ent_text = ent2text[line[2]]
|
||||
relation_text = rel2text[line[1]]
|
||||
|
||||
i=0
|
||||
|
||||
a = tail_filter_entities["\t".join([line[0],line[1]])]
|
||||
b = head_filter_entities["\t".join([line[2],line[1]])]
|
||||
|
||||
guid = "%s-%s" % (set_type, i)
|
||||
text_a = head_ent_text
|
||||
text_b = relation_text
|
||||
text_c = tail_ent_text
|
||||
|
||||
|
||||
if pretrain:
|
||||
text_a_tokens = text_a.split()
|
||||
for i in range(10):
|
||||
st = random.randint(0, len(text_a_tokens))
|
||||
examples.append(
|
||||
InputExample(guid=guid, text_a="[MASK]", text_b=" ".join(text_a_tokens[st:min(st+64, len(text_a_tokens))]), text_c = "", label=ent2id[line[0]], real_label=ent2id[line[0]], en=0, rel=0)
|
||||
)
|
||||
examples.append(
|
||||
InputExample(guid=guid, text_a="[MASK]", text_b=text_a, text_c = "", label=ent2id[line[0]], real_label=ent2id[line[0]], en=0, rel=0)
|
||||
)
|
||||
# examples.append(
|
||||
# InputExample(guid=guid, text_a="[MASK]", text_b=text_c, text_c = "", label=ent2id[line[2]], real_label=ent2id[line[2]], en=0, rel=0)
|
||||
# )
|
||||
else:
|
||||
|
||||
# examples.append(
|
||||
# InputExample(guid=guid, text_a="[MASK]", text_b=text_b + "[PAD]", text_c = "[UNK]" , label=lmap(lambda x: ent2id[x], b), real_label=ent2id[line[0]], en=ent2id[line[2]], rel=rel2id[line[1]]))
|
||||
# examples.append(
|
||||
# InputExample(guid=guid, text_a="[UNK] ", text_b=text_b + "[PAD]", text_c = "[MASK]", label=lmap(lambda x: ent2id[x], a), real_label=ent2id[line[2]], en=ent2id[line[0]], rel=rel2id[line[1]]))
|
||||
|
||||
# examples.append(
|
||||
# InputExample(guid=guid, text_a="[UNK]" + " " + text_c, text_b=text_b + "[PAD]", text_c = "[MASK]", label=lmap(lambda x: ent2id[x], b), real_label=ent2id[line[0]], en=ent2id[line[2]], rel=rel2id[line[1]]))
|
||||
# examples.append(
|
||||
# InputExample(guid=guid, text_a="[MASK]", text_b=text_b + "[PAD]", text_c = "[UNK]" + text_a, label=lmap(lambda x: ent2id[x], a), real_label=ent2id[line[2]], en=ent2id[line[0]], rel=rel2id[line[1]]))
|
||||
|
||||
examples.append(
|
||||
InputExample(guid=guid, text_a="[MASK]", text_b=text_b + "[PAD]", text_c = "[PAD]" + " " + text_c, label=lmap(lambda x: ent2id[x], b), real_label=ent2id[line[0]], en=[rel2id[line[1]], ent2id[line[2]]], rel=rel2id[line[1]]))
|
||||
examples.append(
|
||||
InputExample(guid=guid, text_a="[PAD] ", text_b=text_b + "[PAD]", text_c = "[MASK]" +" " + text_a, label=lmap(lambda x: ent2id[x], a), real_label=ent2id[line[2]], en=[ent2id[line[0]], rel2id[line[1]]], rel=rel2id[line[1]]))
|
||||
return examples
|
||||
|
||||
def filter_init(head, tail, t1,t2, ent2id_, ent2token_, rel2id_):
|
||||
global head_filter_entities
|
||||
global tail_filter_entities
|
||||
global ent2text
|
||||
global rel2text
|
||||
global ent2id
|
||||
global ent2token
|
||||
global rel2id
|
||||
|
||||
head_filter_entities = head
|
||||
tail_filter_entities = tail
|
||||
ent2text =t1
|
||||
rel2text =t2
|
||||
ent2id = ent2id_
|
||||
ent2token = ent2token_
|
||||
rel2id = rel2id_
|
||||
|
||||
def delete_init(ent2text_):
|
||||
global ent2text
|
||||
ent2text = ent2text_
|
||||
|
||||
|
||||
class KGProcessor(DataProcessor):
|
||||
"""Processor for knowledge graph data set."""
|
||||
def __init__(self, tokenizer, args):
|
||||
self.labels = set()
|
||||
self.tokenizer = tokenizer
|
||||
self.args = args
|
||||
self.entity_path = os.path.join(args.data_dir, "entity2textlong.txt") if os.path.exists(os.path.join(args.data_dir, 'entity2textlong.txt')) \
|
||||
else os.path.join(args.data_dir, "entity2text.txt")
|
||||
|
||||
def get_train_examples(self, data_dir):
|
||||
"""See base class."""
|
||||
return self._create_examples(
|
||||
self._read_tsv(os.path.join(data_dir, "train.tsv")), "train", data_dir, self.args)
|
||||
|
||||
def get_dev_examples(self, data_dir):
|
||||
"""See base class."""
|
||||
return self._create_examples(
|
||||
self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev", data_dir, self.args)
|
||||
|
||||
def get_test_examples(self, data_dir, chunk=""):
|
||||
"""See base class."""
|
||||
return self._create_examples(
|
||||
self._read_tsv(os.path.join(data_dir, f"test{chunk}.tsv")), "test", data_dir, self.args)
|
||||
|
||||
def get_relations(self, data_dir):
|
||||
"""Gets all labels (relations) in the knowledge graph."""
|
||||
# return list(self.labels)
|
||||
with open(os.path.join(data_dir, "relations.txt"), 'r') as f:
|
||||
lines = f.readlines()
|
||||
relations = []
|
||||
for line in lines:
|
||||
relations.append(line.strip().split('\t')[0])
|
||||
rel2token = {ent : f"[RELATION_{i}]" for i, ent in enumerate(relations)}
|
||||
return list(rel2token.values())
|
||||
|
||||
def get_labels(self, data_dir):
|
||||
"""Gets all labels (0, 1) for triples in the knowledge graph."""
|
||||
relation = []
|
||||
with open(os.path.join(data_dir, "relation2text.txt"), 'r') as f:
|
||||
lines = f.readlines()
|
||||
entities = []
|
||||
for line in lines:
|
||||
relation.append(line.strip().split("\t")[-1])
|
||||
return relation
|
||||
|
||||
def get_entities(self, data_dir):
|
||||
"""Gets all entities in the knowledge graph."""
|
||||
with open(self.entity_path, 'r') as f:
|
||||
lines = f.readlines()
|
||||
entities = []
|
||||
for line in lines:
|
||||
entities.append(line.strip().split("\t")[0])
|
||||
|
||||
ent2token = {ent : f"[ENTITY_{i}]" for i, ent in enumerate(entities)}
|
||||
return list(ent2token.values())
|
||||
|
||||
def get_train_triples(self, data_dir):
|
||||
"""Gets training triples."""
|
||||
return self._read_tsv(os.path.join(data_dir, "train.tsv"))
|
||||
|
||||
def get_dev_triples(self, data_dir):
|
||||
"""Gets validation triples."""
|
||||
return self._read_tsv(os.path.join(data_dir, "dev.tsv"))
|
||||
|
||||
def get_test_triples(self, data_dir, chunk=""):
|
||||
"""Gets test triples."""
|
||||
return self._read_tsv(os.path.join(data_dir, f"test{chunk}.tsv"))
|
||||
|
||||
def _create_examples(self, lines, set_type, data_dir, args):
|
||||
"""Creates examples for the training and dev sets."""
|
||||
# entity to text
|
||||
ent2text = {}
|
||||
ent2text_with_type = {}
|
||||
with open(self.entity_path, 'r') as f:
|
||||
ent_lines = f.readlines()
|
||||
for line in ent_lines:
|
||||
temp = line.strip().split('\t')
|
||||
try:
|
||||
end = temp[1]#.find(',')
|
||||
if "wiki" in data_dir:
|
||||
assert "Q" in temp[0]
|
||||
ent2text[temp[0]] = temp[1].replace("\\n", " ").replace("\\", "") #[:end]
|
||||
except IndexError:
|
||||
# continue
|
||||
end = " "#.find(',')
|
||||
if "wiki" in data_dir:
|
||||
assert "Q" in temp[0]
|
||||
ent2text[temp[0]] = end #[:end]
|
||||
|
||||
entities = list(ent2text.keys())
|
||||
ent2token = {ent : f"[ENTITY_{i}]" for i, ent in enumerate(entities)}
|
||||
ent2id = {ent : i for i, ent in enumerate(entities)}
|
||||
|
||||
rel2text = {}
|
||||
with open(os.path.join(data_dir, "relation2text.txt"), 'r') as f:
|
||||
rel_lines = f.readlines()
|
||||
for line in rel_lines:
|
||||
temp = line.strip().split('\t')
|
||||
rel2text[temp[0]] = temp[1]
|
||||
relation_names = {}
|
||||
with open(os.path.join(data_dir, "relations.txt"), "r") as file:
|
||||
for line in file.readlines():
|
||||
t = line.strip()
|
||||
relation_names[t] = rel2text[t]
|
||||
|
||||
tmp_lines = []
|
||||
not_in_text = 0
|
||||
for line in tqdm(lines, desc="delete entities without text name."):
|
||||
if (line[0] not in ent2text) or (line[2] not in ent2text) or (line[1] not in rel2text):
|
||||
not_in_text += 1
|
||||
continue
|
||||
tmp_lines.append(line)
|
||||
lines = tmp_lines
|
||||
print(f"total entity not in text : {not_in_text} ")
|
||||
|
||||
# rel id -> relation token id
|
||||
num_entities = len(self.get_entities(args.data_dir))
|
||||
rel2id = {w:i+num_entities for i,w in enumerate(relation_names.keys())}
|
||||
|
||||
# add reverse relation
|
||||
# tmp_rel2id = {}
|
||||
# num_relations = len(rel2id)
|
||||
# cnt = 0
|
||||
# for k, v in rel2id.items():
|
||||
# tmp_rel2id[k + " (reverse)"] = num_relations + cnt
|
||||
# cnt += 1
|
||||
# rel2id.update(tmp_rel2id)
|
||||
|
||||
examples = []
|
||||
# head filter head entity
|
||||
head_filter_entities = defaultdict(list)
|
||||
tail_filter_entities = defaultdict(list)
|
||||
|
||||
dataset_list = ["train.tsv", "dev.tsv", "test.tsv"]
|
||||
# in training, only use the train triples
|
||||
if set_type == "train" and not args.pretrain: dataset_list = dataset_list[0:1]
|
||||
for m in dataset_list:
|
||||
with open(os.path.join(data_dir, m), 'r') as file:
|
||||
train_lines = file.readlines()
|
||||
for idx in range(len(train_lines)):
|
||||
train_lines[idx] = train_lines[idx].strip().split("\t")
|
||||
|
||||
for line in train_lines:
|
||||
tail_filter_entities["\t".join([line[0], line[1]])].append(line[2])
|
||||
head_filter_entities["\t".join([line[2], line[1]])].append(line[0])
|
||||
|
||||
|
||||
|
||||
max_head_entities = max(len(_) for _ in head_filter_entities.values())
|
||||
max_tail_entities = max(len(_) for _ in tail_filter_entities.values())
|
||||
|
||||
|
||||
# use bce loss, ignore the mlm
|
||||
if set_type == "train" and args.bce:
|
||||
lines = []
|
||||
for k, v in tail_filter_entities.items():
|
||||
h, r = k.split('\t')
|
||||
t = v[0]
|
||||
lines.append([h, r, t])
|
||||
for k, v in head_filter_entities.items():
|
||||
t, r = k.split('\t')
|
||||
h = v[0]
|
||||
lines.append([h, r, t])
|
||||
|
||||
|
||||
# for training , select each entity as for get mask embedding.
|
||||
if args.pretrain:
|
||||
rel = list(rel2text.keys())[0]
|
||||
lines = []
|
||||
for k in ent2text.keys():
|
||||
lines.append([k, rel, k])
|
||||
|
||||
print(f"max number of filter entities : {max_head_entities} {max_tail_entities}")
|
||||
|
||||
from os import cpu_count
|
||||
threads = min(1, cpu_count())
|
||||
filter_init(head_filter_entities, tail_filter_entities,ent2text, rel2text, ent2id, ent2token, rel2id
|
||||
)
|
||||
|
||||
if hasattr(args, "faiss_init") and args.faiss_init:
|
||||
annotate_ = partial(
|
||||
solve_get_knowledge_store,
|
||||
pretrain=self.args.pretrain
|
||||
)
|
||||
else:
|
||||
annotate_ = partial(
|
||||
solve,
|
||||
pretrain=self.args.pretrain
|
||||
)
|
||||
examples = list(
|
||||
tqdm(
|
||||
map(annotate_, lines),
|
||||
total=len(lines),
|
||||
desc="convert text to examples"
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
# with Pool(threads, initializer=filter_init, initargs=(head_filter_entities, tail_filter_entities,ent2text, rel2text,
|
||||
# ent2text_with_type, rel2id,)) as pool:
|
||||
# annotate_ = partial(
|
||||
# solve,
|
||||
# )
|
||||
# examples = list(
|
||||
# tqdm(
|
||||
# map(annotate_, lines, chunksize= 128),
|
||||
# total=len(lines),
|
||||
# desc="convert text to examples"
|
||||
# )
|
||||
# )
|
||||
tmp_examples = []
|
||||
for e in examples:
|
||||
for ee in e:
|
||||
tmp_examples.append(ee)
|
||||
examples = tmp_examples
|
||||
# delete vars
|
||||
del head_filter_entities, tail_filter_entities, ent2text, rel2text, ent2id, ent2token, rel2id
|
||||
return examples
|
||||
|
||||
class Verbalizer(object):
|
||||
def __init__(self, args):
|
||||
if "WN18RR" in args.data_dir:
|
||||
self.mode = "WN18RR"
|
||||
elif "FB15k" in args.data_dir:
|
||||
self.mode = "FB15k"
|
||||
elif "umls" in args.data_dir:
|
||||
self.mode = "umls"
|
||||
elif "codexs" in args.data_dir:
|
||||
self.mode = "codexs"
|
||||
elif "codexl" in args.data_dir:
|
||||
self.mode = "codexl"
|
||||
elif "FB13" in args.data_dir:
|
||||
self.mode = "FB13"
|
||||
elif "WN11" in args.data_dir:
|
||||
self.mode = "WN11"
|
||||
|
||||
|
||||
def _convert(self, head, relation, tail):
|
||||
if self.mode == "umls":
|
||||
return f"The {relation} {head} is "
|
||||
|
||||
return f"{head} {relation}"
|
||||
|
||||
|
||||
class KGCDataset(Dataset):
|
||||
def __init__(self, features):
|
||||
self.features = features
|
||||
|
||||
def __getitem__(self, index):
|
||||
return self.features[index]
|
||||
|
||||
def __len__(self):
|
||||
return len(self.features)
|
||||
|
||||
def convert_examples_to_features_init(tokenizer_for_convert):
|
||||
global tokenizer
|
||||
tokenizer = tokenizer_for_convert
|
||||
|
||||
def convert_examples_to_features(example, max_seq_length, mode, pretrain=1):
|
||||
"""Loads a data file into a list of `InputBatch`s."""
|
||||
# tokens_a = tokenizer.tokenize(example.text_a)
|
||||
# tokens_b = tokenizer.tokenize(example.text_b)
|
||||
# tokens_c = tokenizer.tokenize(example.text_c)
|
||||
|
||||
# _truncate_seq_triple(tokens_a, tokens_b, tokens_c, max_length= max_seq_length)
|
||||
text_a = " ".join(example.text_a.split()[:128])
|
||||
text_b = " ".join(example.text_b.split()[:128])
|
||||
text_c = " ".join(example.text_c.split()[:128])
|
||||
|
||||
if pretrain:
|
||||
input_text_a = text_a
|
||||
input_text_b = text_b
|
||||
else:
|
||||
input_text_a = tokenizer.sep_token.join([text_a, text_b])
|
||||
input_text_b = text_c
|
||||
|
||||
|
||||
inputs = tokenizer(
|
||||
input_text_a,
|
||||
input_text_b,
|
||||
truncation="longest_first",
|
||||
max_length=max_seq_length,
|
||||
padding="longest",
|
||||
add_special_tokens=True,
|
||||
)
|
||||
# assert tokenizer.mask_token_id in inputs.input_ids, "mask token must in input"
|
||||
|
||||
features = asdict(InputFeatures(input_ids=inputs["input_ids"],
|
||||
attention_mask=inputs['attention_mask'],
|
||||
labels=torch.tensor(example.label),
|
||||
label=torch.tensor(example.real_label)
|
||||
)
|
||||
)
|
||||
return features
|
||||
|
||||
|
||||
def _truncate_seq_pair(tokens_a, tokens_b, max_length):
|
||||
"""Truncates a sequence pair in place to the maximum length."""
|
||||
|
||||
# This is a simple heuristic which will always truncate the longer sequence
|
||||
# one token at a time. This makes more sense than truncating an equal percent
|
||||
# of tokens from each, since if one sequence is very short then each token
|
||||
# that's truncated likely contains more information than a longer sequence.
|
||||
while True:
|
||||
total_length = len(tokens_a) + len(tokens_b)
|
||||
if total_length <= max_length:
|
||||
break
|
||||
if len(tokens_a) > len(tokens_b):
|
||||
tokens_a.pop()
|
||||
else:
|
||||
tokens_b.pop()
|
||||
|
||||
def _truncate_seq_triple(tokens_a, tokens_b, tokens_c, max_length):
|
||||
"""Truncates a sequence triple in place to the maximum length."""
|
||||
|
||||
# This is a simple heuristic which will always truncate the longer sequence
|
||||
# one token at a time. This makes more sense than truncating an equal percent
|
||||
# of tokens from each, since if one sequence is very short then each token
|
||||
# that's truncated likely contains more information than a longer sequence.
|
||||
while True:
|
||||
total_length = len(tokens_a) + len(tokens_b) + len(tokens_c)
|
||||
if total_length <= max_length:
|
||||
break
|
||||
if len(tokens_a) > len(tokens_b) and len(tokens_a) > len(tokens_c):
|
||||
tokens_a.pop()
|
||||
elif len(tokens_b) > len(tokens_a) and len(tokens_b) > len(tokens_c):
|
||||
tokens_b.pop()
|
||||
elif len(tokens_c) > len(tokens_a) and len(tokens_c) > len(tokens_b):
|
||||
tokens_c.pop()
|
||||
else:
|
||||
tokens_c.pop()
|
||||
|
||||
|
||||
@cache_results(_cache_fp="./dataset")
|
||||
def get_dataset(args, processor, label_list, tokenizer, mode):
|
||||
|
||||
assert mode in ["train", "dev", "test"], "mode must be in train dev test!"
|
||||
|
||||
# use training data to construct the entity embedding
|
||||
combine_train_and_test = False
|
||||
if args.faiss_init and mode == "test" and not args.pretrain:
|
||||
mode = "train"
|
||||
if "ind" in args.data_dir: combine_train_and_test = True
|
||||
else:
|
||||
pass
|
||||
|
||||
if mode == "train":
|
||||
train_examples = processor.get_train_examples(args.data_dir)
|
||||
elif mode == "dev":
|
||||
train_examples = processor.get_dev_examples(args.data_dir)
|
||||
else:
|
||||
train_examples = processor.get_test_examples(args.data_dir)
|
||||
|
||||
if combine_train_and_test:
|
||||
logger.info("use all the dataset for getting the entity mask embedding in pretraining pretraining")
|
||||
logger.info("use all the dataset for getting the entity mask embedding in pretraining pretraining")
|
||||
train_examples = processor.get_test_examples(args.data_dir) + processor.get_train_examples(args.data_dir) + processor.get_dev_examples(args.data_dir)
|
||||
|
||||
from os import cpu_count
|
||||
with open(os.path.join(args.data_dir, f"examples_{mode}.txt"), 'w') as file:
|
||||
for line in train_examples:
|
||||
d = {}
|
||||
d.update(line.__dict__)
|
||||
file.write(json.dumps(d) + '\n')
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, use_fast=False)
|
||||
|
||||
features = []
|
||||
# # with open(os.path.join(args.data_dir, "cached_relation_pattern.pkl"), "rb") as file:
|
||||
# # pattern = pickle.load(file)
|
||||
# pattern = None
|
||||
# convert_examples_to_features_init(tokenizer)
|
||||
# annotate_ = partial(
|
||||
# convert_examples_to_features,
|
||||
# max_seq_length=args.max_seq_length,
|
||||
# mode = mode,
|
||||
# pretrain = args.pretrain
|
||||
# )
|
||||
# features = list(
|
||||
# tqdm(
|
||||
# map(annotate_, train_examples),
|
||||
# total=len(train_examples)
|
||||
# )
|
||||
# )
|
||||
# encoder = MultiprocessingEncoder(tokenizer, args)
|
||||
# encoder.initializer()
|
||||
# for t in tqdm(train_examples):
|
||||
# features.append(encoder.encode_lines([json.dumps(t.__dict__)]))
|
||||
|
||||
# for example in tqdm(train_examples):
|
||||
# text_a = example.text_a
|
||||
# text_b = example.text_b
|
||||
# text_c = example.text_c
|
||||
|
||||
# bpe = tokenizer
|
||||
# if 0:
|
||||
# input_text_a = text_a
|
||||
# input_text_b = text_b
|
||||
# else:
|
||||
# if text_a == "[MASK]":
|
||||
# input_text_a = bpe.sep_token.join([text_a, text_b])
|
||||
# input_text_b = text_c
|
||||
# else:
|
||||
# input_text_a = text_a
|
||||
# input_text_b = bpe.sep_token.join([text_b, text_c])
|
||||
|
||||
|
||||
# inputs = tokenizer(
|
||||
# input_text_a,
|
||||
# # input_text_b,
|
||||
# truncation="longest_first",
|
||||
# max_length=128,
|
||||
# padding="longest",
|
||||
# add_special_tokens=True,
|
||||
# )
|
||||
# # assert tokenizer.mask_token_id in inputs.input_ids, "mask token must in input"
|
||||
|
||||
# # features.append(asdict(InputFeatures(input_ids=inputs["input_ids"],
|
||||
# # attention_mask=inputs['attention_mask'],
|
||||
# # labels=example.label,
|
||||
# # label=example.real_label
|
||||
# # )
|
||||
# # ))
|
||||
|
||||
|
||||
file_inputs = [os.path.join(args.data_dir, f"examples_{mode}.txt")]
|
||||
file_outputs = [os.path.join(args.data_dir, f"features_{mode}.txt")]
|
||||
|
||||
with contextlib.ExitStack() as stack:
|
||||
inputs = [
|
||||
stack.enter_context(open(input, "r", encoding="utf-8"))
|
||||
if input != "-" else sys.stdin
|
||||
for input in file_inputs
|
||||
]
|
||||
outputs = [
|
||||
stack.enter_context(open(output, "w", encoding="utf-8"))
|
||||
if output != "-" else sys.stdout
|
||||
for output in file_outputs
|
||||
]
|
||||
|
||||
encoder = MultiprocessingEncoder(tokenizer, args)
|
||||
pool = Pool(16, initializer=encoder.initializer)
|
||||
encoder.initializer()
|
||||
encoded_lines = pool.imap(encoder.encode_lines, zip(*inputs), 1000)
|
||||
# encoded_lines = map(encoder.encode_lines, zip(*inputs))
|
||||
|
||||
stats = Counter()
|
||||
for i, (filt, enc_lines) in tqdm(enumerate(encoded_lines, start=1), total=len(train_examples)):
|
||||
if filt == "PASS":
|
||||
for enc_line, output_h in zip(enc_lines, outputs):
|
||||
features.append(eval(enc_line))
|
||||
# features.append(enc_line)
|
||||
# print(enc_line, file=output_h)
|
||||
else:
|
||||
stats["num_filtered_" + filt] += 1
|
||||
# if i % 10000 == 0:
|
||||
# print("processed {} lines".format(i), file=sys.stderr)
|
||||
|
||||
for k, v in stats.most_common():
|
||||
print("[{}] filtered {} lines".format(k, v), file=sys.stderr)
|
||||
# threads = min(16, cpu_count())
|
||||
# with Pool(threads, initializer=convert_examples_to_features_init, initargs=(tokenizer,)) as pool:
|
||||
# annotate_ = partial(
|
||||
# convert_examples_to_features,
|
||||
# max_seq_length=args.max_seq_length,
|
||||
# mode = mode,
|
||||
# pretrain = args.pretrain
|
||||
# )
|
||||
# features = list(
|
||||
# tqdm(
|
||||
# pool.imap_unordered(annotate_, train_examples),
|
||||
# total=len(train_examples),
|
||||
# desc="convert examples to features",
|
||||
# )
|
||||
# )
|
||||
|
||||
# num_entities = len(processor.get_entities(args.data_dir))
|
||||
for f_id, f in enumerate(features):
|
||||
en = features[f_id].pop("en")
|
||||
rel = features[f_id].pop("rel")
|
||||
real_label = f['label']
|
||||
cnt = 0
|
||||
if not isinstance(en, list): break
|
||||
|
||||
pos = 0
|
||||
for i,t in enumerate(f['input_ids']):
|
||||
if t == tokenizer.pad_token_id:
|
||||
features[f_id]['input_ids'][i] = en[cnt] + len(tokenizer)
|
||||
cnt += 1
|
||||
if features[f_id]['input_ids'][i] == real_label + len(tokenizer):
|
||||
pos = i
|
||||
if cnt == len(en): break
|
||||
assert not (args.faiss_init and pos == 0)
|
||||
features[f_id]['pos'] = pos
|
||||
|
||||
|
||||
# for i,t in enumerate(f['input_ids']):
|
||||
# if t == tokenizer.pad_token_id:
|
||||
# features[f_id]['input_ids'][i] = rel + len(tokenizer) + num_entities
|
||||
# break
|
||||
|
||||
|
||||
|
||||
features = KGCDataset(features)
|
||||
return features
|
||||
|
||||
|
||||
class MultiprocessingEncoder(object):
|
||||
def __init__(self, tokenizer, args):
|
||||
self.tokenizer = tokenizer
|
||||
self.pretrain = args.pretrain
|
||||
self.max_seq_length = args.max_seq_length
|
||||
|
||||
def initializer(self):
|
||||
global bpe
|
||||
bpe = self.tokenizer
|
||||
|
||||
def encode(self, line):
|
||||
global bpe
|
||||
ids = bpe.encode(line)
|
||||
return list(map(str, ids))
|
||||
|
||||
def decode(self, tokens):
|
||||
global bpe
|
||||
return bpe.decode(tokens)
|
||||
|
||||
def encode_lines(self, lines):
|
||||
"""
|
||||
Encode a set of lines. All lines will be encoded together.
|
||||
"""
|
||||
enc_lines = []
|
||||
for line in lines:
|
||||
line = line.strip()
|
||||
if len(line) == 0:
|
||||
return ["EMPTY", None]
|
||||
# enc_lines.append(" ".join(tokens))
|
||||
enc_lines.append(json.dumps(self.convert_examples_to_features(example=eval(line))))
|
||||
# enc_lines.append(" ")
|
||||
# enc_lines.append("123")
|
||||
return ["PASS", enc_lines]
|
||||
|
||||
def decode_lines(self, lines):
|
||||
dec_lines = []
|
||||
for line in lines:
|
||||
tokens = map(int, line.strip().split())
|
||||
dec_lines.append(self.decode(tokens))
|
||||
return ["PASS", dec_lines]
|
||||
|
||||
def convert_examples_to_features(self, example):
|
||||
pretrain = self.pretrain
|
||||
max_seq_length = self.max_seq_length
|
||||
global bpe
|
||||
"""Loads a data file into a list of `InputBatch`s."""
|
||||
# tokens_a = tokenizer.tokenize(example.text_a)
|
||||
# tokens_b = tokenizer.tokenize(example.text_b)
|
||||
# tokens_c = tokenizer.tokenize(example.text_c)
|
||||
|
||||
# _truncate_seq_triple(tokens_a, tokens_b, tokens_c, max_length= max_seq_length)
|
||||
# text_a = " ".join(example['text_a'].split()[:128])
|
||||
# text_b = " ".join(example['text_b'].split()[:128])
|
||||
# text_c = " ".join(example['text_c'].split()[:128])
|
||||
|
||||
text_a = example['text_a']
|
||||
text_b = example['text_b']
|
||||
text_c = example['text_c']
|
||||
|
||||
if pretrain:
|
||||
# the des of xxx is [MASK] .
|
||||
input_text = f"The description of {text_a} is that {text_b} ."
|
||||
inputs = bpe(
|
||||
input_text,
|
||||
truncation="longest_first",
|
||||
max_length=max_seq_length,
|
||||
padding="longest",
|
||||
add_special_tokens=True,
|
||||
)
|
||||
else:
|
||||
if text_a == "[MASK]":
|
||||
input_text_a = bpe.sep_token.join([text_a, text_b])
|
||||
input_text_b = text_c
|
||||
else:
|
||||
input_text_a = text_a
|
||||
input_text_b = bpe.sep_token.join([text_b, text_c])
|
||||
|
||||
|
||||
inputs = bpe(
|
||||
input_text_a,
|
||||
input_text_b,
|
||||
truncation="longest_first",
|
||||
max_length=max_seq_length,
|
||||
padding="longest",
|
||||
add_special_tokens=True,
|
||||
)
|
||||
# assert bpe.mask_token_id in inputs.input_ids, "mask token must in input"
|
||||
|
||||
features = asdict(InputFeatures(input_ids=inputs["input_ids"],
|
||||
attention_mask=inputs['attention_mask'],
|
||||
labels=example['label'],
|
||||
label=example['real_label'],
|
||||
en=example['en'],
|
||||
rel=example['rel']
|
||||
)
|
||||
)
|
||||
return features
|
||||
if __name__ == "__main__":
|
||||
dataset = KGCDataset('./dataset')
|
2
pretrain/lit_models/__init__.py
Normal file
2
pretrain/lit_models/__init__.py
Normal file
@ -0,0 +1,2 @@
|
||||
from .transformer import *
|
||||
from .base import *
|
97
pretrain/lit_models/base.py
Normal file
97
pretrain/lit_models/base.py
Normal file
@ -0,0 +1,97 @@
|
||||
import argparse
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
from typing import Dict, Any
|
||||
|
||||
|
||||
OPTIMIZER = "AdamW"
|
||||
LR = 5e-5
|
||||
LOSS = "cross_entropy"
|
||||
ONE_CYCLE_TOTAL_STEPS = 100
|
||||
|
||||
class Config(dict):
|
||||
def __getattr__(self, name):
|
||||
return self.get(name)
|
||||
|
||||
def __setattr__(self, name, val):
|
||||
self[name] = val
|
||||
|
||||
|
||||
class BaseLitModel(pl.LightningModule):
|
||||
"""
|
||||
Generic PyTorch-Lightning class that must be initialized with a PyTorch module.
|
||||
"""
|
||||
|
||||
def __init__(self, model, args: argparse.Namespace = None):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
self.args = Config(vars(args)) if args is not None else {}
|
||||
|
||||
optimizer = self.args.get("optimizer", OPTIMIZER)
|
||||
self.optimizer_class = getattr(torch.optim, optimizer)
|
||||
self.lr = self.args.get("lr", LR)
|
||||
|
||||
|
||||
@staticmethod
|
||||
def add_to_argparse(parser):
|
||||
parser.add_argument("--optimizer", type=str, default=OPTIMIZER, help="optimizer class from torch.optim")
|
||||
parser.add_argument("--lr", type=float, default=LR)
|
||||
parser.add_argument("--weight_decay", type=float, default=0.01)
|
||||
return parser
|
||||
|
||||
def configure_optimizers(self):
|
||||
optimizer = self.optimizer_class(self.parameters(), lr=self.lr)
|
||||
if self.one_cycle_max_lr is None:
|
||||
return optimizer
|
||||
scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer=optimizer, max_lr=self.one_cycle_max_lr, total_steps=self.one_cycle_total_steps)
|
||||
return {"optimizer": optimizer, "lr_scheduler": scheduler, "monitor": "val_loss"}
|
||||
|
||||
def forward(self, x):
|
||||
return self.model(x)
|
||||
|
||||
def training_step(self, batch, batch_idx): # pylint: disable=unused-argument
|
||||
x, y = batch
|
||||
logits = self(x)
|
||||
loss = self.loss_fn(logits, y)
|
||||
self.log("train_loss", loss)
|
||||
self.train_acc(logits, y)
|
||||
self.log("train_acc", self.train_acc, on_step=False, on_epoch=True)
|
||||
return loss
|
||||
|
||||
def validation_step(self, batch, batch_idx): # pylint: disable=unused-argument
|
||||
x, y = batch
|
||||
logits = self(x)
|
||||
loss = self.loss_fn(logits, y)
|
||||
self.log("val_loss", loss, prog_bar=True)
|
||||
self.val_acc(logits, y)
|
||||
self.log("val_acc", self.val_acc, on_step=False, on_epoch=True, prog_bar=True)
|
||||
|
||||
def test_step(self, batch, batch_idx): # pylint: disable=unused-argument
|
||||
x, y = batch
|
||||
logits = self(x)
|
||||
self.test_acc(logits, y)
|
||||
self.log("test_acc", self.test_acc, on_step=False, on_epoch=True)
|
||||
|
||||
@property
|
||||
def num_training_steps(self) -> int:
|
||||
"""Total training steps inferred from datamodule and devices."""
|
||||
if isinstance(self.trainer.limit_train_batches, int) and self.trainer.limit_train_batches != 0:
|
||||
dataset_size = self.trainer.limit_train_batches
|
||||
elif isinstance(self.trainer.limit_train_batches, float):
|
||||
# limit_train_batches is a percentage of batches
|
||||
dataset_size = len(self.trainer.datamodule.train_dataloader())
|
||||
dataset_size = int(dataset_size * self.trainer.limit_train_batches)
|
||||
else:
|
||||
dataset_size = len(self.trainer.datamodule.train_dataloader())
|
||||
|
||||
num_devices = max(1, self.trainer.num_gpus, self.trainer.num_processes)
|
||||
if self.trainer.tpu_cores:
|
||||
num_devices = max(num_devices, self.trainer.tpu_cores)
|
||||
|
||||
effective_batch_size = self.trainer.accumulate_grad_batches * num_devices
|
||||
max_estimated_steps = (dataset_size // effective_batch_size) * self.trainer.max_epochs
|
||||
|
||||
if self.trainer.max_steps and self.trainer.max_steps < max_estimated_steps:
|
||||
return self.trainer.max_steps
|
||||
return max_estimated_steps
|
||||
|
503
pretrain/lit_models/transformer.py
Normal file
503
pretrain/lit_models/transformer.py
Normal file
@ -0,0 +1,503 @@
|
||||
from logging import debug
|
||||
import random
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
import pickle
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import numpy as np
|
||||
import json
|
||||
# from transformers.utils.dummy_pt_objects import PrefixConstrainedLogitsProcessor
|
||||
|
||||
from .base import BaseLitModel
|
||||
from transformers.optimization import get_linear_schedule_with_warmup, get_cosine_schedule_with_warmup
|
||||
|
||||
from functools import partial
|
||||
from .utils import rank_score, acc, LabelSmoothSoftmaxCEV1
|
||||
|
||||
from typing import Callable, Iterable, List
|
||||
|
||||
def lmap(f: Callable, x: Iterable) -> List:
|
||||
"""list(map(f, x))"""
|
||||
return list(map(f, x))
|
||||
|
||||
def multilabel_categorical_crossentropy(y_pred, y_true):
|
||||
y_pred = (1 - 2 * y_true) * y_pred
|
||||
y_pred_neg = y_pred - y_true * 1e12
|
||||
y_pred_pos = y_pred - (1 - y_true) * 1e12
|
||||
zeros = torch.zeros_like(y_pred[..., :1])
|
||||
y_pred_neg = torch.cat([y_pred_neg, zeros], dim=-1)
|
||||
y_pred_pos = torch.cat([y_pred_pos, zeros], dim=-1)
|
||||
neg_loss = torch.logsumexp(y_pred_neg, dim=-1)
|
||||
pos_loss = torch.logsumexp(y_pred_pos, dim=-1)
|
||||
return (neg_loss + pos_loss).mean()
|
||||
|
||||
def decode(output_ids, tokenizer):
|
||||
return lmap(str.strip, tokenizer.batch_decode(output_ids, skip_special_tokens=False, clean_up_tokenization_spaces=True))
|
||||
|
||||
class TransformerLitModel(BaseLitModel):
|
||||
def __init__(self, model, args, tokenizer=None, data_config={}):
|
||||
super().__init__(model, args)
|
||||
self.save_hyperparameters(args)
|
||||
if args.bce:
|
||||
self.loss_fn = torch.nn.BCEWithLogitsLoss()
|
||||
elif args.label_smoothing != 0.0:
|
||||
self.loss_fn = LabelSmoothSoftmaxCEV1(lb_smooth=args.label_smoothing)
|
||||
else:
|
||||
self.loss_fn = nn.CrossEntropyLoss()
|
||||
self.best_acc = 0
|
||||
self.first = True
|
||||
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
self.__dict__.update(data_config)
|
||||
# resize the word embedding layer
|
||||
self.model.resize_token_embeddings(len(self.tokenizer))
|
||||
self.decode = partial(decode, tokenizer=self.tokenizer)
|
||||
if args.pretrain:
|
||||
self._freaze_attention()
|
||||
elif "ind" in args.data_dir:
|
||||
# for inductive setting, use feeaze the word embedding
|
||||
self._freaze_word_embedding()
|
||||
|
||||
|
||||
def forward(self, x):
|
||||
return self.model(x)
|
||||
|
||||
def training_step(self, batch, batch_idx): # pylint: disable=unused-argument
|
||||
# embed();exit()
|
||||
# print(self.optimizers().param_groups[1]['lr'])
|
||||
labels = batch.pop("labels")
|
||||
label = batch.pop("label")
|
||||
pos = batch.pop("pos")
|
||||
try:
|
||||
en = batch.pop("en")
|
||||
rel = batch.pop("rel")
|
||||
except KeyError:
|
||||
pass
|
||||
input_ids = batch['input_ids']
|
||||
logits = self.model(**batch, return_dict=True).logits
|
||||
|
||||
_, mask_idx = (input_ids == self.tokenizer.mask_token_id).nonzero(as_tuple=True)
|
||||
bs = input_ids.shape[0]
|
||||
mask_logits = logits[torch.arange(bs), mask_idx][:, self.entity_id_st:self.entity_id_ed]
|
||||
|
||||
assert mask_idx.shape[0] == bs, "only one mask in sequence!"
|
||||
if self.args.bce:
|
||||
loss = self.loss_fn(mask_logits, labels)
|
||||
else:
|
||||
loss = self.loss_fn(mask_logits, label)
|
||||
|
||||
if batch_idx == 0:
|
||||
print('\n'.join(self.decode(batch['input_ids'][:4])))
|
||||
|
||||
|
||||
return loss
|
||||
|
||||
def _eval(self, batch, batch_idx, ):
|
||||
labels = batch.pop("labels")
|
||||
input_ids = batch['input_ids']
|
||||
# single label
|
||||
label = batch.pop('label')
|
||||
pos = batch.pop('pos')
|
||||
my_keys = list(batch.keys())
|
||||
for k in my_keys:
|
||||
if k not in ["input_ids", "attention_mask", "token_type_ids"]:
|
||||
batch.pop(k)
|
||||
logits = self.model(**batch, return_dict=True).logits[:, :, self.entity_id_st:self.entity_id_ed]
|
||||
_, mask_idx = (input_ids == self.tokenizer.mask_token_id).nonzero(as_tuple=True)
|
||||
bsz = input_ids.shape[0]
|
||||
logits = logits[torch.arange(bsz), mask_idx]
|
||||
# get the entity ranks
|
||||
# filter the entity
|
||||
assert labels[0][label[0]], "correct ids must in filiter!"
|
||||
labels[torch.arange(bsz), label] = 0
|
||||
assert logits.shape == labels.shape
|
||||
logits += labels * -100 # mask entityj
|
||||
# for i in range(bsz):
|
||||
# logits[i][labels]
|
||||
|
||||
_, outputs = torch.sort(logits, dim=1, descending=True)
|
||||
_, outputs = torch.sort(outputs, dim=1)
|
||||
ranks = outputs[torch.arange(bsz), label].detach().cpu() + 1
|
||||
|
||||
|
||||
return dict(ranks = np.array(ranks))
|
||||
|
||||
def validation_step(self, batch, batch_idx):
|
||||
result = self._eval(batch, batch_idx)
|
||||
return result
|
||||
|
||||
def validation_epoch_end(self, outputs) -> None:
|
||||
ranks = np.concatenate([_['ranks'] for _ in outputs])
|
||||
total_ranks = ranks.shape[0]
|
||||
|
||||
if not self.args.pretrain:
|
||||
l_ranks = ranks[np.array(list(np.arange(0, total_ranks, 2)))]
|
||||
r_ranks = ranks[np.array(list(np.arange(0, total_ranks, 2))) + 1]
|
||||
self.log("Eval/lhits10", (l_ranks<=10).mean())
|
||||
self.log("Eval/rhits10", (r_ranks<=10).mean())
|
||||
|
||||
hits20 = (ranks<=20).mean()
|
||||
hits10 = (ranks<=10).mean()
|
||||
hits3 = (ranks<=3).mean()
|
||||
hits1 = (ranks<=1).mean()
|
||||
|
||||
self.log("Eval/hits10", hits10)
|
||||
self.log("Eval/hits20", hits20)
|
||||
self.log("Eval/hits3", hits3)
|
||||
self.log("Eval/hits1", hits1)
|
||||
self.log("Eval/mean_rank", ranks.mean())
|
||||
self.log("Eval/mrr", (1. / ranks).mean())
|
||||
self.log("hits10", hits10, prog_bar=True)
|
||||
self.log("hits1", hits1, prog_bar=True)
|
||||
|
||||
|
||||
|
||||
|
||||
def test_step(self, batch, batch_idx): # pylint: disable=unused-argument
|
||||
# ranks = self._eval(batch, batch_idx)
|
||||
result = self._eval(batch, batch_idx)
|
||||
# self.log("Test/ranks", np.mean(ranks))
|
||||
|
||||
return result
|
||||
|
||||
def test_epoch_end(self, outputs) -> None:
|
||||
ranks = np.concatenate([_['ranks'] for _ in outputs])
|
||||
|
||||
hits20 = (ranks<=20).mean()
|
||||
hits10 = (ranks<=10).mean()
|
||||
hits3 = (ranks<=3).mean()
|
||||
hits1 = (ranks<=1).mean()
|
||||
|
||||
|
||||
self.log("Test/hits10", hits10)
|
||||
self.log("Test/hits20", hits20)
|
||||
self.log("Test/hits3", hits3)
|
||||
self.log("Test/hits1", hits1)
|
||||
self.log("Test/mean_rank", ranks.mean())
|
||||
self.log("Test/mrr", (1. / ranks).mean())
|
||||
|
||||
def configure_optimizers(self):
|
||||
no_decay_param = ["bias", "LayerNorm.weight"]
|
||||
|
||||
optimizer_group_parameters = [
|
||||
{"params": [p for n, p in self.model.named_parameters() if p.requires_grad and not any(nd in n for nd in no_decay_param)], "weight_decay": self.args.weight_decay},
|
||||
{"params": [p for n, p in self.model.named_parameters() if p.requires_grad and any(nd in n for nd in no_decay_param)], "weight_decay": 0}
|
||||
]
|
||||
|
||||
optimizer = self.optimizer_class(optimizer_group_parameters, lr=self.lr, eps=1e-8)
|
||||
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=self.num_training_steps * self.args.warm_up_radio, num_training_steps=self.num_training_steps)
|
||||
return {
|
||||
"optimizer": optimizer,
|
||||
"lr_scheduler":{
|
||||
'scheduler': scheduler,
|
||||
'interval': 'step', # or 'epoch'
|
||||
'frequency': 1,
|
||||
}
|
||||
}
|
||||
|
||||
def _freaze_attention(self):
|
||||
for k, v in self.model.named_parameters():
|
||||
if "word" not in k:
|
||||
v.requires_grad = False
|
||||
else:
|
||||
print(k)
|
||||
|
||||
def _freaze_word_embedding(self):
|
||||
for k, v in self.model.named_parameters():
|
||||
if "word" in k:
|
||||
print(k)
|
||||
v.requires_grad = False
|
||||
|
||||
@staticmethod
|
||||
def add_to_argparse(parser):
|
||||
parser = BaseLitModel.add_to_argparse(parser)
|
||||
|
||||
parser.add_argument("--label_smoothing", type=float, default=0.1, help="")
|
||||
parser.add_argument("--bce", type=int, default=0, help="")
|
||||
return parser
|
||||
|
||||
|
||||
|
||||
import faiss
|
||||
import os
|
||||
class GetEntityEmbeddingLitModel(TransformerLitModel):
|
||||
def __init__(self, model, args, tokenizer, data_config={}):
|
||||
super().__init__(model, args, tokenizer, data_config)
|
||||
|
||||
self.faissid2entityid = {}
|
||||
# self.index = faiss.IndexFlatL2(d) # build the index
|
||||
|
||||
d, measure = self.model.config.hidden_size, faiss.METRIC_L2
|
||||
# param = 'HNSW64'
|
||||
# self.index = faiss.index_factory(d, param, measure)
|
||||
self.index = faiss.IndexFlatL2(d) # build the index
|
||||
# print(self.index.is_trained) # 此时输出为True
|
||||
# index.add(xb)
|
||||
self.cnt_batch = 0
|
||||
self.total_embedding = []
|
||||
|
||||
|
||||
|
||||
def test_step(self, batch, batch_idx):
|
||||
labels = batch.pop("labels")
|
||||
mask_idx = batch.pop("pos")
|
||||
input_ids = batch['input_ids']
|
||||
# single label
|
||||
label = batch.pop('label')
|
||||
# last layer
|
||||
hidden_states = self.model(**batch, return_dict=True, output_hidden_states=True).hidden_states[-1]
|
||||
# _, mask_idx = (input_ids == self.tokenizer.mask_token_id).nonzero(as_tuple=True)
|
||||
bsz = input_ids.shape[0]
|
||||
entity_embedding = hidden_states[torch.arange(bsz), mask_idx].cpu()
|
||||
# use normalize or not ?
|
||||
# entity_embedding = F.normalize(entity_embedding, dim=-1, p = 2)
|
||||
self.total_embedding.append(entity_embedding)
|
||||
# self.index.add(np.array(entity_embedding, dtype=np.float32))
|
||||
for i, l in zip(range(bsz), label):
|
||||
self.faissid2entityid[i+self.cnt_batch] = l.cpu()
|
||||
self.cnt_batch += bsz
|
||||
|
||||
|
||||
def test_epoch_end(self, outputs) -> None:
|
||||
self.total_embedding = np.concatenate(self.total_embedding, axis=0)
|
||||
# self.index.train(self.total_embedding)
|
||||
print(faiss.MatrixStats(self.total_embedding).comments)
|
||||
self.index.add(self.total_embedding)
|
||||
faiss.write_index(self.index, os.path.join(self.args.data_dir, "faiss_dump.index"))
|
||||
with open(os.path.join(self.args.data_dir, "faissid2entityid.pkl") ,'wb') as file:
|
||||
pickle.dump(self.faissid2entityid, file)
|
||||
|
||||
with open(os.path.join(self.args.data_dir, "total_embedding.pkl") ,'wb') as file:
|
||||
pickle.dump(self.total_embedding, file)
|
||||
# print(f"number of entity embedding : {len(self.faissid2entityid)}")
|
||||
|
||||
@staticmethod
|
||||
def add_to_argparse(parser):
|
||||
parser = TransformerLitModel.add_to_argparse(parser)
|
||||
parser.add_argument("--faiss_init", type=int, default=1, help="get the embedding and save it the file.")
|
||||
return parser
|
||||
|
||||
class UseEntityEmbeddingLitModel(TransformerLitModel):
|
||||
def __init__(self, model, args, tokenizer, data_config={}):
|
||||
super().__init__(model, args, tokenizer, data_config)
|
||||
|
||||
self.faissid2entityid = pickle.load(open(os.path.join(self.args.data_dir, "faissid2entityid.pkl") ,'rb'))
|
||||
self.index = faiss.read_index(os.path.join(self.args.data_dir, "faiss_dump.index"))
|
||||
|
||||
|
||||
self.dis2logits = distance2logits_2
|
||||
|
||||
def _eval(self, batch, batch_idx, ):
|
||||
labels = batch.pop("labels")
|
||||
pos = batch.pop("pos")
|
||||
input_ids = batch['input_ids']
|
||||
# single label
|
||||
label = batch.pop('label')
|
||||
|
||||
hidden_states = self.model(**batch, return_dict=True, output_hidden_states=True).hidden_states[-1]
|
||||
_, mask_idx = (input_ids == self.tokenizer.mask_token_id).nonzero(as_tuple=True)
|
||||
bsz = input_ids.shape[0]
|
||||
mask_embedding = np.array(hidden_states[torch.arange(bsz), mask_idx].cpu(), dtype=np.float32)
|
||||
topk = 200
|
||||
D, I = self.index.search(mask_embedding, topk)
|
||||
labels[torch.arange(bsz), label] = 0
|
||||
|
||||
# logits = torch.zeros_like(labels)
|
||||
# D = torch.softmax(torch.exp(-1. * torch.tensor(D)), dim=-1)
|
||||
# for i in range(bsz):
|
||||
# for j in range(topk):
|
||||
# logits[i][self.faissid2entityid[I[i][j]]] += D[i][j]
|
||||
# # get the entity ranks
|
||||
# # filter the entity
|
||||
# assert labels[0][label[0]], "correct ids must in filiter!"
|
||||
# labels[torch.arange(bsz), label] = 0
|
||||
# assert logits.shape == labels.shape
|
||||
# logits += labels * -100 # mask entityj
|
||||
|
||||
# _, outputs = torch.sort(logits, dim=1, descending=True)
|
||||
# _, outputs = torch.sort(outputs, dim=1)
|
||||
# ranks = outputs[torch.arange(bsz), label].detach().cpu() + 1
|
||||
|
||||
|
||||
entity_logits = torch.full(labels.shape, -100.).to(self.device)
|
||||
D = self.dis2logits(D)
|
||||
for i in range(bsz):
|
||||
for j in range(topk):
|
||||
# filter entity in labels
|
||||
if I[i][j] not in self.faissid2entityid:
|
||||
print(I[i][j])
|
||||
break
|
||||
# assert I[i][j] in self.faissid2entityid, print(I[i][j])
|
||||
if labels[i][self.faissid2entityid[I[i][j]]]: continue
|
||||
if entity_logits[i][self.faissid2entityid[I[i][j]]] == -100.:
|
||||
entity_logits[i][self.faissid2entityid[I[i][j]]] = D[i][j]
|
||||
# no added together
|
||||
# else:
|
||||
# entity_logits[i][self.faissid2entityid[I[i][j]]] += D[i][j]
|
||||
# get the entity ranks
|
||||
# filter the entity
|
||||
|
||||
assert entity_logits.shape == labels.shape
|
||||
|
||||
_, outputs = torch.sort(entity_logits, dim=1, descending=True)
|
||||
_, outputs = torch.sort(outputs, dim=1)
|
||||
ranks = outputs[torch.arange(bsz), label].detach().cpu() + 1
|
||||
|
||||
|
||||
return dict(ranks = np.array(ranks))
|
||||
|
||||
|
||||
@staticmethod
|
||||
def add_to_argparse(parser):
|
||||
parser = TransformerLitModel.add_to_argparse(parser)
|
||||
parser.add_argument("--faiss_init", type=int, default=0, help="get the embedding and save it the file.")
|
||||
parser.add_argument("--faiss_use", type=int, default=1, help="get the embedding and save it the file.")
|
||||
return parser
|
||||
|
||||
|
||||
class CombineEntityEmbeddingLitModel(UseEntityEmbeddingLitModel):
|
||||
def __init__(self, model, args, tokenizer, data_config={}):
|
||||
super().__init__(model, args, tokenizer, data_config=data_config)
|
||||
self.dis2logits = distance2logits_2
|
||||
self.id2entity = {}
|
||||
with open("./dataset/FB15k-237/entity2textlong.txt", 'r') as file:
|
||||
cnt = 0
|
||||
for line in file.readlines():
|
||||
e, d = line.strip().split("\t")
|
||||
self.id2entity[cnt] = e
|
||||
cnt += 1
|
||||
self.id2entity_t = {}
|
||||
with open("./dataset/FB15k-237/entity2text.txt", 'r') as file:
|
||||
for line in file.readlines():
|
||||
e, d = line.strip().split("\t")
|
||||
self.id2entity_t[e] = d
|
||||
for k, v in self.id2entity.items():
|
||||
self.id2entity[k] = self.id2entity_t[v]
|
||||
def _eval(self, batch, batch_idx, ):
|
||||
labels = batch.pop("labels")
|
||||
input_ids = batch['input_ids']
|
||||
# single label
|
||||
label = batch.pop('label')
|
||||
pos = batch.pop("pos")
|
||||
|
||||
result = self.model(**batch, return_dict=True, output_hidden_states=True)
|
||||
hidden_states = result.hidden_states[-1]
|
||||
_, mask_idx = (input_ids == self.tokenizer.mask_token_id).nonzero(as_tuple=True)
|
||||
bsz = input_ids.shape[0]
|
||||
mask_embedding = np.array(hidden_states[torch.arange(bsz), mask_idx].cpu(), dtype=np.float32)
|
||||
# mask_embedding = np.array(hidden_states[torch.arange(bsz), mask_idx].cpu(), dtype=np.float32)
|
||||
topk = self.args.knn_topk
|
||||
D, I = self.index.search(mask_embedding, topk)
|
||||
D = torch.from_numpy(D).to(self.device)
|
||||
assert labels[0][label[0]], "correct ids must in filiter!"
|
||||
labels[torch.arange(bsz), label] = 0
|
||||
|
||||
|
||||
mask_logits = result.logits[:, :, self.entity_id_st:self.entity_id_ed]
|
||||
mask_logits = mask_logits[torch.arange(bsz), mask_idx]
|
||||
entity_logits = torch.full(labels.shape, 1000.).to(self.device)
|
||||
# D = self.dis2logits(D)
|
||||
for i in range(bsz):
|
||||
for j in range(topk):
|
||||
# filter entity in labels
|
||||
if labels[i][self.faissid2entityid[I[i][j]]]: continue
|
||||
if entity_logits[i][self.faissid2entityid[I[i][j]]] == 1000.:
|
||||
entity_logits[i][self.faissid2entityid[I[i][j]]] = D[i][j]
|
||||
# else:
|
||||
# entity_logits[i][self.faissid2entityid[I[i][j]]] += D[i][j]
|
||||
entity_logits = self.dis2logits(entity_logits)
|
||||
# get the entity ranks
|
||||
# filter the entity
|
||||
assert entity_logits.shape == labels.shape
|
||||
assert mask_logits.shape == labels.shape
|
||||
# entity_logits = torch.softmax(entity_logits + labels * -100, dim=-1) # mask entityj
|
||||
entity_logits = entity_logits + labels* -100.
|
||||
mask_logits = torch.softmax(mask_logits + labels* -100, dim=-1)
|
||||
# logits = mask_logits
|
||||
logits = combine_knn_and_vocab_probs(entity_logits, mask_logits, self.args.knn_lambda)
|
||||
# logits = entity_logits + mask_logits
|
||||
|
||||
|
||||
knn_topk_logits, knn_topk_id = entity_logits.topk(20)
|
||||
mask_topk_logits, mask_topk_id = mask_logits.topk(20)
|
||||
union_topk = []
|
||||
for i in range(bsz):
|
||||
num_same = len(list(set(knn_topk_id[i].cpu().tolist()) & set(mask_topk_id[i].cpu().tolist())))
|
||||
union_topk.append(num_same/ 20.)
|
||||
|
||||
knn_topk_id = knn_topk_id.to("cpu")
|
||||
mask_topk_id = mask_topk_id.to("cpu")
|
||||
mask_topk_logits = mask_topk_logits.to("cpu")
|
||||
knn_topk_logits = knn_topk_logits.to("cpu")
|
||||
label = label.to("cpu")
|
||||
|
||||
|
||||
|
||||
for t in range(bsz):
|
||||
if knn_topk_id[t][0] == label[t] and knn_topk_logits[t][0] > mask_topk_logits[t][0] and mask_topk_logits[t][0] <= 0.4:
|
||||
print(knn_topk_logits[t], knn_topk_id[t])
|
||||
print(lmap(lambda x: self.id2entity[x.item()], knn_topk_id[t]))
|
||||
print(mask_topk_logits[t], mask_topk_id[t])
|
||||
print(lmap(lambda x: self.id2entity[x.item()], mask_topk_id[t]))
|
||||
print(label[t])
|
||||
print()
|
||||
|
||||
_, outputs = torch.sort(logits, dim=1, descending=True)
|
||||
_, outputs = torch.sort(outputs, dim=1)
|
||||
ranks = outputs[torch.arange(bsz), label].detach().cpu() + 1
|
||||
|
||||
|
||||
return dict(ranks = np.array(ranks), knn_topk_id=knn_topk_id, knn_topk_logits=knn_topk_logits,
|
||||
mask_topk_id=mask_topk_id, mask_topk_logits=mask_topk_logits, num_same = np.array(union_topk))
|
||||
|
||||
def test_epoch_end(self, outputs) -> None:
|
||||
|
||||
ranks = np.concatenate([_['ranks'] for _ in outputs])
|
||||
num_same = np.concatenate([_['num_same'] for _ in outputs])
|
||||
results_keys = list(outputs[0].keys())
|
||||
results = {}
|
||||
# for k in results_keys:
|
||||
# results.
|
||||
|
||||
self.log("Test/num_same", num_same.mean())
|
||||
|
||||
hits20 = (ranks<=20).mean()
|
||||
hits10 = (ranks<=10).mean()
|
||||
hits3 = (ranks<=3).mean()
|
||||
hits1 = (ranks<=1).mean()
|
||||
|
||||
|
||||
self.log("Test/hits10", hits10)
|
||||
self.log("Test/hits20", hits20)
|
||||
self.log("Test/hits3", hits3)
|
||||
self.log("Test/hits1", hits1)
|
||||
self.log("Test/mean_rank", ranks.mean())
|
||||
self.log("Test/mrr", (1. / ranks).mean())
|
||||
|
||||
def add_to_argparse(parser):
|
||||
parser = TransformerLitModel.add_to_argparse(parser)
|
||||
parser.add_argument("--knn_lambda", type=float, default=0.5, help="lambda * knn + (1-lambda) * mask logits , lambda of knn logits and mask logits.")
|
||||
parser.add_argument("--knn_topk", type=int, default=100, help="")
|
||||
|
||||
return parser
|
||||
|
||||
def combine_knn_and_vocab_probs(knn_p, vocab_p, coeff=0.5):
|
||||
combine_probs = torch.stack([vocab_p, knn_p], dim=0)
|
||||
coeffs = torch.ones_like(combine_probs)
|
||||
coeffs[0] = np.log(1 - coeff)
|
||||
coeffs[1] = np.log(coeff)
|
||||
curr_prob = torch.logsumexp(combine_probs + coeffs, dim=0)
|
||||
|
||||
return curr_prob
|
||||
|
||||
def distance2logits(D):
|
||||
return torch.softmax( -1. * torch.tensor(D) / 30., dim=-1)
|
||||
|
||||
def distance2logits_2(D, n=10):
|
||||
if not isinstance(D, torch.Tensor):
|
||||
D = torch.tensor(D)
|
||||
if torch.sum(D) != 0.0:
|
||||
distances = torch.exp(-D/n) / torch.sum(torch.exp(-D/n), dim=-1, keepdim=True)
|
||||
return distances
|
66
pretrain/lit_models/utils.py
Normal file
66
pretrain/lit_models/utils.py
Normal file
@ -0,0 +1,66 @@
|
||||
import json
|
||||
import numpy as np
|
||||
|
||||
def rank_score(ranks):
|
||||
# prepare the dataset
|
||||
len_samples = len(ranks)
|
||||
hits10 = [0] * len_samples
|
||||
hits5 = [0] * len_samples
|
||||
hits1 = [0] * len_samples
|
||||
mrr = []
|
||||
|
||||
|
||||
for idx, rank in enumerate(ranks):
|
||||
if rank <= 10:
|
||||
hits10[idx] = 1.
|
||||
if rank <= 5:
|
||||
hits5[idx] = 1.
|
||||
if rank <= 1:
|
||||
hits1[idx] = 1.
|
||||
mrr.append(1./rank)
|
||||
|
||||
|
||||
return np.mean(hits10), np.mean(hits5), np.mean(hits1), np.mean(mrr)
|
||||
|
||||
def acc(logits, labels):
|
||||
preds = np.argmax(logits, axis=-1)
|
||||
return (preds == labels).mean()
|
||||
import torch.nn as nn
|
||||
import torch
|
||||
class LabelSmoothSoftmaxCEV1(nn.Module):
|
||||
'''
|
||||
This is the autograd version, you can also try the LabelSmoothSoftmaxCEV2 that uses derived gradients
|
||||
'''
|
||||
|
||||
def __init__(self, lb_smooth=0.1, reduction='mean', ignore_index=-100):
|
||||
super(LabelSmoothSoftmaxCEV1, self).__init__()
|
||||
self.lb_smooth = lb_smooth
|
||||
self.reduction = reduction
|
||||
self.lb_ignore = ignore_index
|
||||
self.log_softmax = nn.LogSoftmax(dim=1)
|
||||
|
||||
def forward(self, logits, label):
|
||||
'''
|
||||
args: logits: tensor of shape (N, C, H, W)
|
||||
args: label: tensor of shape(N, H, W)
|
||||
'''
|
||||
# overcome ignored label
|
||||
with torch.no_grad():
|
||||
num_classes = logits.size(1)
|
||||
label = label.clone().detach()
|
||||
ignore = label == self.lb_ignore
|
||||
n_valid = (ignore == 0).sum()
|
||||
label[ignore] = 0
|
||||
lb_pos, lb_neg = 1. - self.lb_smooth, self.lb_smooth / num_classes
|
||||
label = torch.empty_like(logits).fill_(
|
||||
lb_neg).scatter_(1, label.unsqueeze(1), lb_pos).detach()
|
||||
|
||||
logs = self.log_softmax(logits)
|
||||
loss = -torch.sum(logs * label, dim=1)
|
||||
loss[ignore] = 0
|
||||
if self.reduction == 'mean':
|
||||
loss = loss.sum() / n_valid
|
||||
if self.reduction == 'sum':
|
||||
loss = loss.sum()
|
||||
|
||||
return loss
|
141
pretrain/main.py
Normal file
141
pretrain/main.py
Normal file
@ -0,0 +1,141 @@
|
||||
import argparse
|
||||
import importlib
|
||||
from logging import debug
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import pytorch_lightning as pl
|
||||
import lit_models
|
||||
import yaml
|
||||
import time
|
||||
from transformers import AutoConfig
|
||||
import os
|
||||
# from utils import get_decoder_input_ids
|
||||
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
|
||||
|
||||
# In order to ensure reproducible experiments, we must set random seeds.
|
||||
|
||||
|
||||
def _import_class(module_and_class_name: str) -> type:
|
||||
"""Import class from a module, e.g. 'text_recognizer.models.MLP'"""
|
||||
module_name, class_name = module_and_class_name.rsplit(".", 1)
|
||||
module = importlib.import_module(module_name)
|
||||
class_ = getattr(module, class_name)
|
||||
return class_
|
||||
|
||||
|
||||
def _setup_parser():
|
||||
"""Set up Python's ArgumentParser with data, model, trainer, and other arguments."""
|
||||
parser = argparse.ArgumentParser(add_help=False)
|
||||
|
||||
# Add Trainer specific arguments, such as --max_epochs, --gpus, --precision
|
||||
trainer_parser = pl.Trainer.add_argparse_args(parser)
|
||||
trainer_parser._action_groups[1].title = "Trainer Args" # pylint: disable=protected-access
|
||||
parser = argparse.ArgumentParser(add_help=False, parents=[trainer_parser])
|
||||
|
||||
# Basic arguments
|
||||
parser.add_argument("--wandb", action="store_true", default=False)
|
||||
parser.add_argument("--litmodel_class", type=str, default="TransformerLitModel")
|
||||
parser.add_argument("--seed", type=int, default=7)
|
||||
parser.add_argument("--data_class", type=str, default="KGC")
|
||||
parser.add_argument("--chunk", type=str, default="")
|
||||
parser.add_argument("--model_class", type=str, default="RobertaUseLabelWord")
|
||||
parser.add_argument("--checkpoint", type=str, default=None)
|
||||
|
||||
# Get the data and model classes, so that we can add their specific arguments
|
||||
temp_args, _ = parser.parse_known_args()
|
||||
data_class = _import_class(f"data.{temp_args.data_class}")
|
||||
model_class = _import_class(f"models.{temp_args.model_class}")
|
||||
lit_model_class = _import_class(f"lit_models.{temp_args.litmodel_class}")
|
||||
|
||||
# Get data, model, and LitModel specific arguments
|
||||
data_group = parser.add_argument_group("Data Args")
|
||||
data_class.add_to_argparse(data_group)
|
||||
|
||||
model_group = parser.add_argument_group("Model Args")
|
||||
if hasattr(model_class, "add_to_argparse"):
|
||||
model_class.add_to_argparse(model_group)
|
||||
|
||||
lit_model_group = parser.add_argument_group("LitModel Args")
|
||||
lit_model_class.add_to_argparse(lit_model_group)
|
||||
|
||||
parser.add_argument("--help", "-h", action="help")
|
||||
return parser
|
||||
|
||||
def _saved_pretrain(lit_model, tokenizer, path):
|
||||
lit_model.model.save_pretrained(path)
|
||||
tokenizer.save_pretrained(path)
|
||||
|
||||
|
||||
def main():
|
||||
parser = _setup_parser()
|
||||
args = parser.parse_args()
|
||||
|
||||
np.random.seed(args.seed)
|
||||
torch.manual_seed(args.seed)
|
||||
pl.seed_everything(args.seed)
|
||||
|
||||
data_class = _import_class(f"data.{args.data_class}")
|
||||
model_class = _import_class(f"models.{args.model_class}")
|
||||
litmodel_class = _import_class(f"lit_models.{args.litmodel_class}")
|
||||
|
||||
config = AutoConfig.from_pretrained(args.model_name_or_path)
|
||||
# update parameters
|
||||
config.label_smoothing = args.label_smoothing
|
||||
|
||||
|
||||
model = model_class.from_pretrained(args.model_name_or_path, config=config)
|
||||
data = data_class(args, model)
|
||||
tokenizer = data.tokenizer
|
||||
|
||||
|
||||
lit_model = litmodel_class(args=args, model=model, tokenizer=tokenizer, data_config=data.get_config())
|
||||
if args.checkpoint:
|
||||
lit_model.load_state_dict(torch.load(args.checkpoint, map_location="cpu")["state_dict"])
|
||||
|
||||
logger = pl.loggers.TensorBoardLogger("training/logs")
|
||||
if args.wandb:
|
||||
logger = pl.loggers.WandbLogger(project="kgc_bert", name=args.data_dir.split("/")[-1])
|
||||
logger.log_hyperparams(vars(args))
|
||||
|
||||
metric_name = "Eval/mrr" if not args.pretrain else "Eval/hits1"
|
||||
|
||||
|
||||
early_callback = pl.callbacks.EarlyStopping(monitor="Eval/mrr", mode="max", patience=10)
|
||||
model_checkpoint = pl.callbacks.ModelCheckpoint(monitor=metric_name, mode="max",
|
||||
filename=args.data_dir.split("/")[-1] + '/{epoch}-{Eval/hits10:.2f}-{Eval/hits1:.2f}' if not args.pretrain else args.data_dir.split("/")[-1] + '/{epoch}-{step}-{Eval/hits10:.2f}',
|
||||
dirpath="output",
|
||||
save_weights_only=True, # to be modified
|
||||
every_n_train_steps=100 if args.pretrain else None,
|
||||
save_top_k=5 if args.pretrain else 1
|
||||
)
|
||||
callbacks = [early_callback, model_checkpoint]
|
||||
|
||||
|
||||
|
||||
|
||||
# args.weights_summary = "full" # Print full summary of the model
|
||||
trainer = pl.Trainer.from_argparse_args(args, callbacks=callbacks, logger=logger, default_root_dir="training/logs")
|
||||
|
||||
if "EntityEmbedding" not in lit_model.__class__.__name__:
|
||||
trainer.fit(lit_model, datamodule=data)
|
||||
path = model_checkpoint.best_model_path
|
||||
lit_model.load_state_dict(torch.load(path)["state_dict"])
|
||||
|
||||
result = trainer.test(lit_model, datamodule=data)
|
||||
print(result)
|
||||
|
||||
# _saved_pretrain(lit_model, tokenizer, path)
|
||||
if "EntityEmbedding" not in lit_model.__class__.__name__:
|
||||
print("*path"*30)
|
||||
print(path)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
main()
|
6
pretrain/models/__init__.py
Executable file
6
pretrain/models/__init__.py
Executable file
@ -0,0 +1,6 @@
|
||||
|
||||
|
||||
from transformers import BartForConditionalGeneration, T5ForConditionalGeneration, GPT2LMHeadModel
|
||||
|
||||
from .model import *
|
||||
|
10
pretrain/models/model.py
Normal file
10
pretrain/models/model.py
Normal file
@ -0,0 +1,10 @@
|
||||
from transformers.models.bert.modeling_bert import BertForMaskedLM
|
||||
|
||||
|
||||
|
||||
|
||||
class BertKGC(BertForMaskedLM):
|
||||
@staticmethod
|
||||
def add_to_argparse(parser):
|
||||
parser.add_argument("--pretrain", type=int, default=0, help="")
|
||||
return parser
|
1159
pretrain/models/utils.py
Normal file
1159
pretrain/models/utils.py
Normal file
File diff suppressed because it is too large
Load Diff
14
pretrain/scripts/pretrain_fb15k-237.sh
Normal file
14
pretrain/scripts/pretrain_fb15k-237.sh
Normal file
@ -0,0 +1,14 @@
|
||||
nohup python -u main.py --gpus "1" --max_epochs=16 --num_workers=32 \
|
||||
--model_name_or_path bert-base-uncased \
|
||||
--accumulate_grad_batches 1 \
|
||||
--model_class BertKGC \
|
||||
--batch_size 128 \
|
||||
--pretrain 1 \
|
||||
--bce 0 \
|
||||
--check_val_every_n_epoch 1 \
|
||||
--overwrite_cache \
|
||||
--data_dir /kg_374/Relphormer/dataset/FB15k-237 \
|
||||
--eval_batch_size 256 \
|
||||
--max_seq_length 64 \
|
||||
--lr 1e-4 \
|
||||
>logs/pretrain_fb15k-237.log 2>&1 &
|
14
pretrain/scripts/pretrain_umls.sh
Executable file
14
pretrain/scripts/pretrain_umls.sh
Executable file
@ -0,0 +1,14 @@
|
||||
nohup python -u main.py --gpus "0," --max_epochs=20 --num_workers=32 \
|
||||
--model_name_or_path bert-base-uncased \
|
||||
--accumulate_grad_batches 1 \
|
||||
--model_class BertKGC \
|
||||
--batch_size 128 \
|
||||
--pretrain 1 \
|
||||
--bce 0 \
|
||||
--check_val_every_n_epoch 1 \
|
||||
--overwrite_cache \
|
||||
--data_dir xxx/Relphormer/dataset/umls \
|
||||
--eval_batch_size 256 \
|
||||
--max_seq_length 64 \
|
||||
--lr 1e-4 \
|
||||
>logs/pretrain_umls.log 2>&1 &
|
16
pretrain/scripts/pretrain_wn18rr.sh
Normal file
16
pretrain/scripts/pretrain_wn18rr.sh
Normal file
@ -0,0 +1,16 @@
|
||||
nohup python -u main.py --gpus "0," --max_epochs=15 --num_workers=32 \
|
||||
--model_name_or_path bert-base-uncased \
|
||||
--accumulate_grad_batches 1 \
|
||||
--bce 0 \
|
||||
--model_class BertKGC \
|
||||
--batch_size 128 \
|
||||
--pretrain 1 \
|
||||
--check_val_every_n_epoch 1 \
|
||||
--data_dir xxx/Relphormer/dataset/WN18RR \
|
||||
--overwrite_cache \
|
||||
--eval_batch_size 256 \
|
||||
--precision 16 \
|
||||
--max_seq_length 32 \
|
||||
--lr 1e-4 \
|
||||
>logs/pretrain_wn18rr.log 2>&1 &
|
||||
|
Reference in New Issue
Block a user