Compare commits
	
		
			1 Commits
		
	
	
		
			negative_s
			...
			negative_s
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| 6cc55301ad | 
							
								
								
									
										1
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										1
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							| @@ -7,4 +7,3 @@ dataset/FB15k-237/masked_*.txt | |||||||
| dataset/FB15k-237/cached_*.pkl | dataset/FB15k-237/cached_*.pkl | ||||||
| **/__pycache__/ | **/__pycache__/ | ||||||
| **/.DS_Store | **/.DS_Store | ||||||
| nohup.out |  | ||||||
|   | |||||||
							
								
								
									
										8
									
								
								.vscode/launch.json
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										8
									
								
								.vscode/launch.json
									
									
									
									
										vendored
									
									
								
							| @@ -12,18 +12,17 @@ | |||||||
|             "console": "integratedTerminal", |             "console": "integratedTerminal", | ||||||
|             "justMyCode": true, |             "justMyCode": true, | ||||||
|             "args": [ |             "args": [ | ||||||
|                 "--gpus", "1",  |                 "--gpus", "1,",  | ||||||
|                 "--max_epochs=16",   |                 "--max_epochs=16",   | ||||||
|                 "--num_workers=32",  |                 "--num_workers=32",  | ||||||
|                 "--model_name_or_path",  "bert-base-uncased", |                 "--model_name_or_path",  "bert-base-uncased", | ||||||
|                 "--accumulate_grad_batches", "1",  |                 "--accumulate_grad_batches", "1",  | ||||||
|                 "--model_class", "BertKGC", |                 "--model_class", "BertKGC", | ||||||
|                 "--batch_size", "64", |                 "--batch_size", "32", | ||||||
|                 "--checkpoint", "/kg_374/Relphormer/pretrain/output/FB15k-237/epoch=15-step=19299-Eval/hits10=0.96.ckpt", |                 "--checkpoint", "/root/kg_374/Relphormer/pretrain/output/FB15k-237/epoch=15-step=38899-Eval/hits10=0.96.ckpt", | ||||||
|                 "--pretrain", "0", |                 "--pretrain", "0", | ||||||
|                 "--bce", "0", |                 "--bce", "0", | ||||||
|                 "--check_val_every_n_epoch", "1", |                 "--check_val_every_n_epoch", "1", | ||||||
|                 "--overwrite_cache", |  | ||||||
|                 "--data_dir", "dataset/FB15k-237",  |                 "--data_dir", "dataset/FB15k-237",  | ||||||
|                 "--eval_batch_size", "128", |                 "--eval_batch_size", "128", | ||||||
|                 "--max_seq_length", "128", |                 "--max_seq_length", "128", | ||||||
| @@ -31,7 +30,6 @@ | |||||||
|                 "--max_triplet", "64", |                 "--max_triplet", "64", | ||||||
|                 "--add_attn_bias", "True", |                 "--add_attn_bias", "True", | ||||||
|                 "--use_global_node", "True", |                 "--use_global_node", "True", | ||||||
|                 "--fast_dev_run", "True", |  | ||||||
|             ] |             ] | ||||||
|         } |         } | ||||||
|     ] |     ] | ||||||
|   | |||||||
| @@ -10,7 +10,7 @@ from transformers.tokenization_utils_base import (BatchEncoding, | |||||||
|                                                   PreTrainedTokenizerBase) |                                                   PreTrainedTokenizerBase) | ||||||
|  |  | ||||||
| from .base_data_module import BaseDataModule | from .base_data_module import BaseDataModule | ||||||
| from .processor import KGProcessor, get_dataset, getNegativeEntityId | from .processor import KGProcessor, get_dataset | ||||||
| import transformers | import transformers | ||||||
| transformers.logging.set_verbosity_error() | transformers.logging.set_verbosity_error() | ||||||
|  |  | ||||||
| @@ -79,7 +79,6 @@ class DataCollatorForSeq2Seq: | |||||||
|     label_pad_token_id: int = -100 |     label_pad_token_id: int = -100 | ||||||
|     return_tensors: str = "pt" |     return_tensors: str = "pt" | ||||||
|     num_labels: int = 0 |     num_labels: int = 0 | ||||||
|     args: Any = None |  | ||||||
|  |  | ||||||
|     def __call__(self, features, return_tensors=None): |     def __call__(self, features, return_tensors=None): | ||||||
|  |  | ||||||
| @@ -106,7 +105,6 @@ class DataCollatorForSeq2Seq: | |||||||
|                 if isinstance(l, int):  |                 if isinstance(l, int):  | ||||||
|                     new_labels[i][l] = 1 |                     new_labels[i][l] = 1 | ||||||
|                 else: |                 else: | ||||||
|                     if (l[0] != getNegativeEntityId(self.args)): |  | ||||||
|                     for j in l: |                     for j in l: | ||||||
|                         new_labels[i][j] = 1 |                         new_labels[i][j] = 1 | ||||||
|             labels = new_labels |             labels = new_labels | ||||||
| @@ -143,7 +141,6 @@ class KGC(BaseDataModule): | |||||||
|             padding="longest", |             padding="longest", | ||||||
|             max_length=self.args.max_seq_length, |             max_length=self.args.max_seq_length, | ||||||
|             num_labels = len(entity_list), |             num_labels = len(entity_list), | ||||||
|             args=args |  | ||||||
|         ) |         ) | ||||||
|         relations_tokens = self.processor.get_relations(args.data_dir) |         relations_tokens = self.processor.get_relations(args.data_dir) | ||||||
|         self.num_relations = len(relations_tokens) |         self.num_relations = len(relations_tokens) | ||||||
|   | |||||||
| @@ -5,7 +5,7 @@ import contextlib | |||||||
| import sys | import sys | ||||||
|  |  | ||||||
| from collections import Counter | from collections import Counter | ||||||
| from multiprocessing import Pool, synchronize | from multiprocessing import Pool | ||||||
| from torch._C import HOIST_CONV_PACKED_PARAMS | from torch._C import HOIST_CONV_PACKED_PARAMS | ||||||
| from torch.utils.data import Dataset, Sampler, IterableDataset | from torch.utils.data import Dataset, Sampler, IterableDataset | ||||||
| from collections import defaultdict | from collections import defaultdict | ||||||
| @@ -110,7 +110,7 @@ def cache_results(_cache_fp, _refresh=False, _verbose=1): | |||||||
|              |              | ||||||
|             model_name = my_args.model_name_or_path.split("/")[-1] |             model_name = my_args.model_name_or_path.split("/")[-1] | ||||||
|             is_pretrain = my_args.pretrain |             is_pretrain = my_args.pretrain | ||||||
|             cache_filepath = os.path.join(my_args.data_dir, f"cached_{func.__name__}_{mode}_features{model_name}_pretrain{is_pretrain}_faiss{my_args.faiss_init}_seqlength{my_args.max_seq_length}_{my_args.litmodel_class}.pkl") |             cache_filepath = os.path.join(my_args.data_dir, f"cached_{mode}_features{model_name}_pretrain{is_pretrain}_faiss{my_args.faiss_init}_seqlength{my_args.max_seq_length}_{my_args.litmodel_class}.pkl") | ||||||
|             refresh = my_args.overwrite_cache |             refresh = my_args.overwrite_cache | ||||||
|  |  | ||||||
|             if cache_filepath is not None and refresh is False: |             if cache_filepath is not None and refresh is False: | ||||||
| @@ -137,116 +137,6 @@ def cache_results(_cache_fp, _refresh=False, _verbose=1): | |||||||
|  |  | ||||||
|     return wrapper_ |     return wrapper_ | ||||||
|  |  | ||||||
| def cache_results_load_once(_cache_fp, _refresh=False, _verbose=1, _global_var=None): |  | ||||||
|     r""" |  | ||||||
|     === USE IN TRAINING MODE ONLY === |  | ||||||
|     cache_results是fastNLP中用于cache数据的装饰器。通过下面的例子看一下如何使用:: |  | ||||||
|  |  | ||||||
|         import time |  | ||||||
|         import numpy as np |  | ||||||
|         from fastNLP import cache_results |  | ||||||
|          |  | ||||||
|         @cache_results('cache.pkl') |  | ||||||
|         def process_data(): |  | ||||||
|             # 一些比较耗时的工作,比如读取数据,预处理数据等,这里用time.sleep()代替耗时 |  | ||||||
|             time.sleep(1) |  | ||||||
|             return np.random.randint(10, size=(5,)) |  | ||||||
|          |  | ||||||
|         start_time = time.time() |  | ||||||
|         print("res =",process_data()) |  | ||||||
|         print(time.time() - start_time) |  | ||||||
|          |  | ||||||
|         start_time = time.time() |  | ||||||
|         print("res =",process_data()) |  | ||||||
|         print(time.time() - start_time) |  | ||||||
|          |  | ||||||
|         # 输出内容如下,可以看到两次结果相同,且第二次几乎没有花费时间 |  | ||||||
|         # Save cache to cache.pkl. |  | ||||||
|         # res = [5 4 9 1 8] |  | ||||||
|         # 1.0042750835418701 |  | ||||||
|         # Read cache from cache.pkl. |  | ||||||
|         # res = [5 4 9 1 8] |  | ||||||
|         # 0.0040721893310546875 |  | ||||||
|  |  | ||||||
|     可以看到第二次运行的时候,只用了0.0001s左右,是由于第二次运行将直接从cache.pkl这个文件读取数据,而不会经过再次预处理:: |  | ||||||
|  |  | ||||||
|         # 还是以上面的例子为例,如果需要重新生成另一个cache,比如另一个数据集的内容,通过如下的方式调用即可 |  | ||||||
|         process_data(_cache_fp='cache2.pkl')  # 完全不影响之前的‘cache.pkl' |  | ||||||
|  |  | ||||||
|     上面的_cache_fp是cache_results会识别的参数,它将从'cache2.pkl'这里缓存/读取数据,即这里的'cache2.pkl'覆盖默认的 |  | ||||||
|     'cache.pkl'。如果在你的函数前面加上了@cache_results()则你的函数会增加三个参数[_cache_fp, _refresh, _verbose]。 |  | ||||||
|     上面的例子即为使用_cache_fp的情况,这三个参数不会传入到你的函数中,当然你写的函数参数名也不可能包含这三个名称:: |  | ||||||
|  |  | ||||||
|         process_data(_cache_fp='cache2.pkl', _refresh=True)  # 这里强制重新生成一份对预处理的cache。 |  | ||||||
|         #  _verbose是用于控制输出信息的,如果为0,则不输出任何内容;如果为1,则会提醒当前步骤是读取的cache还是生成了新的cache |  | ||||||
|  |  | ||||||
|     :param str _cache_fp: 将返回结果缓存到什么位置;或从什么位置读取缓存。如果为None,cache_results没有任何效用,除非在 |  | ||||||
|         函数调用的时候传入_cache_fp这个参数。 |  | ||||||
|     :param bool _refresh: 是否重新生成cache。 |  | ||||||
|     :param int _verbose: 是否打印cache的信息。 |  | ||||||
|     :return: |  | ||||||
|     """ |  | ||||||
|  |  | ||||||
|     def wrapper_(func): |  | ||||||
|         signature = inspect.signature(func) |  | ||||||
|         for key, _ in signature.parameters.items(): |  | ||||||
|             if key in ('_cache_fp', '_refresh', '_verbose', '_global_var'): |  | ||||||
|                 raise RuntimeError("The function decorated by cache_results cannot have keyword `{}`.".format(key)) |  | ||||||
|         v = globals().get(_global_var, None) |  | ||||||
|         if (v is not None): |  | ||||||
|             return v |  | ||||||
|  |  | ||||||
|         def wrapper(*args, **kwargs): |  | ||||||
|              |  | ||||||
|             my_args = args[0] |  | ||||||
|             mode = "train" |  | ||||||
|             if '_cache_fp' in kwargs: |  | ||||||
|                 cache_filepath = kwargs.pop('_cache_fp') |  | ||||||
|                 assert isinstance(cache_filepath, str), "_cache_fp can only be str." |  | ||||||
|             else: |  | ||||||
|                 cache_filepath = _cache_fp |  | ||||||
|             if '_refresh' in kwargs: |  | ||||||
|                 refresh = kwargs.pop('_refresh') |  | ||||||
|                 assert isinstance(refresh, bool), "_refresh can only be bool." |  | ||||||
|             else: |  | ||||||
|                 refresh = _refresh |  | ||||||
|             if '_verbose' in kwargs: |  | ||||||
|                 verbose = kwargs.pop('_verbose') |  | ||||||
|                 assert isinstance(verbose, int), "_verbose can only be integer." |  | ||||||
|             else: |  | ||||||
|                 verbose = _verbose |  | ||||||
|             refresh_flag = True |  | ||||||
|              |  | ||||||
|             model_name = my_args.model_name_or_path.split("/")[-1] |  | ||||||
|             is_pretrain = my_args.pretrain |  | ||||||
|             cache_filepath = os.path.join(my_args.data_dir, f"cached_{func.__name__}_{mode}_features{model_name}_pretrain{is_pretrain}.pkl") |  | ||||||
|             refresh = my_args.overwrite_cache |  | ||||||
|  |  | ||||||
|             if cache_filepath is not None and refresh is False: |  | ||||||
|                 # load data |  | ||||||
|                 if os.path.exists(cache_filepath): |  | ||||||
|                     with open(cache_filepath, 'rb') as f: |  | ||||||
|                         results = pickle.load(f) |  | ||||||
|                         globals()[_global_var] = results |  | ||||||
|                     if verbose == 1: |  | ||||||
|                         logger.info("Read cache from {}.".format(cache_filepath)) |  | ||||||
|                     refresh_flag = False |  | ||||||
|  |  | ||||||
|             if refresh_flag: |  | ||||||
|                 results = func(*args, **kwargs) |  | ||||||
|                 if cache_filepath is not None: |  | ||||||
|                     if results is None: |  | ||||||
|                         raise RuntimeError("The return value is None. Delete the decorator.") |  | ||||||
|                     with open(cache_filepath, 'wb') as f: |  | ||||||
|                         pickle.dump(results, f) |  | ||||||
|                     logger.info("Save cache to {}.".format(cache_filepath)) |  | ||||||
|  |  | ||||||
|             return results |  | ||||||
|  |  | ||||||
|         return wrapper |  | ||||||
|  |  | ||||||
|     return wrapper_ |  | ||||||
|  |  | ||||||
|  |  | ||||||
| import argparse | import argparse | ||||||
| import csv | import csv | ||||||
| @@ -345,31 +235,6 @@ class DataProcessor(object): | |||||||
|  |  | ||||||
| import copy | import copy | ||||||
|  |  | ||||||
| from collections import deque |  | ||||||
| import threading |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class _LiveState(type): |  | ||||||
|     _instances = {} |  | ||||||
|     _lock = threading.Lock() |  | ||||||
|  |  | ||||||
|     def __call__(cls, *args, **kwargs): |  | ||||||
|         if cls not in cls._instances: |  | ||||||
|             with cls._lock: |  | ||||||
|                 if cls not in cls._instances: |  | ||||||
|                     cls._instances[cls] = super(_LiveState, cls).__call__(*args, **kwargs) |  | ||||||
|          |  | ||||||
|         return cls._instances[cls] |  | ||||||
|  |  | ||||||
| class LiveState(metaclass=_LiveState): |  | ||||||
|  |  | ||||||
|     def __init__(self): |  | ||||||
|         self._pool_size = 4 |  | ||||||
|         self._deq = deque(maxlen=self._pool_size) |  | ||||||
|     def put(self, item): |  | ||||||
|         self._deq.append(item) |  | ||||||
|     def get(self): |  | ||||||
|         return list(self._deq) |  | ||||||
|  |  | ||||||
| def solve_get_knowledge_store(line, set_type="train", pretrain=1): | def solve_get_knowledge_store(line, set_type="train", pretrain=1): | ||||||
|     """ |     """ | ||||||
| @@ -499,58 +364,6 @@ def solve(line,  set_type="train", pretrain=1, max_triplet=32): | |||||||
|             InputExample(guid=guid, text_a="[MASK]", text_b="[PAD]", text_c = "[PAD]", text_d = list(masked_head_seq), label=lmap(lambda x: ent2id[x], b), real_label=ent2id[line[0]], en=[line[1], line[2]], en_id = [rel2id[line[1]], ent2id[line[2]]], rel=rel2id[line[1]], text_d_id = list(masked_head_seq_id), graph_inf = masked_head_graph_list)) |             InputExample(guid=guid, text_a="[MASK]", text_b="[PAD]", text_c = "[PAD]", text_d = list(masked_head_seq), label=lmap(lambda x: ent2id[x], b), real_label=ent2id[line[0]], en=[line[1], line[2]], en_id = [rel2id[line[1]], ent2id[line[2]]], rel=rel2id[line[1]], text_d_id = list(masked_head_seq_id), graph_inf = masked_head_graph_list)) | ||||||
|         examples.append( |         examples.append( | ||||||
|             InputExample(guid=guid, text_a="[PAD]", text_b="[PAD]", text_c = "[MASK]", text_d = list(masked_tail_seq), label=lmap(lambda x: ent2id[x], a), real_label=ent2id[line[2]], en=[line[0], line[1]], en_id = [ent2id[line[0]], rel2id[line[1]]], rel=rel2id[line[1]], text_d_id = list(masked_tail_seq_id), graph_inf = masked_tail_graph_list)) |             InputExample(guid=guid, text_a="[PAD]", text_b="[PAD]", text_c = "[MASK]", text_d = list(masked_tail_seq), label=lmap(lambda x: ent2id[x], a), real_label=ent2id[line[2]], en=[line[0], line[1]], en_id = [ent2id[line[0]], rel2id[line[1]]], rel=rel2id[line[1]], text_d_id = list(masked_tail_seq_id), graph_inf = masked_tail_graph_list)) | ||||||
|          |  | ||||||
|         liveState = LiveState() |  | ||||||
|         _prev = liveState.get() |  | ||||||
|  |  | ||||||
|         if (set_type == "train" and len(_prev) > 0): |  | ||||||
|  |  | ||||||
|             for prev_ent in _prev: |  | ||||||
|  |  | ||||||
|                 if (prev_ent == line[0] or prev_ent == line[2]): |  | ||||||
|                     continue |  | ||||||
|  |  | ||||||
|                 z = head_filter_entities["\t".join([prev_ent,line[1]])] |  | ||||||
|                 if (len(z) == 0): |  | ||||||
|                     z.append('[NEG]') |  | ||||||
|                     z.append(line[2]) |  | ||||||
|                     z.append(line[0]) |  | ||||||
|                 masked_neg_seq = set() |  | ||||||
|                 masked_neg_seq_id = set() |  | ||||||
|  |  | ||||||
|                 masked_neg_graph_list = masked_tail_neighbor.get("\t".join([prev_ent, line[1]]), []) if len(masked_tail_neighbor.get("\t".join([prev_ent, line[1]]), [])) < max_triplet else \ |  | ||||||
|                     random.sample(masked_tail_neighbor["\t".join([prev_ent, line[1]])], max_triplet) |  | ||||||
|                 if (len(masked_head_graph_list) == 0): |  | ||||||
|                     masked_head_graph_list.append(['[NEG]', line[1], '[NEG]']) |  | ||||||
|  |  | ||||||
|                 for item in masked_neg_graph_list: |  | ||||||
|                     masked_neg_seq.add(item[0]) |  | ||||||
|                     masked_neg_seq.add(item[1]) |  | ||||||
|                     masked_neg_seq.add(item[2]) |  | ||||||
|                     masked_neg_seq_id.add(ent2id[item[0]]) |  | ||||||
|                     masked_neg_seq_id.add(rel2id[item[1]]) |  | ||||||
|                     masked_neg_seq_id.add(ent2id[item[2]]) |  | ||||||
|                      |  | ||||||
|                 masked_neg_seq = masked_neg_seq.difference({line[0]}) |  | ||||||
|                 masked_neg_seq = masked_neg_seq.difference({line[2]}) |  | ||||||
|                 masked_neg_seq = masked_neg_seq.difference({line[1]}) |  | ||||||
|                 masked_neg_seq = masked_neg_seq.difference({prev_ent}) |  | ||||||
|                 masked_neg_seq_id = masked_neg_seq_id.difference({ent2id[line[0]]}) |  | ||||||
|                 masked_neg_seq_id = masked_neg_seq_id.difference({rel2id[line[1]]}) |  | ||||||
|                 masked_neg_seq_id = masked_neg_seq_id.difference({ent2id[line[2]]}) |  | ||||||
|                 masked_neg_seq_id = masked_neg_seq_id.difference({ent2id[prev_ent]}) |  | ||||||
|                 # examples.append( |  | ||||||
|                 #     InputExample(guid=guid, text_a="[MASK]", text_b=' '.join(text_b.split(' ')[:16]) + " [PAD]", text_c = "[PAD]" + " " + ' '.join(text_c.split(' ')[:16]), text_d = masked_head_seq, label=lmap(lambda x: ent2id[x], b), real_label=ent2id[line[0]], en=[rel2id[line[1]], ent2id[line[2]]], rel=rel2id[line[1]])) |  | ||||||
|                 # examples.append( |  | ||||||
|                 #     InputExample(guid=guid, text_a="[PAD] ", text_b=' '.join(text_b.split(' ')[:16]) + " [PAD]", text_c = "[MASK]" +" " + ' '.join(text_a.split(' ')[:16]), text_d = masked_tail_seq, label=lmap(lambda x: ent2id[x], a), real_label=ent2id[line[2]], en=[ent2id[line[0]], rel2id[line[1]]], rel=rel2id[line[1]])) |  | ||||||
|                 examples.append( |  | ||||||
|                     InputExample(guid=guid, text_a="[MASK]", text_b="[PAD]", text_c = "[PAD]", text_d = list(masked_neg_seq), label=lmap(lambda x: ent2id[x], z), real_label=ent2id[line[0]], en=[line[1], prev_ent], en_id = [rel2id[line[1]], ent2id[prev_ent]], rel=rel2id[line[1]], text_d_id = list(masked_neg_seq_id), graph_inf = masked_neg_graph_list)) |  | ||||||
|                 examples.append(     |  | ||||||
|                     InputExample(guid=guid, text_a="[MASK]", text_b="[PAD]", text_c = "[PAD]", text_d = list(masked_neg_seq), label=lmap(lambda x: ent2id[x], z), real_label=ent2id[line[2]], en=[line[1], prev_ent], en_id = [rel2id[line[1]], ent2id[prev_ent]], rel=rel2id[line[1]], text_d_id = list(masked_neg_seq_id), graph_inf = masked_neg_graph_list)) |  | ||||||
|          |  | ||||||
|  |  | ||||||
|         liveState.put(line[0]) |  | ||||||
|         liveState.put(line[2]) |  | ||||||
|     return examples |     return examples | ||||||
|  |  | ||||||
| def filter_init(head, tail, t1,t2, ent2id_, ent2token_, rel2id_, masked_head_neighbor_, masked_tail_neighbor_, rel2token_): | def filter_init(head, tail, t1,t2, ent2id_, ent2token_, rel2id_, masked_head_neighbor_, masked_tail_neighbor_, rel2token_): | ||||||
| @@ -564,7 +377,6 @@ def filter_init(head, tail, t1,t2, ent2id_, ent2token_, rel2id_, masked_head_nei | |||||||
|     global masked_head_neighbor |     global masked_head_neighbor | ||||||
|     global masked_tail_neighbor |     global masked_tail_neighbor | ||||||
|     global rel2token |     global rel2token | ||||||
|     # global negativeEntity |  | ||||||
|  |  | ||||||
|     head_filter_entities = head |     head_filter_entities = head | ||||||
|     tail_filter_entities = tail |     tail_filter_entities = tail | ||||||
| @@ -576,23 +388,11 @@ def filter_init(head, tail, t1,t2, ent2id_, ent2token_, rel2id_, masked_head_nei | |||||||
|     masked_head_neighbor = masked_head_neighbor_ |     masked_head_neighbor = masked_head_neighbor_ | ||||||
|     masked_tail_neighbor = masked_tail_neighbor_ |     masked_tail_neighbor = masked_tail_neighbor_ | ||||||
|     rel2token = rel2token_ |     rel2token = rel2token_ | ||||||
|     # negativeEntity = ent2id['[NEG]'] |  | ||||||
|     print("Initialized negative entity ID") |  | ||||||
|  |  | ||||||
| def delete_init(ent2text_): | def delete_init(ent2text_): | ||||||
|     global ent2text |     global ent2text | ||||||
|     ent2text = ent2text_ |     ent2text = ent2text_ | ||||||
|  |  | ||||||
| def getEntityIdByName(name): |  | ||||||
|     global ent2id |  | ||||||
|     return ent2id[name] |  | ||||||
|  |  | ||||||
| @cache_results_load_once(_cache_fp="./dataset", _global_var='negativeEntity') |  | ||||||
| def getNegativeEntityId(args): |  | ||||||
|     global negativeEntity |  | ||||||
|     negativeEntity = ent2id['[NEG]'] |  | ||||||
|     return negativeEntity |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class KGProcessor(DataProcessor): | class KGProcessor(DataProcessor): | ||||||
|     """Processor for knowledge graph data set.""" |     """Processor for knowledge graph data set.""" | ||||||
| @@ -643,7 +443,6 @@ class KGProcessor(DataProcessor): | |||||||
|         """Gets all entities in the knowledge graph.""" |         """Gets all entities in the knowledge graph.""" | ||||||
|         with open(self.entity_path, 'r') as f: |         with open(self.entity_path, 'r') as f: | ||||||
|             lines = f.readlines() |             lines = f.readlines() | ||||||
|             lines.append('[NEG]\t') |  | ||||||
|             entities = [] |             entities = [] | ||||||
|             for line in lines: |             for line in lines: | ||||||
|                 entities.append(line.strip().split("\t")[0]) |                 entities.append(line.strip().split("\t")[0]) | ||||||
| @@ -670,7 +469,6 @@ class KGProcessor(DataProcessor): | |||||||
|         ent2text_with_type = {} |         ent2text_with_type = {} | ||||||
|         with open(self.entity_path, 'r') as f: |         with open(self.entity_path, 'r') as f: | ||||||
|             ent_lines = f.readlines() |             ent_lines = f.readlines() | ||||||
|             ent_lines.append('[NEG]\t') |  | ||||||
|             for line in ent_lines: |             for line in ent_lines: | ||||||
|                 temp = line.strip().split('\t') |                 temp = line.strip().split('\t') | ||||||
|                 try: |                 try: | ||||||
| @@ -772,8 +570,6 @@ class KGProcessor(DataProcessor): | |||||||
|         threads = min(1, cpu_count()) |         threads = min(1, cpu_count()) | ||||||
|         filter_init(head_filter_entities, tail_filter_entities,ent2text, rel2text, ent2id, ent2token, rel2id, masked_head_neighbor, masked_tail_neighbor, rel2token |         filter_init(head_filter_entities, tail_filter_entities,ent2text, rel2text, ent2id, ent2token, rel2id, masked_head_neighbor, masked_tail_neighbor, rel2token | ||||||
|             ) |             ) | ||||||
|         # cache this |  | ||||||
|         getNegativeEntityId(args) |  | ||||||
|          |          | ||||||
|         if hasattr(args, "faiss_init") and args.faiss_init: |         if hasattr(args, "faiss_init") and args.faiss_init: | ||||||
|             annotate_ = partial( |             annotate_ = partial( | ||||||
| @@ -783,7 +579,6 @@ class KGProcessor(DataProcessor): | |||||||
|         else: |         else: | ||||||
|             annotate_ = partial( |             annotate_ = partial( | ||||||
|                 solve, |                 solve, | ||||||
|                 set_type=set_type, |  | ||||||
|                 pretrain=self.args.pretrain, |                 pretrain=self.args.pretrain, | ||||||
|                 max_triplet=self.args.max_triplet |                 max_triplet=self.args.max_triplet | ||||||
|             ) |             ) | ||||||
|   | |||||||
| @@ -73,21 +73,63 @@ class TransformerLitModel(BaseLitModel): | |||||||
|     def forward(self, x): |     def forward(self, x): | ||||||
|         return self.model(x) |         return self.model(x) | ||||||
|  |  | ||||||
|  |     def create_negatives(self, batch): | ||||||
|  |         negativeBatches = [] | ||||||
|  |         label = batch['label'] | ||||||
|  |  | ||||||
|  |         for i in range(label.shape[0]): | ||||||
|  |             newBatch = {} | ||||||
|  |             newBatch['attention_mask'] = None | ||||||
|  |             newBatch['input_ids'] = torch.clone(batch['input_ids']) | ||||||
|  |             newBatch['label'] = torch.zeros_like(batch['label']) | ||||||
|  |             negativeBatches.append(newBatch) | ||||||
|  |  | ||||||
|  |         entity_idx = [] | ||||||
|  |         self_label = [] | ||||||
|  |         for idx, l in enumerate(label): | ||||||
|  |             decoded = self.decode([batch['input_ids'][idx]])[0].split(' ') | ||||||
|  |             for j in range(1, len(decoded)): | ||||||
|  |                 if (decoded[j].startswith("[ENTITY_")): | ||||||
|  |                     entity_idx.append(j) | ||||||
|  |                     self_label.append(batch['input_ids'][idx][j])    | ||||||
|  |                     break | ||||||
|  |          | ||||||
|  |         for idx, lbl in enumerate(label): | ||||||
|  |             for i in range(label.shape[0]): | ||||||
|  |                 if (negativeBatches[idx]['input_ids'][i][entity_idx[i]] != lbl): | ||||||
|  |                     negativeBatches[idx]['input_ids'][i][entity_idx[i]] = lbl | ||||||
|  |                 else: | ||||||
|  |                     negativeBatches[idx]['input_ids'][i][entity_idx[i]] = self_label[i] | ||||||
|  |                  | ||||||
|  |         return negativeBatches | ||||||
|  |  | ||||||
|     def training_step(self, batch, batch_idx):  # pylint: disable=unused-argument |     def training_step(self, batch, batch_idx):  # pylint: disable=unused-argument | ||||||
|         # embed();exit() |         # embed();exit() | ||||||
|         # print(self.optimizers().param_groups[1]['lr']) |         # print(self.optimizers().param_groups[1]['lr']) | ||||||
|  |  | ||||||
|  |         negativeBatches = self.create_negatives(batch) | ||||||
|  |  | ||||||
|  |         loss = 0 | ||||||
|  |  | ||||||
|  |         for negativeBatch in negativeBatches: | ||||||
|  |             label = negativeBatch.pop("label") | ||||||
|  |             input_ids = batch['input_ids'] | ||||||
|  |              | ||||||
|  |             logits = self.model(**negativeBatch, return_dict=True, distance_attention=None).logits | ||||||
|  |             _, mask_idx = (input_ids == self.tokenizer.mask_token_id).nonzero(as_tuple=True) | ||||||
|  |             bs = input_ids.shape[0] | ||||||
|  |             mask_logits = logits[torch.arange(bs), mask_idx][:, self.entity_id_st:self.entity_id_ed] | ||||||
|  |             loss += self.loss_fn(mask_logits, label) | ||||||
|  |  | ||||||
|         labels = batch.pop("labels") |         labels = batch.pop("labels") | ||||||
|         label = batch.pop("label") |         label = batch.pop("label") | ||||||
|         pos = batch.pop("pos") |         pos = batch.pop("pos") | ||||||
|         try: |         try: | ||||||
|             en = batch.pop("en") |             en = batch.pop("en") | ||||||
|             # self.print("__DEBUG__: en", en) |  | ||||||
|             rel = batch.pop("rel") |             rel = batch.pop("rel") | ||||||
|             # self.print("__DEBUG__: rel", rel) |  | ||||||
|         except KeyError: |         except KeyError: | ||||||
|             pass |             pass | ||||||
|         input_ids = batch['input_ids'] |         input_ids = batch['input_ids'] | ||||||
|         # self.print("__DEBUG__: input_ids", input_ids) |  | ||||||
|  |  | ||||||
|         distance_attention = torch.stack([pad_distance(len(input_ids[i]) - len(distance) - 1, distance) for i, distance in enumerate(batch['distance_attention'])]) |         distance_attention = torch.stack([pad_distance(len(input_ids[i]) - len(distance) - 1, distance) for i, distance in enumerate(batch['distance_attention'])]) | ||||||
|         distance = batch.pop("distance_attention") |         distance = batch.pop("distance_attention") | ||||||
| @@ -113,9 +155,9 @@ class TransformerLitModel(BaseLitModel): | |||||||
|  |  | ||||||
|         assert mask_idx.shape[0] == bs, "only one mask in sequence!" |         assert mask_idx.shape[0] == bs, "only one mask in sequence!" | ||||||
|         if self.args.bce: |         if self.args.bce: | ||||||
|             loss = self.loss_fn(mask_logits, labels) |             loss += self.loss_fn(mask_logits, labels) | ||||||
|         else: |         else: | ||||||
|             loss = self.loss_fn(mask_logits, label) |             loss += self.loss_fn(mask_logits, label) | ||||||
|  |  | ||||||
|         if batch_idx == 0: |         if batch_idx == 0: | ||||||
|             print('\n'.join(self.decode(batch['input_ids'][:4]))) |             print('\n'.join(self.decode(batch['input_ids'][:4]))) | ||||||
| @@ -385,17 +427,13 @@ class CombineEntityEmbeddingLitModel(UseEntityEmbeddingLitModel): | |||||||
|         self.id2entity = {} |         self.id2entity = {} | ||||||
|         with open("./dataset/FB15k-237/entity2textlong.txt", 'r') as file: |         with open("./dataset/FB15k-237/entity2textlong.txt", 'r') as file: | ||||||
|             cnt = 0 |             cnt = 0 | ||||||
|             lines = file.readlines() |             for line in file.readlines(): | ||||||
|             lines.append('[NEG]\t') |  | ||||||
|             for line in lines: |  | ||||||
|                 e, d = line.strip().split("\t") |                 e, d = line.strip().split("\t") | ||||||
|                 self.id2entity[cnt] = e |                 self.id2entity[cnt] = e | ||||||
|                 cnt += 1 |                 cnt += 1 | ||||||
|         self.id2entity_t = {} |         self.id2entity_t = {} | ||||||
|         with open("./dataset/FB15k-237/entity2text.txt", 'r') as file: |         with open("./dataset/FB15k-237/entity2text.txt", 'r') as file: | ||||||
|             lines = file.readlines() |             for line in file.readlines(): | ||||||
|             lines.append('[NEG]\t') |  | ||||||
|             for line in lines: |  | ||||||
|                 e, d = line.strip().split("\t") |                 e, d = line.strip().split("\t") | ||||||
|                 self.id2entity_t[e] = d |                 self.id2entity_t[e] = d | ||||||
|         for k, v in self.id2entity.items(): |         for k, v in self.id2entity.items(): | ||||||
|   | |||||||
							
								
								
									
										4
									
								
								main.py
									
									
									
									
									
								
							
							
						
						
									
										4
									
								
								main.py
									
									
									
									
									
								
							| @@ -98,7 +98,6 @@ def main(): | |||||||
|     tokenizer = data.tokenizer |     tokenizer = data.tokenizer | ||||||
|  |  | ||||||
|     lit_model = litmodel_class(args=args, model=model, tokenizer=tokenizer, data_config=data.get_config()) |     lit_model = litmodel_class(args=args, model=model, tokenizer=tokenizer, data_config=data.get_config()) | ||||||
|     print("__DEBUG__: Initialized") |  | ||||||
|     if args.checkpoint: |     if args.checkpoint: | ||||||
|         lit_model.load_state_dict(torch.load(args.checkpoint, map_location="cpu")["state_dict"], strict=False) |         lit_model.load_state_dict(torch.load(args.checkpoint, map_location="cpu")["state_dict"], strict=False) | ||||||
|  |  | ||||||
| @@ -123,10 +122,7 @@ def main(): | |||||||
|     # args.weights_summary = "full"  # Print full summary of the model |     # args.weights_summary = "full"  # Print full summary of the model | ||||||
|     trainer = pl.Trainer.from_argparse_args(args, callbacks=callbacks, logger=logger, default_root_dir="training/logs") |     trainer = pl.Trainer.from_argparse_args(args, callbacks=callbacks, logger=logger, default_root_dir="training/logs") | ||||||
|      |      | ||||||
|     print('__DEBUG__: Init trainer') |  | ||||||
|  |  | ||||||
|     if "EntityEmbedding" not in lit_model.__class__.__name__: |     if "EntityEmbedding" not in lit_model.__class__.__name__: | ||||||
|         print('__DEBUG__: Fit trainer') |  | ||||||
|         trainer.fit(lit_model, datamodule=data) |         trainer.fit(lit_model, datamodule=data) | ||||||
|         path = model_checkpoint.best_model_path |         path = model_checkpoint.best_model_path | ||||||
|         lit_model.load_state_dict(torch.load(path)["state_dict"], strict=False) |         lit_model.load_state_dict(torch.load(path)["state_dict"], strict=False) | ||||||
|   | |||||||
| @@ -10,7 +10,7 @@ from transformers.tokenization_utils_base import (BatchEncoding, | |||||||
|                                                   PreTrainedTokenizerBase) |                                                   PreTrainedTokenizerBase) | ||||||
|  |  | ||||||
| from .base_data_module import BaseDataModule | from .base_data_module import BaseDataModule | ||||||
| from .processor import KGProcessor, get_dataset, getNegativeEntityId | from .processor import KGProcessor, get_dataset | ||||||
| import transformers | import transformers | ||||||
| transformers.logging.set_verbosity_error() | transformers.logging.set_verbosity_error() | ||||||
|  |  | ||||||
| @@ -106,7 +106,6 @@ class DataCollatorForSeq2Seq: | |||||||
|                 if isinstance(l, int):  |                 if isinstance(l, int):  | ||||||
|                     new_labels[i][l] = 1 |                     new_labels[i][l] = 1 | ||||||
|                 else: |                 else: | ||||||
|                     if (l[0] != getNegativeEntityId()): |  | ||||||
|                     for j in l: |                     for j in l: | ||||||
|                         new_labels[i][j] = 1 |                         new_labels[i][j] = 1 | ||||||
|             labels = new_labels |             labels = new_labels | ||||||
|   | |||||||
| @@ -314,7 +314,6 @@ def filter_init(head, tail, t1,t2, ent2id_, ent2token_, rel2id_): | |||||||
|     global ent2id |     global ent2id | ||||||
|     global ent2token |     global ent2token | ||||||
|     global rel2id |     global rel2id | ||||||
|     global negativeEntity |  | ||||||
|  |  | ||||||
|     head_filter_entities = head |     head_filter_entities = head | ||||||
|     tail_filter_entities = tail |     tail_filter_entities = tail | ||||||
| @@ -323,19 +322,11 @@ def filter_init(head, tail, t1,t2, ent2id_, ent2token_, rel2id_): | |||||||
|     ent2id = ent2id_ |     ent2id = ent2id_ | ||||||
|     ent2token = ent2token_ |     ent2token = ent2token_ | ||||||
|     rel2id = rel2id_ |     rel2id = rel2id_ | ||||||
|     negativeEntity = ent2id['[NEG]'] |  | ||||||
|  |  | ||||||
| def delete_init(ent2text_): | def delete_init(ent2text_): | ||||||
|     global ent2text |     global ent2text | ||||||
|     ent2text = ent2text_ |     ent2text = ent2text_ | ||||||
|  |  | ||||||
| def getEntityIdByName(name): |  | ||||||
|     global ent2id |  | ||||||
|     return ent2id[name] |  | ||||||
|  |  | ||||||
| def getNegativeEntityId(): |  | ||||||
|     global negativeEntity |  | ||||||
|     return negativeEntity |  | ||||||
|  |  | ||||||
| class KGProcessor(DataProcessor): | class KGProcessor(DataProcessor): | ||||||
|     """Processor for knowledge graph data set.""" |     """Processor for knowledge graph data set.""" | ||||||
| @@ -386,7 +377,6 @@ class KGProcessor(DataProcessor): | |||||||
|         """Gets all entities in the knowledge graph.""" |         """Gets all entities in the knowledge graph.""" | ||||||
|         with open(self.entity_path, 'r') as f: |         with open(self.entity_path, 'r') as f: | ||||||
|             lines = f.readlines() |             lines = f.readlines() | ||||||
|             lines.append('[NEG]\t') |  | ||||||
|             entities = [] |             entities = [] | ||||||
|             for line in lines: |             for line in lines: | ||||||
|                 entities.append(line.strip().split("\t")[0]) |                 entities.append(line.strip().split("\t")[0]) | ||||||
| @@ -413,7 +403,6 @@ class KGProcessor(DataProcessor): | |||||||
|         ent2text_with_type = {} |         ent2text_with_type = {} | ||||||
|         with open(self.entity_path, 'r') as f: |         with open(self.entity_path, 'r') as f: | ||||||
|             ent_lines = f.readlines() |             ent_lines = f.readlines() | ||||||
|             ent_lines.append('[NEG]\t') |  | ||||||
|             for line in ent_lines: |             for line in ent_lines: | ||||||
|                 temp = line.strip().split('\t') |                 temp = line.strip().split('\t') | ||||||
|                 try: |                 try: | ||||||
|   | |||||||
| @@ -364,17 +364,13 @@ class CombineEntityEmbeddingLitModel(UseEntityEmbeddingLitModel): | |||||||
|         self.id2entity = {} |         self.id2entity = {} | ||||||
|         with open("./dataset/FB15k-237/entity2textlong.txt", 'r') as file: |         with open("./dataset/FB15k-237/entity2textlong.txt", 'r') as file: | ||||||
|             cnt = 0 |             cnt = 0 | ||||||
|             lines = file.readlines() |             for line in file.readlines(): | ||||||
|             lines.append('[NEG]\t') |  | ||||||
|             for line in lines: |  | ||||||
|                 e, d = line.strip().split("\t") |                 e, d = line.strip().split("\t") | ||||||
|                 self.id2entity[cnt] = e |                 self.id2entity[cnt] = e | ||||||
|                 cnt += 1 |                 cnt += 1 | ||||||
|         self.id2entity_t = {} |         self.id2entity_t = {} | ||||||
|         with open("./dataset/FB15k-237/entity2text.txt", 'r') as file: |         with open("./dataset/FB15k-237/entity2text.txt", 'r') as file: | ||||||
|             lines = file.readlines() |             for line in file.readlines(): | ||||||
|             lines.append('[NEG]\t') |  | ||||||
|             for line in lines: |  | ||||||
|                 e, d = line.strip().split("\t") |                 e, d = line.strip().split("\t") | ||||||
|                 self.id2entity_t[e] = d |                 self.id2entity_t[e] = d | ||||||
|         for k, v in self.id2entity.items(): |         for k, v in self.id2entity.items(): | ||||||
|   | |||||||
| @@ -1,13 +1,13 @@ | |||||||
| nohup python -u main.py --gpus "1" --max_epochs=16  --num_workers=32 \ | nohup python -u main.py --gpus "1," --max_epochs=16  --num_workers=32 \ | ||||||
|    --model_name_or_path  bert-base-uncased \ |    --model_name_or_path  bert-base-uncased \ | ||||||
|    --accumulate_grad_batches 1 \ |    --accumulate_grad_batches 1 \ | ||||||
|    --model_class BertKGC \ |    --model_class BertKGC \ | ||||||
|    --batch_size 128 \ |    --batch_size 64 \ | ||||||
|    --pretrain 1 \ |    --pretrain 1 \ | ||||||
|    --bce 0 \ |    --bce 0 \ | ||||||
|    --check_val_every_n_epoch 1 \ |    --check_val_every_n_epoch 1 \ | ||||||
|    --overwrite_cache \ |    --overwrite_cache \ | ||||||
|    --data_dir /kg_374/Relphormer/dataset/FB15k-237 \ |    --data_dir /root/kg_374/Relphormer/dataset/FB15k-237 \ | ||||||
|    --eval_batch_size 256 \ |    --eval_batch_size 256 \ | ||||||
|    --max_seq_length 64 \ |    --max_seq_length 64 \ | ||||||
|    --lr 1e-4 \ |    --lr 1e-4 \ | ||||||
|   | |||||||
| @@ -1,9 +1,9 @@ | |||||||
| nohup python -u main.py --gpus "2," --max_epochs=16  --num_workers=32 \ | nohup python -u main.py --gpus "1," --max_epochs=16  --num_workers=32 \ | ||||||
|    --model_name_or_path  bert-base-uncased \ |    --model_name_or_path  bert-base-uncased \ | ||||||
|    --accumulate_grad_batches 1 \ |    --accumulate_grad_batches 1 \ | ||||||
|    --model_class BertKGC \ |    --model_class BertKGC \ | ||||||
|    --batch_size 64 \ |    --batch_size 16 \ | ||||||
|    --checkpoint /kg_374/Relphormer/output/FB15k-237/epoch=1-Eval/hits10=Eval/hits1=0.47-Eval/hits1=0.22.ckpt \ |    --checkpoint /root/kg_374/Relphormer/pretrain/output/FB15k-237/epoch\=15-step\=38899-Eval/hits10=0.96.ckpt \ | ||||||
|    --pretrain 0 \ |    --pretrain 0 \ | ||||||
|    --bce 0 \ |    --bce 0 \ | ||||||
|    --check_val_every_n_epoch 1 \ |    --check_val_every_n_epoch 1 \ | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user