Thesis/main.py

736 lines
32 KiB
Python
Raw Normal View History

2023-05-04 08:49:41 +00:00
import os
import uuid
import argparse
import logging
import logging.config
2023-06-24 04:11:17 +00:00
import pandas as pd
import sys
2023-05-04 08:49:41 +00:00
import torch
import numpy as np
2023-06-08 06:40:16 +00:00
import time
2023-05-04 08:49:41 +00:00
from collections import defaultdict as ddict
from pprint import pprint
from ordered_set import OrderedSet
from torch.utils.data import DataLoader
from data_loader import TrainDataset, TestDataset
from utils import get_logger, get_combined_results, set_gpu, prepare_env, set_seed
from models import ComplEx, ConvE, HypER, InteractE, FouriER, TuckER
2024-04-27 03:18:48 +00:00
import traceback
2023-05-04 08:49:41 +00:00
class Main(object):
2023-06-08 06:40:16 +00:00
def __init__(self, params, logger):
2023-05-04 08:49:41 +00:00
"""
Constructor of the runner class
Parameters
----------
params: List of hyper-parameters of the model
Returns
-------
Creates computational graph and optimizer
"""
self.p = params
2023-06-08 06:40:16 +00:00
self.logger = logger
2023-05-04 08:49:41 +00:00
self.logger.info(vars(self.p))
if self.p.gpu != '-1' and torch.cuda.is_available():
self.device = torch.device('cuda')
torch.cuda.set_rng_state(torch.cuda.get_rng_state())
torch.backends.cudnn.deterministic = True
else:
self.device = torch.device('cpu')
self.load_data()
self.model = self.add_model()
self.optimizer = self.add_optimizer(self.model.parameters())
def load_data(self):
"""
Reading in raw triples and converts it into a standard format.
Parameters
----------
self.p.dataset: Takes in the name of the dataset (FB15k-237, WN18RR, YAGO3-10)
Returns
-------
self.ent2id: Entity to unique identifier mapping
self.id2rel: Inverse mapping of self.ent2id
self.rel2id: Relation to unique identifier mapping
self.num_ent: Number of entities in the Knowledge graph
self.num_rel: Number of relations in the Knowledge graph
self.embed_dim: Embedding dimension used
self.data['train']: Stores the triples corresponding to training dataset
self.data['valid']: Stores the triples corresponding to validation dataset
self.data['test']: Stores the triples corresponding to test dataset
self.data_iter: The dataloader for different data splits
self.chequer_perm: Stores the Chequer reshaping arrangement
"""
ent_set, rel_set = OrderedSet(), OrderedSet()
for split in ['train', 'test', 'valid']:
for line in open('./data/{}/{}.txt'.format(self.p.dataset, split)):
2023-06-08 06:40:16 +00:00
sub, rel, obj, *_ = map(str.lower, line.strip().split('\t'))
2023-05-04 08:49:41 +00:00
ent_set.add(sub)
rel_set.add(rel)
ent_set.add(obj)
2023-05-13 17:33:59 +00:00
self.ent2id = {}
for line in open('./data/{}/{}'.format(self.p.dataset, "entities.dict")):
2023-06-24 04:11:17 +00:00
id, ent = map(str.lower, line.replace('\xa0', '').strip().split('\t'))
2023-05-13 17:33:59 +00:00
self.ent2id[ent] = int(id)
self.rel2id = {}
for line in open('./data/{}/{}'.format(self.p.dataset, "relations.dict")):
id, rel = map(str.lower, line.strip().split('\t'))
self.rel2id[rel] = int(id)
# self.ent2id = {ent: idx for idx, ent in enumerate(ent_set)}
# self.rel2id = {rel: idx for idx, rel in enumerate(rel_set)}
2023-05-04 08:49:41 +00:00
self.rel2id.update({rel+'_reverse': idx+len(self.rel2id)
for idx, rel in enumerate(rel_set)})
self.id2ent = {idx: ent for ent, idx in self.ent2id.items()}
self.id2rel = {idx: rel for rel, idx in self.rel2id.items()}
self.p.num_ent = len(self.ent2id)
self.p.num_rel = len(self.rel2id) // 2
self.p.embed_dim = self.p.k_w * \
self.p.k_h if self.p.embed_dim is None else self.p.embed_dim
self.data = ddict(list)
sr2o = ddict(set)
for split in ['train', 'test', 'valid']:
for line in open('./data/{}/{}.txt'.format(self.p.dataset, split)):
2023-06-24 04:11:17 +00:00
sub, rel, obj, *_ = map(str.lower, line.replace('\xa0', '').strip().split('\t'))
2024-06-16 12:09:47 +00:00
nt_rel = rel.split('[')[0]
sub, rel, obj, nt_rel = self.ent2id[sub], self.rel2id[rel], self.ent2id[obj], self.rel2id[nt_rel]
self.data[split].append((sub, rel, obj, nt_rel))
2023-05-04 08:49:41 +00:00
if split == 'train':
2024-06-16 12:09:47 +00:00
sr2o[(sub, rel, nt_rel)].add(obj)
sr2o[(obj, rel+self.p.num_rel, nt_rel + self.p.num_rel)].add(sub)
2023-05-04 08:49:41 +00:00
self.data = dict(self.data)
self.sr2o = {k: list(v) for k, v in sr2o.items()}
for split in ['test', 'valid']:
2024-06-16 12:09:47 +00:00
for sub, rel, obj, nt_rel in self.data[split]:
sr2o[(sub, rel, nt_rel)].add(obj)
sr2o[(obj, rel+self.p.num_rel, nt_rel + self.p.num_rel)].add(sub)
2023-05-04 08:49:41 +00:00
self.sr2o_all = {k: list(v) for k, v in sr2o.items()}
self.triples = ddict(list)
if self.p.train_strategy == 'one_to_n':
2024-06-16 12:09:47 +00:00
for (sub, rel, nt_rel), obj in self.sr2o.items():
2023-05-04 08:49:41 +00:00
self.triples['train'].append(
2024-06-16 12:09:47 +00:00
{'triple': (sub, rel, -1, nt_rel), 'label': self.sr2o[(sub, rel, nt_rel)], 'sub_samp': 1})
2023-05-04 08:49:41 +00:00
else:
2024-06-16 12:09:47 +00:00
for sub, rel, obj, nt_rel in self.data['train']:
2023-05-04 08:49:41 +00:00
rel_inv = rel + self.p.num_rel
2024-06-16 12:09:47 +00:00
sub_samp = len(self.sr2o[(sub, rel, nt_rel)]) + \
2024-06-18 17:03:14 +00:00
len(self.sr2o[(obj, rel_inv, nt_rel + self.p.num_rel)])
2023-05-04 08:49:41 +00:00
sub_samp = np.sqrt(1/sub_samp)
self.triples['train'].append({'triple': (
2024-06-16 12:09:47 +00:00
sub, rel, obj, nt_rel), 'label': self.sr2o[(sub, rel, nt_rel)], 'sub_samp': sub_samp})
2023-05-04 08:49:41 +00:00
self.triples['train'].append({'triple': (
2024-06-18 17:03:14 +00:00
obj, rel_inv, sub, nt_rel + self.p.num_rel), 'label': self.sr2o[(obj, rel_inv, nt_rel + self.p.num_rel)], 'sub_samp': sub_samp})
2023-05-04 08:49:41 +00:00
for split in ['test', 'valid']:
2024-06-16 12:09:47 +00:00
for sub, rel, obj, nt_rel in self.data[split]:
2023-05-04 08:49:41 +00:00
rel_inv = rel + self.p.num_rel
self.triples['{}_{}'.format(split, 'tail')].append(
2024-06-16 12:09:47 +00:00
{'triple': (sub, rel, obj, nt_rel), 'label': self.sr2o_all[(sub, rel, nt_rel)]})
2023-05-04 08:49:41 +00:00
self.triples['{}_{}'.format(split, 'head')].append(
2024-06-16 12:09:47 +00:00
{'triple': (obj, rel_inv, sub, nt_rel + self.p.num_rel), 'label': self.sr2o_all[(obj, rel_inv, nt_rel + self.p.num_rel)]})
2023-05-04 08:49:41 +00:00
self.triples = dict(self.triples)
def get_data_loader(dataset_class, split, batch_size, shuffle=True):
return DataLoader(
dataset_class(self.triples[split], self.p),
batch_size=batch_size,
shuffle=shuffle,
num_workers=max(0, self.p.num_workers),
collate_fn=dataset_class.collate_fn
)
self.data_iter = {
'train' : get_data_loader(TrainDataset, 'train', self.p.batch_size),
'valid_head' : get_data_loader(TestDataset, 'valid_head', self.p.batch_size),
'valid_tail' : get_data_loader(TestDataset, 'valid_tail', self.p.batch_size),
'test_head' : get_data_loader(TestDataset, 'test_head', self.p.batch_size),
'test_tail' : get_data_loader(TestDataset, 'test_tail', self.p.batch_size),
}
self.chequer_perm = self.get_chequer_perm()
def get_chequer_perm(self):
"""
Function to generate the chequer permutation required for InteractE model
Parameters
----------
Returns
-------
"""
ent_perm = np.int32([np.random.permutation(self.p.embed_dim)
for _ in range(self.p.perm)])
rel_perm = np.int32([np.random.permutation(self.p.embed_dim)
for _ in range(self.p.perm)])
comb_idx = []
for k in range(self.p.perm):
temp = []
ent_idx, rel_idx = 0, 0
for i in range(self.p.k_h):
for j in range(self.p.k_w):
if k % 2 == 0:
if i % 2 == 0:
temp.append(ent_perm[k, ent_idx])
ent_idx += 1
temp.append(rel_perm[k, rel_idx]+self.p.embed_dim)
rel_idx += 1
else:
temp.append(rel_perm[k, rel_idx]+self.p.embed_dim)
rel_idx += 1
temp.append(ent_perm[k, ent_idx])
ent_idx += 1
else:
if i % 2 == 0:
temp.append(rel_perm[k, rel_idx]+self.p.embed_dim)
rel_idx += 1
temp.append(ent_perm[k, ent_idx])
ent_idx += 1
else:
temp.append(ent_perm[k, ent_idx])
ent_idx += 1
temp.append(rel_perm[k, rel_idx]+self.p.embed_dim)
rel_idx += 1
comb_idx.append(temp)
chequer_perm = torch.LongTensor(np.int32(comb_idx)).to(self.device)
return chequer_perm
def add_model(self):
"""
Creates the computational graph
Parameters
----------
Returns
-------
Creates the computational graph for model and initializes it
"""
model = FouriER(self.p)
model.to(self.device)
return model
def add_optimizer(self, parameters):
"""
Creates an optimizer for training the parameters
Parameters
----------
parameters: The parameters of the model
Returns
-------
Returns an optimizer for learning the parameters of the model
"""
if self.p.opt == 'adam':
return torch.optim.Adam(parameters, lr=self.p.lr, weight_decay=self.p.l2)
else:
return torch.optim.SGD(parameters, lr=self.p.lr, weight_decay=self.p.l2)
def read_batch(self, batch, split):
"""
Function to read a batch of data and move the tensors in batch to CPU/GPU
Parameters
----------
batch: the batch to process
split: (string) If split == 'train', 'valid' or 'test' split
Returns
-------
triples: The triples used for this split
labels: The label for each triple
"""
if split == 'train':
if self.p.train_strategy == 'one_to_x':
triple, label, neg_ent, sub_samp = [
_.to(self.device) for _ in batch]
2024-06-16 12:09:47 +00:00
return triple[:, 0], triple[:, 1], triple[:, 2], triple[:, 3], label, neg_ent, sub_samp
2023-05-04 08:49:41 +00:00
else:
triple, label = [_.to(self.device) for _ in batch]
2024-06-16 12:09:47 +00:00
return triple[:, 0], triple[:, 1], triple[:, 2], triple[:, 3], label, None, None
2023-05-04 08:49:41 +00:00
else:
triple, label = [_.to(self.device) for _ in batch]
2024-06-16 12:09:47 +00:00
return triple[:, 0], triple[:, 1], triple[:, 2], triple[:, 3], label
2023-05-04 08:49:41 +00:00
def save_model(self, save_path):
"""
Function to save a model. It saves the model parameters, best validation scores,
best epoch corresponding to best validation, state of the optimizer and all arguments for the run.
Parameters
----------
save_path: path where the model is saved
Returns
-------
"""
state = {
'state_dict' : self.model.state_dict(),
'best_val' : self.best_val,
'best_epoch' : self.best_epoch,
'optimizer' : self.optimizer.state_dict(),
'args' : vars(self.p)
}
torch.save(state, save_path)
def load_model(self, load_path):
"""
Function to load a saved model
Parameters
----------
load_path: path to the saved model
Returns
-------
"""
state = torch.load(load_path)
state_dict = state['state_dict']
self.best_val_mrr = state['best_val']['mrr']
self.best_val = state['best_val']
self.model.load_state_dict(state_dict)
self.optimizer.load_state_dict(state['optimizer'])
# def evaluate(self, split, epoch=0):
# """
# Function to evaluate the model on validation or test set
# Parameters
# ----------
# split: (string) If split == 'valid' then evaluate on the validation set, else the test set
# epoch: (int) Current epoch count
# Returns
# -------
# resutls: The evaluation results containing the following:
# results['mr']: Average of ranks_left and ranks_right
# results['mrr']: Mean Reciprocal Rank
# results['hits@k']: Probability of getting the correct preodiction in top-k ranks based on predicted score
# """
# left_results = self.predict(split=split, mode='tail_batch')
# right_results = self.predict(split=split, mode='head_batch')
# results = get_combined_results(left_results, right_results)
# self.logger.info('[Epoch {} {}]: MRR: Tail : {:.5}, Head : {:.5}, Avg : {:.5}'.format(
# epoch, split, results['left_mrr'], results['right_mrr'], results['mrr']))
# return results
def evaluate(self, split, epoch=0):
"""
Function to evaluate the model on validation or test set
Parameters
----------
split: (string) If split == 'valid' then evaluate on the validation set, else the test set
epoch: (int) Current epoch count
Returns
-------
resutls: The evaluation results containing the following:
results['mr']: Average of ranks_left and ranks_right
results['mrr']: Mean Reciprocal Rank
results['hits@k']: Probability of getting the correct preodiction in top-k ranks based on predicted score
"""
left_results = self.predict(split=split, mode='tail_batch')
right_results = self.predict(split=split, mode='head_batch')
results = get_combined_results(left_results, right_results)
res_mrr = '\n\tMRR: Tail : {:.5}, Head : {:.5}, Avg : {:.5}\n'.format(results['left_mrr'],
results['right_mrr'],
results['mrr'])
res_mr = '\tMR: Tail : {:.5}, Head : {:.5}, Avg : {:.5}\n'.format(results['left_mr'],
results['right_mr'],
results['mr'])
res_hit1 = '\tHit-1: Tail : {:.5}, Head : {:.5}, Avg : {:.5}\n'.format(results['left_hits@1'],
results['right_hits@1'],
results['hits@1'])
res_hit3 = '\tHit-3: Tail : {:.5}, Head : {:.5}, Avg : {:.5}\n'.format(results['left_hits@3'],
results['right_hits@3'],
results['hits@3'])
res_hit10 = '\tHit-10: Tail : {:.5}, Head : {:.5}, Avg : {:.5}'.format(results['left_hits@10'],
results['right_hits@10'],
results['hits@10'])
log_res = res_mrr + res_mr + res_hit1 + res_hit3 + res_hit10
if (epoch + 1) % 10 == 0 or split == 'test':
self.logger.info(
'[Evaluating Epoch {} {}]: {}'.format(epoch, split, log_res))
else:
self.logger.info(
'[Evaluating Epoch {} {}]: {}'.format(epoch, split, res_mrr))
return results
def predict(self, split='valid', mode='tail_batch'):
"""
Function to run model evaluation for a given mode
Parameters
----------
split: (string) If split == 'valid' then evaluate on the validation set, else the test set
mode: (string): Can be 'head_batch' or 'tail_batch'
Returns
-------
resutls: The evaluation results containing the following:
results['mr']: Average of ranks_left and ranks_right
results['mrr']: Mean Reciprocal Rank
results['hits@k']: Probability of getting the correct preodiction in top-k ranks based on predicted score
"""
self.model.eval()
with torch.no_grad():
results = {}
train_iter = iter(
self.data_iter['{}_{}'.format(split, mode.split('_')[0])])
2023-06-24 04:11:17 +00:00
sub_all = []
obj_all = []
rel_all = []
target_score = []
target_rank = []
obj_pred = []
obj_pred_score = []
2023-05-04 08:49:41 +00:00
for step, batch in enumerate(train_iter):
2024-06-18 16:46:20 +00:00
sub, rel, obj, nt_rel, label = self.read_batch(batch, split)
pred = self.model.forward(sub, rel, nt_rel, None, 'one_to_n')
2023-05-04 08:49:41 +00:00
b_range = torch.arange(pred.size()[0], device=self.device)
target_pred = pred[b_range, obj]
pred = torch.where(label.byte(), torch.zeros_like(pred), pred)
pred[b_range, obj] = target_pred
2023-06-24 04:11:17 +00:00
highest = torch.argsort(pred, dim=1, descending=True)[:,0]
highest_score = pred[b_range, highest]
2023-05-04 08:49:41 +00:00
ranks = 1 + torch.argsort(torch.argsort(pred, dim=1,
descending=True), dim=1, descending=False)[b_range, obj]
2023-06-24 04:11:17 +00:00
sub_all.extend(sub.cpu().numpy())
obj_all.extend(obj.cpu().numpy())
rel_all.extend(rel.cpu().numpy())
target_score.extend(target_pred.cpu().numpy())
target_rank.extend(ranks.cpu().numpy())
obj_pred.extend(highest.cpu().numpy())
obj_pred_score.extend(highest_score.cpu().numpy())
2023-05-04 08:49:41 +00:00
ranks = ranks.float()
results['count'] = torch.numel(
ranks) + results.get('count', 0.0)
results['mr'] = torch.sum(
ranks).item() + results.get('mr', 0.0)
results['mrr'] = torch.sum(
1.0/ranks).item() + results.get('mrr', 0.0)
for k in range(10):
results['hits@{}'.format(k+1)] = torch.numel(
ranks[ranks <= (k+1)]) + results.get('hits@{}'.format(k+1), 0.0)
if step % 100 == 0:
self.logger.info('[{}, {} Step {}]\t{}'.format(
split.title(), mode.title(), step, self.p.name))
2023-06-24 04:11:17 +00:00
df = pd.DataFrame({"sub":sub_all,"rel":rel_all,"obj":obj_all, "rank": target_rank,"score":target_score, "pred":obj_pred,"pred_score":obj_pred_score})
df.to_csv(f"{self.p.name}_result.csv",header=True, index=False)
2023-05-04 08:49:41 +00:00
return results
def run_epoch(self, epoch):
"""
Function to run one epoch of training
Parameters
----------
epoch: current epoch count
Returns
-------
loss: The loss value after the completion of one epoch
"""
self.model.train()
losses = []
train_iter = iter(self.data_iter['train'])
for step, batch in enumerate(train_iter):
self.optimizer.zero_grad()
2024-06-16 12:09:47 +00:00
sub, rel, obj, nt_rel, label, neg_ent, sub_samp = self.read_batch(
2023-05-04 08:49:41 +00:00
batch, 'train')
2024-06-16 12:09:47 +00:00
pred = self.model.forward(sub, rel, nt_rel, neg_ent, self.p.train_strategy)
2023-05-04 08:49:41 +00:00
loss = self.model.loss(pred, label, sub_samp)
loss.backward()
self.optimizer.step()
losses.append(loss.item())
if step % 100 == 0:
self.logger.info('[E:{}| {}]: Train Loss:{:.5}, Val MRR:{:.5}, \t{}'.format(
epoch, step, np.mean(losses), self.best_val_mrr, self.p.name))
loss = np.mean(losses)
self.logger.info(
'[Epoch:{}]: Training Loss:{:.4}\n'.format(epoch, loss))
return loss
def fit(self):
"""
Function to run training and evaluation of model
Parameters
----------
Returns
-------
"""
self.best_val_mrr, self.best_val, self.best_epoch = 0., {}, 0.
val_mrr = 0
save_path = os.path.join('./torch_saved', self.p.name)
if self.p.restore:
self.load_model(save_path)
self.logger.info('Successfully Loaded previous model')
for epoch in range(self.p.max_epochs):
train_loss = self.run_epoch(epoch)
val_results = self.evaluate('valid', epoch)
if val_results['mrr'] > self.best_val_mrr:
self.best_val = val_results
self.best_val_mrr = val_results['mrr']
self.best_epoch = epoch
self.save_model(save_path)
self.logger.info('[Epoch {}]: Training Loss: {:.5}, Valid MRR: {:.5}, \n\n\n'.format(
epoch, train_loss, self.best_val_mrr))
# Restoring model corresponding to the best validation performance and evaluation on test data
self.logger.info('Loading best model, evaluating on test data')
self.load_model(save_path)
self.evaluate('test')
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Parser For Arguments", formatter_class=argparse.ArgumentDefaultsHelpFormatter)
# Dataset and Experiment name
parser.add_argument('--data', dest="dataset", default='FB15k-237',
help='Dataset to use for the experiment')
parser.add_argument("--name", default='testrun_' +
str(uuid.uuid4())[:8], help='Name of the experiment')
# Training parameters
parser.add_argument("--gpu", type=str, default='-1',
help='GPU to use, set -1 for CPU')
parser.add_argument("--train_strategy", type=str,
default='one_to_n', help='Training strategy to use')
parser.add_argument("--opt", type=str, default='adam',
help='Optimizer to use for training')
parser.add_argument('--neg_num', dest="neg_num", default=1000, type=int,
help='Number of negative samples to use for loss calculation')
parser.add_argument('--batch', dest="batch_size",
default=128, type=int, help='Batch size')
parser.add_argument("--l2", type=float, default=0.0,
help='L2 regularization')
parser.add_argument("--lr", type=float, default=0.0001,
help='Learning Rate')
parser.add_argument("--epoch", dest='max_epochs', default=500,
type=int, help='Maximum number of epochs. Default: 500')
parser.add_argument("--num_workers", type=int, default=0,
help='Maximum number of workers used in DataLoader. Default: 10')
parser.add_argument('--seed', dest="seed", default=42,
type=int, help='Seed to reproduce results. Default: 42')
parser.add_argument('--restore', dest="restore", action='store_true',
help='Restore from the previously saved model')
# Model parameters
parser.add_argument("--lbl_smooth", dest='lbl_smooth', default=0.1,
type=float, help='Label smoothing for true labels')
parser.add_argument("--embed_dim", type=int, default=400,
help='Embedding dimension for entity and relation, ignored if k_h and k_w are set')
# Specific setting for embedding vectors: entity embedding vector and relation embedding vector
parser.add_argument('--ent_vec_dim', type=int, default=400,
help="Embedding dimension of entity. Default: 200")
parser.add_argument('--rel_vec_dim', type=int, default=400,
help="Embedding dimension of relation. Default: 200")
parser.add_argument('--bias', dest="bias", action='store_true',
help='Whether to use bias in the model.')
parser.add_argument('--form', type=str, default='plain',
help='The reshaping form to use.')
# Reshape matrix parameters for InteractE
parser.add_argument('--k_w', dest="k_w", default=10, type=int,
help='Width of the reshaped matrix. Default: 10')
parser.add_argument('--k_h', dest="k_h", default=20, type=int,
help='Height of the reshaped matrix. Default: 20')
parser.add_argument('--num_filt', dest="num_filt", default=96, type=int,
help='Number of filters in convolution. Default: 96. Test: 32, 64, 128')
parser.add_argument('--ker_sz', dest="ker_sz", default=9, type=int,
help='Kernel size to use. Default: 9. Test: 3, 5, 7, 9')
parser.add_argument('--perm', dest="perm", default=1, type=int,
help='Number of Feature rearrangement to use. Default: 1, 2, 3, 4, 5')
# Configuration for dropout technique
parser.add_argument('--hid_drop', dest="hid_drop", default=0.5, type=float,
help='Dropout for Hidden layer. Default: 0.5. Test: 0.2, 0.3, 0.4, 0.5')
parser.add_argument('--feat_drop', dest="feat_drop", default=0.2, type=float,
help='Dropout for Feature. Default: 0.5. Test: 0.2, 0.3, 0.4, 0.5')
parser.add_argument('--inp_drop', dest="inp_drop", default=0.2, type=float,
help='Dropout for Input layer. Default: 0.5. Test: 0.2, 0.3, 0.4, 0.5')
2023-05-13 17:33:59 +00:00
parser.add_argument('--drop_path', dest="drop_path", default=0.0, type=float,
2023-05-04 08:49:41 +00:00
help='Path dropout. Default: 0.5. Test: 0.2, 0.3, 0.4, 0.5')
2023-05-13 17:33:59 +00:00
parser.add_argument('--drop', dest="drop", default=0.0, type=float,
2023-05-04 08:49:41 +00:00
help='Inner drop. Default: 0.5. Test: 0.2, 0.3, 0.4, 0.5')
# Configuration for in/output channels for ConvE, HypER, HypE
parser.add_argument('--in_channels', dest="in_channels",
default=1, type=int, help='Input channels. Default: 1')
parser.add_argument('--out_channels', dest="out_channels", default=32, type=int,
help='Output channels. Default: 96. Test: 32, 64, 128. Can be the same with num_filt hyperparameter.')
parser.add_argument('--filt_h', type=int, default=1,
help='Height of filter. This configuration for HypER model. Default: 1. Choice: 1, 2, 3, 5, 7, 9')
parser.add_argument('--filt_w', type=int, default=9,
help='Width of filter. This configuration for HypER model. Default: 9. If filt_h is 1, then filt_w: 1, 2, 3, 5, 7, 9, 11, 12, 13, 15')
# Configuration for mixer layer
parser.add_argument('--image_h', dest="image_h",
default=128, type=int, help='')
parser.add_argument('--image_w', dest="image_w",
default=128, type=int, help='')
parser.add_argument('--patch_size', dest="patch_size",
default=8, type=int, help='')
parser.add_argument('--mixer_dim', dest="mixer_dim",
default=256, type=int, help='')
parser.add_argument('--expansion_factor', dest="expansion_factor",
default=4, type=int, help='')
parser.add_argument('--expansion_factor_token', dest="expansion_factor_token",
default=0.5, type=float, help='')
parser.add_argument('--mixer_depth', dest="mixer_depth",
default=16, type=int, help='')
parser.add_argument('--mixer_dropout', dest="mixer_dropout",
default=0.2, type=float, help='')
# Logging parameters
parser.add_argument('--logdir', dest="log_dir",
default='./log/', help='Log directory')
parser.add_argument('--config', dest="config_dir",
default='./config/', help='Config directory')
parser.add_argument('--test_only', action='store_true', default=False)
2023-05-17 05:38:03 +00:00
parser.add_argument('--grid_search', action='store_true', default=False)
2023-05-04 08:49:41 +00:00
args = parser.parse_args()
prepare_env()
set_gpu(args.gpu)
set_seed(args.seed)
2023-05-17 05:38:03 +00:00
if (args.grid_search):
2023-06-08 06:40:16 +00:00
model = Main(args)
2023-05-17 05:38:03 +00:00
from sklearn.model_selection import GridSearchCV
from skorch import NeuralNet
estimator = NeuralNet(
2023-05-17 06:59:24 +00:00
module=FouriER(model.p),
2023-05-17 05:38:03 +00:00
criterion=torch.nn.BCELoss,
optimizer=torch.optim.Adam,
max_epochs=100,
batch_size=128,
verbose=False
)
paramsGrid = {
2023-05-17 07:13:00 +00:00
'optimizer__lr': [0.0003, 0.001],
2023-05-17 06:54:49 +00:00
# 'optimizer__weight_decay': [1e-4, 1e-5, 1e-6],
2023-05-17 06:56:24 +00:00
# 'module__hid_drop': [0.2, 0.5, 0.7],
2023-05-17 06:52:52 +00:00
# 'module__embed_dim': [300, 400, 500],
2023-05-17 05:38:03 +00:00
}
2023-05-17 06:52:52 +00:00
grid = GridSearchCV(estimator=estimator, param_grid=paramsGrid, n_jobs=-1, cv=2, scoring=torch.nn.BCELoss)
2023-05-17 06:10:36 +00:00
data = np.array(model.triples['train'])
2023-05-17 06:50:51 +00:00
data = data[np.random.choice(np.arange(len(data)), size=int(len(data) * 0.15), replace=False)]
2023-05-17 06:10:36 +00:00
dataloader = iter(DataLoader(
TrainDataset(data, model.p),
batch_size=len(data),
shuffle=True,
num_workers=max(0, model.p.num_workers),
collate_fn=TrainDataset.collate_fn
))
2023-05-17 06:36:06 +00:00
for step, batch in enumerate(dataloader):
2024-06-18 16:46:20 +00:00
sub, rel, obj, nt_rel, label, neg_ent, sub_samp = model.read_batch(
2023-05-17 06:10:36 +00:00
batch, 'train')
2023-05-17 06:41:13 +00:00
if (neg_ent is None):
neg_ent = np.repeat(None, repeats=len(sub))
else:
neg_ent = neg_ent.cpu()
2023-05-17 06:10:36 +00:00
2023-05-17 07:04:48 +00:00
inputs = []
2023-05-17 07:08:06 +00:00
for i in range(len(sub)):
2023-05-17 07:04:48 +00:00
input = {}
input['sub'] = sub[i]
input['rel'] = rel[i]
input['neg_ents'] = neg_ent[i]
inputs.append(input)
search = grid.fit(inputs, label)
2023-05-17 06:10:36 +00:00
print("BEST SCORE: ", search.best_score_)
print("BEST PARAMS: ", search.best_params_)
2023-06-08 06:40:16 +00:00
logger = get_logger(
args.name, args.log_dir, args.config_dir)
2023-05-04 08:49:41 +00:00
if (args.test_only):
2023-06-08 06:40:16 +00:00
model = Main(args, logger)
2023-05-04 08:49:41 +00:00
save_path = os.path.join('./torch_saved', args.name)
model.load_model(save_path)
model.evaluate('test')
else:
2024-04-27 03:26:58 +00:00
model = Main(args, logger)
model.fit()
# while True:
# try:
# model = Main(args, logger)
# model.fit()
# except Exception as e:
# print(e)
# traceback.print_exc()
# try:
# del model
# except Exception:
# pass
# time.sleep(30)
# continue
# break