diff --git a/models.py b/models.py index dc314da..2fe29be 100644 --- a/models.py +++ b/models.py @@ -906,7 +906,7 @@ class PEG(nn.Module): # feedforward -class FeedForward(nn.Module): +class FeedForwardDSSA(nn.Module): def __init__(self, dim, mult = 4, dropout = 0.): super().__init__() inner_dim = int(dim * mult) @@ -1094,7 +1094,10 @@ class PoolFormerBlock(nn.Module): self.attn_heads = 4 self.attn_mask = None # self.token_mixer = WindowAttention(dim=dim, window_size=to_2tuple(self.window_size), num_heads=4) - self.token_mixer = DSSA(dim, heads=self.attn_heads, window_size=self.window_size) + self.token_mixer = nn.ModuleList([ + DSSA(dim, heads=self.attn_heads, window_size=self.window_size), + FeedForwardDSSA(dim) + ]) self.norm2 = norm_layer(dim) mlp_hidden_dim = int(dim * mlp_ratio) self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, @@ -1213,7 +1216,7 @@ class LayerNormChannel(nn.Module): + self.bias.unsqueeze(-1).unsqueeze(-1) return x -class FeedForward(nn.Module): +class FeedForwardFNet(nn.Module): def __init__(self, dim, hidden_dim, dropout = 0.): super().__init__() self.net = nn.Sequential( @@ -1249,7 +1252,7 @@ class FNet(nn.Module): for _ in range(depth): self.layers.append(nn.ModuleList([ PreNorm(dim, FNetBlock()), - PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout)) + PreNorm(dim, FeedForwardFNet(dim, mlp_dim, dropout = dropout)) ])) def forward(self, x): for attn, ff in self.layers: