try modify swin

This commit is contained in:
thanhvc3 2024-04-29 17:14:48 +07:00
parent f86e27dab7
commit 805d4fb536

View File

@ -965,7 +965,7 @@ class PoolFormerBlock(nn.Module):
if self.shift_size > 0: if self.shift_size > 0:
# calculate attention mask for SW-MSA # calculate attention mask for SW-MSA
H, W = self.input_resolution H, W = self.input_resolution
img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 img_mask = torch.zeros((1, 1, H, W)) # 1 H W 1
h_slices = (slice(0, -self.window_size), h_slices = (slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size), slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None)) slice(-self.shift_size, None))
@ -975,7 +975,7 @@ class PoolFormerBlock(nn.Module):
cnt = 0 cnt = 0
for h in h_slices: for h in h_slices:
for w in w_slices: for w in w_slices:
img_mask[:, h, w, :] = cnt img_mask[:, :, h, w] = cnt
cnt += 1 cnt += 1
print(self.input_resolution) print(self.input_resolution)