try gtp vit

This commit is contained in:
thanhvc3 2024-04-28 15:35:24 +07:00
parent 352f5f9da9
commit 3b6db89be1

View File

@ -1249,7 +1249,7 @@ class PoolFormerBlock(nn.Module):
cls_token=False) cls_token=False)
original_shape = x.attn.shape original_shape = x.attn.shape
x_attn = x_attn.view(-1, self.window_size * self.window_size, C) x_attn = x_attn.view(-1, self.window_size * self.window_size, C)
x_attn, weight, token_scales = propagate(x, weight, index_kept, index_prop, standard="GraphProp", x_attn, weight, token_scales = propagate(x_attn, weight, index_kept, index_prop, standard="GraphProp",
alpha=0.1, token_scales=token_scales, cls_token=False) alpha=0.1, token_scales=token_scales, cls_token=False)
x_attn = x_attn.view(*original_shape) x_attn = x_attn.view(*original_shape)
if self.use_layer_scale: if self.use_layer_scale: