try sep vit
This commit is contained in:
parent
3243b1d963
commit
ab5c1d0b4b
11
models.py
11
models.py
@ -906,7 +906,7 @@ class PEG(nn.Module):
|
|||||||
|
|
||||||
# feedforward
|
# feedforward
|
||||||
|
|
||||||
class FeedForward(nn.Module):
|
class FeedForwardDSSA(nn.Module):
|
||||||
def __init__(self, dim, mult = 4, dropout = 0.):
|
def __init__(self, dim, mult = 4, dropout = 0.):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
inner_dim = int(dim * mult)
|
inner_dim = int(dim * mult)
|
||||||
@ -1094,7 +1094,10 @@ class PoolFormerBlock(nn.Module):
|
|||||||
self.attn_heads = 4
|
self.attn_heads = 4
|
||||||
self.attn_mask = None
|
self.attn_mask = None
|
||||||
# self.token_mixer = WindowAttention(dim=dim, window_size=to_2tuple(self.window_size), num_heads=4)
|
# self.token_mixer = WindowAttention(dim=dim, window_size=to_2tuple(self.window_size), num_heads=4)
|
||||||
self.token_mixer = DSSA(dim, heads=self.attn_heads, window_size=self.window_size)
|
self.token_mixer = nn.ModuleList([
|
||||||
|
DSSA(dim, heads=self.attn_heads, window_size=self.window_size),
|
||||||
|
FeedForwardDSSA(dim)
|
||||||
|
])
|
||||||
self.norm2 = norm_layer(dim)
|
self.norm2 = norm_layer(dim)
|
||||||
mlp_hidden_dim = int(dim * mlp_ratio)
|
mlp_hidden_dim = int(dim * mlp_ratio)
|
||||||
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim,
|
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim,
|
||||||
@ -1213,7 +1216,7 @@ class LayerNormChannel(nn.Module):
|
|||||||
+ self.bias.unsqueeze(-1).unsqueeze(-1)
|
+ self.bias.unsqueeze(-1).unsqueeze(-1)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
class FeedForward(nn.Module):
|
class FeedForwardFNet(nn.Module):
|
||||||
def __init__(self, dim, hidden_dim, dropout = 0.):
|
def __init__(self, dim, hidden_dim, dropout = 0.):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.net = nn.Sequential(
|
self.net = nn.Sequential(
|
||||||
@ -1249,7 +1252,7 @@ class FNet(nn.Module):
|
|||||||
for _ in range(depth):
|
for _ in range(depth):
|
||||||
self.layers.append(nn.ModuleList([
|
self.layers.append(nn.ModuleList([
|
||||||
PreNorm(dim, FNetBlock()),
|
PreNorm(dim, FNetBlock()),
|
||||||
PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
|
PreNorm(dim, FeedForwardFNet(dim, mlp_dim, dropout = dropout))
|
||||||
]))
|
]))
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
for attn, ff in self.layers:
|
for attn, ff in self.layers:
|
||||||
|
Loading…
Reference in New Issue
Block a user