From d4ac470c54fd5d1804c27b4a8ad7c78c80fcd012 Mon Sep 17 00:00:00 2001 From: thanhvc3 Date: Sat, 27 Apr 2024 11:07:48 +0700 Subject: [PATCH] try swin --- models.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/models.py b/models.py index 1c15490..0457965 100644 --- a/models.py +++ b/models.py @@ -788,6 +788,8 @@ class WindowAttention(nn.Module): """ print(x.shape) B_, C, N, _ = x.shape + x = x.reshape(B_, C, N * N) + B_, C, N = x.shape qkv_bias = None if self.q_bias is not None: qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias))