Relphormer baseline
This commit is contained in:
		
							
								
								
									
										2
									
								
								pretrain/data/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										2
									
								
								pretrain/data/__init__.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,2 @@
 | 
			
		||||
from .data_module import KGC
 | 
			
		||||
from .processor import convert_examples_to_features, KGProcessor
 | 
			
		||||
							
								
								
									
										71
									
								
								pretrain/data/base_data_module.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										71
									
								
								pretrain/data/base_data_module.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,71 @@
 | 
			
		||||
"""Base DataModule class."""
 | 
			
		||||
from pathlib import Path
 | 
			
		||||
from typing import Dict
 | 
			
		||||
import argparse
 | 
			
		||||
import os
 | 
			
		||||
 | 
			
		||||
import pytorch_lightning as pl
 | 
			
		||||
from torch.utils.data import DataLoader
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Config(dict):
 | 
			
		||||
    def __getattr__(self, name):
 | 
			
		||||
        return self.get(name)
 | 
			
		||||
 | 
			
		||||
    def __setattr__(self, name, val):
 | 
			
		||||
        self[name] = val
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
BATCH_SIZE = 8
 | 
			
		||||
NUM_WORKERS = 8
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class BaseDataModule(pl.LightningDataModule):
 | 
			
		||||
    """
 | 
			
		||||
    Base DataModule.
 | 
			
		||||
    Learn more at https://pytorch-lightning.readthedocs.io/en/stable/datamodules.html
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    def __init__(self, args: argparse.Namespace = None) -> None:
 | 
			
		||||
        super().__init__()
 | 
			
		||||
        self.args = Config(vars(args)) if args is not None else {}
 | 
			
		||||
        self.batch_size = self.args.get("batch_size", BATCH_SIZE)
 | 
			
		||||
        self.num_workers = self.args.get("num_workers", NUM_WORKERS)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def add_to_argparse(parser):
 | 
			
		||||
        parser.add_argument(
 | 
			
		||||
            "--batch_size", type=int, default=BATCH_SIZE, help="Number of examples to operate on per forward step."
 | 
			
		||||
        )
 | 
			
		||||
        parser.add_argument(
 | 
			
		||||
            "--num_workers", type=int, default=0, help="Number of additional processes to load data."
 | 
			
		||||
        )
 | 
			
		||||
        parser.add_argument(
 | 
			
		||||
            "--dataset", type=str, default="./dataset/NELL", help="Number of additional processes to load data."
 | 
			
		||||
        )
 | 
			
		||||
        return parser
 | 
			
		||||
 | 
			
		||||
    def prepare_data(self):
 | 
			
		||||
        """
 | 
			
		||||
        Use this method to do things that might write to disk or that need to be done only from a single GPU in distributed settings (so don't set state `self.x = y`).
 | 
			
		||||
        """
 | 
			
		||||
        pass
 | 
			
		||||
 | 
			
		||||
    def setup(self, stage=None):
 | 
			
		||||
        """
 | 
			
		||||
        Split into train, val, test, and set dims.
 | 
			
		||||
        Should assign `torch Dataset` objects to self.data_train, self.data_val, and optionally self.data_test.
 | 
			
		||||
        """
 | 
			
		||||
        self.data_train = None
 | 
			
		||||
        self.data_val = None
 | 
			
		||||
        self.data_test = None
 | 
			
		||||
 | 
			
		||||
    def train_dataloader(self):
 | 
			
		||||
        return DataLoader(self.data_train, shuffle=True, batch_size=self.batch_size, num_workers=self.num_workers, pin_memory=True)
 | 
			
		||||
 | 
			
		||||
    def val_dataloader(self):
 | 
			
		||||
        return DataLoader(self.data_val, shuffle=False, batch_size=self.batch_size, num_workers=self.num_workers, pin_memory=True)
 | 
			
		||||
 | 
			
		||||
    def test_dataloader(self):
 | 
			
		||||
        return DataLoader(self.data_test, shuffle=False, batch_size=self.batch_size, num_workers=self.num_workers, pin_memory=True)
 | 
			
		||||
							
								
								
									
										196
									
								
								pretrain/data/data_module.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										196
									
								
								pretrain/data/data_module.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,196 @@
 | 
			
		||||
from dataclasses import dataclass
 | 
			
		||||
from typing import Any, Callable, Dict, List, NewType, Optional, Tuple, Union
 | 
			
		||||
from enum import Enum
 | 
			
		||||
import torch
 | 
			
		||||
 | 
			
		||||
from torch.utils.data import DataLoader
 | 
			
		||||
from transformers import AutoTokenizer, BertTokenizer
 | 
			
		||||
# from transformers.configuration_bert import BertTokenizer, BertTokenizerFast
 | 
			
		||||
from transformers.tokenization_utils_base import (BatchEncoding,
 | 
			
		||||
                                                  PreTrainedTokenizerBase)
 | 
			
		||||
 | 
			
		||||
from .base_data_module import BaseDataModule
 | 
			
		||||
from .processor import KGProcessor, get_dataset
 | 
			
		||||
import transformers
 | 
			
		||||
transformers.logging.set_verbosity_error()
 | 
			
		||||
 | 
			
		||||
class ExplicitEnum(Enum):
 | 
			
		||||
    """
 | 
			
		||||
    Enum with more explicit error message for missing values.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def _missing_(cls, value):
 | 
			
		||||
        raise ValueError(
 | 
			
		||||
            f"{value} is not a valid {cls.__name__}, please select one of {list(cls._value2member_map_.keys())}"
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
class PaddingStrategy(ExplicitEnum):
 | 
			
		||||
    """
 | 
			
		||||
    Possible values for the ``padding`` argument in :meth:`PreTrainedTokenizerBase.__call__`. Useful for tab-completion
 | 
			
		||||
    in an IDE.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    LONGEST = "longest"
 | 
			
		||||
    MAX_LENGTH = "max_length"
 | 
			
		||||
    DO_NOT_PAD = "do_not_pad"
 | 
			
		||||
 | 
			
		||||
import numpy as np
 | 
			
		||||
 | 
			
		||||
@dataclass
 | 
			
		||||
class DataCollatorForSeq2Seq:
 | 
			
		||||
    """
 | 
			
		||||
    Data collator that will dynamically pad the inputs received, as well as the labels.
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
        tokenizer (:class:`~transformers.PreTrainedTokenizer` or :class:`~transformers.PreTrainedTokenizerFast`):
 | 
			
		||||
            The tokenizer used for encoding the data.
 | 
			
		||||
        model (:class:`~transformers.PreTrainedModel`):
 | 
			
		||||
            The model that is being trained. If set and has the `prepare_decoder_input_ids_from_labels`, use it to
 | 
			
		||||
            prepare the `decoder_input_ids`
 | 
			
		||||
 | 
			
		||||
            This is useful when using `label_smoothing` to avoid calculating loss twice.
 | 
			
		||||
        padding (:obj:`bool`, :obj:`str` or :class:`~transformers.file_utils.PaddingStrategy`, `optional`, defaults to :obj:`True`):
 | 
			
		||||
            Select a strategy to pad the returned sequences (according to the model's padding side and padding index)
 | 
			
		||||
            among:
 | 
			
		||||
 | 
			
		||||
            * :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
 | 
			
		||||
              sequence is provided).
 | 
			
		||||
            * :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
 | 
			
		||||
              maximum acceptable input length for the model if that argument is not provided.
 | 
			
		||||
            * :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
 | 
			
		||||
              different lengths).
 | 
			
		||||
        max_length (:obj:`int`, `optional`):
 | 
			
		||||
            Maximum length of the returned list and optionally padding length (see above).
 | 
			
		||||
        pad_to_multiple_of (:obj:`int`, `optional`):
 | 
			
		||||
            If set will pad the sequence to a multiple of the provided value.
 | 
			
		||||
 | 
			
		||||
            This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >=
 | 
			
		||||
            7.5 (Volta).
 | 
			
		||||
        label_pad_token_id (:obj:`int`, `optional`, defaults to -100):
 | 
			
		||||
            The id to use when padding the labels (-100 will be automatically ignored by PyTorch loss functions).
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    tokenizer: PreTrainedTokenizerBase
 | 
			
		||||
    model: Optional[Any] = None
 | 
			
		||||
    padding: Union[bool, str, PaddingStrategy] = True
 | 
			
		||||
    max_length: Optional[int] = None
 | 
			
		||||
    pad_to_multiple_of: Optional[int] = None
 | 
			
		||||
    label_pad_token_id: int = -100
 | 
			
		||||
    return_tensors: str = "pt"
 | 
			
		||||
    num_labels: int = 0
 | 
			
		||||
 | 
			
		||||
    def __call__(self, features, return_tensors=None):
 | 
			
		||||
 | 
			
		||||
        if return_tensors is None:
 | 
			
		||||
            return_tensors = self.return_tensors
 | 
			
		||||
        labels = [feature.pop("labels") for feature in features] if "labels" in features[0].keys() else None
 | 
			
		||||
        label = [feature.pop("label") for feature in features]
 | 
			
		||||
        features_keys = {}
 | 
			
		||||
        name_keys = list(features[0].keys())
 | 
			
		||||
        for k in name_keys:
 | 
			
		||||
            # ignore the padding arguments
 | 
			
		||||
            if k in ["input_ids", "attention_mask", "token_type_ids"]: continue
 | 
			
		||||
            try:
 | 
			
		||||
                features_keys[k] = [feature.pop(k) for feature in features]
 | 
			
		||||
            except KeyError:
 | 
			
		||||
                continue
 | 
			
		||||
        #    features_keys[k] = [feature.pop(k) for feature in features]
 | 
			
		||||
 | 
			
		||||
        # We have to pad the labels before calling `tokenizer.pad` as this method won't pad them and needs them of the
 | 
			
		||||
        # same length to return tensors.
 | 
			
		||||
        bsz = len(labels)
 | 
			
		||||
        with torch.no_grad():
 | 
			
		||||
            new_labels = torch.zeros(bsz, self.num_labels)
 | 
			
		||||
            for i,l in enumerate(labels):
 | 
			
		||||
                if isinstance(l, int): 
 | 
			
		||||
                    new_labels[i][l] = 1
 | 
			
		||||
                else:
 | 
			
		||||
                    for j in l:
 | 
			
		||||
                        new_labels[i][j] = 1
 | 
			
		||||
            labels = new_labels
 | 
			
		||||
 | 
			
		||||
        features = self.tokenizer.pad(
 | 
			
		||||
            features,
 | 
			
		||||
            padding=self.padding,
 | 
			
		||||
            max_length=self.max_length,
 | 
			
		||||
            pad_to_multiple_of=self.pad_to_multiple_of,
 | 
			
		||||
            return_tensors=return_tensors,
 | 
			
		||||
        )
 | 
			
		||||
        features['labels'] = labels
 | 
			
		||||
        features['label'] = torch.tensor(label)
 | 
			
		||||
        features.update(features_keys)
 | 
			
		||||
 | 
			
		||||
        return features
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class KGC(BaseDataModule):
 | 
			
		||||
    def __init__(self, args, model) -> None:
 | 
			
		||||
        super().__init__(args)
 | 
			
		||||
        self.tokenizer = AutoTokenizer.from_pretrained(self.args.model_name_or_path, use_fast=False)
 | 
			
		||||
        self.processor = KGProcessor(self.tokenizer, args)
 | 
			
		||||
        self.label_list = self.processor.get_labels(args.data_dir)
 | 
			
		||||
 | 
			
		||||
        entity_list = self.processor.get_entities(args.data_dir)
 | 
			
		||||
        
 | 
			
		||||
        num_added_tokens = self.tokenizer.add_special_tokens({'additional_special_tokens': entity_list})
 | 
			
		||||
        self.sampler = DataCollatorForSeq2Seq(self.tokenizer,
 | 
			
		||||
            model=model,
 | 
			
		||||
            label_pad_token_id=self.tokenizer.pad_token_id,
 | 
			
		||||
            pad_to_multiple_of=8 if self.args.precision == 16 else None,
 | 
			
		||||
            padding="longest",
 | 
			
		||||
            max_length=self.args.max_seq_length,
 | 
			
		||||
            num_labels = len(entity_list),
 | 
			
		||||
        )
 | 
			
		||||
        relations_tokens = self.processor.get_relations(args.data_dir)
 | 
			
		||||
        self.num_relations = len(relations_tokens)
 | 
			
		||||
        num_added_tokens = self.tokenizer.add_special_tokens({'additional_special_tokens': relations_tokens})
 | 
			
		||||
 | 
			
		||||
        vocab = self.tokenizer.get_added_vocab()
 | 
			
		||||
        self.relation_id_st = vocab[relations_tokens[0]]
 | 
			
		||||
        self.relation_id_ed = vocab[relations_tokens[-1]] + 1
 | 
			
		||||
        self.entity_id_st = vocab[entity_list[0]]
 | 
			
		||||
        self.entity_id_ed = vocab[entity_list[-1]] + 1
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    def setup(self, stage=None):
 | 
			
		||||
        self.data_train = get_dataset(self.args, self.processor, self.label_list, self.tokenizer, "train")
 | 
			
		||||
        self.data_val = get_dataset(self.args, self.processor, self.label_list, self.tokenizer, "dev")
 | 
			
		||||
        self.data_test = get_dataset(self.args, self.processor, self.label_list, self.tokenizer, "test")
 | 
			
		||||
 | 
			
		||||
    def prepare_data(self):
 | 
			
		||||
        pass
 | 
			
		||||
 | 
			
		||||
    def get_config(self):
 | 
			
		||||
        d = {}
 | 
			
		||||
        for k, v in self.__dict__.items():
 | 
			
		||||
            if "st" in k or "ed" in k:
 | 
			
		||||
                d.update({k:v})
 | 
			
		||||
        
 | 
			
		||||
        return d
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def add_to_argparse(parser):
 | 
			
		||||
        BaseDataModule.add_to_argparse(parser)
 | 
			
		||||
        parser.add_argument("--model_name_or_path", type=str, default="roberta-base", help="the name or the path to the pretrained model")
 | 
			
		||||
        parser.add_argument("--data_dir", type=str, default="roberta-base", help="the name or the path to the pretrained model")
 | 
			
		||||
        parser.add_argument("--max_seq_length", type=int, default=256, help="Number of examples to operate on per forward step.")
 | 
			
		||||
        parser.add_argument("--warm_up_radio", type=float, default=0.1, help="Number of examples to operate on per forward step.")
 | 
			
		||||
        parser.add_argument("--eval_batch_size", type=int, default=8)
 | 
			
		||||
        parser.add_argument("--overwrite_cache", action="store_true", default=False)
 | 
			
		||||
        return parser
 | 
			
		||||
 | 
			
		||||
    def get_tokenizer(self):
 | 
			
		||||
        return self.tokenizer
 | 
			
		||||
 | 
			
		||||
    def train_dataloader(self):
 | 
			
		||||
        return DataLoader(self.data_train, num_workers=self.num_workers, pin_memory=True, collate_fn=self.sampler, batch_size=self.args.batch_size, shuffle=not self.args.faiss_init)
 | 
			
		||||
 | 
			
		||||
    def val_dataloader(self):
 | 
			
		||||
        return DataLoader(self.data_val, num_workers=self.num_workers, pin_memory=True, collate_fn=self.sampler, batch_size=self.args.eval_batch_size)
 | 
			
		||||
 | 
			
		||||
    def test_dataloader(self):
 | 
			
		||||
        return DataLoader(self.data_test, num_workers=self.num_workers, pin_memory=True, collate_fn=self.sampler, batch_size=self.args.eval_batch_size)
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										936
									
								
								pretrain/data/processor.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										936
									
								
								pretrain/data/processor.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,936 @@
 | 
			
		||||
from re import DEBUG
 | 
			
		||||
 | 
			
		||||
import contextlib
 | 
			
		||||
import sys
 | 
			
		||||
 | 
			
		||||
from collections import Counter
 | 
			
		||||
from multiprocessing import Pool
 | 
			
		||||
from torch._C import HOIST_CONV_PACKED_PARAMS
 | 
			
		||||
from torch.utils.data import Dataset, Sampler, IterableDataset
 | 
			
		||||
from collections import defaultdict
 | 
			
		||||
from functools import partial
 | 
			
		||||
from multiprocessing import Pool
 | 
			
		||||
import os
 | 
			
		||||
import random
 | 
			
		||||
import json
 | 
			
		||||
import torch
 | 
			
		||||
import copy
 | 
			
		||||
import numpy as np
 | 
			
		||||
import pickle
 | 
			
		||||
from tqdm import tqdm
 | 
			
		||||
from dataclasses import dataclass, asdict, replace
 | 
			
		||||
import inspect
 | 
			
		||||
 | 
			
		||||
from transformers.models.auto.tokenization_auto import AutoTokenizer
 | 
			
		||||
 | 
			
		||||
from models.utils import get_entity_spans_pre_processing
 | 
			
		||||
 | 
			
		||||
def lmap(a, b):
 | 
			
		||||
    return list(map(a,b))
 | 
			
		||||
 | 
			
		||||
def cache_results(_cache_fp, _refresh=False, _verbose=1):
 | 
			
		||||
    r"""
 | 
			
		||||
    cache_results是fastNLP中用于cache数据的装饰器。通过下面的例子看一下如何使用::
 | 
			
		||||
 | 
			
		||||
        import time
 | 
			
		||||
        import numpy as np
 | 
			
		||||
        from fastNLP import cache_results
 | 
			
		||||
        
 | 
			
		||||
        @cache_results('cache.pkl')
 | 
			
		||||
        def process_data():
 | 
			
		||||
            # 一些比较耗时的工作,比如读取数据,预处理数据等,这里用time.sleep()代替耗时
 | 
			
		||||
            time.sleep(1)
 | 
			
		||||
            return np.random.randint(10, size=(5,))
 | 
			
		||||
        
 | 
			
		||||
        start_time = time.time()
 | 
			
		||||
        print("res =",process_data())
 | 
			
		||||
        print(time.time() - start_time)
 | 
			
		||||
        
 | 
			
		||||
        start_time = time.time()
 | 
			
		||||
        print("res =",process_data())
 | 
			
		||||
        print(time.time() - start_time)
 | 
			
		||||
        
 | 
			
		||||
        # 输出内容如下,可以看到两次结果相同,且第二次几乎没有花费时间
 | 
			
		||||
        # Save cache to cache.pkl.
 | 
			
		||||
        # res = [5 4 9 1 8]
 | 
			
		||||
        # 1.0042750835418701
 | 
			
		||||
        # Read cache from cache.pkl.
 | 
			
		||||
        # res = [5 4 9 1 8]
 | 
			
		||||
        # 0.0040721893310546875
 | 
			
		||||
 | 
			
		||||
    可以看到第二次运行的时候,只用了0.0001s左右,是由于第二次运行将直接从cache.pkl这个文件读取数据,而不会经过再次预处理::
 | 
			
		||||
 | 
			
		||||
        # 还是以上面的例子为例,如果需要重新生成另一个cache,比如另一个数据集的内容,通过如下的方式调用即可
 | 
			
		||||
        process_data(_cache_fp='cache2.pkl')  # 完全不影响之前的‘cache.pkl'
 | 
			
		||||
 | 
			
		||||
    上面的_cache_fp是cache_results会识别的参数,它将从'cache2.pkl'这里缓存/读取数据,即这里的'cache2.pkl'覆盖默认的
 | 
			
		||||
    'cache.pkl'。如果在你的函数前面加上了@cache_results()则你的函数会增加三个参数[_cache_fp, _refresh, _verbose]。
 | 
			
		||||
    上面的例子即为使用_cache_fp的情况,这三个参数不会传入到你的函数中,当然你写的函数参数名也不可能包含这三个名称::
 | 
			
		||||
 | 
			
		||||
        process_data(_cache_fp='cache2.pkl', _refresh=True)  # 这里强制重新生成一份对预处理的cache。
 | 
			
		||||
        #  _verbose是用于控制输出信息的,如果为0,则不输出任何内容;如果为1,则会提醒当前步骤是读取的cache还是生成了新的cache
 | 
			
		||||
 | 
			
		||||
    :param str _cache_fp: 将返回结果缓存到什么位置;或从什么位置读取缓存。如果为None,cache_results没有任何效用,除非在
 | 
			
		||||
        函数调用的时候传入_cache_fp这个参数。
 | 
			
		||||
    :param bool _refresh: 是否重新生成cache。
 | 
			
		||||
    :param int _verbose: 是否打印cache的信息。
 | 
			
		||||
    :return:
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    def wrapper_(func):
 | 
			
		||||
        signature = inspect.signature(func)
 | 
			
		||||
        for key, _ in signature.parameters.items():
 | 
			
		||||
            if key in ('_cache_fp', '_refresh', '_verbose'):
 | 
			
		||||
                raise RuntimeError("The function decorated by cache_results cannot have keyword `{}`.".format(key))
 | 
			
		||||
 | 
			
		||||
        def wrapper(*args, **kwargs):
 | 
			
		||||
            my_args = args[0]
 | 
			
		||||
            mode = args[-1]
 | 
			
		||||
            if '_cache_fp' in kwargs:
 | 
			
		||||
                cache_filepath = kwargs.pop('_cache_fp')
 | 
			
		||||
                assert isinstance(cache_filepath, str), "_cache_fp can only be str."
 | 
			
		||||
            else:
 | 
			
		||||
                cache_filepath = _cache_fp
 | 
			
		||||
            if '_refresh' in kwargs:
 | 
			
		||||
                refresh = kwargs.pop('_refresh')
 | 
			
		||||
                assert isinstance(refresh, bool), "_refresh can only be bool."
 | 
			
		||||
            else:
 | 
			
		||||
                refresh = _refresh
 | 
			
		||||
            if '_verbose' in kwargs:
 | 
			
		||||
                verbose = kwargs.pop('_verbose')
 | 
			
		||||
                assert isinstance(verbose, int), "_verbose can only be integer."
 | 
			
		||||
            else:
 | 
			
		||||
                verbose = _verbose
 | 
			
		||||
            refresh_flag = True
 | 
			
		||||
            
 | 
			
		||||
            model_name = my_args.model_name_or_path.split("/")[-1]
 | 
			
		||||
            is_pretrain = my_args.pretrain
 | 
			
		||||
            cache_filepath = os.path.join(my_args.data_dir, f"cached_{mode}_features{model_name}_pretrain{is_pretrain}_faiss{my_args.faiss_init}_seqlength{my_args.max_seq_length}_{my_args.litmodel_class}.pkl")
 | 
			
		||||
            refresh = my_args.overwrite_cache
 | 
			
		||||
 | 
			
		||||
            if cache_filepath is not None and refresh is False:
 | 
			
		||||
                # load data
 | 
			
		||||
                if os.path.exists(cache_filepath):
 | 
			
		||||
                    with open(cache_filepath, 'rb') as f:
 | 
			
		||||
                        results = pickle.load(f)
 | 
			
		||||
                    if verbose == 1:
 | 
			
		||||
                        logger.info("Read cache from {}.".format(cache_filepath))
 | 
			
		||||
                    refresh_flag = False
 | 
			
		||||
 | 
			
		||||
            if refresh_flag:
 | 
			
		||||
                results = func(*args, **kwargs)
 | 
			
		||||
                if cache_filepath is not None:
 | 
			
		||||
                    if results is None:
 | 
			
		||||
                        raise RuntimeError("The return value is None. Delete the decorator.")
 | 
			
		||||
                    with open(cache_filepath, 'wb') as f:
 | 
			
		||||
                        pickle.dump(results, f)
 | 
			
		||||
                    logger.info("Save cache to {}.".format(cache_filepath))
 | 
			
		||||
 | 
			
		||||
            return results
 | 
			
		||||
 | 
			
		||||
        return wrapper
 | 
			
		||||
 | 
			
		||||
    return wrapper_
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
import argparse
 | 
			
		||||
import csv
 | 
			
		||||
import logging
 | 
			
		||||
import os
 | 
			
		||||
import random
 | 
			
		||||
import sys
 | 
			
		||||
 | 
			
		||||
import numpy as np
 | 
			
		||||
import torch
 | 
			
		||||
from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler,
 | 
			
		||||
                              TensorDataset)
 | 
			
		||||
from torch.utils.data.distributed import DistributedSampler
 | 
			
		||||
from tqdm import tqdm, trange
 | 
			
		||||
 | 
			
		||||
# from torch.nn import CrossEntropyLoss, MSELoss
 | 
			
		||||
# from scipy.stats import pearsonr, spearmanr
 | 
			
		||||
# from sklearn.metrics import matthews_corrcoef, f1_scoreclass 
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
logger = logging.getLogger(__name__)
 | 
			
		||||
 | 
			
		||||
class InputExample(object):
 | 
			
		||||
    """A single training/test example for simple sequence classification."""
 | 
			
		||||
 | 
			
		||||
    def __init__(self, guid, text_a, text_b=None, text_c=None, label=None, real_label=None, en=None, rel=None):
 | 
			
		||||
        """Constructs a InputExample.
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            guid: Unique id for the example.
 | 
			
		||||
            text_a: string. The untokenized text of the first sequence. For single
 | 
			
		||||
            sequence tasks, only this sequence must be specified.
 | 
			
		||||
            text_b: (Optional) string. The untokenized text of the second sequence.
 | 
			
		||||
            Only must be specified for sequence pair tasks.
 | 
			
		||||
            text_c: (Optional) string. The untokenized text of the third sequence.
 | 
			
		||||
            Only must be specified for sequence triple tasks.
 | 
			
		||||
            label: (Optional) string. list of entities
 | 
			
		||||
        """
 | 
			
		||||
        self.guid = guid
 | 
			
		||||
        self.text_a = text_a
 | 
			
		||||
        self.text_b = text_b
 | 
			
		||||
        self.text_c = text_c
 | 
			
		||||
        self.label = label
 | 
			
		||||
        self.real_label = real_label
 | 
			
		||||
        self.en = en
 | 
			
		||||
        self.rel = rel # rel id
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@dataclass
 | 
			
		||||
class InputFeatures:
 | 
			
		||||
    """A single set of features of data."""
 | 
			
		||||
 | 
			
		||||
    input_ids: torch.Tensor
 | 
			
		||||
    attention_mask: torch.Tensor
 | 
			
		||||
    labels: torch.Tensor = None
 | 
			
		||||
    label: torch.Tensor = None
 | 
			
		||||
    en: torch.Tensor = 0
 | 
			
		||||
    rel: torch.Tensor = 0
 | 
			
		||||
    pos: torch.Tensor = 0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class DataProcessor(object):
 | 
			
		||||
    """Base class for data converters for sequence classification data sets."""
 | 
			
		||||
 | 
			
		||||
    def get_train_examples(self, data_dir):
 | 
			
		||||
        """Gets a collection of `InputExample`s for the train set."""
 | 
			
		||||
        raise NotImplementedError()
 | 
			
		||||
 | 
			
		||||
    def get_dev_examples(self, data_dir):
 | 
			
		||||
        """Gets a collection of `InputExample`s for the dev set."""
 | 
			
		||||
        raise NotImplementedError()
 | 
			
		||||
 | 
			
		||||
    def get_labels(self, data_dir):
 | 
			
		||||
        """Gets the list of labels for this data set."""
 | 
			
		||||
        raise NotImplementedError()
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def _read_tsv(cls, input_file, quotechar=None):
 | 
			
		||||
        """Reads a tab separated value file."""
 | 
			
		||||
        with open(input_file, "r", encoding="utf-8") as f:
 | 
			
		||||
            reader = csv.reader(f, delimiter="\t", quotechar=quotechar)
 | 
			
		||||
            lines = []
 | 
			
		||||
            for line in reader:
 | 
			
		||||
                if sys.version_info[0] == 2:
 | 
			
		||||
                    line = list(unicode(cell, 'utf-8') for cell in line)
 | 
			
		||||
                lines.append(line)
 | 
			
		||||
            return lines
 | 
			
		||||
 | 
			
		||||
import copy
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def solve_get_knowledge_store(line, set_type="train", pretrain=1):
 | 
			
		||||
    """
 | 
			
		||||
    use the LM to get the entity embedding.
 | 
			
		||||
    Transductive: triples + text description
 | 
			
		||||
    Inductive: text description
 | 
			
		||||
    
 | 
			
		||||
    """
 | 
			
		||||
    examples = []
 | 
			
		||||
        
 | 
			
		||||
    head_ent_text = ent2text[line[0]]
 | 
			
		||||
    tail_ent_text = ent2text[line[2]]
 | 
			
		||||
    relation_text = rel2text[line[1]]
 | 
			
		||||
    
 | 
			
		||||
    i=0
 | 
			
		||||
    
 | 
			
		||||
    a = tail_filter_entities["\t".join([line[0],line[1]])]
 | 
			
		||||
    b = head_filter_entities["\t".join([line[2],line[1]])]
 | 
			
		||||
    
 | 
			
		||||
    guid = "%s-%s" % (set_type, i)
 | 
			
		||||
    text_a = head_ent_text
 | 
			
		||||
    text_b = relation_text
 | 
			
		||||
    text_c = tail_ent_text 
 | 
			
		||||
 | 
			
		||||
    # use the description of c to predict A
 | 
			
		||||
    examples.append(
 | 
			
		||||
        InputExample(guid=guid, text_a="[PAD]", text_b=text_b + "[PAD]", text_c = "[PAD]" + " " + text_c, label=lmap(lambda x: ent2id[x], b), real_label=ent2id[line[0]], en=[ent2id[line[0]], rel2id[line[1]], ent2id[line[2]]], rel=0)
 | 
			
		||||
    )
 | 
			
		||||
    examples.append(
 | 
			
		||||
        InputExample(guid=guid, text_a="[PAD]", text_b=text_b + "[PAD]", text_c = "[PAD]" + " " + text_a, label=lmap(lambda x: ent2id[x], b), real_label=ent2id[line[2]], en=[ent2id[line[0]], rel2id[line[1]], ent2id[line[2]]], rel=0)
 | 
			
		||||
    )
 | 
			
		||||
    return examples
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def solve(line,  set_type="train", pretrain=1):
 | 
			
		||||
    examples = []
 | 
			
		||||
        
 | 
			
		||||
    head_ent_text = ent2text[line[0]]
 | 
			
		||||
    tail_ent_text = ent2text[line[2]]
 | 
			
		||||
    relation_text = rel2text[line[1]]
 | 
			
		||||
    
 | 
			
		||||
    i=0
 | 
			
		||||
    
 | 
			
		||||
    a = tail_filter_entities["\t".join([line[0],line[1]])]
 | 
			
		||||
    b = head_filter_entities["\t".join([line[2],line[1]])]
 | 
			
		||||
    
 | 
			
		||||
    guid = "%s-%s" % (set_type, i)
 | 
			
		||||
    text_a = head_ent_text
 | 
			
		||||
    text_b = relation_text
 | 
			
		||||
    text_c = tail_ent_text 
 | 
			
		||||
 | 
			
		||||
    
 | 
			
		||||
    if pretrain:
 | 
			
		||||
        text_a_tokens = text_a.split()
 | 
			
		||||
        for i in range(10):
 | 
			
		||||
            st = random.randint(0, len(text_a_tokens))
 | 
			
		||||
            examples.append(
 | 
			
		||||
                InputExample(guid=guid, text_a="[MASK]", text_b=" ".join(text_a_tokens[st:min(st+64, len(text_a_tokens))]), text_c = "", label=ent2id[line[0]], real_label=ent2id[line[0]], en=0, rel=0)
 | 
			
		||||
            )
 | 
			
		||||
        examples.append(
 | 
			
		||||
            InputExample(guid=guid, text_a="[MASK]", text_b=text_a, text_c = "", label=ent2id[line[0]], real_label=ent2id[line[0]], en=0, rel=0)
 | 
			
		||||
        )
 | 
			
		||||
        # examples.append(
 | 
			
		||||
        #     InputExample(guid=guid, text_a="[MASK]", text_b=text_c, text_c = "", label=ent2id[line[2]], real_label=ent2id[line[2]], en=0, rel=0)
 | 
			
		||||
        # )
 | 
			
		||||
    else:
 | 
			
		||||
 | 
			
		||||
        # examples.append(
 | 
			
		||||
        #     InputExample(guid=guid, text_a="[MASK]", text_b=text_b + "[PAD]", text_c = "[UNK]" , label=lmap(lambda x: ent2id[x], b), real_label=ent2id[line[0]], en=ent2id[line[2]], rel=rel2id[line[1]]))
 | 
			
		||||
        # examples.append(
 | 
			
		||||
        #     InputExample(guid=guid, text_a="[UNK] ", text_b=text_b + "[PAD]", text_c = "[MASK]", label=lmap(lambda x: ent2id[x], a), real_label=ent2id[line[2]], en=ent2id[line[0]], rel=rel2id[line[1]]))
 | 
			
		||||
 | 
			
		||||
        # examples.append(
 | 
			
		||||
        #     InputExample(guid=guid, text_a="[UNK]" + " " + text_c, text_b=text_b + "[PAD]", text_c = "[MASK]", label=lmap(lambda x: ent2id[x], b), real_label=ent2id[line[0]], en=ent2id[line[2]], rel=rel2id[line[1]]))
 | 
			
		||||
        # examples.append(
 | 
			
		||||
        #     InputExample(guid=guid, text_a="[MASK]", text_b=text_b + "[PAD]", text_c = "[UNK]" + text_a, label=lmap(lambda x: ent2id[x], a), real_label=ent2id[line[2]], en=ent2id[line[0]], rel=rel2id[line[1]]))
 | 
			
		||||
 | 
			
		||||
        examples.append(
 | 
			
		||||
            InputExample(guid=guid, text_a="[MASK]", text_b=text_b + "[PAD]", text_c = "[PAD]" + " " + text_c, label=lmap(lambda x: ent2id[x], b), real_label=ent2id[line[0]], en=[rel2id[line[1]], ent2id[line[2]]], rel=rel2id[line[1]]))
 | 
			
		||||
        examples.append(
 | 
			
		||||
            InputExample(guid=guid, text_a="[PAD] ", text_b=text_b + "[PAD]", text_c = "[MASK]" +" " + text_a, label=lmap(lambda x: ent2id[x], a), real_label=ent2id[line[2]], en=[ent2id[line[0]], rel2id[line[1]]], rel=rel2id[line[1]]))
 | 
			
		||||
    return examples
 | 
			
		||||
 | 
			
		||||
def filter_init(head, tail, t1,t2, ent2id_, ent2token_, rel2id_):
 | 
			
		||||
    global head_filter_entities
 | 
			
		||||
    global tail_filter_entities
 | 
			
		||||
    global ent2text
 | 
			
		||||
    global rel2text
 | 
			
		||||
    global ent2id
 | 
			
		||||
    global ent2token
 | 
			
		||||
    global rel2id
 | 
			
		||||
 | 
			
		||||
    head_filter_entities = head
 | 
			
		||||
    tail_filter_entities = tail
 | 
			
		||||
    ent2text =t1
 | 
			
		||||
    rel2text =t2
 | 
			
		||||
    ent2id = ent2id_
 | 
			
		||||
    ent2token = ent2token_
 | 
			
		||||
    rel2id = rel2id_
 | 
			
		||||
 | 
			
		||||
def delete_init(ent2text_):
 | 
			
		||||
    global ent2text
 | 
			
		||||
    ent2text = ent2text_
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class KGProcessor(DataProcessor):
 | 
			
		||||
    """Processor for knowledge graph data set."""
 | 
			
		||||
    def __init__(self, tokenizer, args):
 | 
			
		||||
        self.labels = set()
 | 
			
		||||
        self.tokenizer = tokenizer
 | 
			
		||||
        self.args = args
 | 
			
		||||
        self.entity_path = os.path.join(args.data_dir, "entity2textlong.txt") if os.path.exists(os.path.join(args.data_dir, 'entity2textlong.txt')) \
 | 
			
		||||
        else os.path.join(args.data_dir, "entity2text.txt")
 | 
			
		||||
    
 | 
			
		||||
    def get_train_examples(self, data_dir):
 | 
			
		||||
        """See base class."""
 | 
			
		||||
        return self._create_examples(
 | 
			
		||||
            self._read_tsv(os.path.join(data_dir, "train.tsv")), "train", data_dir, self.args)
 | 
			
		||||
 | 
			
		||||
    def get_dev_examples(self, data_dir):
 | 
			
		||||
        """See base class."""
 | 
			
		||||
        return self._create_examples(
 | 
			
		||||
            self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev", data_dir, self.args)
 | 
			
		||||
 | 
			
		||||
    def get_test_examples(self, data_dir, chunk=""):
 | 
			
		||||
      """See base class."""
 | 
			
		||||
      return self._create_examples(
 | 
			
		||||
          self._read_tsv(os.path.join(data_dir, f"test{chunk}.tsv")), "test", data_dir, self.args)
 | 
			
		||||
 | 
			
		||||
    def get_relations(self, data_dir):
 | 
			
		||||
        """Gets all labels (relations) in the knowledge graph."""
 | 
			
		||||
        # return list(self.labels)
 | 
			
		||||
        with open(os.path.join(data_dir, "relations.txt"), 'r') as f:
 | 
			
		||||
            lines = f.readlines()
 | 
			
		||||
            relations = []
 | 
			
		||||
            for line in lines:
 | 
			
		||||
                relations.append(line.strip().split('\t')[0])
 | 
			
		||||
        rel2token = {ent : f"[RELATION_{i}]" for i, ent in enumerate(relations)}
 | 
			
		||||
        return list(rel2token.values())
 | 
			
		||||
 | 
			
		||||
    def get_labels(self, data_dir):
 | 
			
		||||
        """Gets all labels (0, 1) for triples in the knowledge graph."""
 | 
			
		||||
        relation = []
 | 
			
		||||
        with open(os.path.join(data_dir, "relation2text.txt"), 'r') as f:
 | 
			
		||||
            lines = f.readlines()
 | 
			
		||||
            entities = []
 | 
			
		||||
            for line in lines:
 | 
			
		||||
                relation.append(line.strip().split("\t")[-1])
 | 
			
		||||
        return relation
 | 
			
		||||
 | 
			
		||||
    def get_entities(self, data_dir):
 | 
			
		||||
        """Gets all entities in the knowledge graph."""
 | 
			
		||||
        with open(self.entity_path, 'r') as f:
 | 
			
		||||
            lines = f.readlines()
 | 
			
		||||
            entities = []
 | 
			
		||||
            for line in lines:
 | 
			
		||||
                entities.append(line.strip().split("\t")[0])
 | 
			
		||||
        
 | 
			
		||||
        ent2token = {ent : f"[ENTITY_{i}]" for i, ent in enumerate(entities)}
 | 
			
		||||
        return list(ent2token.values())
 | 
			
		||||
 | 
			
		||||
    def get_train_triples(self, data_dir):
 | 
			
		||||
        """Gets training triples."""
 | 
			
		||||
        return self._read_tsv(os.path.join(data_dir, "train.tsv"))
 | 
			
		||||
 | 
			
		||||
    def get_dev_triples(self, data_dir):
 | 
			
		||||
        """Gets validation triples."""
 | 
			
		||||
        return self._read_tsv(os.path.join(data_dir, "dev.tsv"))
 | 
			
		||||
 | 
			
		||||
    def get_test_triples(self, data_dir, chunk=""):
 | 
			
		||||
        """Gets test triples."""
 | 
			
		||||
        return self._read_tsv(os.path.join(data_dir, f"test{chunk}.tsv"))
 | 
			
		||||
 | 
			
		||||
    def _create_examples(self, lines, set_type, data_dir, args):
 | 
			
		||||
        """Creates examples for the training and dev sets."""
 | 
			
		||||
        # entity to text
 | 
			
		||||
        ent2text = {}
 | 
			
		||||
        ent2text_with_type = {}
 | 
			
		||||
        with open(self.entity_path, 'r') as f:
 | 
			
		||||
            ent_lines = f.readlines()
 | 
			
		||||
            for line in ent_lines:
 | 
			
		||||
                temp = line.strip().split('\t')
 | 
			
		||||
                try:
 | 
			
		||||
                    end = temp[1]#.find(',')
 | 
			
		||||
                    if "wiki" in data_dir:
 | 
			
		||||
                        assert "Q" in temp[0]
 | 
			
		||||
                    ent2text[temp[0]] = temp[1].replace("\\n", " ").replace("\\", "") #[:end]
 | 
			
		||||
                except IndexError:
 | 
			
		||||
                    # continue
 | 
			
		||||
                    end = " "#.find(',')
 | 
			
		||||
                    if "wiki" in data_dir:
 | 
			
		||||
                        assert "Q" in temp[0]
 | 
			
		||||
                    ent2text[temp[0]] = end #[:end]
 | 
			
		||||
  
 | 
			
		||||
        entities = list(ent2text.keys())
 | 
			
		||||
        ent2token = {ent : f"[ENTITY_{i}]" for i, ent in enumerate(entities)}
 | 
			
		||||
        ent2id = {ent : i for i, ent in enumerate(entities)}
 | 
			
		||||
        
 | 
			
		||||
        rel2text = {}
 | 
			
		||||
        with open(os.path.join(data_dir, "relation2text.txt"), 'r') as f:
 | 
			
		||||
            rel_lines = f.readlines()
 | 
			
		||||
            for line in rel_lines:
 | 
			
		||||
                temp = line.strip().split('\t')
 | 
			
		||||
                rel2text[temp[0]] = temp[1]      
 | 
			
		||||
        relation_names = {}
 | 
			
		||||
        with open(os.path.join(data_dir, "relations.txt"), "r") as file:
 | 
			
		||||
            for line in file.readlines():
 | 
			
		||||
                t = line.strip()
 | 
			
		||||
                relation_names[t] = rel2text[t]
 | 
			
		||||
 | 
			
		||||
        tmp_lines = []
 | 
			
		||||
        not_in_text = 0
 | 
			
		||||
        for line in tqdm(lines, desc="delete entities without text name."):
 | 
			
		||||
            if (line[0] not in ent2text) or (line[2] not in ent2text) or (line[1] not in rel2text):
 | 
			
		||||
                not_in_text += 1
 | 
			
		||||
                continue
 | 
			
		||||
            tmp_lines.append(line)
 | 
			
		||||
        lines = tmp_lines
 | 
			
		||||
        print(f"total entity not in text : {not_in_text} ")
 | 
			
		||||
 | 
			
		||||
        # rel id -> relation token id
 | 
			
		||||
        num_entities = len(self.get_entities(args.data_dir))
 | 
			
		||||
        rel2id = {w:i+num_entities for i,w in enumerate(relation_names.keys())}
 | 
			
		||||
 | 
			
		||||
        # add reverse relation 
 | 
			
		||||
        # tmp_rel2id = {}
 | 
			
		||||
        # num_relations = len(rel2id)
 | 
			
		||||
        # cnt = 0
 | 
			
		||||
        # for k, v in rel2id.items():
 | 
			
		||||
        #     tmp_rel2id[k + " (reverse)"] = num_relations + cnt
 | 
			
		||||
        #     cnt += 1
 | 
			
		||||
        # rel2id.update(tmp_rel2id)
 | 
			
		||||
 | 
			
		||||
        examples = []
 | 
			
		||||
        # head filter head entity
 | 
			
		||||
        head_filter_entities = defaultdict(list)
 | 
			
		||||
        tail_filter_entities = defaultdict(list)
 | 
			
		||||
 | 
			
		||||
        dataset_list = ["train.tsv", "dev.tsv", "test.tsv"]
 | 
			
		||||
        # in training, only use the train triples
 | 
			
		||||
        if set_type == "train" and not args.pretrain: dataset_list = dataset_list[0:1]
 | 
			
		||||
        for m in dataset_list:
 | 
			
		||||
            with open(os.path.join(data_dir, m), 'r') as file:
 | 
			
		||||
                train_lines = file.readlines()
 | 
			
		||||
                for idx in range(len(train_lines)):
 | 
			
		||||
                    train_lines[idx] = train_lines[idx].strip().split("\t")
 | 
			
		||||
 | 
			
		||||
            for line in train_lines:
 | 
			
		||||
                tail_filter_entities["\t".join([line[0], line[1]])].append(line[2])
 | 
			
		||||
                head_filter_entities["\t".join([line[2], line[1]])].append(line[0])
 | 
			
		||||
 | 
			
		||||
        
 | 
			
		||||
        
 | 
			
		||||
        max_head_entities = max(len(_) for _ in head_filter_entities.values())
 | 
			
		||||
        max_tail_entities = max(len(_) for _ in tail_filter_entities.values())
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
        # use bce loss, ignore the mlm
 | 
			
		||||
        if set_type == "train" and args.bce:
 | 
			
		||||
            lines = []
 | 
			
		||||
            for k, v in tail_filter_entities.items():
 | 
			
		||||
                h, r = k.split('\t')
 | 
			
		||||
                t = v[0]
 | 
			
		||||
                lines.append([h, r, t])
 | 
			
		||||
            for k, v in head_filter_entities.items():
 | 
			
		||||
                t, r = k.split('\t')
 | 
			
		||||
                h = v[0]
 | 
			
		||||
                lines.append([h, r, t])
 | 
			
		||||
        
 | 
			
		||||
 | 
			
		||||
        # for training , select each entity as for get mask embedding.
 | 
			
		||||
        if args.pretrain:
 | 
			
		||||
            rel = list(rel2text.keys())[0]
 | 
			
		||||
            lines = []
 | 
			
		||||
            for k in ent2text.keys():
 | 
			
		||||
                lines.append([k, rel, k])
 | 
			
		||||
        
 | 
			
		||||
        print(f"max number of filter entities : {max_head_entities} {max_tail_entities}")
 | 
			
		||||
 | 
			
		||||
        from os import cpu_count
 | 
			
		||||
        threads = min(1, cpu_count())
 | 
			
		||||
        filter_init(head_filter_entities, tail_filter_entities,ent2text, rel2text, ent2id, ent2token, rel2id
 | 
			
		||||
            )
 | 
			
		||||
        
 | 
			
		||||
        if hasattr(args, "faiss_init") and args.faiss_init:
 | 
			
		||||
            annotate_ = partial(
 | 
			
		||||
                solve_get_knowledge_store,
 | 
			
		||||
                pretrain=self.args.pretrain
 | 
			
		||||
            )
 | 
			
		||||
        else:
 | 
			
		||||
            annotate_ = partial(
 | 
			
		||||
                solve,
 | 
			
		||||
                pretrain=self.args.pretrain
 | 
			
		||||
            )
 | 
			
		||||
        examples = list(
 | 
			
		||||
            tqdm(
 | 
			
		||||
                map(annotate_, lines),
 | 
			
		||||
                total=len(lines),
 | 
			
		||||
                desc="convert text to examples"
 | 
			
		||||
            )
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
        # with Pool(threads, initializer=filter_init, initargs=(head_filter_entities, tail_filter_entities,ent2text, rel2text,
 | 
			
		||||
        #     ent2text_with_type, rel2id,)) as pool:
 | 
			
		||||
        #     annotate_ = partial(
 | 
			
		||||
        #         solve,
 | 
			
		||||
        #     )
 | 
			
		||||
        #     examples = list(
 | 
			
		||||
        #         tqdm(
 | 
			
		||||
        #             map(annotate_, lines, chunksize= 128),
 | 
			
		||||
        #             total=len(lines),
 | 
			
		||||
        #             desc="convert text to examples"
 | 
			
		||||
        #         )
 | 
			
		||||
        #     )
 | 
			
		||||
        tmp_examples = []
 | 
			
		||||
        for e in examples:
 | 
			
		||||
            for ee in e:
 | 
			
		||||
                tmp_examples.append(ee)
 | 
			
		||||
        examples = tmp_examples
 | 
			
		||||
        # delete vars
 | 
			
		||||
        del head_filter_entities, tail_filter_entities, ent2text, rel2text, ent2id, ent2token, rel2id
 | 
			
		||||
        return examples
 | 
			
		||||
 | 
			
		||||
class Verbalizer(object):
 | 
			
		||||
    def __init__(self, args):
 | 
			
		||||
        if "WN18RR" in args.data_dir:
 | 
			
		||||
            self.mode = "WN18RR"
 | 
			
		||||
        elif "FB15k" in args.data_dir:
 | 
			
		||||
            self.mode = "FB15k"
 | 
			
		||||
        elif "umls" in args.data_dir:
 | 
			
		||||
            self.mode = "umls"
 | 
			
		||||
        elif "codexs" in args.data_dir:
 | 
			
		||||
            self.mode = "codexs"
 | 
			
		||||
        elif "codexl" in args.data_dir:
 | 
			
		||||
            self.mode = "codexl"
 | 
			
		||||
        elif "FB13" in args.data_dir:
 | 
			
		||||
            self.mode = "FB13"
 | 
			
		||||
        elif "WN11" in args.data_dir:
 | 
			
		||||
            self.mode = "WN11"
 | 
			
		||||
        
 | 
			
		||||
    
 | 
			
		||||
    def _convert(self, head, relation, tail):
 | 
			
		||||
        if self.mode == "umls":
 | 
			
		||||
            return f"The {relation} {head} is "
 | 
			
		||||
        
 | 
			
		||||
        return f"{head} {relation}"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class KGCDataset(Dataset):
 | 
			
		||||
    def __init__(self, features):
 | 
			
		||||
        self.features = features
 | 
			
		||||
 | 
			
		||||
    def __getitem__(self, index):
 | 
			
		||||
        return self.features[index]
 | 
			
		||||
    
 | 
			
		||||
    def __len__(self):
 | 
			
		||||
        return len(self.features)
 | 
			
		||||
 | 
			
		||||
def convert_examples_to_features_init(tokenizer_for_convert):
 | 
			
		||||
    global tokenizer
 | 
			
		||||
    tokenizer = tokenizer_for_convert
 | 
			
		||||
 | 
			
		||||
def convert_examples_to_features(example, max_seq_length, mode, pretrain=1):
 | 
			
		||||
    """Loads a data file into a list of `InputBatch`s."""
 | 
			
		||||
    # tokens_a = tokenizer.tokenize(example.text_a)
 | 
			
		||||
    # tokens_b = tokenizer.tokenize(example.text_b)
 | 
			
		||||
    # tokens_c = tokenizer.tokenize(example.text_c)
 | 
			
		||||
 | 
			
		||||
    # _truncate_seq_triple(tokens_a, tokens_b, tokens_c, max_length= max_seq_length)
 | 
			
		||||
    text_a = " ".join(example.text_a.split()[:128])
 | 
			
		||||
    text_b = " ".join(example.text_b.split()[:128])
 | 
			
		||||
    text_c = " ".join(example.text_c.split()[:128])
 | 
			
		||||
    
 | 
			
		||||
    if pretrain:
 | 
			
		||||
        input_text_a = text_a
 | 
			
		||||
        input_text_b = text_b
 | 
			
		||||
    else:
 | 
			
		||||
        input_text_a = tokenizer.sep_token.join([text_a, text_b])
 | 
			
		||||
        input_text_b = text_c
 | 
			
		||||
    
 | 
			
		||||
 | 
			
		||||
    inputs = tokenizer(
 | 
			
		||||
        input_text_a,
 | 
			
		||||
        input_text_b,
 | 
			
		||||
        truncation="longest_first",
 | 
			
		||||
        max_length=max_seq_length,
 | 
			
		||||
        padding="longest",
 | 
			
		||||
        add_special_tokens=True,
 | 
			
		||||
    )
 | 
			
		||||
    # assert tokenizer.mask_token_id in inputs.input_ids, "mask token must in input"
 | 
			
		||||
 | 
			
		||||
    features = asdict(InputFeatures(input_ids=inputs["input_ids"],
 | 
			
		||||
                            attention_mask=inputs['attention_mask'],
 | 
			
		||||
                            labels=torch.tensor(example.label),
 | 
			
		||||
                            label=torch.tensor(example.real_label)
 | 
			
		||||
        )
 | 
			
		||||
    )
 | 
			
		||||
    return features
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _truncate_seq_pair(tokens_a, tokens_b, max_length):
 | 
			
		||||
    """Truncates a sequence pair in place to the maximum length."""
 | 
			
		||||
 | 
			
		||||
    # This is a simple heuristic which will always truncate the longer sequence
 | 
			
		||||
    # one token at a time. This makes more sense than truncating an equal percent
 | 
			
		||||
    # of tokens from each, since if one sequence is very short then each token
 | 
			
		||||
    # that's truncated likely contains more information than a longer sequence.
 | 
			
		||||
    while True:
 | 
			
		||||
        total_length = len(tokens_a) + len(tokens_b)
 | 
			
		||||
        if total_length <= max_length:
 | 
			
		||||
            break
 | 
			
		||||
        if len(tokens_a) > len(tokens_b):
 | 
			
		||||
            tokens_a.pop()
 | 
			
		||||
        else:
 | 
			
		||||
            tokens_b.pop()
 | 
			
		||||
 | 
			
		||||
def _truncate_seq_triple(tokens_a, tokens_b, tokens_c, max_length):
 | 
			
		||||
    """Truncates a sequence triple in place to the maximum length."""
 | 
			
		||||
 | 
			
		||||
    # This is a simple heuristic which will always truncate the longer sequence
 | 
			
		||||
    # one token at a time. This makes more sense than truncating an equal percent
 | 
			
		||||
    # of tokens from each, since if one sequence is very short then each token
 | 
			
		||||
    # that's truncated likely contains more information than a longer sequence.
 | 
			
		||||
    while True:
 | 
			
		||||
        total_length = len(tokens_a) + len(tokens_b) + len(tokens_c)
 | 
			
		||||
        if total_length <= max_length:
 | 
			
		||||
            break
 | 
			
		||||
        if len(tokens_a) > len(tokens_b) and len(tokens_a) > len(tokens_c):
 | 
			
		||||
            tokens_a.pop()
 | 
			
		||||
        elif len(tokens_b) > len(tokens_a) and len(tokens_b) > len(tokens_c):
 | 
			
		||||
            tokens_b.pop()
 | 
			
		||||
        elif len(tokens_c) > len(tokens_a) and len(tokens_c) > len(tokens_b):
 | 
			
		||||
            tokens_c.pop()
 | 
			
		||||
        else:
 | 
			
		||||
            tokens_c.pop()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@cache_results(_cache_fp="./dataset")
 | 
			
		||||
def get_dataset(args, processor, label_list, tokenizer, mode):
 | 
			
		||||
 | 
			
		||||
    assert mode in ["train", "dev", "test"], "mode must be in train dev test!"
 | 
			
		||||
 | 
			
		||||
    # use training data to construct the entity embedding
 | 
			
		||||
    combine_train_and_test = False
 | 
			
		||||
    if args.faiss_init and mode == "test" and not args.pretrain:
 | 
			
		||||
        mode = "train"
 | 
			
		||||
        if "ind" in args.data_dir: combine_train_and_test = True
 | 
			
		||||
    else:
 | 
			
		||||
        pass
 | 
			
		||||
 | 
			
		||||
    if mode == "train":
 | 
			
		||||
        train_examples = processor.get_train_examples(args.data_dir)
 | 
			
		||||
    elif mode == "dev":
 | 
			
		||||
        train_examples = processor.get_dev_examples(args.data_dir)
 | 
			
		||||
    else:
 | 
			
		||||
        train_examples = processor.get_test_examples(args.data_dir)
 | 
			
		||||
    
 | 
			
		||||
    if combine_train_and_test:
 | 
			
		||||
        logger.info("use all the dataset for getting the entity mask embedding in pretraining pretraining")
 | 
			
		||||
        logger.info("use all the dataset for getting the entity mask embedding in pretraining pretraining")
 | 
			
		||||
        train_examples = processor.get_test_examples(args.data_dir) + processor.get_train_examples(args.data_dir) + processor.get_dev_examples(args.data_dir)
 | 
			
		||||
 | 
			
		||||
    from os import cpu_count
 | 
			
		||||
    with open(os.path.join(args.data_dir, f"examples_{mode}.txt"), 'w') as file:
 | 
			
		||||
        for line in train_examples:
 | 
			
		||||
            d = {}
 | 
			
		||||
            d.update(line.__dict__)
 | 
			
		||||
            file.write(json.dumps(d) + '\n')
 | 
			
		||||
    
 | 
			
		||||
    tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, use_fast=False)
 | 
			
		||||
 | 
			
		||||
    features = []
 | 
			
		||||
    # # with open(os.path.join(args.data_dir, "cached_relation_pattern.pkl"), "rb") as file:
 | 
			
		||||
    # #     pattern = pickle.load(file)
 | 
			
		||||
    # pattern = None
 | 
			
		||||
    # convert_examples_to_features_init(tokenizer)
 | 
			
		||||
    # annotate_ = partial(
 | 
			
		||||
    #     convert_examples_to_features,
 | 
			
		||||
    #     max_seq_length=args.max_seq_length,
 | 
			
		||||
    #     mode = mode,
 | 
			
		||||
    #     pretrain = args.pretrain
 | 
			
		||||
    # )
 | 
			
		||||
    # features = list(
 | 
			
		||||
    #     tqdm(
 | 
			
		||||
    #         map(annotate_, train_examples),
 | 
			
		||||
    #         total=len(train_examples)
 | 
			
		||||
    #     )
 | 
			
		||||
    # )
 | 
			
		||||
    # encoder = MultiprocessingEncoder(tokenizer, args)
 | 
			
		||||
    # encoder.initializer()
 | 
			
		||||
    # for t in tqdm(train_examples):
 | 
			
		||||
    #     features.append(encoder.encode_lines([json.dumps(t.__dict__)]))
 | 
			
		||||
    
 | 
			
		||||
    # for example in tqdm(train_examples):
 | 
			
		||||
    #     text_a = example.text_a
 | 
			
		||||
    #     text_b = example.text_b
 | 
			
		||||
    #     text_c = example.text_c
 | 
			
		||||
 | 
			
		||||
    #     bpe = tokenizer
 | 
			
		||||
    #     if 0:
 | 
			
		||||
    #         input_text_a = text_a
 | 
			
		||||
    #         input_text_b = text_b
 | 
			
		||||
    #     else:
 | 
			
		||||
    #         if text_a == "[MASK]":
 | 
			
		||||
    #             input_text_a = bpe.sep_token.join([text_a, text_b])
 | 
			
		||||
    #             input_text_b = text_c
 | 
			
		||||
    #         else:
 | 
			
		||||
    #             input_text_a = text_a
 | 
			
		||||
    #             input_text_b = bpe.sep_token.join([text_b, text_c])
 | 
			
		||||
        
 | 
			
		||||
 | 
			
		||||
    #     inputs = tokenizer(
 | 
			
		||||
    #         input_text_a,
 | 
			
		||||
    #         # input_text_b,
 | 
			
		||||
    #         truncation="longest_first",
 | 
			
		||||
    #         max_length=128,
 | 
			
		||||
    #         padding="longest",
 | 
			
		||||
    #         add_special_tokens=True,
 | 
			
		||||
    #     )
 | 
			
		||||
    #     # assert tokenizer.mask_token_id in inputs.input_ids, "mask token must in input"
 | 
			
		||||
 | 
			
		||||
    #     # features.append(asdict(InputFeatures(input_ids=inputs["input_ids"],
 | 
			
		||||
    #     #                         attention_mask=inputs['attention_mask'],
 | 
			
		||||
    #     #                         labels=example.label,
 | 
			
		||||
    #     #                         label=example.real_label
 | 
			
		||||
    #     #     )
 | 
			
		||||
    #     # ))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    file_inputs = [os.path.join(args.data_dir, f"examples_{mode}.txt")]
 | 
			
		||||
    file_outputs = [os.path.join(args.data_dir, f"features_{mode}.txt")]
 | 
			
		||||
 | 
			
		||||
    with contextlib.ExitStack() as stack:
 | 
			
		||||
        inputs = [
 | 
			
		||||
            stack.enter_context(open(input, "r", encoding="utf-8"))
 | 
			
		||||
            if input != "-" else sys.stdin
 | 
			
		||||
            for input in file_inputs
 | 
			
		||||
        ]
 | 
			
		||||
        outputs = [
 | 
			
		||||
            stack.enter_context(open(output, "w", encoding="utf-8"))
 | 
			
		||||
            if output != "-" else sys.stdout
 | 
			
		||||
            for output in file_outputs
 | 
			
		||||
        ]
 | 
			
		||||
 | 
			
		||||
        encoder = MultiprocessingEncoder(tokenizer, args)
 | 
			
		||||
        pool = Pool(16, initializer=encoder.initializer)
 | 
			
		||||
        encoder.initializer()
 | 
			
		||||
        encoded_lines = pool.imap(encoder.encode_lines, zip(*inputs), 1000)
 | 
			
		||||
        # encoded_lines = map(encoder.encode_lines, zip(*inputs))
 | 
			
		||||
 | 
			
		||||
        stats = Counter()
 | 
			
		||||
        for i, (filt, enc_lines) in tqdm(enumerate(encoded_lines, start=1), total=len(train_examples)):
 | 
			
		||||
            if filt == "PASS":
 | 
			
		||||
                for enc_line, output_h in zip(enc_lines, outputs):
 | 
			
		||||
                    features.append(eval(enc_line))
 | 
			
		||||
                    # features.append(enc_line)
 | 
			
		||||
                    # print(enc_line, file=output_h)
 | 
			
		||||
            else:
 | 
			
		||||
                stats["num_filtered_" + filt] += 1
 | 
			
		||||
            # if i % 10000 == 0:
 | 
			
		||||
            #     print("processed {} lines".format(i), file=sys.stderr)
 | 
			
		||||
 | 
			
		||||
        for k, v in stats.most_common():
 | 
			
		||||
            print("[{}] filtered {} lines".format(k, v), file=sys.stderr)
 | 
			
		||||
    # threads = min(16, cpu_count())
 | 
			
		||||
    # with Pool(threads, initializer=convert_examples_to_features_init, initargs=(tokenizer,)) as pool:
 | 
			
		||||
    #     annotate_ = partial(
 | 
			
		||||
    #         convert_examples_to_features,
 | 
			
		||||
    #         max_seq_length=args.max_seq_length,
 | 
			
		||||
    #         mode = mode,
 | 
			
		||||
    #         pretrain = args.pretrain
 | 
			
		||||
    #     )
 | 
			
		||||
    #     features = list(
 | 
			
		||||
    #         tqdm(
 | 
			
		||||
    #             pool.imap_unordered(annotate_, train_examples),
 | 
			
		||||
    #             total=len(train_examples),
 | 
			
		||||
    #             desc="convert examples to features",
 | 
			
		||||
    #         )
 | 
			
		||||
    #     )
 | 
			
		||||
 | 
			
		||||
    # num_entities = len(processor.get_entities(args.data_dir))
 | 
			
		||||
    for f_id, f in enumerate(features):
 | 
			
		||||
        en = features[f_id].pop("en")
 | 
			
		||||
        rel = features[f_id].pop("rel")
 | 
			
		||||
        real_label = f['label']
 | 
			
		||||
        cnt = 0
 | 
			
		||||
        if not isinstance(en, list): break
 | 
			
		||||
 | 
			
		||||
        pos = 0
 | 
			
		||||
        for i,t in enumerate(f['input_ids']):
 | 
			
		||||
            if t == tokenizer.pad_token_id:
 | 
			
		||||
                features[f_id]['input_ids'][i] = en[cnt] + len(tokenizer)
 | 
			
		||||
                cnt += 1
 | 
			
		||||
            if features[f_id]['input_ids'][i] == real_label + len(tokenizer):
 | 
			
		||||
                pos = i
 | 
			
		||||
            if cnt == len(en): break
 | 
			
		||||
        assert not (args.faiss_init and pos == 0)
 | 
			
		||||
        features[f_id]['pos'] = pos
 | 
			
		||||
 | 
			
		||||
        
 | 
			
		||||
        # for i,t in enumerate(f['input_ids']):
 | 
			
		||||
        #     if t == tokenizer.pad_token_id:
 | 
			
		||||
        #         features[f_id]['input_ids'][i] = rel + len(tokenizer) + num_entities
 | 
			
		||||
        #         break
 | 
			
		||||
 | 
			
		||||
        
 | 
			
		||||
 | 
			
		||||
    features = KGCDataset(features)
 | 
			
		||||
    return features
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class MultiprocessingEncoder(object):
 | 
			
		||||
    def __init__(self, tokenizer, args):
 | 
			
		||||
        self.tokenizer = tokenizer
 | 
			
		||||
        self.pretrain = args.pretrain
 | 
			
		||||
        self.max_seq_length = args.max_seq_length
 | 
			
		||||
 | 
			
		||||
    def initializer(self):
 | 
			
		||||
        global bpe
 | 
			
		||||
        bpe = self.tokenizer
 | 
			
		||||
 | 
			
		||||
    def encode(self, line):
 | 
			
		||||
        global bpe
 | 
			
		||||
        ids = bpe.encode(line)
 | 
			
		||||
        return list(map(str, ids))
 | 
			
		||||
 | 
			
		||||
    def decode(self, tokens):
 | 
			
		||||
        global bpe
 | 
			
		||||
        return bpe.decode(tokens)
 | 
			
		||||
 | 
			
		||||
    def encode_lines(self, lines):
 | 
			
		||||
        """
 | 
			
		||||
        Encode a set of lines. All lines will be encoded together.
 | 
			
		||||
        """
 | 
			
		||||
        enc_lines = []
 | 
			
		||||
        for line in lines:
 | 
			
		||||
            line = line.strip()
 | 
			
		||||
            if len(line) == 0:
 | 
			
		||||
                return ["EMPTY", None]
 | 
			
		||||
            # enc_lines.append(" ".join(tokens))
 | 
			
		||||
            enc_lines.append(json.dumps(self.convert_examples_to_features(example=eval(line))))
 | 
			
		||||
            # enc_lines.append(" ")
 | 
			
		||||
            # enc_lines.append("123")
 | 
			
		||||
        return ["PASS", enc_lines]
 | 
			
		||||
 | 
			
		||||
    def decode_lines(self, lines):
 | 
			
		||||
        dec_lines = []
 | 
			
		||||
        for line in lines:
 | 
			
		||||
            tokens = map(int, line.strip().split())
 | 
			
		||||
            dec_lines.append(self.decode(tokens))
 | 
			
		||||
        return ["PASS", dec_lines]
 | 
			
		||||
 | 
			
		||||
    def convert_examples_to_features(self, example):
 | 
			
		||||
        pretrain = self.pretrain
 | 
			
		||||
        max_seq_length = self.max_seq_length
 | 
			
		||||
        global bpe
 | 
			
		||||
        """Loads a data file into a list of `InputBatch`s."""
 | 
			
		||||
        # tokens_a = tokenizer.tokenize(example.text_a)
 | 
			
		||||
        # tokens_b = tokenizer.tokenize(example.text_b)
 | 
			
		||||
        # tokens_c = tokenizer.tokenize(example.text_c)
 | 
			
		||||
 | 
			
		||||
        # _truncate_seq_triple(tokens_a, tokens_b, tokens_c, max_length= max_seq_length)
 | 
			
		||||
        # text_a = " ".join(example['text_a'].split()[:128])
 | 
			
		||||
        # text_b = " ".join(example['text_b'].split()[:128])
 | 
			
		||||
        # text_c = " ".join(example['text_c'].split()[:128])
 | 
			
		||||
        
 | 
			
		||||
        text_a = example['text_a']
 | 
			
		||||
        text_b = example['text_b']
 | 
			
		||||
        text_c = example['text_c']
 | 
			
		||||
 | 
			
		||||
        if pretrain:
 | 
			
		||||
            # the des of xxx is [MASK] .
 | 
			
		||||
            input_text = f"The description of {text_a} is that {text_b} ."
 | 
			
		||||
            inputs = bpe(
 | 
			
		||||
                input_text,
 | 
			
		||||
                truncation="longest_first",
 | 
			
		||||
                max_length=max_seq_length,
 | 
			
		||||
                padding="longest",
 | 
			
		||||
                add_special_tokens=True,
 | 
			
		||||
            )
 | 
			
		||||
        else:
 | 
			
		||||
            if text_a == "[MASK]":
 | 
			
		||||
                input_text_a = bpe.sep_token.join([text_a, text_b])
 | 
			
		||||
                input_text_b = text_c
 | 
			
		||||
            else:
 | 
			
		||||
                input_text_a = text_a
 | 
			
		||||
                input_text_b = bpe.sep_token.join([text_b, text_c])
 | 
			
		||||
        
 | 
			
		||||
 | 
			
		||||
            inputs = bpe(
 | 
			
		||||
                input_text_a,
 | 
			
		||||
                input_text_b,
 | 
			
		||||
                truncation="longest_first",
 | 
			
		||||
                max_length=max_seq_length,
 | 
			
		||||
                padding="longest",
 | 
			
		||||
                add_special_tokens=True,
 | 
			
		||||
            )
 | 
			
		||||
        # assert bpe.mask_token_id in inputs.input_ids, "mask token must in input"
 | 
			
		||||
 | 
			
		||||
        features = asdict(InputFeatures(input_ids=inputs["input_ids"],
 | 
			
		||||
                                attention_mask=inputs['attention_mask'],
 | 
			
		||||
                                labels=example['label'],
 | 
			
		||||
                                label=example['real_label'],
 | 
			
		||||
                                en=example['en'],
 | 
			
		||||
                                rel=example['rel']
 | 
			
		||||
            )
 | 
			
		||||
        )
 | 
			
		||||
        return features
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    dataset = KGCDataset('./dataset')
 | 
			
		||||
							
								
								
									
										2
									
								
								pretrain/lit_models/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										2
									
								
								pretrain/lit_models/__init__.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,2 @@
 | 
			
		||||
from .transformer import *
 | 
			
		||||
from .base import *
 | 
			
		||||
							
								
								
									
										97
									
								
								pretrain/lit_models/base.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										97
									
								
								pretrain/lit_models/base.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,97 @@
 | 
			
		||||
import argparse
 | 
			
		||||
import pytorch_lightning as pl
 | 
			
		||||
import torch
 | 
			
		||||
from typing import Dict, Any
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
OPTIMIZER = "AdamW"
 | 
			
		||||
LR = 5e-5
 | 
			
		||||
LOSS = "cross_entropy"
 | 
			
		||||
ONE_CYCLE_TOTAL_STEPS = 100
 | 
			
		||||
 | 
			
		||||
class Config(dict):
 | 
			
		||||
    def __getattr__(self, name):
 | 
			
		||||
        return self.get(name)
 | 
			
		||||
 | 
			
		||||
    def __setattr__(self, name, val):
 | 
			
		||||
        self[name] = val
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class BaseLitModel(pl.LightningModule):
 | 
			
		||||
    """
 | 
			
		||||
    Generic PyTorch-Lightning class that must be initialized with a PyTorch module.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    def __init__(self, model, args: argparse.Namespace = None):
 | 
			
		||||
        super().__init__()
 | 
			
		||||
        self.model = model
 | 
			
		||||
        self.args = Config(vars(args)) if args is not None else {}
 | 
			
		||||
 | 
			
		||||
        optimizer = self.args.get("optimizer", OPTIMIZER)
 | 
			
		||||
        self.optimizer_class = getattr(torch.optim, optimizer)
 | 
			
		||||
        self.lr = self.args.get("lr", LR)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def add_to_argparse(parser):
 | 
			
		||||
        parser.add_argument("--optimizer", type=str, default=OPTIMIZER, help="optimizer class from torch.optim")
 | 
			
		||||
        parser.add_argument("--lr", type=float, default=LR)
 | 
			
		||||
        parser.add_argument("--weight_decay", type=float, default=0.01)
 | 
			
		||||
        return parser
 | 
			
		||||
 | 
			
		||||
    def configure_optimizers(self):
 | 
			
		||||
        optimizer = self.optimizer_class(self.parameters(), lr=self.lr)
 | 
			
		||||
        if self.one_cycle_max_lr is None:
 | 
			
		||||
            return optimizer
 | 
			
		||||
        scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer=optimizer, max_lr=self.one_cycle_max_lr, total_steps=self.one_cycle_total_steps)
 | 
			
		||||
        return {"optimizer": optimizer, "lr_scheduler": scheduler, "monitor": "val_loss"}
 | 
			
		||||
 | 
			
		||||
    def forward(self, x):
 | 
			
		||||
        return self.model(x)
 | 
			
		||||
 | 
			
		||||
    def training_step(self, batch, batch_idx):  # pylint: disable=unused-argument
 | 
			
		||||
        x, y = batch
 | 
			
		||||
        logits = self(x)
 | 
			
		||||
        loss = self.loss_fn(logits, y)
 | 
			
		||||
        self.log("train_loss", loss)
 | 
			
		||||
        self.train_acc(logits, y)
 | 
			
		||||
        self.log("train_acc", self.train_acc, on_step=False, on_epoch=True)
 | 
			
		||||
        return loss
 | 
			
		||||
 | 
			
		||||
    def validation_step(self, batch, batch_idx):  # pylint: disable=unused-argument
 | 
			
		||||
        x, y = batch
 | 
			
		||||
        logits = self(x)
 | 
			
		||||
        loss = self.loss_fn(logits, y)
 | 
			
		||||
        self.log("val_loss", loss, prog_bar=True)
 | 
			
		||||
        self.val_acc(logits, y)
 | 
			
		||||
        self.log("val_acc", self.val_acc, on_step=False, on_epoch=True, prog_bar=True)
 | 
			
		||||
 | 
			
		||||
    def test_step(self, batch, batch_idx):  # pylint: disable=unused-argument
 | 
			
		||||
        x, y = batch
 | 
			
		||||
        logits = self(x)
 | 
			
		||||
        self.test_acc(logits, y)
 | 
			
		||||
        self.log("test_acc", self.test_acc, on_step=False, on_epoch=True)
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def num_training_steps(self) -> int:
 | 
			
		||||
        """Total training steps inferred from datamodule and devices."""
 | 
			
		||||
        if isinstance(self.trainer.limit_train_batches, int) and self.trainer.limit_train_batches != 0:
 | 
			
		||||
            dataset_size = self.trainer.limit_train_batches
 | 
			
		||||
        elif isinstance(self.trainer.limit_train_batches, float):
 | 
			
		||||
            # limit_train_batches is a percentage of batches
 | 
			
		||||
            dataset_size = len(self.trainer.datamodule.train_dataloader())
 | 
			
		||||
            dataset_size = int(dataset_size * self.trainer.limit_train_batches)
 | 
			
		||||
        else:
 | 
			
		||||
            dataset_size = len(self.trainer.datamodule.train_dataloader())
 | 
			
		||||
 | 
			
		||||
        num_devices = max(1, self.trainer.num_gpus, self.trainer.num_processes)
 | 
			
		||||
        if self.trainer.tpu_cores:
 | 
			
		||||
            num_devices = max(num_devices, self.trainer.tpu_cores)
 | 
			
		||||
 | 
			
		||||
        effective_batch_size = self.trainer.accumulate_grad_batches * num_devices
 | 
			
		||||
        max_estimated_steps = (dataset_size // effective_batch_size) * self.trainer.max_epochs
 | 
			
		||||
 | 
			
		||||
        if self.trainer.max_steps and self.trainer.max_steps < max_estimated_steps:
 | 
			
		||||
            return self.trainer.max_steps
 | 
			
		||||
        return max_estimated_steps
 | 
			
		||||
    
 | 
			
		||||
							
								
								
									
										503
									
								
								pretrain/lit_models/transformer.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										503
									
								
								pretrain/lit_models/transformer.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,503 @@
 | 
			
		||||
from logging import debug
 | 
			
		||||
import random
 | 
			
		||||
import pytorch_lightning as pl
 | 
			
		||||
import torch
 | 
			
		||||
import pickle
 | 
			
		||||
import torch.nn as nn
 | 
			
		||||
import torch.nn.functional as F
 | 
			
		||||
import numpy as np
 | 
			
		||||
import json
 | 
			
		||||
# from transformers.utils.dummy_pt_objects import PrefixConstrainedLogitsProcessor
 | 
			
		||||
 | 
			
		||||
from .base import BaseLitModel
 | 
			
		||||
from transformers.optimization import get_linear_schedule_with_warmup, get_cosine_schedule_with_warmup
 | 
			
		||||
 | 
			
		||||
from functools import partial
 | 
			
		||||
from .utils import rank_score, acc, LabelSmoothSoftmaxCEV1
 | 
			
		||||
 | 
			
		||||
from typing import Callable, Iterable, List
 | 
			
		||||
 | 
			
		||||
def lmap(f: Callable, x: Iterable) -> List:
 | 
			
		||||
    """list(map(f, x))"""
 | 
			
		||||
    return list(map(f, x))
 | 
			
		||||
 | 
			
		||||
def multilabel_categorical_crossentropy(y_pred, y_true):
 | 
			
		||||
    y_pred = (1 - 2 * y_true) * y_pred
 | 
			
		||||
    y_pred_neg = y_pred - y_true * 1e12
 | 
			
		||||
    y_pred_pos = y_pred - (1 - y_true) * 1e12
 | 
			
		||||
    zeros = torch.zeros_like(y_pred[..., :1])
 | 
			
		||||
    y_pred_neg = torch.cat([y_pred_neg, zeros], dim=-1)
 | 
			
		||||
    y_pred_pos = torch.cat([y_pred_pos, zeros], dim=-1)
 | 
			
		||||
    neg_loss = torch.logsumexp(y_pred_neg, dim=-1)
 | 
			
		||||
    pos_loss = torch.logsumexp(y_pred_pos, dim=-1)
 | 
			
		||||
    return (neg_loss + pos_loss).mean()
 | 
			
		||||
 | 
			
		||||
def decode(output_ids, tokenizer):
 | 
			
		||||
    return lmap(str.strip, tokenizer.batch_decode(output_ids, skip_special_tokens=False, clean_up_tokenization_spaces=True))
 | 
			
		||||
 | 
			
		||||
class TransformerLitModel(BaseLitModel):
 | 
			
		||||
    def __init__(self, model, args, tokenizer=None, data_config={}):
 | 
			
		||||
        super().__init__(model, args)
 | 
			
		||||
        self.save_hyperparameters(args)
 | 
			
		||||
        if args.bce:
 | 
			
		||||
            self.loss_fn = torch.nn.BCEWithLogitsLoss()
 | 
			
		||||
        elif args.label_smoothing != 0.0:
 | 
			
		||||
            self.loss_fn = LabelSmoothSoftmaxCEV1(lb_smooth=args.label_smoothing)
 | 
			
		||||
        else:
 | 
			
		||||
            self.loss_fn = nn.CrossEntropyLoss()
 | 
			
		||||
        self.best_acc = 0
 | 
			
		||||
        self.first = True
 | 
			
		||||
        
 | 
			
		||||
        self.tokenizer = tokenizer
 | 
			
		||||
 | 
			
		||||
        self.__dict__.update(data_config)
 | 
			
		||||
        # resize the word embedding layer
 | 
			
		||||
        self.model.resize_token_embeddings(len(self.tokenizer))
 | 
			
		||||
        self.decode = partial(decode, tokenizer=self.tokenizer)
 | 
			
		||||
        if args.pretrain:
 | 
			
		||||
            self._freaze_attention()
 | 
			
		||||
        elif "ind" in args.data_dir:
 | 
			
		||||
            # for inductive setting, use feeaze the word embedding
 | 
			
		||||
            self._freaze_word_embedding()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    def forward(self, x):
 | 
			
		||||
        return self.model(x)
 | 
			
		||||
 | 
			
		||||
    def training_step(self, batch, batch_idx):  # pylint: disable=unused-argument
 | 
			
		||||
        # embed();exit()
 | 
			
		||||
        # print(self.optimizers().param_groups[1]['lr'])
 | 
			
		||||
        labels = batch.pop("labels")
 | 
			
		||||
        label = batch.pop("label")
 | 
			
		||||
        pos = batch.pop("pos")
 | 
			
		||||
        try:
 | 
			
		||||
            en = batch.pop("en")
 | 
			
		||||
            rel = batch.pop("rel")
 | 
			
		||||
        except KeyError:
 | 
			
		||||
            pass
 | 
			
		||||
        input_ids = batch['input_ids']
 | 
			
		||||
        logits = self.model(**batch, return_dict=True).logits
 | 
			
		||||
 | 
			
		||||
        _, mask_idx = (input_ids == self.tokenizer.mask_token_id).nonzero(as_tuple=True)
 | 
			
		||||
        bs = input_ids.shape[0]
 | 
			
		||||
        mask_logits = logits[torch.arange(bs), mask_idx][:, self.entity_id_st:self.entity_id_ed]
 | 
			
		||||
 | 
			
		||||
        assert mask_idx.shape[0] == bs, "only one mask in sequence!"
 | 
			
		||||
        if self.args.bce:
 | 
			
		||||
            loss = self.loss_fn(mask_logits, labels)
 | 
			
		||||
        else:
 | 
			
		||||
            loss = self.loss_fn(mask_logits, label)
 | 
			
		||||
 | 
			
		||||
        if batch_idx == 0:
 | 
			
		||||
            print('\n'.join(self.decode(batch['input_ids'][:4])))
 | 
			
		||||
        
 | 
			
		||||
 | 
			
		||||
        return loss
 | 
			
		||||
 | 
			
		||||
    def _eval(self, batch, batch_idx, ):
 | 
			
		||||
        labels = batch.pop("labels")
 | 
			
		||||
        input_ids = batch['input_ids']
 | 
			
		||||
        # single label
 | 
			
		||||
        label = batch.pop('label')
 | 
			
		||||
        pos = batch.pop('pos')
 | 
			
		||||
        my_keys = list(batch.keys())
 | 
			
		||||
        for k in my_keys:
 | 
			
		||||
            if k not in ["input_ids", "attention_mask", "token_type_ids"]:
 | 
			
		||||
                batch.pop(k)
 | 
			
		||||
        logits = self.model(**batch, return_dict=True).logits[:, :, self.entity_id_st:self.entity_id_ed]
 | 
			
		||||
        _, mask_idx = (input_ids == self.tokenizer.mask_token_id).nonzero(as_tuple=True)
 | 
			
		||||
        bsz = input_ids.shape[0]
 | 
			
		||||
        logits = logits[torch.arange(bsz), mask_idx]
 | 
			
		||||
        # get the entity ranks
 | 
			
		||||
        # filter the entity
 | 
			
		||||
        assert labels[0][label[0]], "correct ids must in filiter!"
 | 
			
		||||
        labels[torch.arange(bsz), label] = 0
 | 
			
		||||
        assert logits.shape == labels.shape
 | 
			
		||||
        logits += labels * -100 # mask entityj
 | 
			
		||||
        # for i in range(bsz):
 | 
			
		||||
        #     logits[i][labels]
 | 
			
		||||
 | 
			
		||||
        _, outputs = torch.sort(logits, dim=1, descending=True)
 | 
			
		||||
        _, outputs = torch.sort(outputs, dim=1)
 | 
			
		||||
        ranks = outputs[torch.arange(bsz), label].detach().cpu() + 1
 | 
			
		||||
        
 | 
			
		||||
 | 
			
		||||
        return dict(ranks = np.array(ranks))
 | 
			
		||||
 | 
			
		||||
    def validation_step(self, batch, batch_idx):
 | 
			
		||||
        result = self._eval(batch, batch_idx)
 | 
			
		||||
        return result
 | 
			
		||||
 | 
			
		||||
    def validation_epoch_end(self, outputs) -> None:
 | 
			
		||||
        ranks = np.concatenate([_['ranks'] for _ in outputs])
 | 
			
		||||
        total_ranks = ranks.shape[0]
 | 
			
		||||
 | 
			
		||||
        if not self.args.pretrain:
 | 
			
		||||
            l_ranks = ranks[np.array(list(np.arange(0, total_ranks, 2)))]
 | 
			
		||||
            r_ranks = ranks[np.array(list(np.arange(0, total_ranks, 2))) + 1]
 | 
			
		||||
            self.log("Eval/lhits10", (l_ranks<=10).mean())
 | 
			
		||||
            self.log("Eval/rhits10", (r_ranks<=10).mean())
 | 
			
		||||
 | 
			
		||||
        hits20 = (ranks<=20).mean()
 | 
			
		||||
        hits10 = (ranks<=10).mean()
 | 
			
		||||
        hits3 = (ranks<=3).mean()
 | 
			
		||||
        hits1 = (ranks<=1).mean()
 | 
			
		||||
 | 
			
		||||
        self.log("Eval/hits10", hits10)
 | 
			
		||||
        self.log("Eval/hits20", hits20)
 | 
			
		||||
        self.log("Eval/hits3", hits3)
 | 
			
		||||
        self.log("Eval/hits1", hits1)
 | 
			
		||||
        self.log("Eval/mean_rank", ranks.mean())
 | 
			
		||||
        self.log("Eval/mrr", (1. / ranks).mean())
 | 
			
		||||
        self.log("hits10", hits10, prog_bar=True)
 | 
			
		||||
        self.log("hits1", hits1, prog_bar=True)
 | 
			
		||||
 | 
			
		||||
            
 | 
			
		||||
    
 | 
			
		||||
 | 
			
		||||
    def test_step(self, batch, batch_idx):  # pylint: disable=unused-argument
 | 
			
		||||
        # ranks = self._eval(batch, batch_idx)
 | 
			
		||||
        result = self._eval(batch, batch_idx)
 | 
			
		||||
        # self.log("Test/ranks", np.mean(ranks))
 | 
			
		||||
 | 
			
		||||
        return result
 | 
			
		||||
 | 
			
		||||
    def test_epoch_end(self, outputs) -> None:
 | 
			
		||||
        ranks = np.concatenate([_['ranks'] for _ in outputs])
 | 
			
		||||
 | 
			
		||||
        hits20 = (ranks<=20).mean()
 | 
			
		||||
        hits10 = (ranks<=10).mean()
 | 
			
		||||
        hits3 = (ranks<=3).mean()
 | 
			
		||||
        hits1 = (ranks<=1).mean()
 | 
			
		||||
 | 
			
		||||
       
 | 
			
		||||
        self.log("Test/hits10", hits10)
 | 
			
		||||
        self.log("Test/hits20", hits20)
 | 
			
		||||
        self.log("Test/hits3", hits3)
 | 
			
		||||
        self.log("Test/hits1", hits1)
 | 
			
		||||
        self.log("Test/mean_rank", ranks.mean())
 | 
			
		||||
        self.log("Test/mrr", (1. / ranks).mean())
 | 
			
		||||
 | 
			
		||||
    def configure_optimizers(self):
 | 
			
		||||
        no_decay_param = ["bias", "LayerNorm.weight"]
 | 
			
		||||
 | 
			
		||||
        optimizer_group_parameters = [
 | 
			
		||||
            {"params": [p for n, p in self.model.named_parameters() if p.requires_grad and not any(nd in n for nd in no_decay_param)], "weight_decay": self.args.weight_decay},
 | 
			
		||||
            {"params": [p for n, p in self.model.named_parameters() if p.requires_grad and any(nd in n for nd in no_decay_param)], "weight_decay": 0}
 | 
			
		||||
        ]
 | 
			
		||||
 | 
			
		||||
        optimizer = self.optimizer_class(optimizer_group_parameters, lr=self.lr, eps=1e-8)
 | 
			
		||||
        scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=self.num_training_steps * self.args.warm_up_radio, num_training_steps=self.num_training_steps)
 | 
			
		||||
        return {
 | 
			
		||||
            "optimizer": optimizer, 
 | 
			
		||||
            "lr_scheduler":{
 | 
			
		||||
                'scheduler': scheduler,
 | 
			
		||||
                'interval': 'step',  # or 'epoch'
 | 
			
		||||
                'frequency': 1,
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
    
 | 
			
		||||
    def _freaze_attention(self):
 | 
			
		||||
        for k, v in self.model.named_parameters():
 | 
			
		||||
            if "word" not in k:
 | 
			
		||||
                v.requires_grad = False
 | 
			
		||||
            else:
 | 
			
		||||
                print(k)
 | 
			
		||||
    
 | 
			
		||||
    def _freaze_word_embedding(self):
 | 
			
		||||
        for k, v in self.model.named_parameters():
 | 
			
		||||
            if "word" in k:
 | 
			
		||||
                print(k)
 | 
			
		||||
                v.requires_grad = False
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def add_to_argparse(parser):
 | 
			
		||||
        parser = BaseLitModel.add_to_argparse(parser)
 | 
			
		||||
 | 
			
		||||
        parser.add_argument("--label_smoothing", type=float, default=0.1, help="")
 | 
			
		||||
        parser.add_argument("--bce", type=int, default=0, help="")
 | 
			
		||||
        return parser
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
import faiss
 | 
			
		||||
import os
 | 
			
		||||
class GetEntityEmbeddingLitModel(TransformerLitModel):
 | 
			
		||||
    def __init__(self, model, args, tokenizer, data_config={}):
 | 
			
		||||
        super().__init__(model, args, tokenizer, data_config)
 | 
			
		||||
 | 
			
		||||
        self.faissid2entityid = {}
 | 
			
		||||
        # self.index = faiss.IndexFlatL2(d)   # build the index
 | 
			
		||||
 | 
			
		||||
        d, measure = self.model.config.hidden_size, faiss.METRIC_L2   
 | 
			
		||||
        # param =  'HNSW64' 
 | 
			
		||||
        # self.index = faiss.index_factory(d, param, measure)  
 | 
			
		||||
        self.index = faiss.IndexFlatL2(d)   # build the index
 | 
			
		||||
        # print(self.index.is_trained)                          # 此时输出为True 
 | 
			
		||||
        # index.add(xb)
 | 
			
		||||
        self.cnt_batch = 0
 | 
			
		||||
        self.total_embedding = []
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    def test_step(self, batch, batch_idx):
 | 
			
		||||
        labels = batch.pop("labels")
 | 
			
		||||
        mask_idx = batch.pop("pos")
 | 
			
		||||
        input_ids = batch['input_ids']
 | 
			
		||||
        # single label
 | 
			
		||||
        label = batch.pop('label')
 | 
			
		||||
        # last layer 
 | 
			
		||||
        hidden_states = self.model(**batch, return_dict=True, output_hidden_states=True).hidden_states[-1]
 | 
			
		||||
        # _, mask_idx = (input_ids == self.tokenizer.mask_token_id).nonzero(as_tuple=True)
 | 
			
		||||
        bsz = input_ids.shape[0]
 | 
			
		||||
        entity_embedding = hidden_states[torch.arange(bsz), mask_idx].cpu()
 | 
			
		||||
        # use normalize or not ?
 | 
			
		||||
        # entity_embedding = F.normalize(entity_embedding, dim=-1, p = 2)
 | 
			
		||||
        self.total_embedding.append(entity_embedding)
 | 
			
		||||
        # self.index.add(np.array(entity_embedding, dtype=np.float32))
 | 
			
		||||
        for i, l in zip(range(bsz), label):
 | 
			
		||||
            self.faissid2entityid[i+self.cnt_batch] = l.cpu()
 | 
			
		||||
        self.cnt_batch += bsz
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    def test_epoch_end(self, outputs) -> None:
 | 
			
		||||
        self.total_embedding = np.concatenate(self.total_embedding, axis=0)
 | 
			
		||||
        # self.index.train(self.total_embedding)
 | 
			
		||||
        print(faiss.MatrixStats(self.total_embedding).comments)
 | 
			
		||||
        self.index.add(self.total_embedding)
 | 
			
		||||
        faiss.write_index(self.index, os.path.join(self.args.data_dir, "faiss_dump.index"))
 | 
			
		||||
        with open(os.path.join(self.args.data_dir, "faissid2entityid.pkl") ,'wb') as file:
 | 
			
		||||
            pickle.dump(self.faissid2entityid, file)
 | 
			
		||||
 | 
			
		||||
        with open(os.path.join(self.args.data_dir, "total_embedding.pkl") ,'wb') as file:
 | 
			
		||||
            pickle.dump(self.total_embedding, file)
 | 
			
		||||
        # print(f"number of  entity embedding : {len(self.faissid2entityid)}")
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def add_to_argparse(parser):
 | 
			
		||||
        parser = TransformerLitModel.add_to_argparse(parser)
 | 
			
		||||
        parser.add_argument("--faiss_init", type=int, default=1, help="get the embedding and save it the file.")
 | 
			
		||||
        return parser
 | 
			
		||||
 | 
			
		||||
class UseEntityEmbeddingLitModel(TransformerLitModel):
 | 
			
		||||
    def __init__(self, model, args, tokenizer, data_config={}):
 | 
			
		||||
        super().__init__(model, args, tokenizer, data_config)
 | 
			
		||||
 | 
			
		||||
        self.faissid2entityid = pickle.load(open(os.path.join(self.args.data_dir, "faissid2entityid.pkl") ,'rb'))
 | 
			
		||||
        self.index = faiss.read_index(os.path.join(self.args.data_dir, "faiss_dump.index"))
 | 
			
		||||
        
 | 
			
		||||
 | 
			
		||||
        self.dis2logits = distance2logits_2
 | 
			
		||||
    
 | 
			
		||||
    def _eval(self, batch, batch_idx, ):
 | 
			
		||||
        labels = batch.pop("labels")
 | 
			
		||||
        pos = batch.pop("pos")
 | 
			
		||||
        input_ids = batch['input_ids']
 | 
			
		||||
        # single label
 | 
			
		||||
        label = batch.pop('label')
 | 
			
		||||
 | 
			
		||||
        hidden_states = self.model(**batch, return_dict=True, output_hidden_states=True).hidden_states[-1]
 | 
			
		||||
        _, mask_idx = (input_ids == self.tokenizer.mask_token_id).nonzero(as_tuple=True)
 | 
			
		||||
        bsz = input_ids.shape[0]
 | 
			
		||||
        mask_embedding = np.array(hidden_states[torch.arange(bsz), mask_idx].cpu(), dtype=np.float32)
 | 
			
		||||
        topk = 200
 | 
			
		||||
        D, I = self.index.search(mask_embedding, topk)
 | 
			
		||||
        labels[torch.arange(bsz), label] = 0
 | 
			
		||||
 | 
			
		||||
        # logits = torch.zeros_like(labels)
 | 
			
		||||
        # D = torch.softmax(torch.exp(-1. * torch.tensor(D)), dim=-1)
 | 
			
		||||
        # for i in range(bsz):
 | 
			
		||||
        #     for j in range(topk):
 | 
			
		||||
        #         logits[i][self.faissid2entityid[I[i][j]]] += D[i][j]
 | 
			
		||||
        # # get the entity ranks
 | 
			
		||||
        # # filter the entity
 | 
			
		||||
        # assert labels[0][label[0]], "correct ids must in filiter!"
 | 
			
		||||
        # labels[torch.arange(bsz), label] = 0
 | 
			
		||||
        # assert logits.shape == labels.shape
 | 
			
		||||
        # logits += labels * -100 # mask entityj
 | 
			
		||||
 | 
			
		||||
        # _, outputs = torch.sort(logits, dim=1, descending=True)
 | 
			
		||||
        # _, outputs = torch.sort(outputs, dim=1)
 | 
			
		||||
        # ranks = outputs[torch.arange(bsz), label].detach().cpu() + 1
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
        entity_logits = torch.full(labels.shape, -100.).to(self.device)
 | 
			
		||||
        D = self.dis2logits(D)
 | 
			
		||||
        for i in range(bsz):
 | 
			
		||||
            for j in range(topk):
 | 
			
		||||
                # filter entity in labels
 | 
			
		||||
                if I[i][j] not in self.faissid2entityid: 
 | 
			
		||||
                    print(I[i][j])
 | 
			
		||||
                    break
 | 
			
		||||
                # assert I[i][j] in self.faissid2entityid, print(I[i][j])
 | 
			
		||||
                if labels[i][self.faissid2entityid[I[i][j]]]: continue
 | 
			
		||||
                if entity_logits[i][self.faissid2entityid[I[i][j]]] == -100.:
 | 
			
		||||
                    entity_logits[i][self.faissid2entityid[I[i][j]]] = D[i][j]
 | 
			
		||||
                # no added together
 | 
			
		||||
                # else:
 | 
			
		||||
                #     entity_logits[i][self.faissid2entityid[I[i][j]]] += D[i][j]
 | 
			
		||||
        # get the entity ranks
 | 
			
		||||
        # filter the entity
 | 
			
		||||
 | 
			
		||||
        assert entity_logits.shape == labels.shape
 | 
			
		||||
 | 
			
		||||
        _, outputs = torch.sort(entity_logits, dim=1, descending=True)
 | 
			
		||||
        _, outputs = torch.sort(outputs, dim=1)
 | 
			
		||||
        ranks = outputs[torch.arange(bsz), label].detach().cpu() + 1
 | 
			
		||||
        
 | 
			
		||||
 | 
			
		||||
        return dict(ranks = np.array(ranks))
 | 
			
		||||
    
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def add_to_argparse(parser):
 | 
			
		||||
        parser = TransformerLitModel.add_to_argparse(parser)
 | 
			
		||||
        parser.add_argument("--faiss_init", type=int, default=0, help="get the embedding and save it the file.")
 | 
			
		||||
        parser.add_argument("--faiss_use", type=int, default=1, help="get the embedding and save it the file.")
 | 
			
		||||
        return parser
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class CombineEntityEmbeddingLitModel(UseEntityEmbeddingLitModel):
 | 
			
		||||
    def __init__(self, model, args, tokenizer, data_config={}):
 | 
			
		||||
        super().__init__(model, args, tokenizer, data_config=data_config)
 | 
			
		||||
        self.dis2logits = distance2logits_2
 | 
			
		||||
        self.id2entity = {}
 | 
			
		||||
        with open("./dataset/FB15k-237/entity2textlong.txt", 'r') as file:
 | 
			
		||||
            cnt = 0
 | 
			
		||||
            for line in file.readlines():
 | 
			
		||||
                e, d = line.strip().split("\t")
 | 
			
		||||
                self.id2entity[cnt] = e
 | 
			
		||||
                cnt += 1
 | 
			
		||||
        self.id2entity_t = {}
 | 
			
		||||
        with open("./dataset/FB15k-237/entity2text.txt", 'r') as file:
 | 
			
		||||
            for line in file.readlines():
 | 
			
		||||
                e, d = line.strip().split("\t")
 | 
			
		||||
                self.id2entity_t[e] = d
 | 
			
		||||
        for k, v in self.id2entity.items():
 | 
			
		||||
            self.id2entity[k] = self.id2entity_t[v]
 | 
			
		||||
    def _eval(self, batch, batch_idx, ):
 | 
			
		||||
        labels = batch.pop("labels")
 | 
			
		||||
        input_ids = batch['input_ids']
 | 
			
		||||
        # single label
 | 
			
		||||
        label = batch.pop('label')
 | 
			
		||||
        pos = batch.pop("pos")
 | 
			
		||||
 | 
			
		||||
        result = self.model(**batch, return_dict=True, output_hidden_states=True)
 | 
			
		||||
        hidden_states = result.hidden_states[-1]
 | 
			
		||||
        _, mask_idx = (input_ids == self.tokenizer.mask_token_id).nonzero(as_tuple=True)
 | 
			
		||||
        bsz = input_ids.shape[0]
 | 
			
		||||
        mask_embedding = np.array(hidden_states[torch.arange(bsz), mask_idx].cpu(), dtype=np.float32)
 | 
			
		||||
        # mask_embedding = np.array(hidden_states[torch.arange(bsz), mask_idx].cpu(), dtype=np.float32)
 | 
			
		||||
        topk = self.args.knn_topk
 | 
			
		||||
        D, I = self.index.search(mask_embedding, topk)
 | 
			
		||||
        D = torch.from_numpy(D).to(self.device)
 | 
			
		||||
        assert labels[0][label[0]], "correct ids must in filiter!"
 | 
			
		||||
        labels[torch.arange(bsz), label] = 0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
        mask_logits = result.logits[:, :, self.entity_id_st:self.entity_id_ed]
 | 
			
		||||
        mask_logits = mask_logits[torch.arange(bsz), mask_idx]
 | 
			
		||||
        entity_logits = torch.full(labels.shape, 1000.).to(self.device)
 | 
			
		||||
        # D = self.dis2logits(D)
 | 
			
		||||
        for i in range(bsz):
 | 
			
		||||
            for j in range(topk):
 | 
			
		||||
                # filter entity in labels
 | 
			
		||||
                if labels[i][self.faissid2entityid[I[i][j]]]: continue
 | 
			
		||||
                if entity_logits[i][self.faissid2entityid[I[i][j]]] == 1000.:
 | 
			
		||||
                    entity_logits[i][self.faissid2entityid[I[i][j]]] = D[i][j]
 | 
			
		||||
                # else:
 | 
			
		||||
                #     entity_logits[i][self.faissid2entityid[I[i][j]]] += D[i][j]
 | 
			
		||||
        entity_logits = self.dis2logits(entity_logits)
 | 
			
		||||
        # get the entity ranks
 | 
			
		||||
        # filter the entity
 | 
			
		||||
        assert entity_logits.shape == labels.shape
 | 
			
		||||
        assert mask_logits.shape == labels.shape
 | 
			
		||||
        # entity_logits = torch.softmax(entity_logits + labels * -100, dim=-1) # mask entityj
 | 
			
		||||
        entity_logits = entity_logits + labels* -100.
 | 
			
		||||
        mask_logits = torch.softmax(mask_logits + labels* -100, dim=-1)
 | 
			
		||||
        # logits = mask_logits
 | 
			
		||||
        logits = combine_knn_and_vocab_probs(entity_logits, mask_logits, self.args.knn_lambda)
 | 
			
		||||
        # logits = entity_logits + mask_logits
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
        knn_topk_logits, knn_topk_id  = entity_logits.topk(20)
 | 
			
		||||
        mask_topk_logits, mask_topk_id  = mask_logits.topk(20)
 | 
			
		||||
        union_topk = []
 | 
			
		||||
        for i in range(bsz):
 | 
			
		||||
            num_same = len(list(set(knn_topk_id[i].cpu().tolist()) & set(mask_topk_id[i].cpu().tolist())))
 | 
			
		||||
            union_topk.append(num_same/ 20.)
 | 
			
		||||
        
 | 
			
		||||
        knn_topk_id = knn_topk_id.to("cpu")
 | 
			
		||||
        mask_topk_id = mask_topk_id.to("cpu")
 | 
			
		||||
        mask_topk_logits = mask_topk_logits.to("cpu")
 | 
			
		||||
        knn_topk_logits = knn_topk_logits.to("cpu")
 | 
			
		||||
        label = label.to("cpu")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
        for t in range(bsz):
 | 
			
		||||
            if knn_topk_id[t][0] == label[t] and knn_topk_logits[t][0] > mask_topk_logits[t][0] and mask_topk_logits[t][0] <= 0.4:
 | 
			
		||||
                print(knn_topk_logits[t], knn_topk_id[t])
 | 
			
		||||
                print(lmap(lambda x: self.id2entity[x.item()], knn_topk_id[t]))
 | 
			
		||||
                print(mask_topk_logits[t], mask_topk_id[t])
 | 
			
		||||
                print(lmap(lambda x: self.id2entity[x.item()], mask_topk_id[t]))
 | 
			
		||||
                print(label[t])
 | 
			
		||||
                print()
 | 
			
		||||
 | 
			
		||||
        _, outputs = torch.sort(logits, dim=1, descending=True)
 | 
			
		||||
        _, outputs = torch.sort(outputs, dim=1)
 | 
			
		||||
        ranks = outputs[torch.arange(bsz), label].detach().cpu() + 1
 | 
			
		||||
        
 | 
			
		||||
 | 
			
		||||
        return dict(ranks = np.array(ranks), knn_topk_id=knn_topk_id, knn_topk_logits=knn_topk_logits,
 | 
			
		||||
            mask_topk_id=mask_topk_id, mask_topk_logits=mask_topk_logits, num_same = np.array(union_topk))
 | 
			
		||||
    
 | 
			
		||||
    def test_epoch_end(self, outputs) -> None:
 | 
			
		||||
 | 
			
		||||
        ranks = np.concatenate([_['ranks'] for _ in outputs])
 | 
			
		||||
        num_same = np.concatenate([_['num_same'] for _ in outputs])
 | 
			
		||||
        results_keys = list(outputs[0].keys())
 | 
			
		||||
        results = {}
 | 
			
		||||
        # for k in results_keys:
 | 
			
		||||
        #     results.
 | 
			
		||||
 | 
			
		||||
        self.log("Test/num_same", num_same.mean())
 | 
			
		||||
 | 
			
		||||
        hits20 = (ranks<=20).mean()
 | 
			
		||||
        hits10 = (ranks<=10).mean()
 | 
			
		||||
        hits3 = (ranks<=3).mean()
 | 
			
		||||
        hits1 = (ranks<=1).mean()
 | 
			
		||||
 | 
			
		||||
       
 | 
			
		||||
        self.log("Test/hits10", hits10)
 | 
			
		||||
        self.log("Test/hits20", hits20)
 | 
			
		||||
        self.log("Test/hits3", hits3)
 | 
			
		||||
        self.log("Test/hits1", hits1)
 | 
			
		||||
        self.log("Test/mean_rank", ranks.mean())
 | 
			
		||||
        self.log("Test/mrr", (1. / ranks).mean())
 | 
			
		||||
 | 
			
		||||
    def add_to_argparse(parser):
 | 
			
		||||
        parser = TransformerLitModel.add_to_argparse(parser)
 | 
			
		||||
        parser.add_argument("--knn_lambda", type=float, default=0.5, help="lambda * knn + (1-lambda) * mask logits , lambda of knn logits and mask logits.")
 | 
			
		||||
        parser.add_argument("--knn_topk", type=int, default=100, help="")
 | 
			
		||||
 | 
			
		||||
        return parser
 | 
			
		||||
 | 
			
		||||
def combine_knn_and_vocab_probs(knn_p, vocab_p, coeff=0.5):
 | 
			
		||||
    combine_probs = torch.stack([vocab_p, knn_p], dim=0)
 | 
			
		||||
    coeffs = torch.ones_like(combine_probs)
 | 
			
		||||
    coeffs[0] = np.log(1 - coeff)
 | 
			
		||||
    coeffs[1] = np.log(coeff)
 | 
			
		||||
    curr_prob = torch.logsumexp(combine_probs + coeffs, dim=0)
 | 
			
		||||
 | 
			
		||||
    return curr_prob
 | 
			
		||||
 | 
			
		||||
def distance2logits(D):
 | 
			
		||||
    return torch.softmax( -1. * torch.tensor(D) / 30., dim=-1)
 | 
			
		||||
 | 
			
		||||
def distance2logits_2(D, n=10):
 | 
			
		||||
    if not isinstance(D, torch.Tensor):
 | 
			
		||||
        D = torch.tensor(D)
 | 
			
		||||
    if torch.sum(D) != 0.0:
 | 
			
		||||
        distances = torch.exp(-D/n) / torch.sum(torch.exp(-D/n), dim=-1, keepdim=True)
 | 
			
		||||
    return distances
 | 
			
		||||
							
								
								
									
										66
									
								
								pretrain/lit_models/utils.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										66
									
								
								pretrain/lit_models/utils.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,66 @@
 | 
			
		||||
import json
 | 
			
		||||
import numpy as np
 | 
			
		||||
 | 
			
		||||
def rank_score(ranks):
 | 
			
		||||
	# prepare the dataset
 | 
			
		||||
	len_samples = len(ranks)
 | 
			
		||||
	hits10 = [0] * len_samples
 | 
			
		||||
	hits5 = [0] * len_samples
 | 
			
		||||
	hits1 = [0] * len_samples
 | 
			
		||||
	mrr = []
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
	for idx, rank in enumerate(ranks):
 | 
			
		||||
		if rank <= 10:
 | 
			
		||||
			hits10[idx] = 1.
 | 
			
		||||
			if rank <= 5:
 | 
			
		||||
				hits5[idx] = 1.
 | 
			
		||||
				if rank <= 1:
 | 
			
		||||
					hits1[idx] = 1.
 | 
			
		||||
		mrr.append(1./rank)
 | 
			
		||||
	
 | 
			
		||||
 | 
			
		||||
	return np.mean(hits10), np.mean(hits5), np.mean(hits1), np.mean(mrr)
 | 
			
		||||
 | 
			
		||||
def acc(logits, labels):
 | 
			
		||||
    preds = np.argmax(logits, axis=-1)
 | 
			
		||||
    return (preds == labels).mean()
 | 
			
		||||
import torch.nn as nn
 | 
			
		||||
import torch
 | 
			
		||||
class LabelSmoothSoftmaxCEV1(nn.Module):
 | 
			
		||||
    '''
 | 
			
		||||
    This is the autograd version, you can also try the LabelSmoothSoftmaxCEV2 that uses derived gradients
 | 
			
		||||
    '''
 | 
			
		||||
 | 
			
		||||
    def __init__(self, lb_smooth=0.1, reduction='mean', ignore_index=-100):
 | 
			
		||||
        super(LabelSmoothSoftmaxCEV1, self).__init__()
 | 
			
		||||
        self.lb_smooth = lb_smooth
 | 
			
		||||
        self.reduction = reduction
 | 
			
		||||
        self.lb_ignore = ignore_index
 | 
			
		||||
        self.log_softmax = nn.LogSoftmax(dim=1)
 | 
			
		||||
 | 
			
		||||
    def forward(self, logits, label):
 | 
			
		||||
        '''
 | 
			
		||||
        args: logits: tensor of shape (N, C, H, W)
 | 
			
		||||
        args: label: tensor of shape(N, H, W)
 | 
			
		||||
        '''
 | 
			
		||||
        # overcome ignored label
 | 
			
		||||
        with torch.no_grad():
 | 
			
		||||
            num_classes = logits.size(1)
 | 
			
		||||
            label = label.clone().detach()
 | 
			
		||||
            ignore = label == self.lb_ignore
 | 
			
		||||
            n_valid = (ignore == 0).sum()
 | 
			
		||||
            label[ignore] = 0
 | 
			
		||||
            lb_pos, lb_neg = 1. - self.lb_smooth, self.lb_smooth / num_classes
 | 
			
		||||
            label = torch.empty_like(logits).fill_(
 | 
			
		||||
                lb_neg).scatter_(1, label.unsqueeze(1), lb_pos).detach()
 | 
			
		||||
 | 
			
		||||
        logs = self.log_softmax(logits)
 | 
			
		||||
        loss = -torch.sum(logs * label, dim=1)
 | 
			
		||||
        loss[ignore] = 0
 | 
			
		||||
        if self.reduction == 'mean':
 | 
			
		||||
            loss = loss.sum() / n_valid
 | 
			
		||||
        if self.reduction == 'sum':
 | 
			
		||||
            loss = loss.sum()
 | 
			
		||||
 | 
			
		||||
        return loss
 | 
			
		||||
							
								
								
									
										141
									
								
								pretrain/main.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										141
									
								
								pretrain/main.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,141 @@
 | 
			
		||||
import argparse
 | 
			
		||||
import importlib
 | 
			
		||||
from logging import debug
 | 
			
		||||
 | 
			
		||||
import numpy as np
 | 
			
		||||
import torch
 | 
			
		||||
import pytorch_lightning as pl
 | 
			
		||||
import lit_models
 | 
			
		||||
import yaml
 | 
			
		||||
import time
 | 
			
		||||
from transformers import AutoConfig
 | 
			
		||||
import os
 | 
			
		||||
# from utils import get_decoder_input_ids
 | 
			
		||||
 | 
			
		||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# In order to ensure reproducible experiments, we must set random seeds.
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _import_class(module_and_class_name: str) -> type:
 | 
			
		||||
    """Import class from a module, e.g. 'text_recognizer.models.MLP'"""
 | 
			
		||||
    module_name, class_name = module_and_class_name.rsplit(".", 1)
 | 
			
		||||
    module = importlib.import_module(module_name)
 | 
			
		||||
    class_ = getattr(module, class_name)
 | 
			
		||||
    return class_
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _setup_parser():
 | 
			
		||||
    """Set up Python's ArgumentParser with data, model, trainer, and other arguments."""
 | 
			
		||||
    parser = argparse.ArgumentParser(add_help=False)
 | 
			
		||||
 | 
			
		||||
    # Add Trainer specific arguments, such as --max_epochs, --gpus, --precision
 | 
			
		||||
    trainer_parser = pl.Trainer.add_argparse_args(parser)
 | 
			
		||||
    trainer_parser._action_groups[1].title = "Trainer Args"  # pylint: disable=protected-access
 | 
			
		||||
    parser = argparse.ArgumentParser(add_help=False, parents=[trainer_parser])
 | 
			
		||||
 | 
			
		||||
    # Basic arguments
 | 
			
		||||
    parser.add_argument("--wandb", action="store_true", default=False)
 | 
			
		||||
    parser.add_argument("--litmodel_class", type=str, default="TransformerLitModel")
 | 
			
		||||
    parser.add_argument("--seed", type=int, default=7)
 | 
			
		||||
    parser.add_argument("--data_class", type=str, default="KGC")
 | 
			
		||||
    parser.add_argument("--chunk", type=str, default="")
 | 
			
		||||
    parser.add_argument("--model_class", type=str, default="RobertaUseLabelWord")
 | 
			
		||||
    parser.add_argument("--checkpoint", type=str, default=None)
 | 
			
		||||
 | 
			
		||||
    # Get the data and model classes, so that we can add their specific arguments
 | 
			
		||||
    temp_args, _ = parser.parse_known_args()
 | 
			
		||||
    data_class = _import_class(f"data.{temp_args.data_class}")
 | 
			
		||||
    model_class = _import_class(f"models.{temp_args.model_class}")
 | 
			
		||||
    lit_model_class = _import_class(f"lit_models.{temp_args.litmodel_class}")
 | 
			
		||||
 | 
			
		||||
    # Get data, model, and LitModel specific arguments
 | 
			
		||||
    data_group = parser.add_argument_group("Data Args")
 | 
			
		||||
    data_class.add_to_argparse(data_group)
 | 
			
		||||
 | 
			
		||||
    model_group = parser.add_argument_group("Model Args")
 | 
			
		||||
    if hasattr(model_class, "add_to_argparse"):
 | 
			
		||||
        model_class.add_to_argparse(model_group)
 | 
			
		||||
 | 
			
		||||
    lit_model_group = parser.add_argument_group("LitModel Args")
 | 
			
		||||
    lit_model_class.add_to_argparse(lit_model_group)
 | 
			
		||||
 | 
			
		||||
    parser.add_argument("--help", "-h", action="help")
 | 
			
		||||
    return parser
 | 
			
		||||
 | 
			
		||||
def _saved_pretrain(lit_model, tokenizer, path):
 | 
			
		||||
    lit_model.model.save_pretrained(path)
 | 
			
		||||
    tokenizer.save_pretrained(path)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def main():
 | 
			
		||||
    parser = _setup_parser()
 | 
			
		||||
    args = parser.parse_args()
 | 
			
		||||
 | 
			
		||||
    np.random.seed(args.seed)
 | 
			
		||||
    torch.manual_seed(args.seed)
 | 
			
		||||
    pl.seed_everything(args.seed)
 | 
			
		||||
 | 
			
		||||
    data_class = _import_class(f"data.{args.data_class}")
 | 
			
		||||
    model_class = _import_class(f"models.{args.model_class}")
 | 
			
		||||
    litmodel_class = _import_class(f"lit_models.{args.litmodel_class}")
 | 
			
		||||
 | 
			
		||||
    config = AutoConfig.from_pretrained(args.model_name_or_path)
 | 
			
		||||
    # update parameters
 | 
			
		||||
    config.label_smoothing = args.label_smoothing
 | 
			
		||||
    
 | 
			
		||||
    
 | 
			
		||||
    model = model_class.from_pretrained(args.model_name_or_path, config=config)
 | 
			
		||||
    data = data_class(args, model)
 | 
			
		||||
    tokenizer = data.tokenizer
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    lit_model = litmodel_class(args=args, model=model, tokenizer=tokenizer, data_config=data.get_config())
 | 
			
		||||
    if args.checkpoint:
 | 
			
		||||
        lit_model.load_state_dict(torch.load(args.checkpoint, map_location="cpu")["state_dict"])
 | 
			
		||||
 | 
			
		||||
    logger = pl.loggers.TensorBoardLogger("training/logs")
 | 
			
		||||
    if args.wandb:
 | 
			
		||||
        logger = pl.loggers.WandbLogger(project="kgc_bert", name=args.data_dir.split("/")[-1])
 | 
			
		||||
        logger.log_hyperparams(vars(args))
 | 
			
		||||
 | 
			
		||||
    metric_name = "Eval/mrr" if not args.pretrain else "Eval/hits1"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    early_callback = pl.callbacks.EarlyStopping(monitor="Eval/mrr", mode="max", patience=10)
 | 
			
		||||
    model_checkpoint = pl.callbacks.ModelCheckpoint(monitor=metric_name, mode="max",
 | 
			
		||||
        filename=args.data_dir.split("/")[-1] + '/{epoch}-{Eval/hits10:.2f}-{Eval/hits1:.2f}' if not args.pretrain else args.data_dir.split("/")[-1] + '/{epoch}-{step}-{Eval/hits10:.2f}',
 | 
			
		||||
        dirpath="output",
 | 
			
		||||
        save_weights_only=True, # to be modified
 | 
			
		||||
        every_n_train_steps=100 if args.pretrain else None,
 | 
			
		||||
        save_top_k=5 if args.pretrain else 1
 | 
			
		||||
    )
 | 
			
		||||
    callbacks = [early_callback, model_checkpoint]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    # args.weights_summary = "full"  # Print full summary of the model
 | 
			
		||||
    trainer = pl.Trainer.from_argparse_args(args, callbacks=callbacks, logger=logger, default_root_dir="training/logs")
 | 
			
		||||
    
 | 
			
		||||
    if "EntityEmbedding" not in lit_model.__class__.__name__:
 | 
			
		||||
        trainer.fit(lit_model, datamodule=data)
 | 
			
		||||
        path = model_checkpoint.best_model_path
 | 
			
		||||
        lit_model.load_state_dict(torch.load(path)["state_dict"])
 | 
			
		||||
 | 
			
		||||
    result = trainer.test(lit_model, datamodule=data)
 | 
			
		||||
    print(result)
 | 
			
		||||
 | 
			
		||||
    # _saved_pretrain(lit_model, tokenizer, path)
 | 
			
		||||
    if "EntityEmbedding" not in lit_model.__class__.__name__:
 | 
			
		||||
        print("*path"*30)
 | 
			
		||||
        print(path)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
 | 
			
		||||
    main()
 | 
			
		||||
							
								
								
									
										6
									
								
								pretrain/models/__init__.py
									
									
									
									
									
										Executable file
									
								
							
							
						
						
									
										6
									
								
								pretrain/models/__init__.py
									
									
									
									
									
										Executable file
									
								
							@@ -0,0 +1,6 @@
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
from transformers import BartForConditionalGeneration, T5ForConditionalGeneration, GPT2LMHeadModel
 | 
			
		||||
 | 
			
		||||
from .model import *
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										10
									
								
								pretrain/models/model.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										10
									
								
								pretrain/models/model.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,10 @@
 | 
			
		||||
from transformers.models.bert.modeling_bert import BertForMaskedLM
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class BertKGC(BertForMaskedLM):
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def add_to_argparse(parser):
 | 
			
		||||
        parser.add_argument("--pretrain", type=int, default=0, help="")
 | 
			
		||||
        return parser
 | 
			
		||||
							
								
								
									
										1159
									
								
								pretrain/models/utils.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										1159
									
								
								pretrain/models/utils.py
									
									
									
									
									
										Normal file
									
								
							
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							
							
								
								
									
										14
									
								
								pretrain/scripts/pretrain_fb15k-237.sh
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										14
									
								
								pretrain/scripts/pretrain_fb15k-237.sh
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,14 @@
 | 
			
		||||
nohup python -u main.py --gpus "1" --max_epochs=16  --num_workers=32 \
 | 
			
		||||
   --model_name_or_path  bert-base-uncased \
 | 
			
		||||
   --accumulate_grad_batches 1 \
 | 
			
		||||
   --model_class BertKGC \
 | 
			
		||||
   --batch_size 128 \
 | 
			
		||||
   --pretrain 1 \
 | 
			
		||||
   --bce 0 \
 | 
			
		||||
   --check_val_every_n_epoch 1 \
 | 
			
		||||
   --overwrite_cache \
 | 
			
		||||
   --data_dir /kg_374/Relphormer/dataset/FB15k-237 \
 | 
			
		||||
   --eval_batch_size 256 \
 | 
			
		||||
   --max_seq_length 64 \
 | 
			
		||||
   --lr 1e-4 \
 | 
			
		||||
   >logs/pretrain_fb15k-237.log 2>&1 &
 | 
			
		||||
							
								
								
									
										14
									
								
								pretrain/scripts/pretrain_umls.sh
									
									
									
									
									
										Executable file
									
								
							
							
						
						
									
										14
									
								
								pretrain/scripts/pretrain_umls.sh
									
									
									
									
									
										Executable file
									
								
							@@ -0,0 +1,14 @@
 | 
			
		||||
nohup python -u main.py --gpus "0," --max_epochs=20  --num_workers=32 \
 | 
			
		||||
   --model_name_or_path  bert-base-uncased \
 | 
			
		||||
   --accumulate_grad_batches 1 \
 | 
			
		||||
   --model_class BertKGC \
 | 
			
		||||
   --batch_size 128 \
 | 
			
		||||
   --pretrain 1 \
 | 
			
		||||
   --bce 0 \
 | 
			
		||||
   --check_val_every_n_epoch 1 \
 | 
			
		||||
   --overwrite_cache \
 | 
			
		||||
   --data_dir xxx/Relphormer/dataset/umls \
 | 
			
		||||
   --eval_batch_size 256 \
 | 
			
		||||
   --max_seq_length 64 \
 | 
			
		||||
   --lr 1e-4 \
 | 
			
		||||
   >logs/pretrain_umls.log 2>&1 &
 | 
			
		||||
							
								
								
									
										16
									
								
								pretrain/scripts/pretrain_wn18rr.sh
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										16
									
								
								pretrain/scripts/pretrain_wn18rr.sh
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,16 @@
 | 
			
		||||
nohup python -u main.py --gpus "0," --max_epochs=15  --num_workers=32 \
 | 
			
		||||
   --model_name_or_path  bert-base-uncased \
 | 
			
		||||
   --accumulate_grad_batches 1 \
 | 
			
		||||
   --bce 0 \
 | 
			
		||||
   --model_class BertKGC \
 | 
			
		||||
   --batch_size 128 \
 | 
			
		||||
   --pretrain 1 \
 | 
			
		||||
   --check_val_every_n_epoch 1 \
 | 
			
		||||
   --data_dir xxx/Relphormer/dataset/WN18RR \
 | 
			
		||||
   --overwrite_cache \
 | 
			
		||||
   --eval_batch_size 256 \
 | 
			
		||||
   --precision 16 \
 | 
			
		||||
   --max_seq_length 32 \
 | 
			
		||||
   --lr 1e-4 \
 | 
			
		||||
   >logs/pretrain_wn18rr.log 2>&1 &
 | 
			
		||||
 | 
			
		||||
		Reference in New Issue
	
	Block a user