try swin
This commit is contained in:
parent
1b816fed50
commit
ba388148d4
@ -808,6 +808,8 @@ class WindowAttention(nn.Module):
|
|||||||
self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
|
self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
|
||||||
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
|
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
|
||||||
relative_position_bias = 16 * torch.sigmoid(relative_position_bias)
|
relative_position_bias = 16 * torch.sigmoid(relative_position_bias)
|
||||||
|
print(attn.shape)
|
||||||
|
print(relative_position_bias.shape)
|
||||||
attn = attn + relative_position_bias.unsqueeze(0)
|
attn = attn + relative_position_bias.unsqueeze(0)
|
||||||
|
|
||||||
if mask is not None:
|
if mask is not None:
|
||||||
|
Loading…
Reference in New Issue
Block a user