diff --git a/models.py b/models.py index aac4772..bc93c7f 100644 --- a/models.py +++ b/models.py @@ -800,6 +800,7 @@ class WindowAttention(nn.Module): x: input features with shape of (num_windows*B, N, C) mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None """ + B_, N, C = x.shape qkv_bias = None if self.q_bias is not None: qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias))