This commit is contained in:
thanhvc3 2024-04-27 11:45:24 +07:00
parent c03e24f4c2
commit c31588cc5f

View File

@ -915,7 +915,7 @@ class PoolFormerBlock(nn.Module):
B, C, H, W = x.shape B, C, H, W = x.shape
x_windows = window_partition(x, self.window_size) x_windows = window_partition(x, self.window_size)
x_windows = x_windows.view(-1, self.window_size * self.window_size, C) x_windows = x_windows.view(-1, self.window_size * self.window_size, C)
attn_windows = self.token_mixer(self.norm1(x_windows), mask=self.attn_mask) attn_windows = self.token_mixer(x_windows, mask=self.attn_mask)
attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
x_attn = window_reverse(attn_windows, self.window_size, H, W) x_attn = window_reverse(attn_windows, self.window_size, H, W)
if self.use_layer_scale: if self.use_layer_scale: