142 lines
5.0 KiB
Python
142 lines
5.0 KiB
Python
import argparse
|
|
import importlib
|
|
from logging import debug
|
|
|
|
import numpy as np
|
|
import torch
|
|
import pytorch_lightning as pl
|
|
import lit_models
|
|
import yaml
|
|
import time
|
|
from transformers import AutoConfig
|
|
import os
|
|
# from utils import get_decoder_input_ids
|
|
|
|
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
|
|
|
|
|
# In order to ensure reproducible experiments, we must set random seeds.
|
|
|
|
|
|
def _import_class(module_and_class_name: str) -> type:
|
|
"""Import class from a module, e.g. 'text_recognizer.models.MLP'"""
|
|
module_name, class_name = module_and_class_name.rsplit(".", 1)
|
|
module = importlib.import_module(module_name)
|
|
class_ = getattr(module, class_name)
|
|
return class_
|
|
|
|
|
|
def _setup_parser():
|
|
"""Set up Python's ArgumentParser with data, model, trainer, and other arguments."""
|
|
parser = argparse.ArgumentParser(add_help=False)
|
|
|
|
# Add Trainer specific arguments, such as --max_epochs, --gpus, --precision
|
|
trainer_parser = pl.Trainer.add_argparse_args(parser)
|
|
trainer_parser._action_groups[1].title = "Trainer Args" # pylint: disable=protected-access
|
|
parser = argparse.ArgumentParser(add_help=False, parents=[trainer_parser])
|
|
|
|
# Basic arguments
|
|
parser.add_argument("--wandb", action="store_true", default=False)
|
|
parser.add_argument("--litmodel_class", type=str, default="TransformerLitModel")
|
|
parser.add_argument("--seed", type=int, default=7)
|
|
parser.add_argument("--data_class", type=str, default="KGC")
|
|
parser.add_argument("--chunk", type=str, default="")
|
|
parser.add_argument("--model_class", type=str, default="RobertaUseLabelWord")
|
|
parser.add_argument("--checkpoint", type=str, default=None)
|
|
|
|
# Get the data and model classes, so that we can add their specific arguments
|
|
temp_args, _ = parser.parse_known_args()
|
|
data_class = _import_class(f"data.{temp_args.data_class}")
|
|
model_class = _import_class(f"models.{temp_args.model_class}")
|
|
lit_model_class = _import_class(f"lit_models.{temp_args.litmodel_class}")
|
|
|
|
# Get data, model, and LitModel specific arguments
|
|
data_group = parser.add_argument_group("Data Args")
|
|
data_class.add_to_argparse(data_group)
|
|
|
|
model_group = parser.add_argument_group("Model Args")
|
|
if hasattr(model_class, "add_to_argparse"):
|
|
model_class.add_to_argparse(model_group)
|
|
|
|
lit_model_group = parser.add_argument_group("LitModel Args")
|
|
lit_model_class.add_to_argparse(lit_model_group)
|
|
|
|
parser.add_argument("--help", "-h", action="help")
|
|
return parser
|
|
|
|
def _saved_pretrain(lit_model, tokenizer, path):
|
|
lit_model.model.save_pretrained(path)
|
|
tokenizer.save_pretrained(path)
|
|
|
|
|
|
def main():
|
|
parser = _setup_parser()
|
|
args = parser.parse_args()
|
|
|
|
np.random.seed(args.seed)
|
|
torch.manual_seed(args.seed)
|
|
pl.seed_everything(args.seed)
|
|
|
|
data_class = _import_class(f"data.{args.data_class}")
|
|
model_class = _import_class(f"models.{args.model_class}")
|
|
litmodel_class = _import_class(f"lit_models.{args.litmodel_class}")
|
|
|
|
config = AutoConfig.from_pretrained(args.model_name_or_path)
|
|
# update parameters
|
|
config.label_smoothing = args.label_smoothing
|
|
|
|
|
|
model = model_class.from_pretrained(args.model_name_or_path, config=config)
|
|
data = data_class(args, model)
|
|
tokenizer = data.tokenizer
|
|
|
|
|
|
lit_model = litmodel_class(args=args, model=model, tokenizer=tokenizer, data_config=data.get_config())
|
|
if args.checkpoint:
|
|
lit_model.load_state_dict(torch.load(args.checkpoint, map_location="cpu")["state_dict"])
|
|
|
|
logger = pl.loggers.TensorBoardLogger("training/logs")
|
|
if args.wandb:
|
|
logger = pl.loggers.WandbLogger(project="kgc_bert", name=args.data_dir.split("/")[-1])
|
|
logger.log_hyperparams(vars(args))
|
|
|
|
metric_name = "Eval/mrr" if not args.pretrain else "Eval/hits1"
|
|
|
|
|
|
early_callback = pl.callbacks.EarlyStopping(monitor="Eval/mrr", mode="max", patience=10)
|
|
model_checkpoint = pl.callbacks.ModelCheckpoint(monitor=metric_name, mode="max",
|
|
filename=args.data_dir.split("/")[-1] + '/{epoch}-{Eval/hits10:.2f}-{Eval/hits1:.2f}' if not args.pretrain else args.data_dir.split("/")[-1] + '/{epoch}-{step}-{Eval/hits10:.2f}',
|
|
dirpath="output",
|
|
save_weights_only=True, # to be modified
|
|
every_n_train_steps=100 if args.pretrain else None,
|
|
save_top_k=5 if args.pretrain else 1
|
|
)
|
|
callbacks = [early_callback, model_checkpoint]
|
|
|
|
|
|
|
|
|
|
# args.weights_summary = "full" # Print full summary of the model
|
|
trainer = pl.Trainer.from_argparse_args(args, callbacks=callbacks, logger=logger, default_root_dir="training/logs")
|
|
|
|
if "EntityEmbedding" not in lit_model.__class__.__name__:
|
|
trainer.fit(lit_model, datamodule=data)
|
|
path = model_checkpoint.best_model_path
|
|
lit_model.load_state_dict(torch.load(path)["state_dict"])
|
|
|
|
result = trainer.test(lit_model, datamodule=data)
|
|
print(result)
|
|
|
|
# _saved_pretrain(lit_model, tokenizer, path)
|
|
if "EntityEmbedding" not in lit_model.__class__.__name__:
|
|
print("*path"*30)
|
|
print(path)
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
main()
|