try modify swin
This commit is contained in:
parent
d79bdd1c3e
commit
48669c72f4
55
models.py
55
models.py
@ -435,6 +435,50 @@ class TuckER(torch.nn.Module):
|
|||||||
|
|
||||||
return pred
|
return pred
|
||||||
|
|
||||||
|
class PatchMerging(nn.Module):
|
||||||
|
r""" Patch Merging Layer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_resolution (tuple[int]): Resolution of input feature.
|
||||||
|
dim (int): Number of input channels.
|
||||||
|
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, dim, norm_layer=nn.LayerNorm):
|
||||||
|
super().__init__()
|
||||||
|
self.dim = dim
|
||||||
|
self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
|
||||||
|
self.norm = norm_layer(2 * dim)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
"""
|
||||||
|
x: B, C, H, W
|
||||||
|
"""
|
||||||
|
B, C, H, W = x.shape
|
||||||
|
assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
|
||||||
|
|
||||||
|
x = x.view(B, H, W, C)
|
||||||
|
|
||||||
|
x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
|
||||||
|
x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
|
||||||
|
x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
|
||||||
|
x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
|
||||||
|
x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
|
||||||
|
x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
|
||||||
|
|
||||||
|
x = self.reduction(x)
|
||||||
|
x = self.norm(x)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
def extra_repr(self) -> str:
|
||||||
|
return f"input_resolution={self.input_resolution}, dim={self.dim}"
|
||||||
|
|
||||||
|
def flops(self):
|
||||||
|
H, W = self.input_resolution
|
||||||
|
flops = (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim
|
||||||
|
flops += H * W * self.dim // 2
|
||||||
|
return flops
|
||||||
|
|
||||||
class FouriER(torch.nn.Module):
|
class FouriER(torch.nn.Module):
|
||||||
def __init__(self, params, hid_drop = None, embed_dim = None):
|
def __init__(self, params, hid_drop = None, embed_dim = None):
|
||||||
@ -519,11 +563,12 @@ class FouriER(torch.nn.Module):
|
|||||||
if downsamples[i] or embed_dims[i] != embed_dims[i+1]:
|
if downsamples[i] or embed_dims[i] != embed_dims[i+1]:
|
||||||
# downsampling between two stages
|
# downsampling between two stages
|
||||||
network.append(
|
network.append(
|
||||||
PatchEmbed(
|
# PatchEmbed(
|
||||||
patch_size=down_patch_size, stride=down_stride,
|
# patch_size=down_patch_size, stride=down_stride,
|
||||||
padding=down_pad,
|
# padding=down_pad,
|
||||||
in_chans=embed_dims[i], embed_dim=embed_dims[i+1]
|
# in_chans=embed_dims[i], embed_dim=embed_dims[i+1]
|
||||||
)
|
# )
|
||||||
|
PatchMerging(dim=embed_dims[i+1])
|
||||||
)
|
)
|
||||||
|
|
||||||
self.network = nn.ModuleList(network)
|
self.network = nn.ModuleList(network)
|
||||||
|
Loading…
Reference in New Issue
Block a user