From d3a6cfe0414e246a261b1a9e7ae4760135eeb836 Mon Sep 17 00:00:00 2001 From: thanhvc3 Date: Mon, 29 Apr 2024 19:57:20 +0700 Subject: [PATCH] try modify swin --- models.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/models.py b/models.py index fa6ad20..212f3f6 100644 --- a/models.py +++ b/models.py @@ -532,9 +532,9 @@ class FouriER(torch.nn.Module): self.patch_embed = PatchEmbed(in_chans=channels, patch_size=self.p.patch_size, embed_dim=self.p.embed_dim, stride=4, padding=2) network = [] - layers = [4, 4, 12, 4] - embed_dims = [self.p.embed_dim, 128, 320, 128] - mlp_ratios = [4, 4, 4, 4] + layers = [2, 2, 6, 2] + embed_dims = [self.p.embed_dim, 320, 256, 128] + mlp_ratios = [4, 4, 8, 12] num_heads = [2, 4, 8, 16] downsamples = [True, True, True, True] pool_size=3 @@ -558,7 +558,7 @@ class FouriER(torch.nn.Module): use_layer_scale=use_layer_scale, layer_scale_init_value=layer_scale_init_value, num_heads=num_heads[i], input_resolution=(image_h // (2**i), image_w // (2**i)), - window_size=window_size, shift_size=0 if (i % 2 == 0) else window_size // 2) + window_size=window_size, shift_size=0) network.append(stage) if i >= len(layers) - 1: break @@ -949,7 +949,7 @@ class PoolFormerBlock(nn.Module): self.window_size = window_size self.shift_size = shift_size self.input_resolution = input_resolution - self.token_mixer = WindowAttention(dim=dim, window_size=to_2tuple(self.window_size), num_heads=num_heads) + self.token_mixer = WindowAttention(dim=dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, attn_drop=0.2, proj_drop=0.1) self.norm2 = norm_layer(dim) mlp_hidden_dim = int(dim * mlp_ratio) self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim,