diff --git a/models.py b/models.py index 212f3f6..6f47287 100644 --- a/models.py +++ b/models.py @@ -535,7 +535,7 @@ class FouriER(torch.nn.Module): layers = [2, 2, 6, 2] embed_dims = [self.p.embed_dim, 320, 256, 128] mlp_ratios = [4, 4, 8, 12] - num_heads = [2, 4, 8, 16] + num_heads = [4, 4, 4, 4] downsamples = [True, True, True, True] pool_size=3 act_layer=nn.GELU