try sep vit
This commit is contained in:
parent
3243b1d963
commit
ab5c1d0b4b
11
models.py
11
models.py
@ -906,7 +906,7 @@ class PEG(nn.Module):
|
||||
|
||||
# feedforward
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
class FeedForwardDSSA(nn.Module):
|
||||
def __init__(self, dim, mult = 4, dropout = 0.):
|
||||
super().__init__()
|
||||
inner_dim = int(dim * mult)
|
||||
@ -1094,7 +1094,10 @@ class PoolFormerBlock(nn.Module):
|
||||
self.attn_heads = 4
|
||||
self.attn_mask = None
|
||||
# self.token_mixer = WindowAttention(dim=dim, window_size=to_2tuple(self.window_size), num_heads=4)
|
||||
self.token_mixer = DSSA(dim, heads=self.attn_heads, window_size=self.window_size)
|
||||
self.token_mixer = nn.ModuleList([
|
||||
DSSA(dim, heads=self.attn_heads, window_size=self.window_size),
|
||||
FeedForwardDSSA(dim)
|
||||
])
|
||||
self.norm2 = norm_layer(dim)
|
||||
mlp_hidden_dim = int(dim * mlp_ratio)
|
||||
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim,
|
||||
@ -1213,7 +1216,7 @@ class LayerNormChannel(nn.Module):
|
||||
+ self.bias.unsqueeze(-1).unsqueeze(-1)
|
||||
return x
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
class FeedForwardFNet(nn.Module):
|
||||
def __init__(self, dim, hidden_dim, dropout = 0.):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
@ -1249,7 +1252,7 @@ class FNet(nn.Module):
|
||||
for _ in range(depth):
|
||||
self.layers.append(nn.ModuleList([
|
||||
PreNorm(dim, FNetBlock()),
|
||||
PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
|
||||
PreNorm(dim, FeedForwardFNet(dim, mlp_dim, dropout = dropout))
|
||||
]))
|
||||
def forward(self, x):
|
||||
for attn, ff in self.layers:
|
||||
|
Loading…
Reference in New Issue
Block a user