diff --git a/models.py b/models.py index c2fb1c9..a143111 100644 --- a/models.py +++ b/models.py @@ -978,6 +978,7 @@ class PoolFormerBlock(nn.Module): img_mask[:, h, w, :] = cnt cnt += 1 + print(self.input_resolution) mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 mask_windows = mask_windows.view(-1, self.window_size * self.window_size) attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)