try modify swin
This commit is contained in:
parent
8866ea448e
commit
a8ac4d1b3f
@ -872,9 +872,12 @@ class WindowAttention(nn.Module):
|
|||||||
attn = attn + relative_position_bias.unsqueeze(0)
|
attn = attn + relative_position_bias.unsqueeze(0)
|
||||||
|
|
||||||
if mask is not None:
|
if mask is not None:
|
||||||
|
try:
|
||||||
nW = mask.shape[0]
|
nW = mask.shape[0]
|
||||||
attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
|
attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
|
||||||
attn = attn.view(-1, self.num_heads, N, N)
|
attn = attn.view(-1, self.num_heads, N, N)
|
||||||
|
except:
|
||||||
|
pass
|
||||||
attn = self.softmax(attn)
|
attn = self.softmax(attn)
|
||||||
else:
|
else:
|
||||||
attn = self.softmax(attn)
|
attn = self.softmax(attn)
|
||||||
|
Loading…
Reference in New Issue
Block a user