Compare commits
	
		
			1 Commits
		
	
	
		
			negative_s
			...
			negative_s
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| 6cc55301ad | 
							
								
								
									
										1
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										1
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							| @@ -7,4 +7,3 @@ dataset/FB15k-237/masked_*.txt | ||||
| dataset/FB15k-237/cached_*.pkl | ||||
| **/__pycache__/ | ||||
| **/.DS_Store | ||||
| nohup.out | ||||
|   | ||||
							
								
								
									
										8
									
								
								.vscode/launch.json
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										8
									
								
								.vscode/launch.json
									
									
									
									
										vendored
									
									
								
							| @@ -12,18 +12,17 @@ | ||||
|             "console": "integratedTerminal", | ||||
|             "justMyCode": true, | ||||
|             "args": [ | ||||
|                 "--gpus", "1",  | ||||
|                 "--gpus", "1,",  | ||||
|                 "--max_epochs=16",   | ||||
|                 "--num_workers=32",  | ||||
|                 "--model_name_or_path",  "bert-base-uncased", | ||||
|                 "--accumulate_grad_batches", "1",  | ||||
|                 "--model_class", "BertKGC", | ||||
|                 "--batch_size", "64", | ||||
|                 "--checkpoint", "/kg_374/Relphormer/pretrain/output/FB15k-237/epoch=15-step=19299-Eval/hits10=0.96.ckpt", | ||||
|                 "--batch_size", "32", | ||||
|                 "--checkpoint", "/root/kg_374/Relphormer/pretrain/output/FB15k-237/epoch=15-step=38899-Eval/hits10=0.96.ckpt", | ||||
|                 "--pretrain", "0", | ||||
|                 "--bce", "0", | ||||
|                 "--check_val_every_n_epoch", "1", | ||||
|                 "--overwrite_cache", | ||||
|                 "--data_dir", "dataset/FB15k-237",  | ||||
|                 "--eval_batch_size", "128", | ||||
|                 "--max_seq_length", "128", | ||||
| @@ -31,7 +30,6 @@ | ||||
|                 "--max_triplet", "64", | ||||
|                 "--add_attn_bias", "True", | ||||
|                 "--use_global_node", "True", | ||||
|                 "--fast_dev_run", "True", | ||||
|             ] | ||||
|         } | ||||
|     ] | ||||
|   | ||||
| @@ -10,7 +10,7 @@ from transformers.tokenization_utils_base import (BatchEncoding, | ||||
|                                                   PreTrainedTokenizerBase) | ||||
|  | ||||
| from .base_data_module import BaseDataModule | ||||
| from .processor import KGProcessor, get_dataset, getNegativeEntityId | ||||
| from .processor import KGProcessor, get_dataset | ||||
| import transformers | ||||
| transformers.logging.set_verbosity_error() | ||||
|  | ||||
| @@ -79,7 +79,6 @@ class DataCollatorForSeq2Seq: | ||||
|     label_pad_token_id: int = -100 | ||||
|     return_tensors: str = "pt" | ||||
|     num_labels: int = 0 | ||||
|     args: Any = None | ||||
|  | ||||
|     def __call__(self, features, return_tensors=None): | ||||
|  | ||||
| @@ -106,7 +105,6 @@ class DataCollatorForSeq2Seq: | ||||
|                 if isinstance(l, int):  | ||||
|                     new_labels[i][l] = 1 | ||||
|                 else: | ||||
|                     if (l[0] != getNegativeEntityId(self.args)): | ||||
|                     for j in l: | ||||
|                         new_labels[i][j] = 1 | ||||
|             labels = new_labels | ||||
| @@ -143,7 +141,6 @@ class KGC(BaseDataModule): | ||||
|             padding="longest", | ||||
|             max_length=self.args.max_seq_length, | ||||
|             num_labels = len(entity_list), | ||||
|             args=args | ||||
|         ) | ||||
|         relations_tokens = self.processor.get_relations(args.data_dir) | ||||
|         self.num_relations = len(relations_tokens) | ||||
|   | ||||
| @@ -5,7 +5,7 @@ import contextlib | ||||
| import sys | ||||
|  | ||||
| from collections import Counter | ||||
| from multiprocessing import Pool, synchronize | ||||
| from multiprocessing import Pool | ||||
| from torch._C import HOIST_CONV_PACKED_PARAMS | ||||
| from torch.utils.data import Dataset, Sampler, IterableDataset | ||||
| from collections import defaultdict | ||||
| @@ -110,7 +110,7 @@ def cache_results(_cache_fp, _refresh=False, _verbose=1): | ||||
|              | ||||
|             model_name = my_args.model_name_or_path.split("/")[-1] | ||||
|             is_pretrain = my_args.pretrain | ||||
|             cache_filepath = os.path.join(my_args.data_dir, f"cached_{func.__name__}_{mode}_features{model_name}_pretrain{is_pretrain}_faiss{my_args.faiss_init}_seqlength{my_args.max_seq_length}_{my_args.litmodel_class}.pkl") | ||||
|             cache_filepath = os.path.join(my_args.data_dir, f"cached_{mode}_features{model_name}_pretrain{is_pretrain}_faiss{my_args.faiss_init}_seqlength{my_args.max_seq_length}_{my_args.litmodel_class}.pkl") | ||||
|             refresh = my_args.overwrite_cache | ||||
|  | ||||
|             if cache_filepath is not None and refresh is False: | ||||
| @@ -137,116 +137,6 @@ def cache_results(_cache_fp, _refresh=False, _verbose=1): | ||||
|  | ||||
|     return wrapper_ | ||||
|  | ||||
| def cache_results_load_once(_cache_fp, _refresh=False, _verbose=1, _global_var=None): | ||||
|     r""" | ||||
|     === USE IN TRAINING MODE ONLY === | ||||
|     cache_results是fastNLP中用于cache数据的装饰器。通过下面的例子看一下如何使用:: | ||||
|  | ||||
|         import time | ||||
|         import numpy as np | ||||
|         from fastNLP import cache_results | ||||
|          | ||||
|         @cache_results('cache.pkl') | ||||
|         def process_data(): | ||||
|             # 一些比较耗时的工作,比如读取数据,预处理数据等,这里用time.sleep()代替耗时 | ||||
|             time.sleep(1) | ||||
|             return np.random.randint(10, size=(5,)) | ||||
|          | ||||
|         start_time = time.time() | ||||
|         print("res =",process_data()) | ||||
|         print(time.time() - start_time) | ||||
|          | ||||
|         start_time = time.time() | ||||
|         print("res =",process_data()) | ||||
|         print(time.time() - start_time) | ||||
|          | ||||
|         # 输出内容如下,可以看到两次结果相同,且第二次几乎没有花费时间 | ||||
|         # Save cache to cache.pkl. | ||||
|         # res = [5 4 9 1 8] | ||||
|         # 1.0042750835418701 | ||||
|         # Read cache from cache.pkl. | ||||
|         # res = [5 4 9 1 8] | ||||
|         # 0.0040721893310546875 | ||||
|  | ||||
|     可以看到第二次运行的时候,只用了0.0001s左右,是由于第二次运行将直接从cache.pkl这个文件读取数据,而不会经过再次预处理:: | ||||
|  | ||||
|         # 还是以上面的例子为例,如果需要重新生成另一个cache,比如另一个数据集的内容,通过如下的方式调用即可 | ||||
|         process_data(_cache_fp='cache2.pkl')  # 完全不影响之前的‘cache.pkl' | ||||
|  | ||||
|     上面的_cache_fp是cache_results会识别的参数,它将从'cache2.pkl'这里缓存/读取数据,即这里的'cache2.pkl'覆盖默认的 | ||||
|     'cache.pkl'。如果在你的函数前面加上了@cache_results()则你的函数会增加三个参数[_cache_fp, _refresh, _verbose]。 | ||||
|     上面的例子即为使用_cache_fp的情况,这三个参数不会传入到你的函数中,当然你写的函数参数名也不可能包含这三个名称:: | ||||
|  | ||||
|         process_data(_cache_fp='cache2.pkl', _refresh=True)  # 这里强制重新生成一份对预处理的cache。 | ||||
|         #  _verbose是用于控制输出信息的,如果为0,则不输出任何内容;如果为1,则会提醒当前步骤是读取的cache还是生成了新的cache | ||||
|  | ||||
|     :param str _cache_fp: 将返回结果缓存到什么位置;或从什么位置读取缓存。如果为None,cache_results没有任何效用,除非在 | ||||
|         函数调用的时候传入_cache_fp这个参数。 | ||||
|     :param bool _refresh: 是否重新生成cache。 | ||||
|     :param int _verbose: 是否打印cache的信息。 | ||||
|     :return: | ||||
|     """ | ||||
|  | ||||
|     def wrapper_(func): | ||||
|         signature = inspect.signature(func) | ||||
|         for key, _ in signature.parameters.items(): | ||||
|             if key in ('_cache_fp', '_refresh', '_verbose', '_global_var'): | ||||
|                 raise RuntimeError("The function decorated by cache_results cannot have keyword `{}`.".format(key)) | ||||
|         v = globals().get(_global_var, None) | ||||
|         if (v is not None): | ||||
|             return v | ||||
|  | ||||
|         def wrapper(*args, **kwargs): | ||||
|              | ||||
|             my_args = args[0] | ||||
|             mode = "train" | ||||
|             if '_cache_fp' in kwargs: | ||||
|                 cache_filepath = kwargs.pop('_cache_fp') | ||||
|                 assert isinstance(cache_filepath, str), "_cache_fp can only be str." | ||||
|             else: | ||||
|                 cache_filepath = _cache_fp | ||||
|             if '_refresh' in kwargs: | ||||
|                 refresh = kwargs.pop('_refresh') | ||||
|                 assert isinstance(refresh, bool), "_refresh can only be bool." | ||||
|             else: | ||||
|                 refresh = _refresh | ||||
|             if '_verbose' in kwargs: | ||||
|                 verbose = kwargs.pop('_verbose') | ||||
|                 assert isinstance(verbose, int), "_verbose can only be integer." | ||||
|             else: | ||||
|                 verbose = _verbose | ||||
|             refresh_flag = True | ||||
|              | ||||
|             model_name = my_args.model_name_or_path.split("/")[-1] | ||||
|             is_pretrain = my_args.pretrain | ||||
|             cache_filepath = os.path.join(my_args.data_dir, f"cached_{func.__name__}_{mode}_features{model_name}_pretrain{is_pretrain}.pkl") | ||||
|             refresh = my_args.overwrite_cache | ||||
|  | ||||
|             if cache_filepath is not None and refresh is False: | ||||
|                 # load data | ||||
|                 if os.path.exists(cache_filepath): | ||||
|                     with open(cache_filepath, 'rb') as f: | ||||
|                         results = pickle.load(f) | ||||
|                         globals()[_global_var] = results | ||||
|                     if verbose == 1: | ||||
|                         logger.info("Read cache from {}.".format(cache_filepath)) | ||||
|                     refresh_flag = False | ||||
|  | ||||
|             if refresh_flag: | ||||
|                 results = func(*args, **kwargs) | ||||
|                 if cache_filepath is not None: | ||||
|                     if results is None: | ||||
|                         raise RuntimeError("The return value is None. Delete the decorator.") | ||||
|                     with open(cache_filepath, 'wb') as f: | ||||
|                         pickle.dump(results, f) | ||||
|                     logger.info("Save cache to {}.".format(cache_filepath)) | ||||
|  | ||||
|             return results | ||||
|  | ||||
|         return wrapper | ||||
|  | ||||
|     return wrapper_ | ||||
|  | ||||
|  | ||||
| import argparse | ||||
| import csv | ||||
| @@ -345,31 +235,6 @@ class DataProcessor(object): | ||||
|  | ||||
| import copy | ||||
|  | ||||
| from collections import deque | ||||
| import threading | ||||
|  | ||||
|  | ||||
| class _LiveState(type): | ||||
|     _instances = {} | ||||
|     _lock = threading.Lock() | ||||
|  | ||||
|     def __call__(cls, *args, **kwargs): | ||||
|         if cls not in cls._instances: | ||||
|             with cls._lock: | ||||
|                 if cls not in cls._instances: | ||||
|                     cls._instances[cls] = super(_LiveState, cls).__call__(*args, **kwargs) | ||||
|          | ||||
|         return cls._instances[cls] | ||||
|  | ||||
| class LiveState(metaclass=_LiveState): | ||||
|  | ||||
|     def __init__(self): | ||||
|         self._pool_size = 4 | ||||
|         self._deq = deque(maxlen=self._pool_size) | ||||
|     def put(self, item): | ||||
|         self._deq.append(item) | ||||
|     def get(self): | ||||
|         return list(self._deq) | ||||
|  | ||||
| def solve_get_knowledge_store(line, set_type="train", pretrain=1): | ||||
|     """ | ||||
| @@ -499,58 +364,6 @@ def solve(line,  set_type="train", pretrain=1, max_triplet=32): | ||||
|             InputExample(guid=guid, text_a="[MASK]", text_b="[PAD]", text_c = "[PAD]", text_d = list(masked_head_seq), label=lmap(lambda x: ent2id[x], b), real_label=ent2id[line[0]], en=[line[1], line[2]], en_id = [rel2id[line[1]], ent2id[line[2]]], rel=rel2id[line[1]], text_d_id = list(masked_head_seq_id), graph_inf = masked_head_graph_list)) | ||||
|         examples.append( | ||||
|             InputExample(guid=guid, text_a="[PAD]", text_b="[PAD]", text_c = "[MASK]", text_d = list(masked_tail_seq), label=lmap(lambda x: ent2id[x], a), real_label=ent2id[line[2]], en=[line[0], line[1]], en_id = [ent2id[line[0]], rel2id[line[1]]], rel=rel2id[line[1]], text_d_id = list(masked_tail_seq_id), graph_inf = masked_tail_graph_list)) | ||||
|          | ||||
|         liveState = LiveState() | ||||
|         _prev = liveState.get() | ||||
|  | ||||
|         if (set_type == "train" and len(_prev) > 0): | ||||
|  | ||||
|             for prev_ent in _prev: | ||||
|  | ||||
|                 if (prev_ent == line[0] or prev_ent == line[2]): | ||||
|                     continue | ||||
|  | ||||
|                 z = head_filter_entities["\t".join([prev_ent,line[1]])] | ||||
|                 if (len(z) == 0): | ||||
|                     z.append('[NEG]') | ||||
|                     z.append(line[2]) | ||||
|                     z.append(line[0]) | ||||
|                 masked_neg_seq = set() | ||||
|                 masked_neg_seq_id = set() | ||||
|  | ||||
|                 masked_neg_graph_list = masked_tail_neighbor.get("\t".join([prev_ent, line[1]]), []) if len(masked_tail_neighbor.get("\t".join([prev_ent, line[1]]), [])) < max_triplet else \ | ||||
|                     random.sample(masked_tail_neighbor["\t".join([prev_ent, line[1]])], max_triplet) | ||||
|                 if (len(masked_head_graph_list) == 0): | ||||
|                     masked_head_graph_list.append(['[NEG]', line[1], '[NEG]']) | ||||
|  | ||||
|                 for item in masked_neg_graph_list: | ||||
|                     masked_neg_seq.add(item[0]) | ||||
|                     masked_neg_seq.add(item[1]) | ||||
|                     masked_neg_seq.add(item[2]) | ||||
|                     masked_neg_seq_id.add(ent2id[item[0]]) | ||||
|                     masked_neg_seq_id.add(rel2id[item[1]]) | ||||
|                     masked_neg_seq_id.add(ent2id[item[2]]) | ||||
|                      | ||||
|                 masked_neg_seq = masked_neg_seq.difference({line[0]}) | ||||
|                 masked_neg_seq = masked_neg_seq.difference({line[2]}) | ||||
|                 masked_neg_seq = masked_neg_seq.difference({line[1]}) | ||||
|                 masked_neg_seq = masked_neg_seq.difference({prev_ent}) | ||||
|                 masked_neg_seq_id = masked_neg_seq_id.difference({ent2id[line[0]]}) | ||||
|                 masked_neg_seq_id = masked_neg_seq_id.difference({rel2id[line[1]]}) | ||||
|                 masked_neg_seq_id = masked_neg_seq_id.difference({ent2id[line[2]]}) | ||||
|                 masked_neg_seq_id = masked_neg_seq_id.difference({ent2id[prev_ent]}) | ||||
|                 # examples.append( | ||||
|                 #     InputExample(guid=guid, text_a="[MASK]", text_b=' '.join(text_b.split(' ')[:16]) + " [PAD]", text_c = "[PAD]" + " " + ' '.join(text_c.split(' ')[:16]), text_d = masked_head_seq, label=lmap(lambda x: ent2id[x], b), real_label=ent2id[line[0]], en=[rel2id[line[1]], ent2id[line[2]]], rel=rel2id[line[1]])) | ||||
|                 # examples.append( | ||||
|                 #     InputExample(guid=guid, text_a="[PAD] ", text_b=' '.join(text_b.split(' ')[:16]) + " [PAD]", text_c = "[MASK]" +" " + ' '.join(text_a.split(' ')[:16]), text_d = masked_tail_seq, label=lmap(lambda x: ent2id[x], a), real_label=ent2id[line[2]], en=[ent2id[line[0]], rel2id[line[1]]], rel=rel2id[line[1]])) | ||||
|                 examples.append( | ||||
|                     InputExample(guid=guid, text_a="[MASK]", text_b="[PAD]", text_c = "[PAD]", text_d = list(masked_neg_seq), label=lmap(lambda x: ent2id[x], z), real_label=ent2id[line[0]], en=[line[1], prev_ent], en_id = [rel2id[line[1]], ent2id[prev_ent]], rel=rel2id[line[1]], text_d_id = list(masked_neg_seq_id), graph_inf = masked_neg_graph_list)) | ||||
|                 examples.append(     | ||||
|                     InputExample(guid=guid, text_a="[MASK]", text_b="[PAD]", text_c = "[PAD]", text_d = list(masked_neg_seq), label=lmap(lambda x: ent2id[x], z), real_label=ent2id[line[2]], en=[line[1], prev_ent], en_id = [rel2id[line[1]], ent2id[prev_ent]], rel=rel2id[line[1]], text_d_id = list(masked_neg_seq_id), graph_inf = masked_neg_graph_list)) | ||||
|          | ||||
|  | ||||
|         liveState.put(line[0]) | ||||
|         liveState.put(line[2]) | ||||
|     return examples | ||||
|  | ||||
| def filter_init(head, tail, t1,t2, ent2id_, ent2token_, rel2id_, masked_head_neighbor_, masked_tail_neighbor_, rel2token_): | ||||
| @@ -564,7 +377,6 @@ def filter_init(head, tail, t1,t2, ent2id_, ent2token_, rel2id_, masked_head_nei | ||||
|     global masked_head_neighbor | ||||
|     global masked_tail_neighbor | ||||
|     global rel2token | ||||
|     # global negativeEntity | ||||
|  | ||||
|     head_filter_entities = head | ||||
|     tail_filter_entities = tail | ||||
| @@ -576,23 +388,11 @@ def filter_init(head, tail, t1,t2, ent2id_, ent2token_, rel2id_, masked_head_nei | ||||
|     masked_head_neighbor = masked_head_neighbor_ | ||||
|     masked_tail_neighbor = masked_tail_neighbor_ | ||||
|     rel2token = rel2token_ | ||||
|     # negativeEntity = ent2id['[NEG]'] | ||||
|     print("Initialized negative entity ID") | ||||
|  | ||||
| def delete_init(ent2text_): | ||||
|     global ent2text | ||||
|     ent2text = ent2text_ | ||||
|  | ||||
| def getEntityIdByName(name): | ||||
|     global ent2id | ||||
|     return ent2id[name] | ||||
|  | ||||
| @cache_results_load_once(_cache_fp="./dataset", _global_var='negativeEntity') | ||||
| def getNegativeEntityId(args): | ||||
|     global negativeEntity | ||||
|     negativeEntity = ent2id['[NEG]'] | ||||
|     return negativeEntity | ||||
|  | ||||
|  | ||||
| class KGProcessor(DataProcessor): | ||||
|     """Processor for knowledge graph data set.""" | ||||
| @@ -643,7 +443,6 @@ class KGProcessor(DataProcessor): | ||||
|         """Gets all entities in the knowledge graph.""" | ||||
|         with open(self.entity_path, 'r') as f: | ||||
|             lines = f.readlines() | ||||
|             lines.append('[NEG]\t') | ||||
|             entities = [] | ||||
|             for line in lines: | ||||
|                 entities.append(line.strip().split("\t")[0]) | ||||
| @@ -670,7 +469,6 @@ class KGProcessor(DataProcessor): | ||||
|         ent2text_with_type = {} | ||||
|         with open(self.entity_path, 'r') as f: | ||||
|             ent_lines = f.readlines() | ||||
|             ent_lines.append('[NEG]\t') | ||||
|             for line in ent_lines: | ||||
|                 temp = line.strip().split('\t') | ||||
|                 try: | ||||
| @@ -772,8 +570,6 @@ class KGProcessor(DataProcessor): | ||||
|         threads = min(1, cpu_count()) | ||||
|         filter_init(head_filter_entities, tail_filter_entities,ent2text, rel2text, ent2id, ent2token, rel2id, masked_head_neighbor, masked_tail_neighbor, rel2token | ||||
|             ) | ||||
|         # cache this | ||||
|         getNegativeEntityId(args) | ||||
|          | ||||
|         if hasattr(args, "faiss_init") and args.faiss_init: | ||||
|             annotate_ = partial( | ||||
| @@ -783,7 +579,6 @@ class KGProcessor(DataProcessor): | ||||
|         else: | ||||
|             annotate_ = partial( | ||||
|                 solve, | ||||
|                 set_type=set_type, | ||||
|                 pretrain=self.args.pretrain, | ||||
|                 max_triplet=self.args.max_triplet | ||||
|             ) | ||||
|   | ||||
| @@ -73,21 +73,63 @@ class TransformerLitModel(BaseLitModel): | ||||
|     def forward(self, x): | ||||
|         return self.model(x) | ||||
|  | ||||
|     def create_negatives(self, batch): | ||||
|         negativeBatches = [] | ||||
|         label = batch['label'] | ||||
|  | ||||
|         for i in range(label.shape[0]): | ||||
|             newBatch = {} | ||||
|             newBatch['attention_mask'] = None | ||||
|             newBatch['input_ids'] = torch.clone(batch['input_ids']) | ||||
|             newBatch['label'] = torch.zeros_like(batch['label']) | ||||
|             negativeBatches.append(newBatch) | ||||
|  | ||||
|         entity_idx = [] | ||||
|         self_label = [] | ||||
|         for idx, l in enumerate(label): | ||||
|             decoded = self.decode([batch['input_ids'][idx]])[0].split(' ') | ||||
|             for j in range(1, len(decoded)): | ||||
|                 if (decoded[j].startswith("[ENTITY_")): | ||||
|                     entity_idx.append(j) | ||||
|                     self_label.append(batch['input_ids'][idx][j])    | ||||
|                     break | ||||
|          | ||||
|         for idx, lbl in enumerate(label): | ||||
|             for i in range(label.shape[0]): | ||||
|                 if (negativeBatches[idx]['input_ids'][i][entity_idx[i]] != lbl): | ||||
|                     negativeBatches[idx]['input_ids'][i][entity_idx[i]] = lbl | ||||
|                 else: | ||||
|                     negativeBatches[idx]['input_ids'][i][entity_idx[i]] = self_label[i] | ||||
|                  | ||||
|         return negativeBatches | ||||
|  | ||||
|     def training_step(self, batch, batch_idx):  # pylint: disable=unused-argument | ||||
|         # embed();exit() | ||||
|         # print(self.optimizers().param_groups[1]['lr']) | ||||
|  | ||||
|         negativeBatches = self.create_negatives(batch) | ||||
|  | ||||
|         loss = 0 | ||||
|  | ||||
|         for negativeBatch in negativeBatches: | ||||
|             label = negativeBatch.pop("label") | ||||
|             input_ids = batch['input_ids'] | ||||
|              | ||||
|             logits = self.model(**negativeBatch, return_dict=True, distance_attention=None).logits | ||||
|             _, mask_idx = (input_ids == self.tokenizer.mask_token_id).nonzero(as_tuple=True) | ||||
|             bs = input_ids.shape[0] | ||||
|             mask_logits = logits[torch.arange(bs), mask_idx][:, self.entity_id_st:self.entity_id_ed] | ||||
|             loss += self.loss_fn(mask_logits, label) | ||||
|  | ||||
|         labels = batch.pop("labels") | ||||
|         label = batch.pop("label") | ||||
|         pos = batch.pop("pos") | ||||
|         try: | ||||
|             en = batch.pop("en") | ||||
|             # self.print("__DEBUG__: en", en) | ||||
|             rel = batch.pop("rel") | ||||
|             # self.print("__DEBUG__: rel", rel) | ||||
|         except KeyError: | ||||
|             pass | ||||
|         input_ids = batch['input_ids'] | ||||
|         # self.print("__DEBUG__: input_ids", input_ids) | ||||
|  | ||||
|         distance_attention = torch.stack([pad_distance(len(input_ids[i]) - len(distance) - 1, distance) for i, distance in enumerate(batch['distance_attention'])]) | ||||
|         distance = batch.pop("distance_attention") | ||||
| @@ -113,9 +155,9 @@ class TransformerLitModel(BaseLitModel): | ||||
|  | ||||
|         assert mask_idx.shape[0] == bs, "only one mask in sequence!" | ||||
|         if self.args.bce: | ||||
|             loss = self.loss_fn(mask_logits, labels) | ||||
|             loss += self.loss_fn(mask_logits, labels) | ||||
|         else: | ||||
|             loss = self.loss_fn(mask_logits, label) | ||||
|             loss += self.loss_fn(mask_logits, label) | ||||
|  | ||||
|         if batch_idx == 0: | ||||
|             print('\n'.join(self.decode(batch['input_ids'][:4]))) | ||||
| @@ -385,17 +427,13 @@ class CombineEntityEmbeddingLitModel(UseEntityEmbeddingLitModel): | ||||
|         self.id2entity = {} | ||||
|         with open("./dataset/FB15k-237/entity2textlong.txt", 'r') as file: | ||||
|             cnt = 0 | ||||
|             lines = file.readlines() | ||||
|             lines.append('[NEG]\t') | ||||
|             for line in lines: | ||||
|             for line in file.readlines(): | ||||
|                 e, d = line.strip().split("\t") | ||||
|                 self.id2entity[cnt] = e | ||||
|                 cnt += 1 | ||||
|         self.id2entity_t = {} | ||||
|         with open("./dataset/FB15k-237/entity2text.txt", 'r') as file: | ||||
|             lines = file.readlines() | ||||
|             lines.append('[NEG]\t') | ||||
|             for line in lines: | ||||
|             for line in file.readlines(): | ||||
|                 e, d = line.strip().split("\t") | ||||
|                 self.id2entity_t[e] = d | ||||
|         for k, v in self.id2entity.items(): | ||||
|   | ||||
							
								
								
									
										4
									
								
								main.py
									
									
									
									
									
								
							
							
						
						
									
										4
									
								
								main.py
									
									
									
									
									
								
							| @@ -98,7 +98,6 @@ def main(): | ||||
|     tokenizer = data.tokenizer | ||||
|  | ||||
|     lit_model = litmodel_class(args=args, model=model, tokenizer=tokenizer, data_config=data.get_config()) | ||||
|     print("__DEBUG__: Initialized") | ||||
|     if args.checkpoint: | ||||
|         lit_model.load_state_dict(torch.load(args.checkpoint, map_location="cpu")["state_dict"], strict=False) | ||||
|  | ||||
| @@ -123,10 +122,7 @@ def main(): | ||||
|     # args.weights_summary = "full"  # Print full summary of the model | ||||
|     trainer = pl.Trainer.from_argparse_args(args, callbacks=callbacks, logger=logger, default_root_dir="training/logs") | ||||
|      | ||||
|     print('__DEBUG__: Init trainer') | ||||
|  | ||||
|     if "EntityEmbedding" not in lit_model.__class__.__name__: | ||||
|         print('__DEBUG__: Fit trainer') | ||||
|         trainer.fit(lit_model, datamodule=data) | ||||
|         path = model_checkpoint.best_model_path | ||||
|         lit_model.load_state_dict(torch.load(path)["state_dict"], strict=False) | ||||
|   | ||||
| @@ -10,7 +10,7 @@ from transformers.tokenization_utils_base import (BatchEncoding, | ||||
|                                                   PreTrainedTokenizerBase) | ||||
|  | ||||
| from .base_data_module import BaseDataModule | ||||
| from .processor import KGProcessor, get_dataset, getNegativeEntityId | ||||
| from .processor import KGProcessor, get_dataset | ||||
| import transformers | ||||
| transformers.logging.set_verbosity_error() | ||||
|  | ||||
| @@ -106,7 +106,6 @@ class DataCollatorForSeq2Seq: | ||||
|                 if isinstance(l, int):  | ||||
|                     new_labels[i][l] = 1 | ||||
|                 else: | ||||
|                     if (l[0] != getNegativeEntityId()): | ||||
|                     for j in l: | ||||
|                         new_labels[i][j] = 1 | ||||
|             labels = new_labels | ||||
|   | ||||
| @@ -314,7 +314,6 @@ def filter_init(head, tail, t1,t2, ent2id_, ent2token_, rel2id_): | ||||
|     global ent2id | ||||
|     global ent2token | ||||
|     global rel2id | ||||
|     global negativeEntity | ||||
|  | ||||
|     head_filter_entities = head | ||||
|     tail_filter_entities = tail | ||||
| @@ -323,19 +322,11 @@ def filter_init(head, tail, t1,t2, ent2id_, ent2token_, rel2id_): | ||||
|     ent2id = ent2id_ | ||||
|     ent2token = ent2token_ | ||||
|     rel2id = rel2id_ | ||||
|     negativeEntity = ent2id['[NEG]'] | ||||
|  | ||||
| def delete_init(ent2text_): | ||||
|     global ent2text | ||||
|     ent2text = ent2text_ | ||||
|  | ||||
| def getEntityIdByName(name): | ||||
|     global ent2id | ||||
|     return ent2id[name] | ||||
|  | ||||
| def getNegativeEntityId(): | ||||
|     global negativeEntity | ||||
|     return negativeEntity | ||||
|  | ||||
| class KGProcessor(DataProcessor): | ||||
|     """Processor for knowledge graph data set.""" | ||||
| @@ -386,7 +377,6 @@ class KGProcessor(DataProcessor): | ||||
|         """Gets all entities in the knowledge graph.""" | ||||
|         with open(self.entity_path, 'r') as f: | ||||
|             lines = f.readlines() | ||||
|             lines.append('[NEG]\t') | ||||
|             entities = [] | ||||
|             for line in lines: | ||||
|                 entities.append(line.strip().split("\t")[0]) | ||||
| @@ -413,7 +403,6 @@ class KGProcessor(DataProcessor): | ||||
|         ent2text_with_type = {} | ||||
|         with open(self.entity_path, 'r') as f: | ||||
|             ent_lines = f.readlines() | ||||
|             ent_lines.append('[NEG]\t') | ||||
|             for line in ent_lines: | ||||
|                 temp = line.strip().split('\t') | ||||
|                 try: | ||||
|   | ||||
| @@ -364,17 +364,13 @@ class CombineEntityEmbeddingLitModel(UseEntityEmbeddingLitModel): | ||||
|         self.id2entity = {} | ||||
|         with open("./dataset/FB15k-237/entity2textlong.txt", 'r') as file: | ||||
|             cnt = 0 | ||||
|             lines = file.readlines() | ||||
|             lines.append('[NEG]\t') | ||||
|             for line in lines: | ||||
|             for line in file.readlines(): | ||||
|                 e, d = line.strip().split("\t") | ||||
|                 self.id2entity[cnt] = e | ||||
|                 cnt += 1 | ||||
|         self.id2entity_t = {} | ||||
|         with open("./dataset/FB15k-237/entity2text.txt", 'r') as file: | ||||
|             lines = file.readlines() | ||||
|             lines.append('[NEG]\t') | ||||
|             for line in lines: | ||||
|             for line in file.readlines(): | ||||
|                 e, d = line.strip().split("\t") | ||||
|                 self.id2entity_t[e] = d | ||||
|         for k, v in self.id2entity.items(): | ||||
|   | ||||
| @@ -1,13 +1,13 @@ | ||||
| nohup python -u main.py --gpus "1" --max_epochs=16  --num_workers=32 \ | ||||
| nohup python -u main.py --gpus "1," --max_epochs=16  --num_workers=32 \ | ||||
|    --model_name_or_path  bert-base-uncased \ | ||||
|    --accumulate_grad_batches 1 \ | ||||
|    --model_class BertKGC \ | ||||
|    --batch_size 128 \ | ||||
|    --batch_size 64 \ | ||||
|    --pretrain 1 \ | ||||
|    --bce 0 \ | ||||
|    --check_val_every_n_epoch 1 \ | ||||
|    --overwrite_cache \ | ||||
|    --data_dir /kg_374/Relphormer/dataset/FB15k-237 \ | ||||
|    --data_dir /root/kg_374/Relphormer/dataset/FB15k-237 \ | ||||
|    --eval_batch_size 256 \ | ||||
|    --max_seq_length 64 \ | ||||
|    --lr 1e-4 \ | ||||
|   | ||||
| @@ -1,9 +1,9 @@ | ||||
| nohup python -u main.py --gpus "2," --max_epochs=16  --num_workers=32 \ | ||||
| nohup python -u main.py --gpus "1," --max_epochs=16  --num_workers=32 \ | ||||
|    --model_name_or_path  bert-base-uncased \ | ||||
|    --accumulate_grad_batches 1 \ | ||||
|    --model_class BertKGC \ | ||||
|    --batch_size 64 \ | ||||
|    --checkpoint /kg_374/Relphormer/output/FB15k-237/epoch=1-Eval/hits10=Eval/hits1=0.47-Eval/hits1=0.22.ckpt \ | ||||
|    --batch_size 16 \ | ||||
|    --checkpoint /root/kg_374/Relphormer/pretrain/output/FB15k-237/epoch\=15-step\=38899-Eval/hits10=0.96.ckpt \ | ||||
|    --pretrain 0 \ | ||||
|    --bce 0 \ | ||||
|    --check_val_every_n_epoch 1 \ | ||||
|   | ||||
		Reference in New Issue
	
	Block a user