This commit is contained in:
thanhvc3 2024-04-27 11:46:32 +07:00
parent c31588cc5f
commit a1bf2d7389

View File

@ -800,6 +800,7 @@ class WindowAttention(nn.Module):
x: input features with shape of (num_windows*B, N, C) x: input features with shape of (num_windows*B, N, C)
mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
""" """
B_, N, C = x.shape
qkv_bias = None qkv_bias = None
if self.q_bias is not None: if self.q_bias is not None:
qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias)) qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias))