diff --git a/models.py b/models.py index 947ce5a..60e6eb8 100644 --- a/models.py +++ b/models.py @@ -10,6 +10,7 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.models.layers import DropPath, trunc_normal_ from timm.models.registry import register_model from timm.layers.helpers import to_2tuple +from typing import * class ConvE(torch.nn.Module):