Thesis/pretrain/main.py

142 lines
5.0 KiB
Python
Raw Permalink Normal View History

2022-12-26 04:54:46 +00:00
import argparse
import importlib
from logging import debug
import numpy as np
import torch
import pytorch_lightning as pl
import lit_models
import yaml
import time
from transformers import AutoConfig
import os
# from utils import get_decoder_input_ids
os.environ["TOKENIZERS_PARALLELISM"] = "false"
# In order to ensure reproducible experiments, we must set random seeds.
def _import_class(module_and_class_name: str) -> type:
"""Import class from a module, e.g. 'text_recognizer.models.MLP'"""
module_name, class_name = module_and_class_name.rsplit(".", 1)
module = importlib.import_module(module_name)
class_ = getattr(module, class_name)
return class_
def _setup_parser():
"""Set up Python's ArgumentParser with data, model, trainer, and other arguments."""
parser = argparse.ArgumentParser(add_help=False)
# Add Trainer specific arguments, such as --max_epochs, --gpus, --precision
trainer_parser = pl.Trainer.add_argparse_args(parser)
trainer_parser._action_groups[1].title = "Trainer Args" # pylint: disable=protected-access
parser = argparse.ArgumentParser(add_help=False, parents=[trainer_parser])
# Basic arguments
parser.add_argument("--wandb", action="store_true", default=False)
parser.add_argument("--litmodel_class", type=str, default="TransformerLitModel")
parser.add_argument("--seed", type=int, default=7)
parser.add_argument("--data_class", type=str, default="KGC")
parser.add_argument("--chunk", type=str, default="")
parser.add_argument("--model_class", type=str, default="RobertaUseLabelWord")
parser.add_argument("--checkpoint", type=str, default=None)
# Get the data and model classes, so that we can add their specific arguments
temp_args, _ = parser.parse_known_args()
data_class = _import_class(f"data.{temp_args.data_class}")
model_class = _import_class(f"models.{temp_args.model_class}")
lit_model_class = _import_class(f"lit_models.{temp_args.litmodel_class}")
# Get data, model, and LitModel specific arguments
data_group = parser.add_argument_group("Data Args")
data_class.add_to_argparse(data_group)
model_group = parser.add_argument_group("Model Args")
if hasattr(model_class, "add_to_argparse"):
model_class.add_to_argparse(model_group)
lit_model_group = parser.add_argument_group("LitModel Args")
lit_model_class.add_to_argparse(lit_model_group)
parser.add_argument("--help", "-h", action="help")
return parser
def _saved_pretrain(lit_model, tokenizer, path):
lit_model.model.save_pretrained(path)
tokenizer.save_pretrained(path)
def main():
parser = _setup_parser()
args = parser.parse_args()
np.random.seed(args.seed)
torch.manual_seed(args.seed)
pl.seed_everything(args.seed)
data_class = _import_class(f"data.{args.data_class}")
model_class = _import_class(f"models.{args.model_class}")
litmodel_class = _import_class(f"lit_models.{args.litmodel_class}")
config = AutoConfig.from_pretrained(args.model_name_or_path)
# update parameters
config.label_smoothing = args.label_smoothing
model = model_class.from_pretrained(args.model_name_or_path, config=config)
data = data_class(args, model)
tokenizer = data.tokenizer
lit_model = litmodel_class(args=args, model=model, tokenizer=tokenizer, data_config=data.get_config())
if args.checkpoint:
lit_model.load_state_dict(torch.load(args.checkpoint, map_location="cpu")["state_dict"])
logger = pl.loggers.TensorBoardLogger("training/logs")
if args.wandb:
logger = pl.loggers.WandbLogger(project="kgc_bert", name=args.data_dir.split("/")[-1])
logger.log_hyperparams(vars(args))
metric_name = "Eval/mrr" if not args.pretrain else "Eval/hits1"
early_callback = pl.callbacks.EarlyStopping(monitor="Eval/mrr", mode="max", patience=10)
model_checkpoint = pl.callbacks.ModelCheckpoint(monitor=metric_name, mode="max",
filename=args.data_dir.split("/")[-1] + '/{epoch}-{Eval/hits10:.2f}-{Eval/hits1:.2f}' if not args.pretrain else args.data_dir.split("/")[-1] + '/{epoch}-{step}-{Eval/hits10:.2f}',
dirpath="output",
save_weights_only=True, # to be modified
every_n_train_steps=100 if args.pretrain else None,
save_top_k=5 if args.pretrain else 1
)
callbacks = [early_callback, model_checkpoint]
# args.weights_summary = "full" # Print full summary of the model
trainer = pl.Trainer.from_argparse_args(args, callbacks=callbacks, logger=logger, default_root_dir="training/logs")
if "EntityEmbedding" not in lit_model.__class__.__name__:
trainer.fit(lit_model, datamodule=data)
path = model_checkpoint.best_model_path
lit_model.load_state_dict(torch.load(path)["state_dict"])
result = trainer.test(lit_model, datamodule=data)
print(result)
# _saved_pretrain(lit_model, tokenizer, path)
if "EntityEmbedding" not in lit_model.__class__.__name__:
print("*path"*30)
print(path)
if __name__ == "__main__":
main()