try sep vit

This commit is contained in:
thanhvc3 2024-04-28 11:00:08 +07:00
parent 3243b1d963
commit ab5c1d0b4b

View File

@ -906,7 +906,7 @@ class PEG(nn.Module):
# feedforward # feedforward
class FeedForward(nn.Module): class FeedForwardDSSA(nn.Module):
def __init__(self, dim, mult = 4, dropout = 0.): def __init__(self, dim, mult = 4, dropout = 0.):
super().__init__() super().__init__()
inner_dim = int(dim * mult) inner_dim = int(dim * mult)
@ -1094,7 +1094,10 @@ class PoolFormerBlock(nn.Module):
self.attn_heads = 4 self.attn_heads = 4
self.attn_mask = None self.attn_mask = None
# self.token_mixer = WindowAttention(dim=dim, window_size=to_2tuple(self.window_size), num_heads=4) # self.token_mixer = WindowAttention(dim=dim, window_size=to_2tuple(self.window_size), num_heads=4)
self.token_mixer = DSSA(dim, heads=self.attn_heads, window_size=self.window_size) self.token_mixer = nn.ModuleList([
DSSA(dim, heads=self.attn_heads, window_size=self.window_size),
FeedForwardDSSA(dim)
])
self.norm2 = norm_layer(dim) self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio) mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim,
@ -1213,7 +1216,7 @@ class LayerNormChannel(nn.Module):
+ self.bias.unsqueeze(-1).unsqueeze(-1) + self.bias.unsqueeze(-1).unsqueeze(-1)
return x return x
class FeedForward(nn.Module): class FeedForwardFNet(nn.Module):
def __init__(self, dim, hidden_dim, dropout = 0.): def __init__(self, dim, hidden_dim, dropout = 0.):
super().__init__() super().__init__()
self.net = nn.Sequential( self.net = nn.Sequential(
@ -1249,7 +1252,7 @@ class FNet(nn.Module):
for _ in range(depth): for _ in range(depth):
self.layers.append(nn.ModuleList([ self.layers.append(nn.ModuleList([
PreNorm(dim, FNetBlock()), PreNorm(dim, FNetBlock()),
PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout)) PreNorm(dim, FeedForwardFNet(dim, mlp_dim, dropout = dropout))
])) ]))
def forward(self, x): def forward(self, x):
for attn, ff in self.layers: for attn, ff in self.layers: