try gtp vit

This commit is contained in:
thanhvc3 2024-04-28 15:31:58 +07:00
parent b9273b6696
commit 352f5f9da9

View File

@ -969,7 +969,7 @@ def propagate(x: torch.Tensor, weight: torch.Tensor,
- token_scales: Tensor([B, N-1-num_prop]): the scale of tokens after propagation - token_scales: Tensor([B, N-1-num_prop]): the scale of tokens after propagation
""" """
B, C, N = x.shape B, N, C = x.shape
# Step 1: divide tokens # Step 1: divide tokens
if cls_token: if cls_token:
@ -1247,8 +1247,11 @@ class PoolFormerBlock(nn.Module):
x_attn = window_reverse(attn_windows, self.window_size, H, W) x_attn = window_reverse(attn_windows, self.window_size, H, W)
index_kept, index_prop = select(x_attn, standard="MixedAttnMax", num_prop=0, index_kept, index_prop = select(x_attn, standard="MixedAttnMax", num_prop=0,
cls_token=False) cls_token=False)
x, weight, token_scales = propagate(x, weight, index_kept, index_prop, standard="GraphProp", original_shape = x.attn.shape
x_attn = x_attn.view(-1, self.window_size * self.window_size, C)
x_attn, weight, token_scales = propagate(x, weight, index_kept, index_prop, standard="GraphProp",
alpha=0.1, token_scales=token_scales, cls_token=False) alpha=0.1, token_scales=token_scales, cls_token=False)
x_attn = x_attn.view(*original_shape)
if self.use_layer_scale: if self.use_layer_scale:
x = x + self.drop_path( x = x + self.drop_path(
self.layer_scale_1.unsqueeze(-1).unsqueeze(-1) self.layer_scale_1.unsqueeze(-1).unsqueeze(-1)