try modify swin

This commit is contained in:
thanhvc3 2024-04-29 16:07:22 +07:00
parent f8e969cbd1
commit 7e6d4982d9

View File

@ -489,8 +489,9 @@ class FouriER(torch.nn.Module):
embed_dim=self.p.embed_dim, stride=4, padding=2) embed_dim=self.p.embed_dim, stride=4, padding=2)
network = [] network = []
layers = [4, 4, 12, 4] layers = [4, 4, 12, 4]
embed_dims = [self.p.embed_dim, 128, 320, 128] embed_dims = [self.p.embed_dim, 320, 256, 128]
mlp_ratios = [4, 4, 4, 4] mlp_ratios = [4, 4, 8, 12]
num_heads = [2, 4, 8, 16]
downsamples = [True, True, True, True] downsamples = [True, True, True, True]
pool_size=3 pool_size=3
act_layer=nn.GELU act_layer=nn.GELU
@ -510,7 +511,8 @@ class FouriER(torch.nn.Module):
drop_rate=drop_rate, drop_rate=drop_rate,
drop_path_rate=drop_path_rate, drop_path_rate=drop_path_rate,
use_layer_scale=use_layer_scale, use_layer_scale=use_layer_scale,
layer_scale_init_value=layer_scale_init_value) layer_scale_init_value=layer_scale_init_value,
num_heads=num_heads[i])
network.append(stage) network.append(stage)
if i >= len(layers) - 1: if i >= len(layers) - 1:
break break
@ -687,7 +689,7 @@ def basic_blocks(dim, index, layers,
pool_size=3, mlp_ratio=4., pool_size=3, mlp_ratio=4.,
act_layer=nn.GELU, norm_layer=GroupNorm, act_layer=nn.GELU, norm_layer=GroupNorm,
drop_rate=.0, drop_path_rate=0., drop_rate=.0, drop_path_rate=0.,
use_layer_scale=True, layer_scale_init_value=1e-5): use_layer_scale=True, layer_scale_init_value=1e-5, num_heads = 4):
""" """
generate PoolFormer blocks for a stage generate PoolFormer blocks for a stage
return: PoolFormer blocks return: PoolFormer blocks
@ -702,6 +704,7 @@ def basic_blocks(dim, index, layers,
drop=drop_rate, drop_path=block_dpr, drop=drop_rate, drop_path=block_dpr,
use_layer_scale=use_layer_scale, use_layer_scale=use_layer_scale,
layer_scale_init_value=layer_scale_init_value, layer_scale_init_value=layer_scale_init_value,
num_heads=num_heads
)) ))
blocks = nn.Sequential(*blocks) blocks = nn.Sequential(*blocks)
@ -884,7 +887,7 @@ class PoolFormerBlock(nn.Module):
""" """
def __init__(self, dim, pool_size=3, mlp_ratio=4., def __init__(self, dim, pool_size=3, mlp_ratio=4.,
act_layer=nn.GELU, norm_layer=GroupNorm, act_layer=nn.GELU, norm_layer=GroupNorm,
drop=0., drop_path=0., drop=0., drop_path=0., num_heads=4,
use_layer_scale=True, layer_scale_init_value=1e-5): use_layer_scale=True, layer_scale_init_value=1e-5):
super().__init__() super().__init__()
@ -894,7 +897,7 @@ class PoolFormerBlock(nn.Module):
# self.token_mixer = FNetBlock() # self.token_mixer = FNetBlock()
self.window_size = 4 self.window_size = 4
self.attn_mask = None self.attn_mask = None
self.token_mixer = WindowAttention(dim=dim, window_size=to_2tuple(self.window_size), num_heads=4) self.token_mixer = WindowAttention(dim=dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, attn_drop=0.1, proj_drop=0.2)
self.norm2 = norm_layer(dim) self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio) mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim,