From ba388148d4e71598cce1946cc750ce4a320a0c40 Mon Sep 17 00:00:00 2001 From: thanhvc3 Date: Sat, 27 Apr 2024 11:27:38 +0700 Subject: [PATCH] try swin --- models.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/models.py b/models.py index 9e199fb..8441dea 100644 --- a/models.py +++ b/models.py @@ -808,6 +808,8 @@ class WindowAttention(nn.Module): self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww relative_position_bias = 16 * torch.sigmoid(relative_position_bias) + print(attn.shape) + print(relative_position_bias.shape) attn = attn + relative_position_bias.unsqueeze(0) if mask is not None: