try modify swin
This commit is contained in:
		@@ -872,9 +872,12 @@ class WindowAttention(nn.Module):
 | 
				
			|||||||
        attn = attn + relative_position_bias.unsqueeze(0)
 | 
					        attn = attn + relative_position_bias.unsqueeze(0)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if mask is not None:
 | 
					        if mask is not None:
 | 
				
			||||||
 | 
					            try:
 | 
				
			||||||
                nW = mask.shape[0]
 | 
					                nW = mask.shape[0]
 | 
				
			||||||
                attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
 | 
					                attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
 | 
				
			||||||
                attn = attn.view(-1, self.num_heads, N, N)
 | 
					                attn = attn.view(-1, self.num_heads, N, N)
 | 
				
			||||||
 | 
					            except:
 | 
				
			||||||
 | 
					                pass
 | 
				
			||||||
            attn = self.softmax(attn)
 | 
					            attn = self.softmax(attn)
 | 
				
			||||||
        else:
 | 
					        else:
 | 
				
			||||||
            attn = self.softmax(attn)
 | 
					            attn = self.softmax(attn)
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user