Compare commits
13 Commits
Author | SHA1 | Date | |
---|---|---|---|
a14267d96c | |||
d3a6cfe041 | |||
2b6e356e60 | |||
a8ac4d1b3f | |||
8866ea448e | |||
b661823661 | |||
805d4fb536 | |||
f86e27dab7 | |||
65963bf46b | |||
5494206a04 | |||
48669c72f4 | |||
d79bdd1c3e | |||
7e6d4982d9 |
110
models.py
110
models.py
@ -435,6 +435,50 @@ class TuckER(torch.nn.Module):
|
||||
|
||||
return pred
|
||||
|
||||
class PatchMerging(nn.Module):
|
||||
r""" Patch Merging Layer.
|
||||
|
||||
Args:
|
||||
input_resolution (tuple[int]): Resolution of input feature.
|
||||
dim (int): Number of input channels.
|
||||
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
||||
"""
|
||||
|
||||
def __init__(self, dim, norm_layer=nn.LayerNorm):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
|
||||
self.norm = norm_layer(2 * dim)
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
x: B, C, H, W
|
||||
"""
|
||||
B, C, H, W = x.shape
|
||||
assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
|
||||
|
||||
x = x.view(B, H, W, C)
|
||||
|
||||
x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
|
||||
x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
|
||||
x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
|
||||
x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
|
||||
x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
|
||||
x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
|
||||
|
||||
x = self.reduction(x)
|
||||
x = self.norm(x)
|
||||
|
||||
return x
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
return f"input_resolution={self.input_resolution}, dim={self.dim}"
|
||||
|
||||
def flops(self):
|
||||
H, W = self.input_resolution
|
||||
flops = (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim
|
||||
flops += H * W * self.dim // 2
|
||||
return flops
|
||||
|
||||
class FouriER(torch.nn.Module):
|
||||
def __init__(self, params, hid_drop = None, embed_dim = None):
|
||||
@ -488,9 +532,10 @@ class FouriER(torch.nn.Module):
|
||||
self.patch_embed = PatchEmbed(in_chans=channels, patch_size=self.p.patch_size,
|
||||
embed_dim=self.p.embed_dim, stride=4, padding=2)
|
||||
network = []
|
||||
layers = [4, 4, 12, 4]
|
||||
embed_dims = [self.p.embed_dim, 128, 320, 128]
|
||||
mlp_ratios = [4, 4, 4, 4]
|
||||
layers = [2, 2, 6, 2]
|
||||
embed_dims = [self.p.embed_dim, 320, 256, 128]
|
||||
mlp_ratios = [4, 4, 8, 12]
|
||||
num_heads = [4, 4, 4, 4]
|
||||
downsamples = [True, True, True, True]
|
||||
pool_size=3
|
||||
act_layer=nn.GELU
|
||||
@ -502,6 +547,7 @@ class FouriER(torch.nn.Module):
|
||||
down_patch_size=3
|
||||
down_stride=2
|
||||
down_pad=1
|
||||
window_size = 4
|
||||
num_classes=self.p.embed_dim
|
||||
for i in range(len(layers)):
|
||||
stage = basic_blocks(embed_dims[i], i, layers,
|
||||
@ -510,7 +556,9 @@ class FouriER(torch.nn.Module):
|
||||
drop_rate=drop_rate,
|
||||
drop_path_rate=drop_path_rate,
|
||||
use_layer_scale=use_layer_scale,
|
||||
layer_scale_init_value=layer_scale_init_value)
|
||||
layer_scale_init_value=layer_scale_init_value,
|
||||
num_heads=num_heads[i], input_resolution=(image_h // (2**i), image_w // (2**i)),
|
||||
window_size=window_size, shift_size=0)
|
||||
network.append(stage)
|
||||
if i >= len(layers) - 1:
|
||||
break
|
||||
@ -522,6 +570,7 @@ class FouriER(torch.nn.Module):
|
||||
padding=down_pad,
|
||||
in_chans=embed_dims[i], embed_dim=embed_dims[i+1]
|
||||
)
|
||||
# PatchMerging(dim=embed_dims[i+1])
|
||||
)
|
||||
|
||||
self.network = nn.ModuleList(network)
|
||||
@ -687,7 +736,7 @@ def basic_blocks(dim, index, layers,
|
||||
pool_size=3, mlp_ratio=4.,
|
||||
act_layer=nn.GELU, norm_layer=GroupNorm,
|
||||
drop_rate=.0, drop_path_rate=0.,
|
||||
use_layer_scale=True, layer_scale_init_value=1e-5):
|
||||
use_layer_scale=True, layer_scale_init_value=1e-5, num_heads = 4, input_resolution = None, window_size = 4, shift_size = 2):
|
||||
"""
|
||||
generate PoolFormer blocks for a stage
|
||||
return: PoolFormer blocks
|
||||
@ -702,6 +751,8 @@ def basic_blocks(dim, index, layers,
|
||||
drop=drop_rate, drop_path=block_dpr,
|
||||
use_layer_scale=use_layer_scale,
|
||||
layer_scale_init_value=layer_scale_init_value,
|
||||
num_heads=num_heads, input_resolution = input_resolution,
|
||||
window_size=window_size, shift_size=shift_size
|
||||
))
|
||||
blocks = nn.Sequential(*blocks)
|
||||
|
||||
@ -821,9 +872,12 @@ class WindowAttention(nn.Module):
|
||||
attn = attn + relative_position_bias.unsqueeze(0)
|
||||
|
||||
if mask is not None:
|
||||
nW = mask.shape[0]
|
||||
attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
|
||||
attn = attn.view(-1, self.num_heads, N, N)
|
||||
try:
|
||||
nW = mask.shape[0]
|
||||
attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
|
||||
attn = attn.view(-1, self.num_heads, N, N)
|
||||
except:
|
||||
pass
|
||||
attn = self.softmax(attn)
|
||||
else:
|
||||
attn = self.softmax(attn)
|
||||
@ -884,17 +938,18 @@ class PoolFormerBlock(nn.Module):
|
||||
"""
|
||||
def __init__(self, dim, pool_size=3, mlp_ratio=4.,
|
||||
act_layer=nn.GELU, norm_layer=GroupNorm,
|
||||
drop=0., drop_path=0.,
|
||||
use_layer_scale=True, layer_scale_init_value=1e-5):
|
||||
drop=0., drop_path=0., num_heads=4,
|
||||
use_layer_scale=True, layer_scale_init_value=1e-5, input_resolution = None, window_size = 4, shift_size = 2):
|
||||
|
||||
super().__init__()
|
||||
|
||||
self.norm1 = norm_layer(dim)
|
||||
#self.token_mixer = Pooling(pool_size=pool_size)
|
||||
# self.token_mixer = FNetBlock()
|
||||
self.window_size = 4
|
||||
self.attn_mask = None
|
||||
self.token_mixer = WindowAttention(dim=dim, window_size=to_2tuple(self.window_size), num_heads=4)
|
||||
self.window_size = window_size
|
||||
self.shift_size = shift_size
|
||||
self.input_resolution = input_resolution
|
||||
self.token_mixer = WindowAttention(dim=dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, attn_drop=0.2, proj_drop=0.1)
|
||||
self.norm2 = norm_layer(dim)
|
||||
mlp_hidden_dim = int(dim * mlp_ratio)
|
||||
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim,
|
||||
@ -910,6 +965,31 @@ class PoolFormerBlock(nn.Module):
|
||||
self.layer_scale_2 = nn.Parameter(
|
||||
layer_scale_init_value * torch.ones((dim)), requires_grad=True)
|
||||
|
||||
if self.shift_size > 0:
|
||||
# calculate attention mask for SW-MSA
|
||||
H, W = self.input_resolution
|
||||
img_mask = torch.zeros((1, 1, H, W)) # 1 H W 1
|
||||
h_slices = (slice(0, -self.window_size),
|
||||
slice(-self.window_size, -self.shift_size),
|
||||
slice(-self.shift_size, None))
|
||||
w_slices = (slice(0, -self.window_size),
|
||||
slice(-self.window_size, -self.shift_size),
|
||||
slice(-self.shift_size, None))
|
||||
cnt = 0
|
||||
for h in h_slices:
|
||||
for w in w_slices:
|
||||
img_mask[:, :, h, w] = cnt
|
||||
cnt += 1
|
||||
|
||||
mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
|
||||
mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
|
||||
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
|
||||
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
|
||||
else:
|
||||
attn_mask = None
|
||||
|
||||
self.register_buffer("attn_mask", attn_mask)
|
||||
|
||||
def forward(self, x):
|
||||
B, C, H, W = x.shape
|
||||
x_windows = window_partition(x, self.window_size)
|
||||
@ -917,6 +997,10 @@ class PoolFormerBlock(nn.Module):
|
||||
attn_windows = self.token_mixer(x_windows, mask=self.attn_mask)
|
||||
attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
|
||||
x_attn = window_reverse(attn_windows, self.window_size, H, W)
|
||||
if self.shift_size > 0:
|
||||
x = torch.roll(x_attn, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
|
||||
else:
|
||||
x = x_attn
|
||||
if self.use_layer_scale:
|
||||
x = x + self.drop_path(
|
||||
self.layer_scale_1.unsqueeze(-1).unsqueeze(-1)
|
||||
|
Reference in New Issue
Block a user