Thesis/dataset/create_neighbor.py

152 lines
4.6 KiB
Python
Raw Permalink Normal View History

2022-12-26 04:54:46 +00:00
from collections import defaultdict
import time
import argparse
id2entity_name = defaultdict(str)
parser = argparse.ArgumentParser()
parser.add_argument("--dataset", type=str, default=None)
args = parser.parse_args()
# dataset_name = 'FB15k-237'
with open('./' + args.dataset + '/get_neighbor/entity2id.txt', 'r') as file:
entity_lines = file.readlines()
for line in entity_lines:
_name, _id = line.strip().split("\t")
id2entity_name[int(_id)] = _name
id2relation_name = defaultdict(str)
with open('./' + args.dataset + '/get_neighbor/relation2id.txt', 'r') as file:
relation_lines = file.readlines()
for line in relation_lines:
_name, _id = line.strip().split("\t")
id2relation_name[int(_id)] = _name
train_triplet = []
for line in open('./' + args.dataset + '/get_neighbor/train2id.txt', 'r'):
head, relation, tail = line.strip('\n').split()
train_triplet.append(list((int(head), int(relation), int(tail))))
for line in open('./' + args.dataset + '/get_neighbor/test2id.txt', 'r'):
head, relation, tail = line.strip('\n').split()
train_triplet.append(list((int(head), int(relation), int(tail))))
for line in open('./'+args.dataset+'/get_neighbor/valid2id.txt', 'r'):
head, relation, tail = line.strip('\n').split()
train_triplet.append(list((int(head), int(relation), int(tail))))
graph = {}
reverse_graph = {}
def init_graph(graph_triplet):
for triple in graph_triplet:
head = triple[0]
rela = triple[1]
tail = triple[2]
if(head not in graph.keys()):
graph[head] = {}
graph[head][tail] = rela
else:
graph[head][tail] = rela
if(tail not in reverse_graph.keys()):
reverse_graph[tail] = {}
reverse_graph[tail][head] = rela
else:
reverse_graph[tail][head] = rela
# return graph, reverse_graph, node_indegree, node_outdegree
init_graph(train_triplet)
import random
def random_delete(triplet, reserved_num):
reserved = random.sample(triplet, reserved_num)
return reserved
def get_onestep_neighbors(graph, source, sample_num):
triplet = []
try:
nei = list(graph[source].keys())
# nei = random.sample(graph[source].keys(), sample_num)
triplet = [tuple((source, graph[source][nei[i]], nei[i])) for i in range(len(nei))]
except KeyError:
pass
except ValueError:
nei = list(graph[source].keys())
triplet = [tuple((source, graph[source][nei[i]], nei[i])) for i in range(len(nei))]
return triplet
def get_entity_neighbors(traget_entity, max_triplet):
as_head_neighbors = get_onestep_neighbors(graph, traget_entity, max_triplet // 2)
as_tail_neighbors = get_onestep_neighbors(reverse_graph, traget_entity, max_triplet // 2)
all_triplet = as_head_neighbors + as_tail_neighbors
return all_triplet
def get_triplet(triplet):
head_entity = triplet[0]
tail_entity = triplet[2]
triplet = tuple((triplet[0], triplet[1], triplet[2]))
head_triplet = get_entity_neighbors(head_entity, 4)
tail_triplet = get_entity_neighbors(tail_entity, 4)
temp_triplet = list(set(head_triplet + tail_triplet))
temp_triplet = list(set(temp_triplet) - set([triplet]))
# if len(temp_triplet) > 8:
# del_triplet = list(set(temp_triplet) - set([triplet]))
# temp_triplet = random_delete(del_triplet, 7)
return temp_triplet
import copy
def change_(triplet_list):
tri_text = []
for item in triplet_list:
# text = id2entity_name[item[0]] + '\t' + id2relation_name[item[1]] + '\t' + id2entity_name[item[2]]
h = id2entity_name[item[0]]
r = id2relation_name[item[1]]
t = id2entity_name[item[2]]
tri_text.append([h, r, t])
return tri_text
mask_idx = 99999999
masked_tail_neighbor = defaultdict(list)
masked_head_neighbor = defaultdict(list)
for triplet in train_triplet:
tail_masked = copy.deepcopy(triplet)
head_masked = copy.deepcopy(triplet)
tail_masked[2] = mask_idx
head_masked[0] = mask_idx
masked_tail_neighbor['\t'.join([id2entity_name[triplet[0]], id2relation_name[triplet[1]]])] = change_(get_triplet(tail_masked))
masked_head_neighbor['\t'.join([id2entity_name[triplet[2]], id2relation_name[triplet[1]]])] = change_(get_triplet(head_masked))
import json
with open("./" + args.dataset + "/masked_tail_neighbor.txt", "w") as file:
file.write(json.dumps(masked_tail_neighbor, indent=1))
with open("./" + args.dataset + "/masked_head_neighbor.txt", "w") as file:
file.write(json.dumps(masked_head_neighbor, indent=1))