From a8ac4d1b3f2b1deaf22925182148300f7c64ff36 Mon Sep 17 00:00:00 2001 From: thanhvc3 Date: Mon, 29 Apr 2024 17:28:55 +0700 Subject: [PATCH] try modify swin --- models.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/models.py b/models.py index 642d0b2..09ded1e 100644 --- a/models.py +++ b/models.py @@ -872,9 +872,12 @@ class WindowAttention(nn.Module): attn = attn + relative_position_bias.unsqueeze(0) if mask is not None: - nW = mask.shape[0] - attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) - attn = attn.view(-1, self.num_heads, N, N) + try: + nW = mask.shape[0] + attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, self.num_heads, N, N) + except: + pass attn = self.softmax(attn) else: attn = self.softmax(attn)