From 3b6db89be19a9d0475dc9f24a81e70de2e6c1519 Mon Sep 17 00:00:00 2001 From: thanhvc3 Date: Sun, 28 Apr 2024 15:35:24 +0700 Subject: [PATCH] try gtp vit --- models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/models.py b/models.py index 1a89eb3..8bb9df3 100644 --- a/models.py +++ b/models.py @@ -1249,7 +1249,7 @@ class PoolFormerBlock(nn.Module): cls_token=False) original_shape = x.attn.shape x_attn = x_attn.view(-1, self.window_size * self.window_size, C) - x_attn, weight, token_scales = propagate(x, weight, index_kept, index_prop, standard="GraphProp", + x_attn, weight, token_scales = propagate(x_attn, weight, index_kept, index_prop, standard="GraphProp", alpha=0.1, token_scales=token_scales, cls_token=False) x_attn = x_attn.view(*original_shape) if self.use_layer_scale: