From 352f5f9da9c7d4aea1086348200fe644cf0c67eb Mon Sep 17 00:00:00 2001 From: thanhvc3 Date: Sun, 28 Apr 2024 15:31:58 +0700 Subject: [PATCH] try gtp vit --- models.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/models.py b/models.py index 2bdc97b..1a89eb3 100644 --- a/models.py +++ b/models.py @@ -969,7 +969,7 @@ def propagate(x: torch.Tensor, weight: torch.Tensor, - token_scales: Tensor([B, N-1-num_prop]): the scale of tokens after propagation """ - B, C, N = x.shape + B, N, C = x.shape # Step 1: divide tokens if cls_token: @@ -1247,8 +1247,11 @@ class PoolFormerBlock(nn.Module): x_attn = window_reverse(attn_windows, self.window_size, H, W) index_kept, index_prop = select(x_attn, standard="MixedAttnMax", num_prop=0, cls_token=False) - x, weight, token_scales = propagate(x, weight, index_kept, index_prop, standard="GraphProp", + original_shape = x.attn.shape + x_attn = x_attn.view(-1, self.window_size * self.window_size, C) + x_attn, weight, token_scales = propagate(x, weight, index_kept, index_prop, standard="GraphProp", alpha=0.1, token_scales=token_scales, cls_token=False) + x_attn = x_attn.view(*original_shape) if self.use_layer_scale: x = x + self.drop_path( self.layer_scale_1.unsqueeze(-1).unsqueeze(-1)