try modify swin

This commit is contained in:
thanhvc3 2024-04-29 16:38:27 +07:00
parent d79bdd1c3e
commit 48669c72f4

View File

@ -435,6 +435,50 @@ class TuckER(torch.nn.Module):
return pred return pred
class PatchMerging(nn.Module):
r""" Patch Merging Layer.
Args:
input_resolution (tuple[int]): Resolution of input feature.
dim (int): Number of input channels.
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
"""
def __init__(self, dim, norm_layer=nn.LayerNorm):
super().__init__()
self.dim = dim
self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
self.norm = norm_layer(2 * dim)
def forward(self, x):
"""
x: B, C, H, W
"""
B, C, H, W = x.shape
assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
x = x.view(B, H, W, C)
x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
x = self.reduction(x)
x = self.norm(x)
return x
def extra_repr(self) -> str:
return f"input_resolution={self.input_resolution}, dim={self.dim}"
def flops(self):
H, W = self.input_resolution
flops = (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim
flops += H * W * self.dim // 2
return flops
class FouriER(torch.nn.Module): class FouriER(torch.nn.Module):
def __init__(self, params, hid_drop = None, embed_dim = None): def __init__(self, params, hid_drop = None, embed_dim = None):
@ -519,11 +563,12 @@ class FouriER(torch.nn.Module):
if downsamples[i] or embed_dims[i] != embed_dims[i+1]: if downsamples[i] or embed_dims[i] != embed_dims[i+1]:
# downsampling between two stages # downsampling between two stages
network.append( network.append(
PatchEmbed( # PatchEmbed(
patch_size=down_patch_size, stride=down_stride, # patch_size=down_patch_size, stride=down_stride,
padding=down_pad, # padding=down_pad,
in_chans=embed_dims[i], embed_dim=embed_dims[i+1] # in_chans=embed_dims[i], embed_dim=embed_dims[i+1]
) # )
PatchMerging(dim=embed_dims[i+1])
) )
self.network = nn.ModuleList(network) self.network = nn.ModuleList(network)