This commit is contained in:
thanhvc3 2024-04-27 11:08:46 +07:00
parent d4ac470c54
commit 465f98bef8

View File

@ -790,6 +790,7 @@ class WindowAttention(nn.Module):
B_, C, N, _ = x.shape
x = x.reshape(B_, C, N * N)
B_, C, N = x.shape
x = x.reshape(B_, N, C)
qkv_bias = None
if self.q_bias is not None:
qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias))