This commit is contained in:
thanhvc3 2024-04-27 11:52:23 +07:00
parent ae0f43ab4d
commit f8e969cbd1

View File

@ -818,8 +818,6 @@ class WindowAttention(nn.Module):
self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
relative_position_bias = 16 * torch.sigmoid(relative_position_bias) relative_position_bias = 16 * torch.sigmoid(relative_position_bias)
print(attn.shape)
print(relative_position_bias.shape)
attn = attn + relative_position_bias.unsqueeze(0) attn = attn + relative_position_bias.unsqueeze(0)
if mask is not None: if mask is not None: