diff --git a/models.py b/models.py index 0457965..deb2338 100644 --- a/models.py +++ b/models.py @@ -790,6 +790,7 @@ class WindowAttention(nn.Module): B_, C, N, _ = x.shape x = x.reshape(B_, C, N * N) B_, C, N = x.shape + x = x.reshape(B_, N, C) qkv_bias = None if self.q_bias is not None: qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias))