diff --git a/models.py b/models.py index 1a89eb3..8bb9df3 100644 --- a/models.py +++ b/models.py @@ -1249,7 +1249,7 @@ class PoolFormerBlock(nn.Module): cls_token=False) original_shape = x.attn.shape x_attn = x_attn.view(-1, self.window_size * self.window_size, C) - x_attn, weight, token_scales = propagate(x, weight, index_kept, index_prop, standard="GraphProp", + x_attn, weight, token_scales = propagate(x_attn, weight, index_kept, index_prop, standard="GraphProp", alpha=0.1, token_scales=token_scales, cls_token=False) x_attn = x_attn.view(*original_shape) if self.use_layer_scale: