Compare commits
	
		
			2 Commits
		
	
	
		
			sep_vit
			...
			negative_s
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| 45cd8e1396 | |||
| fcfeae2bd3 | 
							
								
								
									
										1
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										1
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							| @@ -7,3 +7,4 @@ dataset/FB15k-237/masked_*.txt | ||||
| dataset/FB15k-237/cached_*.pkl | ||||
| **/__pycache__/ | ||||
| **/.DS_Store | ||||
| nohup.out | ||||
|   | ||||
							
								
								
									
										38
									
								
								.vscode/launch.json
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										38
									
								
								.vscode/launch.json
									
									
									
									
										vendored
									
									
										Normal file
									
								
							| @@ -0,0 +1,38 @@ | ||||
| { | ||||
|     // Use IntelliSense to learn about possible attributes. | ||||
|     // Hover to view descriptions of existing attributes. | ||||
|     // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 | ||||
|     "version": "0.2.0", | ||||
|     "configurations": [ | ||||
|         { | ||||
|             "name": "Python: Current File", | ||||
|             "type": "python", | ||||
|             "request": "launch", | ||||
|             "program": "${file}", | ||||
|             "console": "integratedTerminal", | ||||
|             "justMyCode": true, | ||||
|             "args": [ | ||||
|                 "--gpus", "1",  | ||||
|                 "--max_epochs=16",   | ||||
|                 "--num_workers=32",  | ||||
|                 "--model_name_or_path",  "bert-base-uncased", | ||||
|                 "--accumulate_grad_batches", "1",  | ||||
|                 "--model_class", "BertKGC", | ||||
|                 "--batch_size", "64", | ||||
|                 "--checkpoint", "/kg_374/Relphormer/pretrain/output/FB15k-237/epoch=15-step=19299-Eval/hits10=0.96.ckpt", | ||||
|                 "--pretrain", "0", | ||||
|                 "--bce", "0", | ||||
|                 "--check_val_every_n_epoch", "1", | ||||
|                 "--overwrite_cache", | ||||
|                 "--data_dir", "dataset/FB15k-237",  | ||||
|                 "--eval_batch_size", "128", | ||||
|                 "--max_seq_length", "128", | ||||
|                 "--lr", "3e-5", | ||||
|                 "--max_triplet", "64", | ||||
|                 "--add_attn_bias", "True", | ||||
|                 "--use_global_node", "True", | ||||
|                 "--fast_dev_run", "True", | ||||
|             ] | ||||
|         } | ||||
|     ] | ||||
| } | ||||
| @@ -10,7 +10,7 @@ from transformers.tokenization_utils_base import (BatchEncoding, | ||||
|                                                   PreTrainedTokenizerBase) | ||||
|  | ||||
| from .base_data_module import BaseDataModule | ||||
| from .processor import KGProcessor, get_dataset | ||||
| from .processor import KGProcessor, get_dataset, getNegativeEntityId | ||||
| import transformers | ||||
| transformers.logging.set_verbosity_error() | ||||
|  | ||||
| @@ -79,6 +79,7 @@ class DataCollatorForSeq2Seq: | ||||
|     label_pad_token_id: int = -100 | ||||
|     return_tensors: str = "pt" | ||||
|     num_labels: int = 0 | ||||
|     args: Any = None | ||||
|  | ||||
|     def __call__(self, features, return_tensors=None): | ||||
|  | ||||
| @@ -105,8 +106,9 @@ class DataCollatorForSeq2Seq: | ||||
|                 if isinstance(l, int):  | ||||
|                     new_labels[i][l] = 1 | ||||
|                 else: | ||||
|                     for j in l: | ||||
|                         new_labels[i][j] = 1 | ||||
|                     if (l[0] != getNegativeEntityId(self.args)): | ||||
|                         for j in l: | ||||
|                             new_labels[i][j] = 1 | ||||
|             labels = new_labels | ||||
|  | ||||
|         features = self.tokenizer.pad( | ||||
| @@ -141,6 +143,7 @@ class KGC(BaseDataModule): | ||||
|             padding="longest", | ||||
|             max_length=self.args.max_seq_length, | ||||
|             num_labels = len(entity_list), | ||||
|             args=args | ||||
|         ) | ||||
|         relations_tokens = self.processor.get_relations(args.data_dir) | ||||
|         self.num_relations = len(relations_tokens) | ||||
|   | ||||
| @@ -5,7 +5,7 @@ import contextlib | ||||
| import sys | ||||
|  | ||||
| from collections import Counter | ||||
| from multiprocessing import Pool | ||||
| from multiprocessing import Pool, synchronize | ||||
| from torch._C import HOIST_CONV_PACKED_PARAMS | ||||
| from torch.utils.data import Dataset, Sampler, IterableDataset | ||||
| from collections import defaultdict | ||||
| @@ -110,7 +110,7 @@ def cache_results(_cache_fp, _refresh=False, _verbose=1): | ||||
|              | ||||
|             model_name = my_args.model_name_or_path.split("/")[-1] | ||||
|             is_pretrain = my_args.pretrain | ||||
|             cache_filepath = os.path.join(my_args.data_dir, f"cached_{mode}_features{model_name}_pretrain{is_pretrain}_faiss{my_args.faiss_init}_seqlength{my_args.max_seq_length}_{my_args.litmodel_class}.pkl") | ||||
|             cache_filepath = os.path.join(my_args.data_dir, f"cached_{func.__name__}_{mode}_features{model_name}_pretrain{is_pretrain}_faiss{my_args.faiss_init}_seqlength{my_args.max_seq_length}_{my_args.litmodel_class}.pkl") | ||||
|             refresh = my_args.overwrite_cache | ||||
|  | ||||
|             if cache_filepath is not None and refresh is False: | ||||
| @@ -137,6 +137,116 @@ def cache_results(_cache_fp, _refresh=False, _verbose=1): | ||||
|  | ||||
|     return wrapper_ | ||||
|  | ||||
| def cache_results_load_once(_cache_fp, _refresh=False, _verbose=1, _global_var=None): | ||||
|     r""" | ||||
|     === USE IN TRAINING MODE ONLY === | ||||
|     cache_results是fastNLP中用于cache数据的装饰器。通过下面的例子看一下如何使用:: | ||||
|  | ||||
|         import time | ||||
|         import numpy as np | ||||
|         from fastNLP import cache_results | ||||
|          | ||||
|         @cache_results('cache.pkl') | ||||
|         def process_data(): | ||||
|             # 一些比较耗时的工作,比如读取数据,预处理数据等,这里用time.sleep()代替耗时 | ||||
|             time.sleep(1) | ||||
|             return np.random.randint(10, size=(5,)) | ||||
|          | ||||
|         start_time = time.time() | ||||
|         print("res =",process_data()) | ||||
|         print(time.time() - start_time) | ||||
|          | ||||
|         start_time = time.time() | ||||
|         print("res =",process_data()) | ||||
|         print(time.time() - start_time) | ||||
|          | ||||
|         # 输出内容如下,可以看到两次结果相同,且第二次几乎没有花费时间 | ||||
|         # Save cache to cache.pkl. | ||||
|         # res = [5 4 9 1 8] | ||||
|         # 1.0042750835418701 | ||||
|         # Read cache from cache.pkl. | ||||
|         # res = [5 4 9 1 8] | ||||
|         # 0.0040721893310546875 | ||||
|  | ||||
|     可以看到第二次运行的时候,只用了0.0001s左右,是由于第二次运行将直接从cache.pkl这个文件读取数据,而不会经过再次预处理:: | ||||
|  | ||||
|         # 还是以上面的例子为例,如果需要重新生成另一个cache,比如另一个数据集的内容,通过如下的方式调用即可 | ||||
|         process_data(_cache_fp='cache2.pkl')  # 完全不影响之前的‘cache.pkl' | ||||
|  | ||||
|     上面的_cache_fp是cache_results会识别的参数,它将从'cache2.pkl'这里缓存/读取数据,即这里的'cache2.pkl'覆盖默认的 | ||||
|     'cache.pkl'。如果在你的函数前面加上了@cache_results()则你的函数会增加三个参数[_cache_fp, _refresh, _verbose]。 | ||||
|     上面的例子即为使用_cache_fp的情况,这三个参数不会传入到你的函数中,当然你写的函数参数名也不可能包含这三个名称:: | ||||
|  | ||||
|         process_data(_cache_fp='cache2.pkl', _refresh=True)  # 这里强制重新生成一份对预处理的cache。 | ||||
|         #  _verbose是用于控制输出信息的,如果为0,则不输出任何内容;如果为1,则会提醒当前步骤是读取的cache还是生成了新的cache | ||||
|  | ||||
|     :param str _cache_fp: 将返回结果缓存到什么位置;或从什么位置读取缓存。如果为None,cache_results没有任何效用,除非在 | ||||
|         函数调用的时候传入_cache_fp这个参数。 | ||||
|     :param bool _refresh: 是否重新生成cache。 | ||||
|     :param int _verbose: 是否打印cache的信息。 | ||||
|     :return: | ||||
|     """ | ||||
|  | ||||
|     def wrapper_(func): | ||||
|         signature = inspect.signature(func) | ||||
|         for key, _ in signature.parameters.items(): | ||||
|             if key in ('_cache_fp', '_refresh', '_verbose', '_global_var'): | ||||
|                 raise RuntimeError("The function decorated by cache_results cannot have keyword `{}`.".format(key)) | ||||
|         v = globals().get(_global_var, None) | ||||
|         if (v is not None): | ||||
|             return v | ||||
|  | ||||
|         def wrapper(*args, **kwargs): | ||||
|              | ||||
|             my_args = args[0] | ||||
|             mode = "train" | ||||
|             if '_cache_fp' in kwargs: | ||||
|                 cache_filepath = kwargs.pop('_cache_fp') | ||||
|                 assert isinstance(cache_filepath, str), "_cache_fp can only be str." | ||||
|             else: | ||||
|                 cache_filepath = _cache_fp | ||||
|             if '_refresh' in kwargs: | ||||
|                 refresh = kwargs.pop('_refresh') | ||||
|                 assert isinstance(refresh, bool), "_refresh can only be bool." | ||||
|             else: | ||||
|                 refresh = _refresh | ||||
|             if '_verbose' in kwargs: | ||||
|                 verbose = kwargs.pop('_verbose') | ||||
|                 assert isinstance(verbose, int), "_verbose can only be integer." | ||||
|             else: | ||||
|                 verbose = _verbose | ||||
|             refresh_flag = True | ||||
|              | ||||
|             model_name = my_args.model_name_or_path.split("/")[-1] | ||||
|             is_pretrain = my_args.pretrain | ||||
|             cache_filepath = os.path.join(my_args.data_dir, f"cached_{func.__name__}_{mode}_features{model_name}_pretrain{is_pretrain}.pkl") | ||||
|             refresh = my_args.overwrite_cache | ||||
|  | ||||
|             if cache_filepath is not None and refresh is False: | ||||
|                 # load data | ||||
|                 if os.path.exists(cache_filepath): | ||||
|                     with open(cache_filepath, 'rb') as f: | ||||
|                         results = pickle.load(f) | ||||
|                         globals()[_global_var] = results | ||||
|                     if verbose == 1: | ||||
|                         logger.info("Read cache from {}.".format(cache_filepath)) | ||||
|                     refresh_flag = False | ||||
|  | ||||
|             if refresh_flag: | ||||
|                 results = func(*args, **kwargs) | ||||
|                 if cache_filepath is not None: | ||||
|                     if results is None: | ||||
|                         raise RuntimeError("The return value is None. Delete the decorator.") | ||||
|                     with open(cache_filepath, 'wb') as f: | ||||
|                         pickle.dump(results, f) | ||||
|                     logger.info("Save cache to {}.".format(cache_filepath)) | ||||
|  | ||||
|             return results | ||||
|  | ||||
|         return wrapper | ||||
|  | ||||
|     return wrapper_ | ||||
|  | ||||
|  | ||||
| import argparse | ||||
| import csv | ||||
| @@ -235,6 +345,31 @@ class DataProcessor(object): | ||||
|  | ||||
| import copy | ||||
|  | ||||
| from collections import deque | ||||
| import threading | ||||
|  | ||||
|  | ||||
| class _LiveState(type): | ||||
|     _instances = {} | ||||
|     _lock = threading.Lock() | ||||
|  | ||||
|     def __call__(cls, *args, **kwargs): | ||||
|         if cls not in cls._instances: | ||||
|             with cls._lock: | ||||
|                 if cls not in cls._instances: | ||||
|                     cls._instances[cls] = super(_LiveState, cls).__call__(*args, **kwargs) | ||||
|          | ||||
|         return cls._instances[cls] | ||||
|  | ||||
| class LiveState(metaclass=_LiveState): | ||||
|  | ||||
|     def __init__(self): | ||||
|         self._pool_size = 4 | ||||
|         self._deq = deque(maxlen=self._pool_size) | ||||
|     def put(self, item): | ||||
|         self._deq.append(item) | ||||
|     def get(self): | ||||
|         return list(self._deq) | ||||
|  | ||||
| def solve_get_knowledge_store(line, set_type="train", pretrain=1): | ||||
|     """ | ||||
| @@ -364,6 +499,58 @@ def solve(line,  set_type="train", pretrain=1, max_triplet=32): | ||||
|             InputExample(guid=guid, text_a="[MASK]", text_b="[PAD]", text_c = "[PAD]", text_d = list(masked_head_seq), label=lmap(lambda x: ent2id[x], b), real_label=ent2id[line[0]], en=[line[1], line[2]], en_id = [rel2id[line[1]], ent2id[line[2]]], rel=rel2id[line[1]], text_d_id = list(masked_head_seq_id), graph_inf = masked_head_graph_list)) | ||||
|         examples.append( | ||||
|             InputExample(guid=guid, text_a="[PAD]", text_b="[PAD]", text_c = "[MASK]", text_d = list(masked_tail_seq), label=lmap(lambda x: ent2id[x], a), real_label=ent2id[line[2]], en=[line[0], line[1]], en_id = [ent2id[line[0]], rel2id[line[1]]], rel=rel2id[line[1]], text_d_id = list(masked_tail_seq_id), graph_inf = masked_tail_graph_list)) | ||||
|          | ||||
|         liveState = LiveState() | ||||
|         _prev = liveState.get() | ||||
|  | ||||
|         if (set_type == "train" and len(_prev) > 0): | ||||
|  | ||||
|             for prev_ent in _prev: | ||||
|  | ||||
|                 if (prev_ent == line[0] or prev_ent == line[2]): | ||||
|                     continue | ||||
|  | ||||
|                 z = head_filter_entities["\t".join([prev_ent,line[1]])] | ||||
|                 if (len(z) == 0): | ||||
|                     z.append('[NEG]') | ||||
|                     z.append(line[2]) | ||||
|                     z.append(line[0]) | ||||
|                 masked_neg_seq = set() | ||||
|                 masked_neg_seq_id = set() | ||||
|  | ||||
|                 masked_neg_graph_list = masked_tail_neighbor.get("\t".join([prev_ent, line[1]]), []) if len(masked_tail_neighbor.get("\t".join([prev_ent, line[1]]), [])) < max_triplet else \ | ||||
|                     random.sample(masked_tail_neighbor["\t".join([prev_ent, line[1]])], max_triplet) | ||||
|                 if (len(masked_head_graph_list) == 0): | ||||
|                     masked_head_graph_list.append(['[NEG]', line[1], '[NEG]']) | ||||
|  | ||||
|                 for item in masked_neg_graph_list: | ||||
|                     masked_neg_seq.add(item[0]) | ||||
|                     masked_neg_seq.add(item[1]) | ||||
|                     masked_neg_seq.add(item[2]) | ||||
|                     masked_neg_seq_id.add(ent2id[item[0]]) | ||||
|                     masked_neg_seq_id.add(rel2id[item[1]]) | ||||
|                     masked_neg_seq_id.add(ent2id[item[2]]) | ||||
|                      | ||||
|                 masked_neg_seq = masked_neg_seq.difference({line[0]}) | ||||
|                 masked_neg_seq = masked_neg_seq.difference({line[2]}) | ||||
|                 masked_neg_seq = masked_neg_seq.difference({line[1]}) | ||||
|                 masked_neg_seq = masked_neg_seq.difference({prev_ent}) | ||||
|                 masked_neg_seq_id = masked_neg_seq_id.difference({ent2id[line[0]]}) | ||||
|                 masked_neg_seq_id = masked_neg_seq_id.difference({rel2id[line[1]]}) | ||||
|                 masked_neg_seq_id = masked_neg_seq_id.difference({ent2id[line[2]]}) | ||||
|                 masked_neg_seq_id = masked_neg_seq_id.difference({ent2id[prev_ent]}) | ||||
|                 # examples.append( | ||||
|                 #     InputExample(guid=guid, text_a="[MASK]", text_b=' '.join(text_b.split(' ')[:16]) + " [PAD]", text_c = "[PAD]" + " " + ' '.join(text_c.split(' ')[:16]), text_d = masked_head_seq, label=lmap(lambda x: ent2id[x], b), real_label=ent2id[line[0]], en=[rel2id[line[1]], ent2id[line[2]]], rel=rel2id[line[1]])) | ||||
|                 # examples.append( | ||||
|                 #     InputExample(guid=guid, text_a="[PAD] ", text_b=' '.join(text_b.split(' ')[:16]) + " [PAD]", text_c = "[MASK]" +" " + ' '.join(text_a.split(' ')[:16]), text_d = masked_tail_seq, label=lmap(lambda x: ent2id[x], a), real_label=ent2id[line[2]], en=[ent2id[line[0]], rel2id[line[1]]], rel=rel2id[line[1]])) | ||||
|                 examples.append( | ||||
|                     InputExample(guid=guid, text_a="[MASK]", text_b="[PAD]", text_c = "[PAD]", text_d = list(masked_neg_seq), label=lmap(lambda x: ent2id[x], z), real_label=ent2id[line[0]], en=[line[1], prev_ent], en_id = [rel2id[line[1]], ent2id[prev_ent]], rel=rel2id[line[1]], text_d_id = list(masked_neg_seq_id), graph_inf = masked_neg_graph_list)) | ||||
|                 examples.append(     | ||||
|                     InputExample(guid=guid, text_a="[MASK]", text_b="[PAD]", text_c = "[PAD]", text_d = list(masked_neg_seq), label=lmap(lambda x: ent2id[x], z), real_label=ent2id[line[2]], en=[line[1], prev_ent], en_id = [rel2id[line[1]], ent2id[prev_ent]], rel=rel2id[line[1]], text_d_id = list(masked_neg_seq_id), graph_inf = masked_neg_graph_list)) | ||||
|          | ||||
|  | ||||
|         liveState.put(line[0]) | ||||
|         liveState.put(line[2]) | ||||
|     return examples | ||||
|  | ||||
| def filter_init(head, tail, t1,t2, ent2id_, ent2token_, rel2id_, masked_head_neighbor_, masked_tail_neighbor_, rel2token_): | ||||
| @@ -377,6 +564,7 @@ def filter_init(head, tail, t1,t2, ent2id_, ent2token_, rel2id_, masked_head_nei | ||||
|     global masked_head_neighbor | ||||
|     global masked_tail_neighbor | ||||
|     global rel2token | ||||
|     # global negativeEntity | ||||
|  | ||||
|     head_filter_entities = head | ||||
|     tail_filter_entities = tail | ||||
| @@ -388,11 +576,23 @@ def filter_init(head, tail, t1,t2, ent2id_, ent2token_, rel2id_, masked_head_nei | ||||
|     masked_head_neighbor = masked_head_neighbor_ | ||||
|     masked_tail_neighbor = masked_tail_neighbor_ | ||||
|     rel2token = rel2token_ | ||||
|     # negativeEntity = ent2id['[NEG]'] | ||||
|     print("Initialized negative entity ID") | ||||
|  | ||||
| def delete_init(ent2text_): | ||||
|     global ent2text | ||||
|     ent2text = ent2text_ | ||||
|  | ||||
| def getEntityIdByName(name): | ||||
|     global ent2id | ||||
|     return ent2id[name] | ||||
|  | ||||
| @cache_results_load_once(_cache_fp="./dataset", _global_var='negativeEntity') | ||||
| def getNegativeEntityId(args): | ||||
|     global negativeEntity | ||||
|     negativeEntity = ent2id['[NEG]'] | ||||
|     return negativeEntity | ||||
|  | ||||
|  | ||||
| class KGProcessor(DataProcessor): | ||||
|     """Processor for knowledge graph data set.""" | ||||
| @@ -443,6 +643,7 @@ class KGProcessor(DataProcessor): | ||||
|         """Gets all entities in the knowledge graph.""" | ||||
|         with open(self.entity_path, 'r') as f: | ||||
|             lines = f.readlines() | ||||
|             lines.append('[NEG]\t') | ||||
|             entities = [] | ||||
|             for line in lines: | ||||
|                 entities.append(line.strip().split("\t")[0]) | ||||
| @@ -469,6 +670,7 @@ class KGProcessor(DataProcessor): | ||||
|         ent2text_with_type = {} | ||||
|         with open(self.entity_path, 'r') as f: | ||||
|             ent_lines = f.readlines() | ||||
|             ent_lines.append('[NEG]\t') | ||||
|             for line in ent_lines: | ||||
|                 temp = line.strip().split('\t') | ||||
|                 try: | ||||
| @@ -570,6 +772,8 @@ class KGProcessor(DataProcessor): | ||||
|         threads = min(1, cpu_count()) | ||||
|         filter_init(head_filter_entities, tail_filter_entities,ent2text, rel2text, ent2id, ent2token, rel2id, masked_head_neighbor, masked_tail_neighbor, rel2token | ||||
|             ) | ||||
|         # cache this | ||||
|         getNegativeEntityId(args) | ||||
|          | ||||
|         if hasattr(args, "faiss_init") and args.faiss_init: | ||||
|             annotate_ = partial( | ||||
| @@ -579,6 +783,7 @@ class KGProcessor(DataProcessor): | ||||
|         else: | ||||
|             annotate_ = partial( | ||||
|                 solve, | ||||
|                 set_type=set_type, | ||||
|                 pretrain=self.args.pretrain, | ||||
|                 max_triplet=self.args.max_triplet | ||||
|             ) | ||||
|   | ||||
| @@ -81,10 +81,13 @@ class TransformerLitModel(BaseLitModel): | ||||
|         pos = batch.pop("pos") | ||||
|         try: | ||||
|             en = batch.pop("en") | ||||
|             # self.print("__DEBUG__: en", en) | ||||
|             rel = batch.pop("rel") | ||||
|             # self.print("__DEBUG__: rel", rel) | ||||
|         except KeyError: | ||||
|             pass | ||||
|         input_ids = batch['input_ids'] | ||||
|         # self.print("__DEBUG__: input_ids", input_ids) | ||||
|  | ||||
|         distance_attention = torch.stack([pad_distance(len(input_ids[i]) - len(distance) - 1, distance) for i, distance in enumerate(batch['distance_attention'])]) | ||||
|         distance = batch.pop("distance_attention") | ||||
| @@ -382,13 +385,17 @@ class CombineEntityEmbeddingLitModel(UseEntityEmbeddingLitModel): | ||||
|         self.id2entity = {} | ||||
|         with open("./dataset/FB15k-237/entity2textlong.txt", 'r') as file: | ||||
|             cnt = 0 | ||||
|             for line in file.readlines(): | ||||
|             lines = file.readlines() | ||||
|             lines.append('[NEG]\t') | ||||
|             for line in lines: | ||||
|                 e, d = line.strip().split("\t") | ||||
|                 self.id2entity[cnt] = e | ||||
|                 cnt += 1 | ||||
|         self.id2entity_t = {} | ||||
|         with open("./dataset/FB15k-237/entity2text.txt", 'r') as file: | ||||
|             for line in file.readlines(): | ||||
|             lines = file.readlines() | ||||
|             lines.append('[NEG]\t') | ||||
|             for line in lines: | ||||
|                 e, d = line.strip().split("\t") | ||||
|                 self.id2entity_t[e] = d | ||||
|         for k, v in self.id2entity.items(): | ||||
|   | ||||
							
								
								
									
										6
									
								
								main.py
									
									
									
									
									
								
							
							
						
						
									
										6
									
								
								main.py
									
									
									
									
									
								
							| @@ -98,6 +98,7 @@ def main(): | ||||
|     tokenizer = data.tokenizer | ||||
|  | ||||
|     lit_model = litmodel_class(args=args, model=model, tokenizer=tokenizer, data_config=data.get_config()) | ||||
|     print("__DEBUG__: Initialized") | ||||
|     if args.checkpoint: | ||||
|         lit_model.load_state_dict(torch.load(args.checkpoint, map_location="cpu")["state_dict"], strict=False) | ||||
|  | ||||
| @@ -120,9 +121,12 @@ def main(): | ||||
|     callbacks = [early_callback, model_checkpoint] | ||||
|  | ||||
|     # args.weights_summary = "full"  # Print full summary of the model | ||||
|     trainer = pl.Trainer.from_argparse_args(args, callbacks=callbacks, logger=logger, default_root_dir="training/logs", accelerator="ddp") | ||||
|     trainer = pl.Trainer.from_argparse_args(args, callbacks=callbacks, logger=logger, default_root_dir="training/logs") | ||||
|      | ||||
|     print('__DEBUG__: Init trainer') | ||||
|  | ||||
|     if "EntityEmbedding" not in lit_model.__class__.__name__: | ||||
|         print('__DEBUG__: Fit trainer') | ||||
|         trainer.fit(lit_model, datamodule=data) | ||||
|         path = model_checkpoint.best_model_path | ||||
|         lit_model.load_state_dict(torch.load(path)["state_dict"], strict=False) | ||||
|   | ||||
| @@ -10,7 +10,7 @@ from transformers.tokenization_utils_base import (BatchEncoding, | ||||
|                                                   PreTrainedTokenizerBase) | ||||
|  | ||||
| from .base_data_module import BaseDataModule | ||||
| from .processor import KGProcessor, get_dataset | ||||
| from .processor import KGProcessor, get_dataset, getNegativeEntityId | ||||
| import transformers | ||||
| transformers.logging.set_verbosity_error() | ||||
|  | ||||
| @@ -106,8 +106,9 @@ class DataCollatorForSeq2Seq: | ||||
|                 if isinstance(l, int):  | ||||
|                     new_labels[i][l] = 1 | ||||
|                 else: | ||||
|                     for j in l: | ||||
|                         new_labels[i][j] = 1 | ||||
|                     if (l[0] != getNegativeEntityId()): | ||||
|                         for j in l: | ||||
|                             new_labels[i][j] = 1 | ||||
|             labels = new_labels | ||||
|  | ||||
|         features = self.tokenizer.pad( | ||||
|   | ||||
| @@ -314,6 +314,7 @@ def filter_init(head, tail, t1,t2, ent2id_, ent2token_, rel2id_): | ||||
|     global ent2id | ||||
|     global ent2token | ||||
|     global rel2id | ||||
|     global negativeEntity | ||||
|  | ||||
|     head_filter_entities = head | ||||
|     tail_filter_entities = tail | ||||
| @@ -322,11 +323,19 @@ def filter_init(head, tail, t1,t2, ent2id_, ent2token_, rel2id_): | ||||
|     ent2id = ent2id_ | ||||
|     ent2token = ent2token_ | ||||
|     rel2id = rel2id_ | ||||
|     negativeEntity = ent2id['[NEG]'] | ||||
|  | ||||
| def delete_init(ent2text_): | ||||
|     global ent2text | ||||
|     ent2text = ent2text_ | ||||
|  | ||||
| def getEntityIdByName(name): | ||||
|     global ent2id | ||||
|     return ent2id[name] | ||||
|  | ||||
| def getNegativeEntityId(): | ||||
|     global negativeEntity | ||||
|     return negativeEntity | ||||
|  | ||||
| class KGProcessor(DataProcessor): | ||||
|     """Processor for knowledge graph data set.""" | ||||
| @@ -377,6 +386,7 @@ class KGProcessor(DataProcessor): | ||||
|         """Gets all entities in the knowledge graph.""" | ||||
|         with open(self.entity_path, 'r') as f: | ||||
|             lines = f.readlines() | ||||
|             lines.append('[NEG]\t') | ||||
|             entities = [] | ||||
|             for line in lines: | ||||
|                 entities.append(line.strip().split("\t")[0]) | ||||
| @@ -403,6 +413,7 @@ class KGProcessor(DataProcessor): | ||||
|         ent2text_with_type = {} | ||||
|         with open(self.entity_path, 'r') as f: | ||||
|             ent_lines = f.readlines() | ||||
|             ent_lines.append('[NEG]\t') | ||||
|             for line in ent_lines: | ||||
|                 temp = line.strip().split('\t') | ||||
|                 try: | ||||
|   | ||||
| @@ -364,13 +364,17 @@ class CombineEntityEmbeddingLitModel(UseEntityEmbeddingLitModel): | ||||
|         self.id2entity = {} | ||||
|         with open("./dataset/FB15k-237/entity2textlong.txt", 'r') as file: | ||||
|             cnt = 0 | ||||
|             for line in file.readlines(): | ||||
|             lines = file.readlines() | ||||
|             lines.append('[NEG]\t') | ||||
|             for line in lines: | ||||
|                 e, d = line.strip().split("\t") | ||||
|                 self.id2entity[cnt] = e | ||||
|                 cnt += 1 | ||||
|         self.id2entity_t = {} | ||||
|         with open("./dataset/FB15k-237/entity2text.txt", 'r') as file: | ||||
|             for line in file.readlines(): | ||||
|             lines = file.readlines() | ||||
|             lines.append('[NEG]\t') | ||||
|             for line in lines: | ||||
|                 e, d = line.strip().split("\t") | ||||
|                 self.id2entity_t[e] = d | ||||
|         for k, v in self.id2entity.items(): | ||||
|   | ||||
							
								
								
									
										0
									
								
								pretrain/scripts/pretrain_fb15k-237.sh
									
									
									
									
									
										
										
										Normal file → Executable file
									
								
							
							
						
						
									
										0
									
								
								pretrain/scripts/pretrain_fb15k-237.sh
									
									
									
									
									
										
										
										Normal file → Executable file
									
								
							
							
								
								
									
										5
									
								
								scripts/fb15k-237/fb15k-237.sh
									
									
									
									
									
										
										
										Normal file → Executable file
									
								
							
							
						
						
									
										5
									
								
								scripts/fb15k-237/fb15k-237.sh
									
									
									
									
									
										
										
										Normal file → Executable file
									
								
							| @@ -1,13 +1,12 @@ | ||||
| nohup python -u main.py --gpus "1" --max_epochs=16  --num_workers=32 \ | ||||
| nohup python -u main.py --gpus "2," --max_epochs=16  --num_workers=32 \ | ||||
|    --model_name_or_path  bert-base-uncased \ | ||||
|    --accumulate_grad_batches 1 \ | ||||
|    --model_class BertKGC \ | ||||
|    --batch_size 64 \ | ||||
|    --checkpoint /kg_374/Relphormer/pretrain/output/FB15k-237/epoch=15-step=19299-Eval/hits10=0.96.ckpt \ | ||||
|    --checkpoint /kg_374/Relphormer/output/FB15k-237/epoch=1-Eval/hits10=Eval/hits1=0.47-Eval/hits1=0.22.ckpt \ | ||||
|    --pretrain 0 \ | ||||
|    --bce 0 \ | ||||
|    --check_val_every_n_epoch 1 \ | ||||
|    --overwrite_cache \ | ||||
|    --data_dir dataset/FB15k-237 \ | ||||
|    --eval_batch_size 128 \ | ||||
|    --max_seq_length 128 \ | ||||
|   | ||||
		Reference in New Issue
	
	Block a user