try modify swin

This commit is contained in:
thanhvc3 2024-04-29 17:28:55 +07:00
parent 8866ea448e
commit a8ac4d1b3f

View File

@ -872,9 +872,12 @@ class WindowAttention(nn.Module):
attn = attn + relative_position_bias.unsqueeze(0) attn = attn + relative_position_bias.unsqueeze(0)
if mask is not None: if mask is not None:
try:
nW = mask.shape[0] nW = mask.shape[0]
attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
attn = attn.view(-1, self.num_heads, N, N) attn = attn.view(-1, self.num_heads, N, N)
except:
pass
attn = self.softmax(attn) attn = self.softmax(attn)
else: else:
attn = self.softmax(attn) attn = self.softmax(attn)