Compare commits
2 Commits
main
...
negative_s
Author | SHA1 | Date | |
---|---|---|---|
45cd8e1396 | |||
fcfeae2bd3 |
1
.gitignore
vendored
1
.gitignore
vendored
@ -7,3 +7,4 @@ dataset/FB15k-237/masked_*.txt
|
|||||||
dataset/FB15k-237/cached_*.pkl
|
dataset/FB15k-237/cached_*.pkl
|
||||||
**/__pycache__/
|
**/__pycache__/
|
||||||
**/.DS_Store
|
**/.DS_Store
|
||||||
|
nohup.out
|
||||||
|
38
.vscode/launch.json
vendored
Normal file
38
.vscode/launch.json
vendored
Normal file
@ -0,0 +1,38 @@
|
|||||||
|
{
|
||||||
|
// Use IntelliSense to learn about possible attributes.
|
||||||
|
// Hover to view descriptions of existing attributes.
|
||||||
|
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
|
||||||
|
"version": "0.2.0",
|
||||||
|
"configurations": [
|
||||||
|
{
|
||||||
|
"name": "Python: Current File",
|
||||||
|
"type": "python",
|
||||||
|
"request": "launch",
|
||||||
|
"program": "${file}",
|
||||||
|
"console": "integratedTerminal",
|
||||||
|
"justMyCode": true,
|
||||||
|
"args": [
|
||||||
|
"--gpus", "1",
|
||||||
|
"--max_epochs=16",
|
||||||
|
"--num_workers=32",
|
||||||
|
"--model_name_or_path", "bert-base-uncased",
|
||||||
|
"--accumulate_grad_batches", "1",
|
||||||
|
"--model_class", "BertKGC",
|
||||||
|
"--batch_size", "64",
|
||||||
|
"--checkpoint", "/kg_374/Relphormer/pretrain/output/FB15k-237/epoch=15-step=19299-Eval/hits10=0.96.ckpt",
|
||||||
|
"--pretrain", "0",
|
||||||
|
"--bce", "0",
|
||||||
|
"--check_val_every_n_epoch", "1",
|
||||||
|
"--overwrite_cache",
|
||||||
|
"--data_dir", "dataset/FB15k-237",
|
||||||
|
"--eval_batch_size", "128",
|
||||||
|
"--max_seq_length", "128",
|
||||||
|
"--lr", "3e-5",
|
||||||
|
"--max_triplet", "64",
|
||||||
|
"--add_attn_bias", "True",
|
||||||
|
"--use_global_node", "True",
|
||||||
|
"--fast_dev_run", "True",
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
@ -10,7 +10,7 @@ from transformers.tokenization_utils_base import (BatchEncoding,
|
|||||||
PreTrainedTokenizerBase)
|
PreTrainedTokenizerBase)
|
||||||
|
|
||||||
from .base_data_module import BaseDataModule
|
from .base_data_module import BaseDataModule
|
||||||
from .processor import KGProcessor, get_dataset
|
from .processor import KGProcessor, get_dataset, getNegativeEntityId
|
||||||
import transformers
|
import transformers
|
||||||
transformers.logging.set_verbosity_error()
|
transformers.logging.set_verbosity_error()
|
||||||
|
|
||||||
@ -79,6 +79,7 @@ class DataCollatorForSeq2Seq:
|
|||||||
label_pad_token_id: int = -100
|
label_pad_token_id: int = -100
|
||||||
return_tensors: str = "pt"
|
return_tensors: str = "pt"
|
||||||
num_labels: int = 0
|
num_labels: int = 0
|
||||||
|
args: Any = None
|
||||||
|
|
||||||
def __call__(self, features, return_tensors=None):
|
def __call__(self, features, return_tensors=None):
|
||||||
|
|
||||||
@ -105,8 +106,9 @@ class DataCollatorForSeq2Seq:
|
|||||||
if isinstance(l, int):
|
if isinstance(l, int):
|
||||||
new_labels[i][l] = 1
|
new_labels[i][l] = 1
|
||||||
else:
|
else:
|
||||||
for j in l:
|
if (l[0] != getNegativeEntityId(self.args)):
|
||||||
new_labels[i][j] = 1
|
for j in l:
|
||||||
|
new_labels[i][j] = 1
|
||||||
labels = new_labels
|
labels = new_labels
|
||||||
|
|
||||||
features = self.tokenizer.pad(
|
features = self.tokenizer.pad(
|
||||||
@ -141,6 +143,7 @@ class KGC(BaseDataModule):
|
|||||||
padding="longest",
|
padding="longest",
|
||||||
max_length=self.args.max_seq_length,
|
max_length=self.args.max_seq_length,
|
||||||
num_labels = len(entity_list),
|
num_labels = len(entity_list),
|
||||||
|
args=args
|
||||||
)
|
)
|
||||||
relations_tokens = self.processor.get_relations(args.data_dir)
|
relations_tokens = self.processor.get_relations(args.data_dir)
|
||||||
self.num_relations = len(relations_tokens)
|
self.num_relations = len(relations_tokens)
|
||||||
|
@ -5,7 +5,7 @@ import contextlib
|
|||||||
import sys
|
import sys
|
||||||
|
|
||||||
from collections import Counter
|
from collections import Counter
|
||||||
from multiprocessing import Pool
|
from multiprocessing import Pool, synchronize
|
||||||
from torch._C import HOIST_CONV_PACKED_PARAMS
|
from torch._C import HOIST_CONV_PACKED_PARAMS
|
||||||
from torch.utils.data import Dataset, Sampler, IterableDataset
|
from torch.utils.data import Dataset, Sampler, IterableDataset
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
@ -110,7 +110,7 @@ def cache_results(_cache_fp, _refresh=False, _verbose=1):
|
|||||||
|
|
||||||
model_name = my_args.model_name_or_path.split("/")[-1]
|
model_name = my_args.model_name_or_path.split("/")[-1]
|
||||||
is_pretrain = my_args.pretrain
|
is_pretrain = my_args.pretrain
|
||||||
cache_filepath = os.path.join(my_args.data_dir, f"cached_{mode}_features{model_name}_pretrain{is_pretrain}_faiss{my_args.faiss_init}_seqlength{my_args.max_seq_length}_{my_args.litmodel_class}.pkl")
|
cache_filepath = os.path.join(my_args.data_dir, f"cached_{func.__name__}_{mode}_features{model_name}_pretrain{is_pretrain}_faiss{my_args.faiss_init}_seqlength{my_args.max_seq_length}_{my_args.litmodel_class}.pkl")
|
||||||
refresh = my_args.overwrite_cache
|
refresh = my_args.overwrite_cache
|
||||||
|
|
||||||
if cache_filepath is not None and refresh is False:
|
if cache_filepath is not None and refresh is False:
|
||||||
@ -137,6 +137,116 @@ def cache_results(_cache_fp, _refresh=False, _verbose=1):
|
|||||||
|
|
||||||
return wrapper_
|
return wrapper_
|
||||||
|
|
||||||
|
def cache_results_load_once(_cache_fp, _refresh=False, _verbose=1, _global_var=None):
|
||||||
|
r"""
|
||||||
|
=== USE IN TRAINING MODE ONLY ===
|
||||||
|
cache_results是fastNLP中用于cache数据的装饰器。通过下面的例子看一下如何使用::
|
||||||
|
|
||||||
|
import time
|
||||||
|
import numpy as np
|
||||||
|
from fastNLP import cache_results
|
||||||
|
|
||||||
|
@cache_results('cache.pkl')
|
||||||
|
def process_data():
|
||||||
|
# 一些比较耗时的工作,比如读取数据,预处理数据等,这里用time.sleep()代替耗时
|
||||||
|
time.sleep(1)
|
||||||
|
return np.random.randint(10, size=(5,))
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
print("res =",process_data())
|
||||||
|
print(time.time() - start_time)
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
print("res =",process_data())
|
||||||
|
print(time.time() - start_time)
|
||||||
|
|
||||||
|
# 输出内容如下,可以看到两次结果相同,且第二次几乎没有花费时间
|
||||||
|
# Save cache to cache.pkl.
|
||||||
|
# res = [5 4 9 1 8]
|
||||||
|
# 1.0042750835418701
|
||||||
|
# Read cache from cache.pkl.
|
||||||
|
# res = [5 4 9 1 8]
|
||||||
|
# 0.0040721893310546875
|
||||||
|
|
||||||
|
可以看到第二次运行的时候,只用了0.0001s左右,是由于第二次运行将直接从cache.pkl这个文件读取数据,而不会经过再次预处理::
|
||||||
|
|
||||||
|
# 还是以上面的例子为例,如果需要重新生成另一个cache,比如另一个数据集的内容,通过如下的方式调用即可
|
||||||
|
process_data(_cache_fp='cache2.pkl') # 完全不影响之前的‘cache.pkl'
|
||||||
|
|
||||||
|
上面的_cache_fp是cache_results会识别的参数,它将从'cache2.pkl'这里缓存/读取数据,即这里的'cache2.pkl'覆盖默认的
|
||||||
|
'cache.pkl'。如果在你的函数前面加上了@cache_results()则你的函数会增加三个参数[_cache_fp, _refresh, _verbose]。
|
||||||
|
上面的例子即为使用_cache_fp的情况,这三个参数不会传入到你的函数中,当然你写的函数参数名也不可能包含这三个名称::
|
||||||
|
|
||||||
|
process_data(_cache_fp='cache2.pkl', _refresh=True) # 这里强制重新生成一份对预处理的cache。
|
||||||
|
# _verbose是用于控制输出信息的,如果为0,则不输出任何内容;如果为1,则会提醒当前步骤是读取的cache还是生成了新的cache
|
||||||
|
|
||||||
|
:param str _cache_fp: 将返回结果缓存到什么位置;或从什么位置读取缓存。如果为None,cache_results没有任何效用,除非在
|
||||||
|
函数调用的时候传入_cache_fp这个参数。
|
||||||
|
:param bool _refresh: 是否重新生成cache。
|
||||||
|
:param int _verbose: 是否打印cache的信息。
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
|
||||||
|
def wrapper_(func):
|
||||||
|
signature = inspect.signature(func)
|
||||||
|
for key, _ in signature.parameters.items():
|
||||||
|
if key in ('_cache_fp', '_refresh', '_verbose', '_global_var'):
|
||||||
|
raise RuntimeError("The function decorated by cache_results cannot have keyword `{}`.".format(key))
|
||||||
|
v = globals().get(_global_var, None)
|
||||||
|
if (v is not None):
|
||||||
|
return v
|
||||||
|
|
||||||
|
def wrapper(*args, **kwargs):
|
||||||
|
|
||||||
|
my_args = args[0]
|
||||||
|
mode = "train"
|
||||||
|
if '_cache_fp' in kwargs:
|
||||||
|
cache_filepath = kwargs.pop('_cache_fp')
|
||||||
|
assert isinstance(cache_filepath, str), "_cache_fp can only be str."
|
||||||
|
else:
|
||||||
|
cache_filepath = _cache_fp
|
||||||
|
if '_refresh' in kwargs:
|
||||||
|
refresh = kwargs.pop('_refresh')
|
||||||
|
assert isinstance(refresh, bool), "_refresh can only be bool."
|
||||||
|
else:
|
||||||
|
refresh = _refresh
|
||||||
|
if '_verbose' in kwargs:
|
||||||
|
verbose = kwargs.pop('_verbose')
|
||||||
|
assert isinstance(verbose, int), "_verbose can only be integer."
|
||||||
|
else:
|
||||||
|
verbose = _verbose
|
||||||
|
refresh_flag = True
|
||||||
|
|
||||||
|
model_name = my_args.model_name_or_path.split("/")[-1]
|
||||||
|
is_pretrain = my_args.pretrain
|
||||||
|
cache_filepath = os.path.join(my_args.data_dir, f"cached_{func.__name__}_{mode}_features{model_name}_pretrain{is_pretrain}.pkl")
|
||||||
|
refresh = my_args.overwrite_cache
|
||||||
|
|
||||||
|
if cache_filepath is not None and refresh is False:
|
||||||
|
# load data
|
||||||
|
if os.path.exists(cache_filepath):
|
||||||
|
with open(cache_filepath, 'rb') as f:
|
||||||
|
results = pickle.load(f)
|
||||||
|
globals()[_global_var] = results
|
||||||
|
if verbose == 1:
|
||||||
|
logger.info("Read cache from {}.".format(cache_filepath))
|
||||||
|
refresh_flag = False
|
||||||
|
|
||||||
|
if refresh_flag:
|
||||||
|
results = func(*args, **kwargs)
|
||||||
|
if cache_filepath is not None:
|
||||||
|
if results is None:
|
||||||
|
raise RuntimeError("The return value is None. Delete the decorator.")
|
||||||
|
with open(cache_filepath, 'wb') as f:
|
||||||
|
pickle.dump(results, f)
|
||||||
|
logger.info("Save cache to {}.".format(cache_filepath))
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
return wrapper_
|
||||||
|
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import csv
|
import csv
|
||||||
@ -235,6 +345,31 @@ class DataProcessor(object):
|
|||||||
|
|
||||||
import copy
|
import copy
|
||||||
|
|
||||||
|
from collections import deque
|
||||||
|
import threading
|
||||||
|
|
||||||
|
|
||||||
|
class _LiveState(type):
|
||||||
|
_instances = {}
|
||||||
|
_lock = threading.Lock()
|
||||||
|
|
||||||
|
def __call__(cls, *args, **kwargs):
|
||||||
|
if cls not in cls._instances:
|
||||||
|
with cls._lock:
|
||||||
|
if cls not in cls._instances:
|
||||||
|
cls._instances[cls] = super(_LiveState, cls).__call__(*args, **kwargs)
|
||||||
|
|
||||||
|
return cls._instances[cls]
|
||||||
|
|
||||||
|
class LiveState(metaclass=_LiveState):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self._pool_size = 4
|
||||||
|
self._deq = deque(maxlen=self._pool_size)
|
||||||
|
def put(self, item):
|
||||||
|
self._deq.append(item)
|
||||||
|
def get(self):
|
||||||
|
return list(self._deq)
|
||||||
|
|
||||||
def solve_get_knowledge_store(line, set_type="train", pretrain=1):
|
def solve_get_knowledge_store(line, set_type="train", pretrain=1):
|
||||||
"""
|
"""
|
||||||
@ -364,6 +499,58 @@ def solve(line, set_type="train", pretrain=1, max_triplet=32):
|
|||||||
InputExample(guid=guid, text_a="[MASK]", text_b="[PAD]", text_c = "[PAD]", text_d = list(masked_head_seq), label=lmap(lambda x: ent2id[x], b), real_label=ent2id[line[0]], en=[line[1], line[2]], en_id = [rel2id[line[1]], ent2id[line[2]]], rel=rel2id[line[1]], text_d_id = list(masked_head_seq_id), graph_inf = masked_head_graph_list))
|
InputExample(guid=guid, text_a="[MASK]", text_b="[PAD]", text_c = "[PAD]", text_d = list(masked_head_seq), label=lmap(lambda x: ent2id[x], b), real_label=ent2id[line[0]], en=[line[1], line[2]], en_id = [rel2id[line[1]], ent2id[line[2]]], rel=rel2id[line[1]], text_d_id = list(masked_head_seq_id), graph_inf = masked_head_graph_list))
|
||||||
examples.append(
|
examples.append(
|
||||||
InputExample(guid=guid, text_a="[PAD]", text_b="[PAD]", text_c = "[MASK]", text_d = list(masked_tail_seq), label=lmap(lambda x: ent2id[x], a), real_label=ent2id[line[2]], en=[line[0], line[1]], en_id = [ent2id[line[0]], rel2id[line[1]]], rel=rel2id[line[1]], text_d_id = list(masked_tail_seq_id), graph_inf = masked_tail_graph_list))
|
InputExample(guid=guid, text_a="[PAD]", text_b="[PAD]", text_c = "[MASK]", text_d = list(masked_tail_seq), label=lmap(lambda x: ent2id[x], a), real_label=ent2id[line[2]], en=[line[0], line[1]], en_id = [ent2id[line[0]], rel2id[line[1]]], rel=rel2id[line[1]], text_d_id = list(masked_tail_seq_id), graph_inf = masked_tail_graph_list))
|
||||||
|
|
||||||
|
liveState = LiveState()
|
||||||
|
_prev = liveState.get()
|
||||||
|
|
||||||
|
if (set_type == "train" and len(_prev) > 0):
|
||||||
|
|
||||||
|
for prev_ent in _prev:
|
||||||
|
|
||||||
|
if (prev_ent == line[0] or prev_ent == line[2]):
|
||||||
|
continue
|
||||||
|
|
||||||
|
z = head_filter_entities["\t".join([prev_ent,line[1]])]
|
||||||
|
if (len(z) == 0):
|
||||||
|
z.append('[NEG]')
|
||||||
|
z.append(line[2])
|
||||||
|
z.append(line[0])
|
||||||
|
masked_neg_seq = set()
|
||||||
|
masked_neg_seq_id = set()
|
||||||
|
|
||||||
|
masked_neg_graph_list = masked_tail_neighbor.get("\t".join([prev_ent, line[1]]), []) if len(masked_tail_neighbor.get("\t".join([prev_ent, line[1]]), [])) < max_triplet else \
|
||||||
|
random.sample(masked_tail_neighbor["\t".join([prev_ent, line[1]])], max_triplet)
|
||||||
|
if (len(masked_head_graph_list) == 0):
|
||||||
|
masked_head_graph_list.append(['[NEG]', line[1], '[NEG]'])
|
||||||
|
|
||||||
|
for item in masked_neg_graph_list:
|
||||||
|
masked_neg_seq.add(item[0])
|
||||||
|
masked_neg_seq.add(item[1])
|
||||||
|
masked_neg_seq.add(item[2])
|
||||||
|
masked_neg_seq_id.add(ent2id[item[0]])
|
||||||
|
masked_neg_seq_id.add(rel2id[item[1]])
|
||||||
|
masked_neg_seq_id.add(ent2id[item[2]])
|
||||||
|
|
||||||
|
masked_neg_seq = masked_neg_seq.difference({line[0]})
|
||||||
|
masked_neg_seq = masked_neg_seq.difference({line[2]})
|
||||||
|
masked_neg_seq = masked_neg_seq.difference({line[1]})
|
||||||
|
masked_neg_seq = masked_neg_seq.difference({prev_ent})
|
||||||
|
masked_neg_seq_id = masked_neg_seq_id.difference({ent2id[line[0]]})
|
||||||
|
masked_neg_seq_id = masked_neg_seq_id.difference({rel2id[line[1]]})
|
||||||
|
masked_neg_seq_id = masked_neg_seq_id.difference({ent2id[line[2]]})
|
||||||
|
masked_neg_seq_id = masked_neg_seq_id.difference({ent2id[prev_ent]})
|
||||||
|
# examples.append(
|
||||||
|
# InputExample(guid=guid, text_a="[MASK]", text_b=' '.join(text_b.split(' ')[:16]) + " [PAD]", text_c = "[PAD]" + " " + ' '.join(text_c.split(' ')[:16]), text_d = masked_head_seq, label=lmap(lambda x: ent2id[x], b), real_label=ent2id[line[0]], en=[rel2id[line[1]], ent2id[line[2]]], rel=rel2id[line[1]]))
|
||||||
|
# examples.append(
|
||||||
|
# InputExample(guid=guid, text_a="[PAD] ", text_b=' '.join(text_b.split(' ')[:16]) + " [PAD]", text_c = "[MASK]" +" " + ' '.join(text_a.split(' ')[:16]), text_d = masked_tail_seq, label=lmap(lambda x: ent2id[x], a), real_label=ent2id[line[2]], en=[ent2id[line[0]], rel2id[line[1]]], rel=rel2id[line[1]]))
|
||||||
|
examples.append(
|
||||||
|
InputExample(guid=guid, text_a="[MASK]", text_b="[PAD]", text_c = "[PAD]", text_d = list(masked_neg_seq), label=lmap(lambda x: ent2id[x], z), real_label=ent2id[line[0]], en=[line[1], prev_ent], en_id = [rel2id[line[1]], ent2id[prev_ent]], rel=rel2id[line[1]], text_d_id = list(masked_neg_seq_id), graph_inf = masked_neg_graph_list))
|
||||||
|
examples.append(
|
||||||
|
InputExample(guid=guid, text_a="[MASK]", text_b="[PAD]", text_c = "[PAD]", text_d = list(masked_neg_seq), label=lmap(lambda x: ent2id[x], z), real_label=ent2id[line[2]], en=[line[1], prev_ent], en_id = [rel2id[line[1]], ent2id[prev_ent]], rel=rel2id[line[1]], text_d_id = list(masked_neg_seq_id), graph_inf = masked_neg_graph_list))
|
||||||
|
|
||||||
|
|
||||||
|
liveState.put(line[0])
|
||||||
|
liveState.put(line[2])
|
||||||
return examples
|
return examples
|
||||||
|
|
||||||
def filter_init(head, tail, t1,t2, ent2id_, ent2token_, rel2id_, masked_head_neighbor_, masked_tail_neighbor_, rel2token_):
|
def filter_init(head, tail, t1,t2, ent2id_, ent2token_, rel2id_, masked_head_neighbor_, masked_tail_neighbor_, rel2token_):
|
||||||
@ -377,6 +564,7 @@ def filter_init(head, tail, t1,t2, ent2id_, ent2token_, rel2id_, masked_head_nei
|
|||||||
global masked_head_neighbor
|
global masked_head_neighbor
|
||||||
global masked_tail_neighbor
|
global masked_tail_neighbor
|
||||||
global rel2token
|
global rel2token
|
||||||
|
# global negativeEntity
|
||||||
|
|
||||||
head_filter_entities = head
|
head_filter_entities = head
|
||||||
tail_filter_entities = tail
|
tail_filter_entities = tail
|
||||||
@ -388,11 +576,23 @@ def filter_init(head, tail, t1,t2, ent2id_, ent2token_, rel2id_, masked_head_nei
|
|||||||
masked_head_neighbor = masked_head_neighbor_
|
masked_head_neighbor = masked_head_neighbor_
|
||||||
masked_tail_neighbor = masked_tail_neighbor_
|
masked_tail_neighbor = masked_tail_neighbor_
|
||||||
rel2token = rel2token_
|
rel2token = rel2token_
|
||||||
|
# negativeEntity = ent2id['[NEG]']
|
||||||
|
print("Initialized negative entity ID")
|
||||||
|
|
||||||
def delete_init(ent2text_):
|
def delete_init(ent2text_):
|
||||||
global ent2text
|
global ent2text
|
||||||
ent2text = ent2text_
|
ent2text = ent2text_
|
||||||
|
|
||||||
|
def getEntityIdByName(name):
|
||||||
|
global ent2id
|
||||||
|
return ent2id[name]
|
||||||
|
|
||||||
|
@cache_results_load_once(_cache_fp="./dataset", _global_var='negativeEntity')
|
||||||
|
def getNegativeEntityId(args):
|
||||||
|
global negativeEntity
|
||||||
|
negativeEntity = ent2id['[NEG]']
|
||||||
|
return negativeEntity
|
||||||
|
|
||||||
|
|
||||||
class KGProcessor(DataProcessor):
|
class KGProcessor(DataProcessor):
|
||||||
"""Processor for knowledge graph data set."""
|
"""Processor for knowledge graph data set."""
|
||||||
@ -443,6 +643,7 @@ class KGProcessor(DataProcessor):
|
|||||||
"""Gets all entities in the knowledge graph."""
|
"""Gets all entities in the knowledge graph."""
|
||||||
with open(self.entity_path, 'r') as f:
|
with open(self.entity_path, 'r') as f:
|
||||||
lines = f.readlines()
|
lines = f.readlines()
|
||||||
|
lines.append('[NEG]\t')
|
||||||
entities = []
|
entities = []
|
||||||
for line in lines:
|
for line in lines:
|
||||||
entities.append(line.strip().split("\t")[0])
|
entities.append(line.strip().split("\t")[0])
|
||||||
@ -469,6 +670,7 @@ class KGProcessor(DataProcessor):
|
|||||||
ent2text_with_type = {}
|
ent2text_with_type = {}
|
||||||
with open(self.entity_path, 'r') as f:
|
with open(self.entity_path, 'r') as f:
|
||||||
ent_lines = f.readlines()
|
ent_lines = f.readlines()
|
||||||
|
ent_lines.append('[NEG]\t')
|
||||||
for line in ent_lines:
|
for line in ent_lines:
|
||||||
temp = line.strip().split('\t')
|
temp = line.strip().split('\t')
|
||||||
try:
|
try:
|
||||||
@ -570,6 +772,8 @@ class KGProcessor(DataProcessor):
|
|||||||
threads = min(1, cpu_count())
|
threads = min(1, cpu_count())
|
||||||
filter_init(head_filter_entities, tail_filter_entities,ent2text, rel2text, ent2id, ent2token, rel2id, masked_head_neighbor, masked_tail_neighbor, rel2token
|
filter_init(head_filter_entities, tail_filter_entities,ent2text, rel2text, ent2id, ent2token, rel2id, masked_head_neighbor, masked_tail_neighbor, rel2token
|
||||||
)
|
)
|
||||||
|
# cache this
|
||||||
|
getNegativeEntityId(args)
|
||||||
|
|
||||||
if hasattr(args, "faiss_init") and args.faiss_init:
|
if hasattr(args, "faiss_init") and args.faiss_init:
|
||||||
annotate_ = partial(
|
annotate_ = partial(
|
||||||
@ -579,6 +783,7 @@ class KGProcessor(DataProcessor):
|
|||||||
else:
|
else:
|
||||||
annotate_ = partial(
|
annotate_ = partial(
|
||||||
solve,
|
solve,
|
||||||
|
set_type=set_type,
|
||||||
pretrain=self.args.pretrain,
|
pretrain=self.args.pretrain,
|
||||||
max_triplet=self.args.max_triplet
|
max_triplet=self.args.max_triplet
|
||||||
)
|
)
|
||||||
|
@ -81,10 +81,13 @@ class TransformerLitModel(BaseLitModel):
|
|||||||
pos = batch.pop("pos")
|
pos = batch.pop("pos")
|
||||||
try:
|
try:
|
||||||
en = batch.pop("en")
|
en = batch.pop("en")
|
||||||
|
# self.print("__DEBUG__: en", en)
|
||||||
rel = batch.pop("rel")
|
rel = batch.pop("rel")
|
||||||
|
# self.print("__DEBUG__: rel", rel)
|
||||||
except KeyError:
|
except KeyError:
|
||||||
pass
|
pass
|
||||||
input_ids = batch['input_ids']
|
input_ids = batch['input_ids']
|
||||||
|
# self.print("__DEBUG__: input_ids", input_ids)
|
||||||
|
|
||||||
distance_attention = torch.stack([pad_distance(len(input_ids[i]) - len(distance) - 1, distance) for i, distance in enumerate(batch['distance_attention'])])
|
distance_attention = torch.stack([pad_distance(len(input_ids[i]) - len(distance) - 1, distance) for i, distance in enumerate(batch['distance_attention'])])
|
||||||
distance = batch.pop("distance_attention")
|
distance = batch.pop("distance_attention")
|
||||||
@ -382,13 +385,17 @@ class CombineEntityEmbeddingLitModel(UseEntityEmbeddingLitModel):
|
|||||||
self.id2entity = {}
|
self.id2entity = {}
|
||||||
with open("./dataset/FB15k-237/entity2textlong.txt", 'r') as file:
|
with open("./dataset/FB15k-237/entity2textlong.txt", 'r') as file:
|
||||||
cnt = 0
|
cnt = 0
|
||||||
for line in file.readlines():
|
lines = file.readlines()
|
||||||
|
lines.append('[NEG]\t')
|
||||||
|
for line in lines:
|
||||||
e, d = line.strip().split("\t")
|
e, d = line.strip().split("\t")
|
||||||
self.id2entity[cnt] = e
|
self.id2entity[cnt] = e
|
||||||
cnt += 1
|
cnt += 1
|
||||||
self.id2entity_t = {}
|
self.id2entity_t = {}
|
||||||
with open("./dataset/FB15k-237/entity2text.txt", 'r') as file:
|
with open("./dataset/FB15k-237/entity2text.txt", 'r') as file:
|
||||||
for line in file.readlines():
|
lines = file.readlines()
|
||||||
|
lines.append('[NEG]\t')
|
||||||
|
for line in lines:
|
||||||
e, d = line.strip().split("\t")
|
e, d = line.strip().split("\t")
|
||||||
self.id2entity_t[e] = d
|
self.id2entity_t[e] = d
|
||||||
for k, v in self.id2entity.items():
|
for k, v in self.id2entity.items():
|
||||||
|
6
main.py
6
main.py
@ -98,6 +98,7 @@ def main():
|
|||||||
tokenizer = data.tokenizer
|
tokenizer = data.tokenizer
|
||||||
|
|
||||||
lit_model = litmodel_class(args=args, model=model, tokenizer=tokenizer, data_config=data.get_config())
|
lit_model = litmodel_class(args=args, model=model, tokenizer=tokenizer, data_config=data.get_config())
|
||||||
|
print("__DEBUG__: Initialized")
|
||||||
if args.checkpoint:
|
if args.checkpoint:
|
||||||
lit_model.load_state_dict(torch.load(args.checkpoint, map_location="cpu")["state_dict"], strict=False)
|
lit_model.load_state_dict(torch.load(args.checkpoint, map_location="cpu")["state_dict"], strict=False)
|
||||||
|
|
||||||
@ -120,9 +121,12 @@ def main():
|
|||||||
callbacks = [early_callback, model_checkpoint]
|
callbacks = [early_callback, model_checkpoint]
|
||||||
|
|
||||||
# args.weights_summary = "full" # Print full summary of the model
|
# args.weights_summary = "full" # Print full summary of the model
|
||||||
trainer = pl.Trainer.from_argparse_args(args, callbacks=callbacks, logger=logger, default_root_dir="training/logs", accelerator="ddp")
|
trainer = pl.Trainer.from_argparse_args(args, callbacks=callbacks, logger=logger, default_root_dir="training/logs")
|
||||||
|
|
||||||
|
print('__DEBUG__: Init trainer')
|
||||||
|
|
||||||
if "EntityEmbedding" not in lit_model.__class__.__name__:
|
if "EntityEmbedding" not in lit_model.__class__.__name__:
|
||||||
|
print('__DEBUG__: Fit trainer')
|
||||||
trainer.fit(lit_model, datamodule=data)
|
trainer.fit(lit_model, datamodule=data)
|
||||||
path = model_checkpoint.best_model_path
|
path = model_checkpoint.best_model_path
|
||||||
lit_model.load_state_dict(torch.load(path)["state_dict"], strict=False)
|
lit_model.load_state_dict(torch.load(path)["state_dict"], strict=False)
|
||||||
|
@ -10,7 +10,7 @@ from transformers.tokenization_utils_base import (BatchEncoding,
|
|||||||
PreTrainedTokenizerBase)
|
PreTrainedTokenizerBase)
|
||||||
|
|
||||||
from .base_data_module import BaseDataModule
|
from .base_data_module import BaseDataModule
|
||||||
from .processor import KGProcessor, get_dataset
|
from .processor import KGProcessor, get_dataset, getNegativeEntityId
|
||||||
import transformers
|
import transformers
|
||||||
transformers.logging.set_verbosity_error()
|
transformers.logging.set_verbosity_error()
|
||||||
|
|
||||||
@ -106,8 +106,9 @@ class DataCollatorForSeq2Seq:
|
|||||||
if isinstance(l, int):
|
if isinstance(l, int):
|
||||||
new_labels[i][l] = 1
|
new_labels[i][l] = 1
|
||||||
else:
|
else:
|
||||||
for j in l:
|
if (l[0] != getNegativeEntityId()):
|
||||||
new_labels[i][j] = 1
|
for j in l:
|
||||||
|
new_labels[i][j] = 1
|
||||||
labels = new_labels
|
labels = new_labels
|
||||||
|
|
||||||
features = self.tokenizer.pad(
|
features = self.tokenizer.pad(
|
||||||
|
@ -314,6 +314,7 @@ def filter_init(head, tail, t1,t2, ent2id_, ent2token_, rel2id_):
|
|||||||
global ent2id
|
global ent2id
|
||||||
global ent2token
|
global ent2token
|
||||||
global rel2id
|
global rel2id
|
||||||
|
global negativeEntity
|
||||||
|
|
||||||
head_filter_entities = head
|
head_filter_entities = head
|
||||||
tail_filter_entities = tail
|
tail_filter_entities = tail
|
||||||
@ -322,11 +323,19 @@ def filter_init(head, tail, t1,t2, ent2id_, ent2token_, rel2id_):
|
|||||||
ent2id = ent2id_
|
ent2id = ent2id_
|
||||||
ent2token = ent2token_
|
ent2token = ent2token_
|
||||||
rel2id = rel2id_
|
rel2id = rel2id_
|
||||||
|
negativeEntity = ent2id['[NEG]']
|
||||||
|
|
||||||
def delete_init(ent2text_):
|
def delete_init(ent2text_):
|
||||||
global ent2text
|
global ent2text
|
||||||
ent2text = ent2text_
|
ent2text = ent2text_
|
||||||
|
|
||||||
|
def getEntityIdByName(name):
|
||||||
|
global ent2id
|
||||||
|
return ent2id[name]
|
||||||
|
|
||||||
|
def getNegativeEntityId():
|
||||||
|
global negativeEntity
|
||||||
|
return negativeEntity
|
||||||
|
|
||||||
class KGProcessor(DataProcessor):
|
class KGProcessor(DataProcessor):
|
||||||
"""Processor for knowledge graph data set."""
|
"""Processor for knowledge graph data set."""
|
||||||
@ -377,6 +386,7 @@ class KGProcessor(DataProcessor):
|
|||||||
"""Gets all entities in the knowledge graph."""
|
"""Gets all entities in the knowledge graph."""
|
||||||
with open(self.entity_path, 'r') as f:
|
with open(self.entity_path, 'r') as f:
|
||||||
lines = f.readlines()
|
lines = f.readlines()
|
||||||
|
lines.append('[NEG]\t')
|
||||||
entities = []
|
entities = []
|
||||||
for line in lines:
|
for line in lines:
|
||||||
entities.append(line.strip().split("\t")[0])
|
entities.append(line.strip().split("\t")[0])
|
||||||
@ -403,6 +413,7 @@ class KGProcessor(DataProcessor):
|
|||||||
ent2text_with_type = {}
|
ent2text_with_type = {}
|
||||||
with open(self.entity_path, 'r') as f:
|
with open(self.entity_path, 'r') as f:
|
||||||
ent_lines = f.readlines()
|
ent_lines = f.readlines()
|
||||||
|
ent_lines.append('[NEG]\t')
|
||||||
for line in ent_lines:
|
for line in ent_lines:
|
||||||
temp = line.strip().split('\t')
|
temp = line.strip().split('\t')
|
||||||
try:
|
try:
|
||||||
|
@ -364,13 +364,17 @@ class CombineEntityEmbeddingLitModel(UseEntityEmbeddingLitModel):
|
|||||||
self.id2entity = {}
|
self.id2entity = {}
|
||||||
with open("./dataset/FB15k-237/entity2textlong.txt", 'r') as file:
|
with open("./dataset/FB15k-237/entity2textlong.txt", 'r') as file:
|
||||||
cnt = 0
|
cnt = 0
|
||||||
for line in file.readlines():
|
lines = file.readlines()
|
||||||
|
lines.append('[NEG]\t')
|
||||||
|
for line in lines:
|
||||||
e, d = line.strip().split("\t")
|
e, d = line.strip().split("\t")
|
||||||
self.id2entity[cnt] = e
|
self.id2entity[cnt] = e
|
||||||
cnt += 1
|
cnt += 1
|
||||||
self.id2entity_t = {}
|
self.id2entity_t = {}
|
||||||
with open("./dataset/FB15k-237/entity2text.txt", 'r') as file:
|
with open("./dataset/FB15k-237/entity2text.txt", 'r') as file:
|
||||||
for line in file.readlines():
|
lines = file.readlines()
|
||||||
|
lines.append('[NEG]\t')
|
||||||
|
for line in lines:
|
||||||
e, d = line.strip().split("\t")
|
e, d = line.strip().split("\t")
|
||||||
self.id2entity_t[e] = d
|
self.id2entity_t[e] = d
|
||||||
for k, v in self.id2entity.items():
|
for k, v in self.id2entity.items():
|
||||||
|
0
pretrain/scripts/pretrain_fb15k-237.sh
Normal file → Executable file
0
pretrain/scripts/pretrain_fb15k-237.sh
Normal file → Executable file
5
scripts/fb15k-237/fb15k-237.sh
Normal file → Executable file
5
scripts/fb15k-237/fb15k-237.sh
Normal file → Executable file
@ -1,13 +1,12 @@
|
|||||||
nohup python -u main.py --gpus "1" --max_epochs=16 --num_workers=32 \
|
nohup python -u main.py --gpus "2," --max_epochs=16 --num_workers=32 \
|
||||||
--model_name_or_path bert-base-uncased \
|
--model_name_or_path bert-base-uncased \
|
||||||
--accumulate_grad_batches 1 \
|
--accumulate_grad_batches 1 \
|
||||||
--model_class BertKGC \
|
--model_class BertKGC \
|
||||||
--batch_size 64 \
|
--batch_size 64 \
|
||||||
--checkpoint /kg_374/Relphormer/pretrain/output/FB15k-237/epoch=15-step=19299-Eval/hits10=0.96.ckpt \
|
--checkpoint /kg_374/Relphormer/output/FB15k-237/epoch=1-Eval/hits10=Eval/hits1=0.47-Eval/hits1=0.22.ckpt \
|
||||||
--pretrain 0 \
|
--pretrain 0 \
|
||||||
--bce 0 \
|
--bce 0 \
|
||||||
--check_val_every_n_epoch 1 \
|
--check_val_every_n_epoch 1 \
|
||||||
--overwrite_cache \
|
|
||||||
--data_dir dataset/FB15k-237 \
|
--data_dir dataset/FB15k-237 \
|
||||||
--eval_batch_size 128 \
|
--eval_batch_size 128 \
|
||||||
--max_seq_length 128 \
|
--max_seq_length 128 \
|
||||||
|
Loading…
Reference in New Issue
Block a user