97 lines
3.6 KiB
Python
97 lines
3.6 KiB
Python
|
import argparse
|
||
|
import pytorch_lightning as pl
|
||
|
import torch
|
||
|
from typing import Dict, Any
|
||
|
|
||
|
|
||
|
OPTIMIZER = "AdamW"
|
||
|
LR = 5e-5
|
||
|
LOSS = "cross_entropy"
|
||
|
ONE_CYCLE_TOTAL_STEPS = 100
|
||
|
|
||
|
class Config(dict):
|
||
|
def __getattr__(self, name):
|
||
|
return self.get(name)
|
||
|
|
||
|
def __setattr__(self, name, val):
|
||
|
self[name] = val
|
||
|
|
||
|
|
||
|
class BaseLitModel(pl.LightningModule):
|
||
|
"""
|
||
|
Generic PyTorch-Lightning class that must be initialized with a PyTorch module.
|
||
|
"""
|
||
|
|
||
|
def __init__(self, model, args: argparse.Namespace = None):
|
||
|
super().__init__()
|
||
|
self.model = model
|
||
|
self.args = Config(vars(args)) if args is not None else {}
|
||
|
|
||
|
optimizer = self.args.get("optimizer", OPTIMIZER)
|
||
|
self.optimizer_class = getattr(torch.optim, optimizer)
|
||
|
self.lr = self.args.get("lr", LR)
|
||
|
|
||
|
|
||
|
@staticmethod
|
||
|
def add_to_argparse(parser):
|
||
|
parser.add_argument("--optimizer", type=str, default=OPTIMIZER, help="optimizer class from torch.optim")
|
||
|
parser.add_argument("--lr", type=float, default=LR)
|
||
|
parser.add_argument("--weight_decay", type=float, default=0.01)
|
||
|
return parser
|
||
|
|
||
|
def configure_optimizers(self):
|
||
|
optimizer = self.optimizer_class(self.parameters(), lr=self.lr)
|
||
|
if self.one_cycle_max_lr is None:
|
||
|
return optimizer
|
||
|
scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer=optimizer, max_lr=self.one_cycle_max_lr, total_steps=self.one_cycle_total_steps)
|
||
|
return {"optimizer": optimizer, "lr_scheduler": scheduler, "monitor": "val_loss"}
|
||
|
|
||
|
def forward(self, x):
|
||
|
return self.model(x)
|
||
|
|
||
|
def training_step(self, batch, batch_idx): # pylint: disable=unused-argument
|
||
|
x, y = batch
|
||
|
logits = self(x)
|
||
|
loss = self.loss_fn(logits, y)
|
||
|
self.log("train_loss", loss)
|
||
|
self.train_acc(logits, y)
|
||
|
self.log("train_acc", self.train_acc, on_step=False, on_epoch=True)
|
||
|
return loss
|
||
|
|
||
|
def validation_step(self, batch, batch_idx): # pylint: disable=unused-argument
|
||
|
x, y = batch
|
||
|
logits = self(x)
|
||
|
loss = self.loss_fn(logits, y)
|
||
|
self.log("val_loss", loss, prog_bar=True)
|
||
|
self.val_acc(logits, y)
|
||
|
self.log("val_acc", self.val_acc, on_step=False, on_epoch=True, prog_bar=True)
|
||
|
|
||
|
def test_step(self, batch, batch_idx): # pylint: disable=unused-argument
|
||
|
x, y = batch
|
||
|
logits = self(x)
|
||
|
self.test_acc(logits, y)
|
||
|
self.log("test_acc", self.test_acc, on_step=False, on_epoch=True)
|
||
|
|
||
|
@property
|
||
|
def num_training_steps(self) -> int:
|
||
|
"""Total training steps inferred from datamodule and devices."""
|
||
|
if isinstance(self.trainer.limit_train_batches, int) and self.trainer.limit_train_batches != 0:
|
||
|
dataset_size = self.trainer.limit_train_batches
|
||
|
elif isinstance(self.trainer.limit_train_batches, float):
|
||
|
# limit_train_batches is a percentage of batches
|
||
|
dataset_size = len(self.trainer.datamodule.train_dataloader())
|
||
|
dataset_size = int(dataset_size * self.trainer.limit_train_batches)
|
||
|
else:
|
||
|
dataset_size = len(self.trainer.datamodule.train_dataloader())
|
||
|
|
||
|
num_devices = max(1, self.trainer.num_gpus, self.trainer.num_processes)
|
||
|
if self.trainer.tpu_cores:
|
||
|
num_devices = max(num_devices, self.trainer.tpu_cores)
|
||
|
|
||
|
effective_batch_size = self.trainer.accumulate_grad_batches * num_devices
|
||
|
max_estimated_steps = (dataset_size // effective_batch_size) * self.trainer.max_epochs
|
||
|
|
||
|
if self.trainer.max_steps and self.trainer.max_steps < max_estimated_steps:
|
||
|
return self.trainer.max_steps
|
||
|
return max_estimated_steps
|
||
|
|