try modify swin
This commit is contained in:
parent
f8e969cbd1
commit
7e6d4982d9
15
models.py
15
models.py
@ -489,8 +489,9 @@ class FouriER(torch.nn.Module):
|
|||||||
embed_dim=self.p.embed_dim, stride=4, padding=2)
|
embed_dim=self.p.embed_dim, stride=4, padding=2)
|
||||||
network = []
|
network = []
|
||||||
layers = [4, 4, 12, 4]
|
layers = [4, 4, 12, 4]
|
||||||
embed_dims = [self.p.embed_dim, 128, 320, 128]
|
embed_dims = [self.p.embed_dim, 320, 256, 128]
|
||||||
mlp_ratios = [4, 4, 4, 4]
|
mlp_ratios = [4, 4, 8, 12]
|
||||||
|
num_heads = [2, 4, 8, 16]
|
||||||
downsamples = [True, True, True, True]
|
downsamples = [True, True, True, True]
|
||||||
pool_size=3
|
pool_size=3
|
||||||
act_layer=nn.GELU
|
act_layer=nn.GELU
|
||||||
@ -510,7 +511,8 @@ class FouriER(torch.nn.Module):
|
|||||||
drop_rate=drop_rate,
|
drop_rate=drop_rate,
|
||||||
drop_path_rate=drop_path_rate,
|
drop_path_rate=drop_path_rate,
|
||||||
use_layer_scale=use_layer_scale,
|
use_layer_scale=use_layer_scale,
|
||||||
layer_scale_init_value=layer_scale_init_value)
|
layer_scale_init_value=layer_scale_init_value,
|
||||||
|
num_heads=num_heads[i])
|
||||||
network.append(stage)
|
network.append(stage)
|
||||||
if i >= len(layers) - 1:
|
if i >= len(layers) - 1:
|
||||||
break
|
break
|
||||||
@ -687,7 +689,7 @@ def basic_blocks(dim, index, layers,
|
|||||||
pool_size=3, mlp_ratio=4.,
|
pool_size=3, mlp_ratio=4.,
|
||||||
act_layer=nn.GELU, norm_layer=GroupNorm,
|
act_layer=nn.GELU, norm_layer=GroupNorm,
|
||||||
drop_rate=.0, drop_path_rate=0.,
|
drop_rate=.0, drop_path_rate=0.,
|
||||||
use_layer_scale=True, layer_scale_init_value=1e-5):
|
use_layer_scale=True, layer_scale_init_value=1e-5, num_heads = 4):
|
||||||
"""
|
"""
|
||||||
generate PoolFormer blocks for a stage
|
generate PoolFormer blocks for a stage
|
||||||
return: PoolFormer blocks
|
return: PoolFormer blocks
|
||||||
@ -702,6 +704,7 @@ def basic_blocks(dim, index, layers,
|
|||||||
drop=drop_rate, drop_path=block_dpr,
|
drop=drop_rate, drop_path=block_dpr,
|
||||||
use_layer_scale=use_layer_scale,
|
use_layer_scale=use_layer_scale,
|
||||||
layer_scale_init_value=layer_scale_init_value,
|
layer_scale_init_value=layer_scale_init_value,
|
||||||
|
num_heads=num_heads
|
||||||
))
|
))
|
||||||
blocks = nn.Sequential(*blocks)
|
blocks = nn.Sequential(*blocks)
|
||||||
|
|
||||||
@ -884,7 +887,7 @@ class PoolFormerBlock(nn.Module):
|
|||||||
"""
|
"""
|
||||||
def __init__(self, dim, pool_size=3, mlp_ratio=4.,
|
def __init__(self, dim, pool_size=3, mlp_ratio=4.,
|
||||||
act_layer=nn.GELU, norm_layer=GroupNorm,
|
act_layer=nn.GELU, norm_layer=GroupNorm,
|
||||||
drop=0., drop_path=0.,
|
drop=0., drop_path=0., num_heads=4,
|
||||||
use_layer_scale=True, layer_scale_init_value=1e-5):
|
use_layer_scale=True, layer_scale_init_value=1e-5):
|
||||||
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -894,7 +897,7 @@ class PoolFormerBlock(nn.Module):
|
|||||||
# self.token_mixer = FNetBlock()
|
# self.token_mixer = FNetBlock()
|
||||||
self.window_size = 4
|
self.window_size = 4
|
||||||
self.attn_mask = None
|
self.attn_mask = None
|
||||||
self.token_mixer = WindowAttention(dim=dim, window_size=to_2tuple(self.window_size), num_heads=4)
|
self.token_mixer = WindowAttention(dim=dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, attn_drop=0.1, proj_drop=0.2)
|
||||||
self.norm2 = norm_layer(dim)
|
self.norm2 = norm_layer(dim)
|
||||||
mlp_hidden_dim = int(dim * mlp_ratio)
|
mlp_hidden_dim = int(dim * mlp_ratio)
|
||||||
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim,
|
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim,
|
||||||
|
Loading…
Reference in New Issue
Block a user