From 48669c72f4c2ead8b80f5db87f2cd0a4f03c9087 Mon Sep 17 00:00:00 2001 From: thanhvc3 Date: Mon, 29 Apr 2024 16:38:27 +0700 Subject: [PATCH] try modify swin --- models.py | 55 ++++++++++++++++++++++++++++++++++++++++++++++++++----- 1 file changed, 50 insertions(+), 5 deletions(-) diff --git a/models.py b/models.py index 3302431..40ea5c1 100644 --- a/models.py +++ b/models.py @@ -435,6 +435,50 @@ class TuckER(torch.nn.Module): return pred +class PatchMerging(nn.Module): + r""" Patch Merging Layer. + + Args: + input_resolution (tuple[int]): Resolution of input feature. + dim (int): Number of input channels. + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, dim, norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim + self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) + self.norm = norm_layer(2 * dim) + + def forward(self, x): + """ + x: B, C, H, W + """ + B, C, H, W = x.shape + assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even." + + x = x.view(B, H, W, C) + + x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C + x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C + x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C + x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C + x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C + x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C + + x = self.reduction(x) + x = self.norm(x) + + return x + + def extra_repr(self) -> str: + return f"input_resolution={self.input_resolution}, dim={self.dim}" + + def flops(self): + H, W = self.input_resolution + flops = (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim + flops += H * W * self.dim // 2 + return flops class FouriER(torch.nn.Module): def __init__(self, params, hid_drop = None, embed_dim = None): @@ -519,11 +563,12 @@ class FouriER(torch.nn.Module): if downsamples[i] or embed_dims[i] != embed_dims[i+1]: # downsampling between two stages network.append( - PatchEmbed( - patch_size=down_patch_size, stride=down_stride, - padding=down_pad, - in_chans=embed_dims[i], embed_dim=embed_dims[i+1] - ) + # PatchEmbed( + # patch_size=down_patch_size, stride=down_stride, + # padding=down_pad, + # in_chans=embed_dims[i], embed_dim=embed_dims[i+1] + # ) + PatchMerging(dim=embed_dims[i+1]) ) self.network = nn.ModuleList(network)