From 805d4fb53653b28127550949e26cba612105f575 Mon Sep 17 00:00:00 2001 From: thanhvc3 Date: Mon, 29 Apr 2024 17:14:48 +0700 Subject: [PATCH] try modify swin --- models.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/models.py b/models.py index a143111..ca11dc6 100644 --- a/models.py +++ b/models.py @@ -965,7 +965,7 @@ class PoolFormerBlock(nn.Module): if self.shift_size > 0: # calculate attention mask for SW-MSA H, W = self.input_resolution - img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 + img_mask = torch.zeros((1, 1, H, W)) # 1 H W 1 h_slices = (slice(0, -self.window_size), slice(-self.window_size, -self.shift_size), slice(-self.shift_size, None)) @@ -975,7 +975,7 @@ class PoolFormerBlock(nn.Module): cnt = 0 for h in h_slices: for w in w_slices: - img_mask[:, h, w, :] = cnt + img_mask[:, :, h, w] = cnt cnt += 1 print(self.input_resolution)