From 47bc661a91b7cd512eac1023652cdf4835ca5d31 Mon Sep 17 00:00:00 2001 From: thanhvc3 Date: Sun, 28 Apr 2024 15:40:31 +0700 Subject: [PATCH] try gtp vit --- models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/models.py b/models.py index 8bb9df3..70ed953 100644 --- a/models.py +++ b/models.py @@ -1247,7 +1247,7 @@ class PoolFormerBlock(nn.Module): x_attn = window_reverse(attn_windows, self.window_size, H, W) index_kept, index_prop = select(x_attn, standard="MixedAttnMax", num_prop=0, cls_token=False) - original_shape = x.attn.shape + original_shape = x_attn.shape x_attn = x_attn.view(-1, self.window_size * self.window_size, C) x_attn, weight, token_scales = propagate(x_attn, weight, index_kept, index_prop, standard="GraphProp", alpha=0.1, token_scales=token_scales, cls_token=False)