try modify swin
This commit is contained in:
parent
d79bdd1c3e
commit
48669c72f4
55
models.py
55
models.py
@ -435,6 +435,50 @@ class TuckER(torch.nn.Module):
|
||||
|
||||
return pred
|
||||
|
||||
class PatchMerging(nn.Module):
|
||||
r""" Patch Merging Layer.
|
||||
|
||||
Args:
|
||||
input_resolution (tuple[int]): Resolution of input feature.
|
||||
dim (int): Number of input channels.
|
||||
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
||||
"""
|
||||
|
||||
def __init__(self, dim, norm_layer=nn.LayerNorm):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
|
||||
self.norm = norm_layer(2 * dim)
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
x: B, C, H, W
|
||||
"""
|
||||
B, C, H, W = x.shape
|
||||
assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
|
||||
|
||||
x = x.view(B, H, W, C)
|
||||
|
||||
x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
|
||||
x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
|
||||
x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
|
||||
x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
|
||||
x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
|
||||
x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
|
||||
|
||||
x = self.reduction(x)
|
||||
x = self.norm(x)
|
||||
|
||||
return x
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
return f"input_resolution={self.input_resolution}, dim={self.dim}"
|
||||
|
||||
def flops(self):
|
||||
H, W = self.input_resolution
|
||||
flops = (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim
|
||||
flops += H * W * self.dim // 2
|
||||
return flops
|
||||
|
||||
class FouriER(torch.nn.Module):
|
||||
def __init__(self, params, hid_drop = None, embed_dim = None):
|
||||
@ -519,11 +563,12 @@ class FouriER(torch.nn.Module):
|
||||
if downsamples[i] or embed_dims[i] != embed_dims[i+1]:
|
||||
# downsampling between two stages
|
||||
network.append(
|
||||
PatchEmbed(
|
||||
patch_size=down_patch_size, stride=down_stride,
|
||||
padding=down_pad,
|
||||
in_chans=embed_dims[i], embed_dim=embed_dims[i+1]
|
||||
)
|
||||
# PatchEmbed(
|
||||
# patch_size=down_patch_size, stride=down_stride,
|
||||
# padding=down_pad,
|
||||
# in_chans=embed_dims[i], embed_dim=embed_dims[i+1]
|
||||
# )
|
||||
PatchMerging(dim=embed_dims[i+1])
|
||||
)
|
||||
|
||||
self.network = nn.ModuleList(network)
|
||||
|
Loading…
Reference in New Issue
Block a user