try swin
This commit is contained in:
		@@ -808,6 +808,8 @@ class WindowAttention(nn.Module):
 | 
				
			|||||||
            self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1)  # Wh*Ww,Wh*Ww,nH
 | 
					            self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1)  # Wh*Ww,Wh*Ww,nH
 | 
				
			||||||
        relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()  # nH, Wh*Ww, Wh*Ww
 | 
					        relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()  # nH, Wh*Ww, Wh*Ww
 | 
				
			||||||
        relative_position_bias = 16 * torch.sigmoid(relative_position_bias)
 | 
					        relative_position_bias = 16 * torch.sigmoid(relative_position_bias)
 | 
				
			||||||
 | 
					        print(attn.shape)
 | 
				
			||||||
 | 
					        print(relative_position_bias.shape)
 | 
				
			||||||
        attn = attn + relative_position_bias.unsqueeze(0)
 | 
					        attn = attn + relative_position_bias.unsqueeze(0)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if mask is not None:
 | 
					        if mask is not None:
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user