try gtp vit
This commit is contained in:
		@@ -969,7 +969,7 @@ def propagate(x: torch.Tensor, weight: torch.Tensor,
 | 
				
			|||||||
        - token_scales: Tensor([B, N-1-num_prop]): the scale of tokens after propagation
 | 
					        - token_scales: Tensor([B, N-1-num_prop]): the scale of tokens after propagation
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
    
 | 
					    
 | 
				
			||||||
    B, C, N = x.shape
 | 
					    B, N, C = x.shape
 | 
				
			||||||
    
 | 
					    
 | 
				
			||||||
    # Step 1: divide tokens
 | 
					    # Step 1: divide tokens
 | 
				
			||||||
    if cls_token:
 | 
					    if cls_token:
 | 
				
			||||||
@@ -1247,8 +1247,11 @@ class PoolFormerBlock(nn.Module):
 | 
				
			|||||||
        x_attn = window_reverse(attn_windows, self.window_size, H, W)
 | 
					        x_attn = window_reverse(attn_windows, self.window_size, H, W)
 | 
				
			||||||
        index_kept, index_prop = select(x_attn, standard="MixedAttnMax", num_prop=0,
 | 
					        index_kept, index_prop = select(x_attn, standard="MixedAttnMax", num_prop=0,
 | 
				
			||||||
                                            cls_token=False)
 | 
					                                            cls_token=False)
 | 
				
			||||||
        x, weight, token_scales = propagate(x, weight, index_kept, index_prop, standard="GraphProp",
 | 
					        original_shape = x.attn.shape
 | 
				
			||||||
 | 
					        x_attn = x_attn.view(-1, self.window_size * self.window_size, C)
 | 
				
			||||||
 | 
					        x_attn, weight, token_scales = propagate(x, weight, index_kept, index_prop, standard="GraphProp",
 | 
				
			||||||
                                        alpha=0.1, token_scales=token_scales, cls_token=False)
 | 
					                                        alpha=0.1, token_scales=token_scales, cls_token=False)
 | 
				
			||||||
 | 
					        x_attn = x_attn.view(*original_shape)
 | 
				
			||||||
        if self.use_layer_scale:
 | 
					        if self.use_layer_scale:
 | 
				
			||||||
            x = x + self.drop_path(
 | 
					            x = x + self.drop_path(
 | 
				
			||||||
                self.layer_scale_1.unsqueeze(-1).unsqueeze(-1)
 | 
					                self.layer_scale_1.unsqueeze(-1).unsqueeze(-1)
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user