71 lines
2.3 KiB
Python
71 lines
2.3 KiB
Python
|
"""Base DataModule class."""
|
||
|
from pathlib import Path
|
||
|
from typing import Dict
|
||
|
import argparse
|
||
|
import os
|
||
|
|
||
|
import pytorch_lightning as pl
|
||
|
from torch.utils.data import DataLoader
|
||
|
|
||
|
|
||
|
class Config(dict):
|
||
|
def __getattr__(self, name):
|
||
|
return self.get(name)
|
||
|
|
||
|
def __setattr__(self, name, val):
|
||
|
self[name] = val
|
||
|
|
||
|
|
||
|
BATCH_SIZE = 8
|
||
|
NUM_WORKERS = 8
|
||
|
|
||
|
|
||
|
class BaseDataModule(pl.LightningDataModule):
|
||
|
"""
|
||
|
Base DataModule.
|
||
|
Learn more at https://pytorch-lightning.readthedocs.io/en/stable/datamodules.html
|
||
|
"""
|
||
|
|
||
|
def __init__(self, args: argparse.Namespace = None) -> None:
|
||
|
super().__init__()
|
||
|
self.args = Config(vars(args)) if args is not None else {}
|
||
|
self.batch_size = self.args.get("batch_size", BATCH_SIZE)
|
||
|
self.num_workers = self.args.get("num_workers", NUM_WORKERS)
|
||
|
|
||
|
|
||
|
@staticmethod
|
||
|
def add_to_argparse(parser):
|
||
|
parser.add_argument(
|
||
|
"--batch_size", type=int, default=BATCH_SIZE, help="Number of examples to operate on per forward step."
|
||
|
)
|
||
|
parser.add_argument(
|
||
|
"--num_workers", type=int, default=0, help="Number of additional processes to load data."
|
||
|
)
|
||
|
parser.add_argument(
|
||
|
"--dataset", type=str, default="./dataset/NELL", help="Number of additional processes to load data."
|
||
|
)
|
||
|
return parser
|
||
|
|
||
|
def prepare_data(self):
|
||
|
"""
|
||
|
Use this method to do things that might write to disk or that need to be done only from a single GPU in distributed settings (so don't set state `self.x = y`).
|
||
|
"""
|
||
|
pass
|
||
|
|
||
|
def setup(self, stage=None):
|
||
|
"""
|
||
|
Split into train, val, test, and set dims.
|
||
|
Should assign `torch Dataset` objects to self.data_train, self.data_val, and optionally self.data_test.
|
||
|
"""
|
||
|
self.data_train = None
|
||
|
self.data_val = None
|
||
|
self.data_test = None
|
||
|
|
||
|
def train_dataloader(self):
|
||
|
return DataLoader(self.data_train, shuffle=True, batch_size=self.batch_size, num_workers=self.num_workers, pin_memory=True)
|
||
|
|
||
|
def val_dataloader(self):
|
||
|
return DataLoader(self.data_val, shuffle=False, batch_size=self.batch_size, num_workers=self.num_workers, pin_memory=True)
|
||
|
|
||
|
def test_dataloader(self):
|
||
|
return DataLoader(self.data_test, shuffle=False, batch_size=self.batch_size, num_workers=self.num_workers, pin_memory=True)
|