try gtp vit
This commit is contained in:
parent
352f5f9da9
commit
3b6db89be1
@ -1249,7 +1249,7 @@ class PoolFormerBlock(nn.Module):
|
|||||||
cls_token=False)
|
cls_token=False)
|
||||||
original_shape = x.attn.shape
|
original_shape = x.attn.shape
|
||||||
x_attn = x_attn.view(-1, self.window_size * self.window_size, C)
|
x_attn = x_attn.view(-1, self.window_size * self.window_size, C)
|
||||||
x_attn, weight, token_scales = propagate(x, weight, index_kept, index_prop, standard="GraphProp",
|
x_attn, weight, token_scales = propagate(x_attn, weight, index_kept, index_prop, standard="GraphProp",
|
||||||
alpha=0.1, token_scales=token_scales, cls_token=False)
|
alpha=0.1, token_scales=token_scales, cls_token=False)
|
||||||
x_attn = x_attn.view(*original_shape)
|
x_attn = x_attn.view(*original_shape)
|
||||||
if self.use_layer_scale:
|
if self.use_layer_scale:
|
||||||
|
Loading…
Reference in New Issue
Block a user