diff --git a/models.py b/models.py index a143111..ca11dc6 100644 --- a/models.py +++ b/models.py @@ -965,7 +965,7 @@ class PoolFormerBlock(nn.Module): if self.shift_size > 0: # calculate attention mask for SW-MSA H, W = self.input_resolution - img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 + img_mask = torch.zeros((1, 1, H, W)) # 1 H W 1 h_slices = (slice(0, -self.window_size), slice(-self.window_size, -self.shift_size), slice(-self.shift_size, None)) @@ -975,7 +975,7 @@ class PoolFormerBlock(nn.Module): cnt = 0 for h in h_slices: for w in w_slices: - img_mask[:, h, w, :] = cnt + img_mask[:, :, h, w] = cnt cnt += 1 print(self.input_resolution)