try modify swin

This commit is contained in:
thanhvc3 2024-04-29 17:13:31 +07:00
parent 65963bf46b
commit f86e27dab7

View File

@ -978,6 +978,7 @@ class PoolFormerBlock(nn.Module):
img_mask[:, h, w, :] = cnt
cnt += 1
print(self.input_resolution)
mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)