diff --git a/models.py b/models.py index b262ffd..272f53a 100644 --- a/models.py +++ b/models.py @@ -489,8 +489,9 @@ class FouriER(torch.nn.Module): embed_dim=self.p.embed_dim, stride=4, padding=2) network = [] layers = [4, 4, 12, 4] - embed_dims = [self.p.embed_dim, 128, 320, 128] - mlp_ratios = [4, 4, 4, 4] + embed_dims = [self.p.embed_dim, 320, 256, 128] + mlp_ratios = [4, 4, 8, 12] + num_heads = [2, 4, 8, 16] downsamples = [True, True, True, True] pool_size=3 act_layer=nn.GELU @@ -510,7 +511,8 @@ class FouriER(torch.nn.Module): drop_rate=drop_rate, drop_path_rate=drop_path_rate, use_layer_scale=use_layer_scale, - layer_scale_init_value=layer_scale_init_value) + layer_scale_init_value=layer_scale_init_value, + num_heads=num_heads[i]) network.append(stage) if i >= len(layers) - 1: break @@ -687,7 +689,7 @@ def basic_blocks(dim, index, layers, pool_size=3, mlp_ratio=4., act_layer=nn.GELU, norm_layer=GroupNorm, drop_rate=.0, drop_path_rate=0., - use_layer_scale=True, layer_scale_init_value=1e-5): + use_layer_scale=True, layer_scale_init_value=1e-5, num_heads = 4): """ generate PoolFormer blocks for a stage return: PoolFormer blocks @@ -702,6 +704,7 @@ def basic_blocks(dim, index, layers, drop=drop_rate, drop_path=block_dpr, use_layer_scale=use_layer_scale, layer_scale_init_value=layer_scale_init_value, + num_heads=num_heads )) blocks = nn.Sequential(*blocks) @@ -884,7 +887,7 @@ class PoolFormerBlock(nn.Module): """ def __init__(self, dim, pool_size=3, mlp_ratio=4., act_layer=nn.GELU, norm_layer=GroupNorm, - drop=0., drop_path=0., + drop=0., drop_path=0., num_heads=4, use_layer_scale=True, layer_scale_init_value=1e-5): super().__init__() @@ -894,7 +897,7 @@ class PoolFormerBlock(nn.Module): # self.token_mixer = FNetBlock() self.window_size = 4 self.attn_mask = None - self.token_mixer = WindowAttention(dim=dim, window_size=to_2tuple(self.window_size), num_heads=4) + self.token_mixer = WindowAttention(dim=dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, attn_drop=0.1, proj_drop=0.2) self.norm2 = norm_layer(dim) mlp_hidden_dim = int(dim * mlp_ratio) self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim,