try modify swin
This commit is contained in:
parent
f86e27dab7
commit
805d4fb536
@ -965,7 +965,7 @@ class PoolFormerBlock(nn.Module):
|
|||||||
if self.shift_size > 0:
|
if self.shift_size > 0:
|
||||||
# calculate attention mask for SW-MSA
|
# calculate attention mask for SW-MSA
|
||||||
H, W = self.input_resolution
|
H, W = self.input_resolution
|
||||||
img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
|
img_mask = torch.zeros((1, 1, H, W)) # 1 H W 1
|
||||||
h_slices = (slice(0, -self.window_size),
|
h_slices = (slice(0, -self.window_size),
|
||||||
slice(-self.window_size, -self.shift_size),
|
slice(-self.window_size, -self.shift_size),
|
||||||
slice(-self.shift_size, None))
|
slice(-self.shift_size, None))
|
||||||
@ -975,7 +975,7 @@ class PoolFormerBlock(nn.Module):
|
|||||||
cnt = 0
|
cnt = 0
|
||||||
for h in h_slices:
|
for h in h_slices:
|
||||||
for w in w_slices:
|
for w in w_slices:
|
||||||
img_mask[:, h, w, :] = cnt
|
img_mask[:, :, h, w] = cnt
|
||||||
cnt += 1
|
cnt += 1
|
||||||
|
|
||||||
print(self.input_resolution)
|
print(self.input_resolution)
|
||||||
|
Loading…
Reference in New Issue
Block a user