2023-05-04 08:49:41 +00:00
|
|
|
import os
|
|
|
|
import uuid
|
|
|
|
import argparse
|
|
|
|
import logging
|
|
|
|
import logging.config
|
2023-06-24 04:11:17 +00:00
|
|
|
import pandas as pd
|
|
|
|
import sys
|
2023-05-04 08:49:41 +00:00
|
|
|
|
|
|
|
import torch
|
|
|
|
import numpy as np
|
2023-06-08 06:40:16 +00:00
|
|
|
import time
|
2023-05-04 08:49:41 +00:00
|
|
|
|
|
|
|
from collections import defaultdict as ddict
|
|
|
|
from pprint import pprint
|
|
|
|
|
|
|
|
from ordered_set import OrderedSet
|
|
|
|
from torch.utils.data import DataLoader
|
|
|
|
|
|
|
|
from data_loader import TrainDataset, TestDataset
|
|
|
|
from utils import get_logger, get_combined_results, set_gpu, prepare_env, set_seed
|
|
|
|
|
|
|
|
from models import ComplEx, ConvE, HypER, InteractE, FouriER, TuckER
|
2024-04-27 03:18:48 +00:00
|
|
|
import traceback
|
2023-05-04 08:49:41 +00:00
|
|
|
|
|
|
|
|
|
|
|
class Main(object):
|
|
|
|
|
2023-06-08 06:40:16 +00:00
|
|
|
def __init__(self, params, logger):
|
2023-05-04 08:49:41 +00:00
|
|
|
"""
|
|
|
|
Constructor of the runner class
|
|
|
|
Parameters
|
|
|
|
----------
|
|
|
|
params: List of hyper-parameters of the model
|
|
|
|
|
|
|
|
Returns
|
|
|
|
-------
|
|
|
|
Creates computational graph and optimizer
|
|
|
|
|
|
|
|
"""
|
|
|
|
self.p = params
|
2023-06-08 06:40:16 +00:00
|
|
|
self.logger = logger
|
2023-05-04 08:49:41 +00:00
|
|
|
|
|
|
|
self.logger.info(vars(self.p))
|
|
|
|
|
|
|
|
if self.p.gpu != '-1' and torch.cuda.is_available():
|
|
|
|
self.device = torch.device('cuda')
|
|
|
|
torch.cuda.set_rng_state(torch.cuda.get_rng_state())
|
|
|
|
torch.backends.cudnn.deterministic = True
|
|
|
|
else:
|
|
|
|
self.device = torch.device('cpu')
|
|
|
|
|
|
|
|
self.load_data()
|
|
|
|
self.model = self.add_model()
|
|
|
|
self.optimizer = self.add_optimizer(self.model.parameters())
|
|
|
|
|
|
|
|
def load_data(self):
|
|
|
|
"""
|
|
|
|
Reading in raw triples and converts it into a standard format.
|
|
|
|
Parameters
|
|
|
|
----------
|
|
|
|
self.p.dataset: Takes in the name of the dataset (FB15k-237, WN18RR, YAGO3-10)
|
|
|
|
|
|
|
|
Returns
|
|
|
|
-------
|
|
|
|
self.ent2id: Entity to unique identifier mapping
|
|
|
|
self.id2rel: Inverse mapping of self.ent2id
|
|
|
|
self.rel2id: Relation to unique identifier mapping
|
|
|
|
self.num_ent: Number of entities in the Knowledge graph
|
|
|
|
self.num_rel: Number of relations in the Knowledge graph
|
|
|
|
self.embed_dim: Embedding dimension used
|
|
|
|
self.data['train']: Stores the triples corresponding to training dataset
|
|
|
|
self.data['valid']: Stores the triples corresponding to validation dataset
|
|
|
|
self.data['test']: Stores the triples corresponding to test dataset
|
|
|
|
self.data_iter: The dataloader for different data splits
|
|
|
|
self.chequer_perm: Stores the Chequer reshaping arrangement
|
|
|
|
"""
|
|
|
|
|
|
|
|
ent_set, rel_set = OrderedSet(), OrderedSet()
|
|
|
|
for split in ['train', 'test', 'valid']:
|
|
|
|
for line in open('./data/{}/{}.txt'.format(self.p.dataset, split)):
|
2023-06-08 06:40:16 +00:00
|
|
|
sub, rel, obj, *_ = map(str.lower, line.strip().split('\t'))
|
2023-05-04 08:49:41 +00:00
|
|
|
ent_set.add(sub)
|
|
|
|
rel_set.add(rel)
|
|
|
|
ent_set.add(obj)
|
2023-05-13 17:33:59 +00:00
|
|
|
|
|
|
|
self.ent2id = {}
|
|
|
|
for line in open('./data/{}/{}'.format(self.p.dataset, "entities.dict")):
|
2023-06-24 04:11:17 +00:00
|
|
|
id, ent = map(str.lower, line.replace('\xa0', '').strip().split('\t'))
|
2023-05-13 17:33:59 +00:00
|
|
|
self.ent2id[ent] = int(id)
|
|
|
|
self.rel2id = {}
|
|
|
|
for line in open('./data/{}/{}'.format(self.p.dataset, "relations.dict")):
|
|
|
|
id, rel = map(str.lower, line.strip().split('\t'))
|
|
|
|
self.rel2id[rel] = int(id)
|
|
|
|
|
|
|
|
# self.ent2id = {ent: idx for idx, ent in enumerate(ent_set)}
|
|
|
|
# self.rel2id = {rel: idx for idx, rel in enumerate(rel_set)}
|
2023-05-04 08:49:41 +00:00
|
|
|
self.rel2id.update({rel+'_reverse': idx+len(self.rel2id)
|
|
|
|
for idx, rel in enumerate(rel_set)})
|
|
|
|
|
|
|
|
self.id2ent = {idx: ent for ent, idx in self.ent2id.items()}
|
|
|
|
self.id2rel = {idx: rel for rel, idx in self.rel2id.items()}
|
|
|
|
|
|
|
|
self.p.num_ent = len(self.ent2id)
|
|
|
|
self.p.num_rel = len(self.rel2id) // 2
|
|
|
|
self.p.embed_dim = self.p.k_w * \
|
|
|
|
self.p.k_h if self.p.embed_dim is None else self.p.embed_dim
|
|
|
|
|
|
|
|
self.data = ddict(list)
|
|
|
|
sr2o = ddict(set)
|
|
|
|
|
|
|
|
for split in ['train', 'test', 'valid']:
|
|
|
|
for line in open('./data/{}/{}.txt'.format(self.p.dataset, split)):
|
2023-06-24 04:11:17 +00:00
|
|
|
sub, rel, obj, *_ = map(str.lower, line.replace('\xa0', '').strip().split('\t'))
|
2024-06-16 12:09:47 +00:00
|
|
|
nt_rel = rel.split('[')[0]
|
|
|
|
sub, rel, obj, nt_rel = self.ent2id[sub], self.rel2id[rel], self.ent2id[obj], self.rel2id[nt_rel]
|
|
|
|
self.data[split].append((sub, rel, obj, nt_rel))
|
2023-05-04 08:49:41 +00:00
|
|
|
|
|
|
|
if split == 'train':
|
2024-06-16 12:09:47 +00:00
|
|
|
sr2o[(sub, rel, nt_rel)].add(obj)
|
|
|
|
sr2o[(obj, rel+self.p.num_rel, nt_rel + self.p.num_rel)].add(sub)
|
2023-05-04 08:49:41 +00:00
|
|
|
self.data = dict(self.data)
|
|
|
|
|
|
|
|
self.sr2o = {k: list(v) for k, v in sr2o.items()}
|
|
|
|
for split in ['test', 'valid']:
|
2024-06-16 12:09:47 +00:00
|
|
|
for sub, rel, obj, nt_rel in self.data[split]:
|
|
|
|
sr2o[(sub, rel, nt_rel)].add(obj)
|
|
|
|
sr2o[(obj, rel+self.p.num_rel, nt_rel + self.p.num_rel)].add(sub)
|
2023-05-04 08:49:41 +00:00
|
|
|
|
|
|
|
self.sr2o_all = {k: list(v) for k, v in sr2o.items()}
|
|
|
|
|
|
|
|
self.triples = ddict(list)
|
|
|
|
|
|
|
|
if self.p.train_strategy == 'one_to_n':
|
2024-06-16 12:09:47 +00:00
|
|
|
for (sub, rel, nt_rel), obj in self.sr2o.items():
|
2023-05-04 08:49:41 +00:00
|
|
|
self.triples['train'].append(
|
2024-06-16 12:09:47 +00:00
|
|
|
{'triple': (sub, rel, -1, nt_rel), 'label': self.sr2o[(sub, rel, nt_rel)], 'sub_samp': 1})
|
2023-05-04 08:49:41 +00:00
|
|
|
else:
|
2024-06-16 12:09:47 +00:00
|
|
|
for sub, rel, obj, nt_rel in self.data['train']:
|
2023-05-04 08:49:41 +00:00
|
|
|
rel_inv = rel + self.p.num_rel
|
2024-06-16 12:09:47 +00:00
|
|
|
sub_samp = len(self.sr2o[(sub, rel, nt_rel)]) + \
|
2024-06-18 17:03:14 +00:00
|
|
|
len(self.sr2o[(obj, rel_inv, nt_rel + self.p.num_rel)])
|
2023-05-04 08:49:41 +00:00
|
|
|
sub_samp = np.sqrt(1/sub_samp)
|
|
|
|
|
|
|
|
self.triples['train'].append({'triple': (
|
2024-06-16 12:09:47 +00:00
|
|
|
sub, rel, obj, nt_rel), 'label': self.sr2o[(sub, rel, nt_rel)], 'sub_samp': sub_samp})
|
2023-05-04 08:49:41 +00:00
|
|
|
self.triples['train'].append({'triple': (
|
2024-06-18 17:03:14 +00:00
|
|
|
obj, rel_inv, sub, nt_rel + self.p.num_rel), 'label': self.sr2o[(obj, rel_inv, nt_rel + self.p.num_rel)], 'sub_samp': sub_samp})
|
2023-05-04 08:49:41 +00:00
|
|
|
|
|
|
|
for split in ['test', 'valid']:
|
2024-06-16 12:09:47 +00:00
|
|
|
for sub, rel, obj, nt_rel in self.data[split]:
|
2023-05-04 08:49:41 +00:00
|
|
|
rel_inv = rel + self.p.num_rel
|
|
|
|
self.triples['{}_{}'.format(split, 'tail')].append(
|
2024-06-16 12:09:47 +00:00
|
|
|
{'triple': (sub, rel, obj, nt_rel), 'label': self.sr2o_all[(sub, rel, nt_rel)]})
|
2023-05-04 08:49:41 +00:00
|
|
|
self.triples['{}_{}'.format(split, 'head')].append(
|
2024-06-16 12:09:47 +00:00
|
|
|
{'triple': (obj, rel_inv, sub, nt_rel + self.p.num_rel), 'label': self.sr2o_all[(obj, rel_inv, nt_rel + self.p.num_rel)]})
|
2023-05-04 08:49:41 +00:00
|
|
|
|
|
|
|
self.triples = dict(self.triples)
|
|
|
|
|
|
|
|
def get_data_loader(dataset_class, split, batch_size, shuffle=True):
|
|
|
|
return DataLoader(
|
|
|
|
dataset_class(self.triples[split], self.p),
|
|
|
|
batch_size=batch_size,
|
|
|
|
shuffle=shuffle,
|
|
|
|
num_workers=max(0, self.p.num_workers),
|
|
|
|
collate_fn=dataset_class.collate_fn
|
|
|
|
)
|
|
|
|
|
|
|
|
self.data_iter = {
|
|
|
|
'train' : get_data_loader(TrainDataset, 'train', self.p.batch_size),
|
|
|
|
'valid_head' : get_data_loader(TestDataset, 'valid_head', self.p.batch_size),
|
|
|
|
'valid_tail' : get_data_loader(TestDataset, 'valid_tail', self.p.batch_size),
|
|
|
|
'test_head' : get_data_loader(TestDataset, 'test_head', self.p.batch_size),
|
|
|
|
'test_tail' : get_data_loader(TestDataset, 'test_tail', self.p.batch_size),
|
|
|
|
}
|
|
|
|
|
|
|
|
self.chequer_perm = self.get_chequer_perm()
|
|
|
|
|
|
|
|
def get_chequer_perm(self):
|
|
|
|
"""
|
|
|
|
Function to generate the chequer permutation required for InteractE model
|
|
|
|
Parameters
|
|
|
|
----------
|
|
|
|
|
|
|
|
Returns
|
|
|
|
-------
|
|
|
|
|
|
|
|
"""
|
|
|
|
ent_perm = np.int32([np.random.permutation(self.p.embed_dim)
|
|
|
|
for _ in range(self.p.perm)])
|
|
|
|
rel_perm = np.int32([np.random.permutation(self.p.embed_dim)
|
|
|
|
for _ in range(self.p.perm)])
|
|
|
|
|
|
|
|
comb_idx = []
|
|
|
|
for k in range(self.p.perm):
|
|
|
|
temp = []
|
|
|
|
ent_idx, rel_idx = 0, 0
|
|
|
|
|
|
|
|
for i in range(self.p.k_h):
|
|
|
|
for j in range(self.p.k_w):
|
|
|
|
if k % 2 == 0:
|
|
|
|
if i % 2 == 0:
|
|
|
|
temp.append(ent_perm[k, ent_idx])
|
|
|
|
ent_idx += 1
|
|
|
|
temp.append(rel_perm[k, rel_idx]+self.p.embed_dim)
|
|
|
|
rel_idx += 1
|
|
|
|
else:
|
|
|
|
temp.append(rel_perm[k, rel_idx]+self.p.embed_dim)
|
|
|
|
rel_idx += 1
|
|
|
|
temp.append(ent_perm[k, ent_idx])
|
|
|
|
ent_idx += 1
|
|
|
|
else:
|
|
|
|
if i % 2 == 0:
|
|
|
|
temp.append(rel_perm[k, rel_idx]+self.p.embed_dim)
|
|
|
|
rel_idx += 1
|
|
|
|
temp.append(ent_perm[k, ent_idx])
|
|
|
|
ent_idx += 1
|
|
|
|
else:
|
|
|
|
temp.append(ent_perm[k, ent_idx])
|
|
|
|
ent_idx += 1
|
|
|
|
temp.append(rel_perm[k, rel_idx]+self.p.embed_dim)
|
|
|
|
rel_idx += 1
|
|
|
|
|
|
|
|
comb_idx.append(temp)
|
|
|
|
|
|
|
|
chequer_perm = torch.LongTensor(np.int32(comb_idx)).to(self.device)
|
|
|
|
return chequer_perm
|
|
|
|
|
|
|
|
def add_model(self):
|
|
|
|
"""
|
|
|
|
Creates the computational graph
|
|
|
|
Parameters
|
|
|
|
----------
|
|
|
|
|
|
|
|
Returns
|
|
|
|
-------
|
|
|
|
Creates the computational graph for model and initializes it
|
|
|
|
|
|
|
|
"""
|
|
|
|
model = FouriER(self.p)
|
|
|
|
model.to(self.device)
|
|
|
|
return model
|
|
|
|
|
|
|
|
def add_optimizer(self, parameters):
|
|
|
|
"""
|
|
|
|
Creates an optimizer for training the parameters
|
|
|
|
Parameters
|
|
|
|
----------
|
|
|
|
parameters: The parameters of the model
|
|
|
|
|
|
|
|
Returns
|
|
|
|
-------
|
|
|
|
Returns an optimizer for learning the parameters of the model
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
if self.p.opt == 'adam':
|
|
|
|
return torch.optim.Adam(parameters, lr=self.p.lr, weight_decay=self.p.l2)
|
|
|
|
else:
|
|
|
|
return torch.optim.SGD(parameters, lr=self.p.lr, weight_decay=self.p.l2)
|
|
|
|
|
|
|
|
|
|
|
|
def read_batch(self, batch, split):
|
|
|
|
"""
|
|
|
|
Function to read a batch of data and move the tensors in batch to CPU/GPU
|
|
|
|
Parameters
|
|
|
|
----------
|
|
|
|
batch: the batch to process
|
|
|
|
split: (string) If split == 'train', 'valid' or 'test' split
|
|
|
|
|
|
|
|
Returns
|
|
|
|
-------
|
|
|
|
triples: The triples used for this split
|
|
|
|
labels: The label for each triple
|
|
|
|
"""
|
|
|
|
if split == 'train':
|
|
|
|
if self.p.train_strategy == 'one_to_x':
|
|
|
|
triple, label, neg_ent, sub_samp = [
|
|
|
|
_.to(self.device) for _ in batch]
|
2024-06-16 12:09:47 +00:00
|
|
|
return triple[:, 0], triple[:, 1], triple[:, 2], triple[:, 3], label, neg_ent, sub_samp
|
2023-05-04 08:49:41 +00:00
|
|
|
else:
|
|
|
|
triple, label = [_.to(self.device) for _ in batch]
|
2024-06-16 12:09:47 +00:00
|
|
|
return triple[:, 0], triple[:, 1], triple[:, 2], triple[:, 3], label, None, None
|
2023-05-04 08:49:41 +00:00
|
|
|
else:
|
|
|
|
triple, label = [_.to(self.device) for _ in batch]
|
2024-06-16 12:09:47 +00:00
|
|
|
return triple[:, 0], triple[:, 1], triple[:, 2], triple[:, 3], label
|
2023-05-04 08:49:41 +00:00
|
|
|
|
|
|
|
def save_model(self, save_path):
|
|
|
|
"""
|
|
|
|
Function to save a model. It saves the model parameters, best validation scores,
|
|
|
|
best epoch corresponding to best validation, state of the optimizer and all arguments for the run.
|
|
|
|
Parameters
|
|
|
|
----------
|
|
|
|
save_path: path where the model is saved
|
|
|
|
|
|
|
|
Returns
|
|
|
|
-------
|
|
|
|
"""
|
|
|
|
state = {
|
|
|
|
'state_dict' : self.model.state_dict(),
|
|
|
|
'best_val' : self.best_val,
|
|
|
|
'best_epoch' : self.best_epoch,
|
|
|
|
'optimizer' : self.optimizer.state_dict(),
|
|
|
|
'args' : vars(self.p)
|
|
|
|
}
|
|
|
|
torch.save(state, save_path)
|
|
|
|
|
|
|
|
def load_model(self, load_path):
|
|
|
|
"""
|
|
|
|
Function to load a saved model
|
|
|
|
Parameters
|
|
|
|
----------
|
|
|
|
load_path: path to the saved model
|
|
|
|
|
|
|
|
Returns
|
|
|
|
-------
|
|
|
|
"""
|
|
|
|
state = torch.load(load_path)
|
|
|
|
state_dict = state['state_dict']
|
|
|
|
self.best_val_mrr = state['best_val']['mrr']
|
|
|
|
self.best_val = state['best_val']
|
|
|
|
|
|
|
|
self.model.load_state_dict(state_dict)
|
|
|
|
self.optimizer.load_state_dict(state['optimizer'])
|
|
|
|
|
|
|
|
# def evaluate(self, split, epoch=0):
|
|
|
|
# """
|
|
|
|
# Function to evaluate the model on validation or test set
|
|
|
|
# Parameters
|
|
|
|
# ----------
|
|
|
|
# split: (string) If split == 'valid' then evaluate on the validation set, else the test set
|
|
|
|
# epoch: (int) Current epoch count
|
|
|
|
|
|
|
|
# Returns
|
|
|
|
# -------
|
|
|
|
# resutls: The evaluation results containing the following:
|
|
|
|
# results['mr']: Average of ranks_left and ranks_right
|
|
|
|
# results['mrr']: Mean Reciprocal Rank
|
|
|
|
# results['hits@k']: Probability of getting the correct preodiction in top-k ranks based on predicted score
|
|
|
|
# """
|
|
|
|
# left_results = self.predict(split=split, mode='tail_batch')
|
|
|
|
# right_results = self.predict(split=split, mode='head_batch')
|
|
|
|
# results = get_combined_results(left_results, right_results)
|
|
|
|
# self.logger.info('[Epoch {} {}]: MRR: Tail : {:.5}, Head : {:.5}, Avg : {:.5}'.format(
|
|
|
|
# epoch, split, results['left_mrr'], results['right_mrr'], results['mrr']))
|
|
|
|
# return results
|
|
|
|
|
|
|
|
def evaluate(self, split, epoch=0):
|
|
|
|
"""
|
|
|
|
Function to evaluate the model on validation or test set
|
|
|
|
Parameters
|
|
|
|
----------
|
|
|
|
split: (string) If split == 'valid' then evaluate on the validation set, else the test set
|
|
|
|
epoch: (int) Current epoch count
|
|
|
|
|
|
|
|
Returns
|
|
|
|
-------
|
|
|
|
resutls: The evaluation results containing the following:
|
|
|
|
results['mr']: Average of ranks_left and ranks_right
|
|
|
|
results['mrr']: Mean Reciprocal Rank
|
|
|
|
results['hits@k']: Probability of getting the correct preodiction in top-k ranks based on predicted score
|
|
|
|
"""
|
|
|
|
left_results = self.predict(split=split, mode='tail_batch')
|
|
|
|
right_results = self.predict(split=split, mode='head_batch')
|
|
|
|
results = get_combined_results(left_results, right_results)
|
|
|
|
res_mrr = '\n\tMRR: Tail : {:.5}, Head : {:.5}, Avg : {:.5}\n'.format(results['left_mrr'],
|
|
|
|
results['right_mrr'],
|
|
|
|
results['mrr'])
|
|
|
|
res_mr = '\tMR: Tail : {:.5}, Head : {:.5}, Avg : {:.5}\n'.format(results['left_mr'],
|
|
|
|
results['right_mr'],
|
|
|
|
results['mr'])
|
|
|
|
res_hit1 = '\tHit-1: Tail : {:.5}, Head : {:.5}, Avg : {:.5}\n'.format(results['left_hits@1'],
|
|
|
|
results['right_hits@1'],
|
|
|
|
results['hits@1'])
|
|
|
|
res_hit3 = '\tHit-3: Tail : {:.5}, Head : {:.5}, Avg : {:.5}\n'.format(results['left_hits@3'],
|
|
|
|
results['right_hits@3'],
|
|
|
|
results['hits@3'])
|
|
|
|
res_hit10 = '\tHit-10: Tail : {:.5}, Head : {:.5}, Avg : {:.5}'.format(results['left_hits@10'],
|
|
|
|
results['right_hits@10'],
|
|
|
|
results['hits@10'])
|
|
|
|
log_res = res_mrr + res_mr + res_hit1 + res_hit3 + res_hit10
|
|
|
|
if (epoch + 1) % 10 == 0 or split == 'test':
|
|
|
|
self.logger.info(
|
|
|
|
'[Evaluating Epoch {} {}]: {}'.format(epoch, split, log_res))
|
|
|
|
else:
|
|
|
|
self.logger.info(
|
|
|
|
'[Evaluating Epoch {} {}]: {}'.format(epoch, split, res_mrr))
|
|
|
|
|
|
|
|
return results
|
|
|
|
|
|
|
|
def predict(self, split='valid', mode='tail_batch'):
|
|
|
|
"""
|
|
|
|
Function to run model evaluation for a given mode
|
|
|
|
Parameters
|
|
|
|
----------
|
|
|
|
split: (string) If split == 'valid' then evaluate on the validation set, else the test set
|
|
|
|
mode: (string): Can be 'head_batch' or 'tail_batch'
|
|
|
|
|
|
|
|
Returns
|
|
|
|
-------
|
|
|
|
resutls: The evaluation results containing the following:
|
|
|
|
results['mr']: Average of ranks_left and ranks_right
|
|
|
|
results['mrr']: Mean Reciprocal Rank
|
|
|
|
results['hits@k']: Probability of getting the correct preodiction in top-k ranks based on predicted score
|
|
|
|
"""
|
|
|
|
self.model.eval()
|
|
|
|
|
|
|
|
with torch.no_grad():
|
|
|
|
results = {}
|
|
|
|
train_iter = iter(
|
|
|
|
self.data_iter['{}_{}'.format(split, mode.split('_')[0])])
|
|
|
|
|
2023-06-24 04:11:17 +00:00
|
|
|
sub_all = []
|
|
|
|
obj_all = []
|
|
|
|
rel_all = []
|
|
|
|
target_score = []
|
|
|
|
target_rank = []
|
|
|
|
obj_pred = []
|
|
|
|
obj_pred_score = []
|
2023-05-04 08:49:41 +00:00
|
|
|
for step, batch in enumerate(train_iter):
|
2024-06-18 16:46:20 +00:00
|
|
|
sub, rel, obj, nt_rel, label = self.read_batch(batch, split)
|
|
|
|
pred = self.model.forward(sub, rel, nt_rel, None, 'one_to_n')
|
2023-05-04 08:49:41 +00:00
|
|
|
b_range = torch.arange(pred.size()[0], device=self.device)
|
|
|
|
target_pred = pred[b_range, obj]
|
|
|
|
pred = torch.where(label.byte(), torch.zeros_like(pred), pred)
|
|
|
|
pred[b_range, obj] = target_pred
|
2023-06-24 04:11:17 +00:00
|
|
|
|
|
|
|
highest = torch.argsort(pred, dim=1, descending=True)[:,0]
|
|
|
|
highest_score = pred[b_range, highest]
|
|
|
|
|
2023-05-04 08:49:41 +00:00
|
|
|
ranks = 1 + torch.argsort(torch.argsort(pred, dim=1,
|
|
|
|
descending=True), dim=1, descending=False)[b_range, obj]
|
|
|
|
|
2023-06-24 04:11:17 +00:00
|
|
|
sub_all.extend(sub.cpu().numpy())
|
|
|
|
obj_all.extend(obj.cpu().numpy())
|
|
|
|
rel_all.extend(rel.cpu().numpy())
|
|
|
|
target_score.extend(target_pred.cpu().numpy())
|
|
|
|
target_rank.extend(ranks.cpu().numpy())
|
|
|
|
obj_pred.extend(highest.cpu().numpy())
|
|
|
|
obj_pred_score.extend(highest_score.cpu().numpy())
|
|
|
|
|
2023-05-04 08:49:41 +00:00
|
|
|
ranks = ranks.float()
|
|
|
|
results['count'] = torch.numel(
|
|
|
|
ranks) + results.get('count', 0.0)
|
|
|
|
results['mr'] = torch.sum(
|
|
|
|
ranks).item() + results.get('mr', 0.0)
|
|
|
|
results['mrr'] = torch.sum(
|
|
|
|
1.0/ranks).item() + results.get('mrr', 0.0)
|
|
|
|
for k in range(10):
|
|
|
|
results['hits@{}'.format(k+1)] = torch.numel(
|
|
|
|
ranks[ranks <= (k+1)]) + results.get('hits@{}'.format(k+1), 0.0)
|
|
|
|
|
|
|
|
if step % 100 == 0:
|
|
|
|
self.logger.info('[{}, {} Step {}]\t{}'.format(
|
|
|
|
split.title(), mode.title(), step, self.p.name))
|
2023-06-24 04:11:17 +00:00
|
|
|
df = pd.DataFrame({"sub":sub_all,"rel":rel_all,"obj":obj_all, "rank": target_rank,"score":target_score, "pred":obj_pred,"pred_score":obj_pred_score})
|
|
|
|
df.to_csv(f"{self.p.name}_result.csv",header=True, index=False)
|
2023-05-04 08:49:41 +00:00
|
|
|
return results
|
|
|
|
|
|
|
|
def run_epoch(self, epoch):
|
|
|
|
"""
|
|
|
|
Function to run one epoch of training
|
|
|
|
Parameters
|
|
|
|
----------
|
|
|
|
epoch: current epoch count
|
|
|
|
|
|
|
|
Returns
|
|
|
|
-------
|
|
|
|
loss: The loss value after the completion of one epoch
|
|
|
|
"""
|
|
|
|
self.model.train()
|
|
|
|
|
|
|
|
losses = []
|
|
|
|
train_iter = iter(self.data_iter['train'])
|
|
|
|
|
|
|
|
for step, batch in enumerate(train_iter):
|
|
|
|
self.optimizer.zero_grad()
|
|
|
|
|
2024-06-16 12:09:47 +00:00
|
|
|
sub, rel, obj, nt_rel, label, neg_ent, sub_samp = self.read_batch(
|
2023-05-04 08:49:41 +00:00
|
|
|
batch, 'train')
|
|
|
|
|
2024-06-16 12:09:47 +00:00
|
|
|
pred = self.model.forward(sub, rel, nt_rel, neg_ent, self.p.train_strategy)
|
2023-05-04 08:49:41 +00:00
|
|
|
loss = self.model.loss(pred, label, sub_samp)
|
|
|
|
|
|
|
|
loss.backward()
|
|
|
|
self.optimizer.step()
|
|
|
|
losses.append(loss.item())
|
|
|
|
|
|
|
|
if step % 100 == 0:
|
|
|
|
self.logger.info('[E:{}| {}]: Train Loss:{:.5}, Val MRR:{:.5}, \t{}'.format(
|
|
|
|
epoch, step, np.mean(losses), self.best_val_mrr, self.p.name))
|
|
|
|
|
|
|
|
loss = np.mean(losses)
|
|
|
|
self.logger.info(
|
|
|
|
'[Epoch:{}]: Training Loss:{:.4}\n'.format(epoch, loss))
|
|
|
|
return loss
|
|
|
|
|
|
|
|
def fit(self):
|
|
|
|
"""
|
|
|
|
Function to run training and evaluation of model
|
|
|
|
Parameters
|
|
|
|
----------
|
|
|
|
|
|
|
|
Returns
|
|
|
|
-------
|
|
|
|
"""
|
|
|
|
self.best_val_mrr, self.best_val, self.best_epoch = 0., {}, 0.
|
|
|
|
val_mrr = 0
|
|
|
|
save_path = os.path.join('./torch_saved', self.p.name)
|
|
|
|
|
|
|
|
if self.p.restore:
|
|
|
|
self.load_model(save_path)
|
|
|
|
self.logger.info('Successfully Loaded previous model')
|
|
|
|
|
|
|
|
for epoch in range(self.p.max_epochs):
|
|
|
|
train_loss = self.run_epoch(epoch)
|
|
|
|
val_results = self.evaluate('valid', epoch)
|
|
|
|
|
|
|
|
if val_results['mrr'] > self.best_val_mrr:
|
|
|
|
self.best_val = val_results
|
|
|
|
self.best_val_mrr = val_results['mrr']
|
|
|
|
self.best_epoch = epoch
|
|
|
|
self.save_model(save_path)
|
|
|
|
self.logger.info('[Epoch {}]: Training Loss: {:.5}, Valid MRR: {:.5}, \n\n\n'.format(
|
|
|
|
epoch, train_loss, self.best_val_mrr))
|
|
|
|
|
|
|
|
# Restoring model corresponding to the best validation performance and evaluation on test data
|
|
|
|
self.logger.info('Loading best model, evaluating on test data')
|
|
|
|
self.load_model(save_path)
|
|
|
|
self.evaluate('test')
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
parser = argparse.ArgumentParser(
|
|
|
|
description="Parser For Arguments", formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
|
|
|
|
|
|
|
# Dataset and Experiment name
|
|
|
|
parser.add_argument('--data', dest="dataset", default='FB15k-237',
|
|
|
|
help='Dataset to use for the experiment')
|
|
|
|
parser.add_argument("--name", default='testrun_' +
|
|
|
|
str(uuid.uuid4())[:8], help='Name of the experiment')
|
|
|
|
|
|
|
|
# Training parameters
|
|
|
|
parser.add_argument("--gpu", type=str, default='-1',
|
|
|
|
help='GPU to use, set -1 for CPU')
|
|
|
|
parser.add_argument("--train_strategy", type=str,
|
|
|
|
default='one_to_n', help='Training strategy to use')
|
|
|
|
parser.add_argument("--opt", type=str, default='adam',
|
|
|
|
help='Optimizer to use for training')
|
|
|
|
parser.add_argument('--neg_num', dest="neg_num", default=1000, type=int,
|
|
|
|
help='Number of negative samples to use for loss calculation')
|
|
|
|
parser.add_argument('--batch', dest="batch_size",
|
|
|
|
default=128, type=int, help='Batch size')
|
|
|
|
parser.add_argument("--l2", type=float, default=0.0,
|
|
|
|
help='L2 regularization')
|
|
|
|
parser.add_argument("--lr", type=float, default=0.0001,
|
|
|
|
help='Learning Rate')
|
|
|
|
parser.add_argument("--epoch", dest='max_epochs', default=500,
|
|
|
|
type=int, help='Maximum number of epochs. Default: 500')
|
|
|
|
parser.add_argument("--num_workers", type=int, default=0,
|
|
|
|
help='Maximum number of workers used in DataLoader. Default: 10')
|
|
|
|
parser.add_argument('--seed', dest="seed", default=42,
|
|
|
|
type=int, help='Seed to reproduce results. Default: 42')
|
|
|
|
parser.add_argument('--restore', dest="restore", action='store_true',
|
|
|
|
help='Restore from the previously saved model')
|
|
|
|
|
|
|
|
# Model parameters
|
|
|
|
parser.add_argument("--lbl_smooth", dest='lbl_smooth', default=0.1,
|
|
|
|
type=float, help='Label smoothing for true labels')
|
|
|
|
parser.add_argument("--embed_dim", type=int, default=400,
|
|
|
|
help='Embedding dimension for entity and relation, ignored if k_h and k_w are set')
|
|
|
|
|
|
|
|
# Specific setting for embedding vectors: entity embedding vector and relation embedding vector
|
|
|
|
parser.add_argument('--ent_vec_dim', type=int, default=400,
|
|
|
|
help="Embedding dimension of entity. Default: 200")
|
|
|
|
parser.add_argument('--rel_vec_dim', type=int, default=400,
|
|
|
|
help="Embedding dimension of relation. Default: 200")
|
|
|
|
|
|
|
|
parser.add_argument('--bias', dest="bias", action='store_true',
|
|
|
|
help='Whether to use bias in the model.')
|
|
|
|
parser.add_argument('--form', type=str, default='plain',
|
|
|
|
help='The reshaping form to use.')
|
|
|
|
|
|
|
|
# Reshape matrix parameters for InteractE
|
|
|
|
parser.add_argument('--k_w', dest="k_w", default=10, type=int,
|
|
|
|
help='Width of the reshaped matrix. Default: 10')
|
|
|
|
parser.add_argument('--k_h', dest="k_h", default=20, type=int,
|
|
|
|
help='Height of the reshaped matrix. Default: 20')
|
|
|
|
|
|
|
|
parser.add_argument('--num_filt', dest="num_filt", default=96, type=int,
|
|
|
|
help='Number of filters in convolution. Default: 96. Test: 32, 64, 128')
|
|
|
|
|
|
|
|
parser.add_argument('--ker_sz', dest="ker_sz", default=9, type=int,
|
|
|
|
help='Kernel size to use. Default: 9. Test: 3, 5, 7, 9')
|
|
|
|
parser.add_argument('--perm', dest="perm", default=1, type=int,
|
|
|
|
help='Number of Feature rearrangement to use. Default: 1, 2, 3, 4, 5')
|
|
|
|
|
|
|
|
# Configuration for dropout technique
|
|
|
|
parser.add_argument('--hid_drop', dest="hid_drop", default=0.5, type=float,
|
|
|
|
help='Dropout for Hidden layer. Default: 0.5. Test: 0.2, 0.3, 0.4, 0.5')
|
|
|
|
parser.add_argument('--feat_drop', dest="feat_drop", default=0.2, type=float,
|
|
|
|
help='Dropout for Feature. Default: 0.5. Test: 0.2, 0.3, 0.4, 0.5')
|
|
|
|
parser.add_argument('--inp_drop', dest="inp_drop", default=0.2, type=float,
|
|
|
|
help='Dropout for Input layer. Default: 0.5. Test: 0.2, 0.3, 0.4, 0.5')
|
2023-05-13 17:33:59 +00:00
|
|
|
parser.add_argument('--drop_path', dest="drop_path", default=0.0, type=float,
|
2023-05-04 08:49:41 +00:00
|
|
|
help='Path dropout. Default: 0.5. Test: 0.2, 0.3, 0.4, 0.5')
|
2023-05-13 17:33:59 +00:00
|
|
|
parser.add_argument('--drop', dest="drop", default=0.0, type=float,
|
2023-05-04 08:49:41 +00:00
|
|
|
help='Inner drop. Default: 0.5. Test: 0.2, 0.3, 0.4, 0.5')
|
|
|
|
|
|
|
|
# Configuration for in/output channels for ConvE, HypER, HypE
|
|
|
|
parser.add_argument('--in_channels', dest="in_channels",
|
|
|
|
default=1, type=int, help='Input channels. Default: 1')
|
|
|
|
parser.add_argument('--out_channels', dest="out_channels", default=32, type=int,
|
|
|
|
help='Output channels. Default: 96. Test: 32, 64, 128. Can be the same with num_filt hyperparameter.')
|
|
|
|
parser.add_argument('--filt_h', type=int, default=1,
|
|
|
|
help='Height of filter. This configuration for HypER model. Default: 1. Choice: 1, 2, 3, 5, 7, 9')
|
|
|
|
parser.add_argument('--filt_w', type=int, default=9,
|
|
|
|
help='Width of filter. This configuration for HypER model. Default: 9. If filt_h is 1, then filt_w: 1, 2, 3, 5, 7, 9, 11, 12, 13, 15')
|
|
|
|
|
|
|
|
# Configuration for mixer layer
|
|
|
|
parser.add_argument('--image_h', dest="image_h",
|
|
|
|
default=128, type=int, help='')
|
|
|
|
|
|
|
|
parser.add_argument('--image_w', dest="image_w",
|
|
|
|
default=128, type=int, help='')
|
|
|
|
|
|
|
|
parser.add_argument('--patch_size', dest="patch_size",
|
|
|
|
default=8, type=int, help='')
|
|
|
|
|
|
|
|
parser.add_argument('--mixer_dim', dest="mixer_dim",
|
|
|
|
default=256, type=int, help='')
|
|
|
|
|
|
|
|
parser.add_argument('--expansion_factor', dest="expansion_factor",
|
|
|
|
default=4, type=int, help='')
|
|
|
|
|
|
|
|
parser.add_argument('--expansion_factor_token', dest="expansion_factor_token",
|
|
|
|
default=0.5, type=float, help='')
|
|
|
|
|
|
|
|
parser.add_argument('--mixer_depth', dest="mixer_depth",
|
|
|
|
default=16, type=int, help='')
|
|
|
|
|
|
|
|
parser.add_argument('--mixer_dropout', dest="mixer_dropout",
|
|
|
|
default=0.2, type=float, help='')
|
|
|
|
|
|
|
|
# Logging parameters
|
|
|
|
parser.add_argument('--logdir', dest="log_dir",
|
|
|
|
default='./log/', help='Log directory')
|
|
|
|
parser.add_argument('--config', dest="config_dir",
|
|
|
|
default='./config/', help='Config directory')
|
|
|
|
|
|
|
|
parser.add_argument('--test_only', action='store_true', default=False)
|
2023-05-17 05:38:03 +00:00
|
|
|
parser.add_argument('--grid_search', action='store_true', default=False)
|
2023-05-04 08:49:41 +00:00
|
|
|
|
|
|
|
args = parser.parse_args()
|
|
|
|
|
|
|
|
prepare_env()
|
|
|
|
|
|
|
|
set_gpu(args.gpu)
|
|
|
|
set_seed(args.seed)
|
|
|
|
|
2023-05-17 05:38:03 +00:00
|
|
|
|
|
|
|
if (args.grid_search):
|
2023-06-08 06:40:16 +00:00
|
|
|
|
|
|
|
model = Main(args)
|
2023-05-17 05:38:03 +00:00
|
|
|
from sklearn.model_selection import GridSearchCV
|
|
|
|
from skorch import NeuralNet
|
|
|
|
|
|
|
|
estimator = NeuralNet(
|
2023-05-17 06:59:24 +00:00
|
|
|
module=FouriER(model.p),
|
2023-05-17 05:38:03 +00:00
|
|
|
criterion=torch.nn.BCELoss,
|
|
|
|
optimizer=torch.optim.Adam,
|
|
|
|
max_epochs=100,
|
|
|
|
batch_size=128,
|
|
|
|
verbose=False
|
|
|
|
)
|
|
|
|
|
|
|
|
paramsGrid = {
|
2023-05-17 07:13:00 +00:00
|
|
|
'optimizer__lr': [0.0003, 0.001],
|
2023-05-17 06:54:49 +00:00
|
|
|
# 'optimizer__weight_decay': [1e-4, 1e-5, 1e-6],
|
2023-05-17 06:56:24 +00:00
|
|
|
# 'module__hid_drop': [0.2, 0.5, 0.7],
|
2023-05-17 06:52:52 +00:00
|
|
|
# 'module__embed_dim': [300, 400, 500],
|
2023-05-17 05:38:03 +00:00
|
|
|
}
|
|
|
|
|
2023-05-17 06:52:52 +00:00
|
|
|
grid = GridSearchCV(estimator=estimator, param_grid=paramsGrid, n_jobs=-1, cv=2, scoring=torch.nn.BCELoss)
|
2023-05-17 06:10:36 +00:00
|
|
|
data = np.array(model.triples['train'])
|
2023-05-17 06:50:51 +00:00
|
|
|
data = data[np.random.choice(np.arange(len(data)), size=int(len(data) * 0.15), replace=False)]
|
2023-05-17 06:10:36 +00:00
|
|
|
dataloader = iter(DataLoader(
|
|
|
|
TrainDataset(data, model.p),
|
|
|
|
batch_size=len(data),
|
|
|
|
shuffle=True,
|
|
|
|
num_workers=max(0, model.p.num_workers),
|
|
|
|
collate_fn=TrainDataset.collate_fn
|
|
|
|
))
|
2023-05-17 06:36:06 +00:00
|
|
|
for step, batch in enumerate(dataloader):
|
2024-06-18 16:46:20 +00:00
|
|
|
sub, rel, obj, nt_rel, label, neg_ent, sub_samp = model.read_batch(
|
2023-05-17 06:10:36 +00:00
|
|
|
batch, 'train')
|
2023-05-17 06:41:13 +00:00
|
|
|
|
|
|
|
if (neg_ent is None):
|
|
|
|
neg_ent = np.repeat(None, repeats=len(sub))
|
|
|
|
else:
|
|
|
|
neg_ent = neg_ent.cpu()
|
2023-05-17 06:10:36 +00:00
|
|
|
|
2023-05-17 07:04:48 +00:00
|
|
|
inputs = []
|
2023-05-17 07:08:06 +00:00
|
|
|
for i in range(len(sub)):
|
2023-05-17 07:04:48 +00:00
|
|
|
input = {}
|
|
|
|
input['sub'] = sub[i]
|
|
|
|
input['rel'] = rel[i]
|
|
|
|
input['neg_ents'] = neg_ent[i]
|
|
|
|
inputs.append(input)
|
|
|
|
search = grid.fit(inputs, label)
|
2023-05-17 06:10:36 +00:00
|
|
|
print("BEST SCORE: ", search.best_score_)
|
|
|
|
print("BEST PARAMS: ", search.best_params_)
|
2023-06-08 06:40:16 +00:00
|
|
|
logger = get_logger(
|
|
|
|
args.name, args.log_dir, args.config_dir)
|
2023-05-04 08:49:41 +00:00
|
|
|
if (args.test_only):
|
2023-06-08 06:40:16 +00:00
|
|
|
model = Main(args, logger)
|
2023-05-04 08:49:41 +00:00
|
|
|
save_path = os.path.join('./torch_saved', args.name)
|
|
|
|
model.load_model(save_path)
|
|
|
|
model.evaluate('test')
|
|
|
|
else:
|
2024-04-27 03:26:58 +00:00
|
|
|
model = Main(args, logger)
|
|
|
|
model.fit()
|
|
|
|
# while True:
|
|
|
|
# try:
|
|
|
|
# model = Main(args, logger)
|
|
|
|
# model.fit()
|
|
|
|
# except Exception as e:
|
|
|
|
# print(e)
|
|
|
|
# traceback.print_exc()
|
|
|
|
# try:
|
|
|
|
# del model
|
|
|
|
# except Exception:
|
|
|
|
# pass
|
|
|
|
# time.sleep(30)
|
|
|
|
# continue
|
|
|
|
# break
|