try modify swin
This commit is contained in:
parent
f86e27dab7
commit
805d4fb536
@ -965,7 +965,7 @@ class PoolFormerBlock(nn.Module):
|
||||
if self.shift_size > 0:
|
||||
# calculate attention mask for SW-MSA
|
||||
H, W = self.input_resolution
|
||||
img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
|
||||
img_mask = torch.zeros((1, 1, H, W)) # 1 H W 1
|
||||
h_slices = (slice(0, -self.window_size),
|
||||
slice(-self.window_size, -self.shift_size),
|
||||
slice(-self.shift_size, None))
|
||||
@ -975,7 +975,7 @@ class PoolFormerBlock(nn.Module):
|
||||
cnt = 0
|
||||
for h in h_slices:
|
||||
for w in w_slices:
|
||||
img_mask[:, h, w, :] = cnt
|
||||
img_mask[:, :, h, w] = cnt
|
||||
cnt += 1
|
||||
|
||||
print(self.input_resolution)
|
||||
|
Loading…
Reference in New Issue
Block a user