try gtp vit
This commit is contained in:
		
							
								
								
									
										308
									
								
								models.py
									
									
									
									
									
								
							
							
						
						
									
										308
									
								
								models.py
									
									
									
									
									
								
							@@ -543,6 +543,44 @@ class FouriER(torch.nn.Module):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    def forward_tokens(self, x):
 | 
					    def forward_tokens(self, x):
 | 
				
			||||||
        outs = []
 | 
					        outs = []
 | 
				
			||||||
 | 
					        B, C, H, W = x.shape
 | 
				
			||||||
 | 
					        N = H*W
 | 
				
			||||||
 | 
					        if self.graph_type in ["Semantic", "Mixed"]:
 | 
				
			||||||
 | 
					            # Generate the semantic graph w.r.t. the cosine similarity between tokens
 | 
				
			||||||
 | 
					            # Compute cosine similarity
 | 
				
			||||||
 | 
					            if self.class_token:
 | 
				
			||||||
 | 
					                x_normed = x[:, 1:] / x[:, 1:].norm(dim=-1, keepdim=True)
 | 
				
			||||||
 | 
					            else:
 | 
				
			||||||
 | 
					                x_normed = x / x.norm(dim=-1, keepdim=True)
 | 
				
			||||||
 | 
					            x_cossim = x_normed @ x_normed.transpose(-1, -2)
 | 
				
			||||||
 | 
					            threshold = torch.kthvalue(x_cossim, N-1-self.num_neighbours, dim=-1, keepdim=True)[0] # B,H,1,1 
 | 
				
			||||||
 | 
					            semantic_graph = torch.where(x_cossim>=threshold, 1.0, 0.0)
 | 
				
			||||||
 | 
					            if self.class_token:
 | 
				
			||||||
 | 
					                semantic_graph = semantic_graph - torch.eye(N-1, device=semantic_graph.device).unsqueeze(0)
 | 
				
			||||||
 | 
					            else:
 | 
				
			||||||
 | 
					                semantic_graph = semantic_graph - torch.eye(N, device=semantic_graph.device).unsqueeze(0)
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
 | 
					        if self.graph_type == "None":
 | 
				
			||||||
 | 
					            graph = None
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            if self.graph_type == "Spatial":
 | 
				
			||||||
 | 
					                graph = self.spatial_graph.unsqueeze(0).expand(B,-1,-1)#.to(x.device)
 | 
				
			||||||
 | 
					            elif self.graph_type == "Semantic":
 | 
				
			||||||
 | 
					                graph = semantic_graph
 | 
				
			||||||
 | 
					            elif self.graph_type == "Mixed":
 | 
				
			||||||
 | 
					                # Integrate the spatial graph and semantic graph
 | 
				
			||||||
 | 
					                spatial_graph = self.spatial_graph.unsqueeze(0).expand(B,-1,-1).to(x.device)
 | 
				
			||||||
 | 
					                graph = torch.bitwise_or(semantic_graph.int(), spatial_graph.int()).float()
 | 
				
			||||||
 | 
					            
 | 
				
			||||||
 | 
					            # Symmetrically normalize the graph
 | 
				
			||||||
 | 
					            degree = graph.sum(-1) # B, N
 | 
				
			||||||
 | 
					            degree = torch.diag_embed(degree**(-1/2))
 | 
				
			||||||
 | 
					            graph = degree @ graph @ degree
 | 
				
			||||||
 | 
					            
 | 
				
			||||||
 | 
					        if self.token_scale:
 | 
				
			||||||
 | 
					            token_scales = self.token_scales.unsqueeze(0).expand(B,-1).to(x.device)
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            token_scales = None
 | 
				
			||||||
        for idx, block in enumerate(self.network):
 | 
					        for idx, block in enumerate(self.network):
 | 
				
			||||||
            x = block(x)
 | 
					            x = block(x)
 | 
				
			||||||
        # output only the features of last layer for image classification
 | 
					        # output only the features of last layer for image classification
 | 
				
			||||||
@@ -868,6 +906,270 @@ def window_reverse(windows, window_size, H, W):
 | 
				
			|||||||
    x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, -1, H, W)
 | 
					    x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, -1, H, W)
 | 
				
			||||||
    return x
 | 
					    return x
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def propagate(x: torch.Tensor, weight: torch.Tensor, 
 | 
				
			||||||
 | 
					              index_kept: torch.Tensor, index_prop: torch.Tensor, 
 | 
				
			||||||
 | 
					              standard: str = "None", alpha: Optional[float] = 0, 
 | 
				
			||||||
 | 
					              token_scales: Optional[torch.Tensor] = None,
 | 
				
			||||||
 | 
					              cls_token=True):
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    Propagate tokens based on the selection results.
 | 
				
			||||||
 | 
					    ================================================
 | 
				
			||||||
 | 
					    Args:
 | 
				
			||||||
 | 
					        - x: Tensor([B, N, C]): the feature map of N tokens, including the [CLS] token.
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
 | 
					        - weight: Tensor([B, N-1, N-1]): the weight of each token propagated to the other tokens, 
 | 
				
			||||||
 | 
					                                         excluding the [CLS] token. weight could be a pre-defined 
 | 
				
			||||||
 | 
					                                         graph of the current feature map (by default) or the
 | 
				
			||||||
 | 
					                                         attention map (need to manually modify the Block Module).
 | 
				
			||||||
 | 
					                                         
 | 
				
			||||||
 | 
					        - index_kept: Tensor([B, N-1-num_prop]): the index of kept image tokens in the feature map X
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
 | 
					        - index_prop: Tensor([B, num_prop]): the index of propagated image tokens in the feature map X
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
 | 
					        - standard: str: the method applied to propagate the tokens, including "None", "Mean" and 
 | 
				
			||||||
 | 
					                         "GraphProp"
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
 | 
					        - alpha: float: the coefficient of propagated features
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
 | 
					        - token_scales: Tensor([B, N]): the scale of tokens, including the [CLS] token. token_scales
 | 
				
			||||||
 | 
					                                        is None by default. If it is not None, then token_scales 
 | 
				
			||||||
 | 
					                                        represents the scales of each token and should sum up to N.
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
 | 
					    Return:
 | 
				
			||||||
 | 
					        - x: Tensor([B, N-1-num_prop, C]): the feature map after propagation
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
 | 
					        - weight: Tensor([B, N-1-num_prop, N-1-num_prop]): the graph of feature map after propagation
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
 | 
					        - token_scales: Tensor([B, N-1-num_prop]): the scale of tokens after propagation
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					    B, C, N = x.shape
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					    # Step 1: divide tokens
 | 
				
			||||||
 | 
					    if cls_token:
 | 
				
			||||||
 | 
					        x_cls = x[:, 0:1] # B, 1, C
 | 
				
			||||||
 | 
					    x_kept = x.gather(dim=1, index=index_kept.unsqueeze(-1).expand(-1,-1,C)) # B, N-1-num_prop, C
 | 
				
			||||||
 | 
					    x_prop = x.gather(dim=1, index=index_prop.unsqueeze(-1).expand(-1,-1,C)) # B, num_prop, C
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					    # Step 2: divide token_scales if it is not None
 | 
				
			||||||
 | 
					    if token_scales is not None:
 | 
				
			||||||
 | 
					        if cls_token:
 | 
				
			||||||
 | 
					            token_scales_cls = token_scales[:, 0:1] # B, 1
 | 
				
			||||||
 | 
					        token_scales_kept = token_scales.gather(dim=1, index=index_kept) # B, N-1-num_prop
 | 
				
			||||||
 | 
					        token_scales_prop = token_scales.gather(dim=1, index=index_prop) # B, num_prop
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					    # Step 3: propagate tokens
 | 
				
			||||||
 | 
					    if standard == "None":
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        No further propagation
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        pass
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
 | 
					    elif standard == "Mean":
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        Calculate the mean of all the propagated tokens,
 | 
				
			||||||
 | 
					        and concatenate the result token back to kept tokens.
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        # naive average
 | 
				
			||||||
 | 
					        x_prop = x_prop.mean(1, keepdim=True) # B, 1, C
 | 
				
			||||||
 | 
					        # Concatenate the average token 
 | 
				
			||||||
 | 
					        x_kept = torch.cat((x_kept, x_prop), dim=1) # B, N-num_prop, C
 | 
				
			||||||
 | 
					            
 | 
				
			||||||
 | 
					    elif standard == "GraphProp":
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        Propagate all the propagated token to kept token
 | 
				
			||||||
 | 
					        with respect to the weights and token scales.
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        assert weight is not None, "The graph weight is needed for graph propagation"
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
 | 
					        # Step 3.1: divide propagation weights.
 | 
				
			||||||
 | 
					        if cls_token:
 | 
				
			||||||
 | 
					            index_kept = index_kept - 1 # since weights do not include the [CLS] token
 | 
				
			||||||
 | 
					            index_prop = index_prop - 1 # since weights do not include the [CLS] token
 | 
				
			||||||
 | 
					            weight = weight.gather(dim=1, index=index_kept.unsqueeze(-1).expand(-1,-1,N-1)) # B, N-1-num_prop, N-1
 | 
				
			||||||
 | 
					            weight_prop = weight.gather(dim=2, index=index_prop.unsqueeze(1).expand(-1,weight.shape[1],-1)) # B, N-1-num_prop, num_prop
 | 
				
			||||||
 | 
					            weight = weight.gather(dim=2, index=index_kept.unsqueeze(1).expand(-1,weight.shape[1],-1)) # B, N-1-num_prop, N-1-num_prop
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            weight = weight.gather(dim=1, index=index_kept.unsqueeze(-1).expand(-1,-1,N)) # B, N-1-num_prop, N-1
 | 
				
			||||||
 | 
					            weight_prop = weight.gather(dim=2, index=index_prop.unsqueeze(1).expand(-1,weight.shape[1],-1)) # B, N-1-num_prop, num_prop
 | 
				
			||||||
 | 
					            weight = weight.gather(dim=2, index=index_kept.unsqueeze(1).expand(-1,weight.shape[1],-1)) # B, N-1-num_prop, N-1-num_prop
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
 | 
					        # Step 3.2: generate the broadcast message and propagate the message to corresponding kept tokens
 | 
				
			||||||
 | 
					        # Simple implementation
 | 
				
			||||||
 | 
					        x_prop = weight_prop @ x_prop # B, N-1-num_prop, C
 | 
				
			||||||
 | 
					        x_kept = x_kept + alpha * x_prop # B, N-1-num_prop, C
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
 | 
					        """ scatter_reduce implementation for batched inputs
 | 
				
			||||||
 | 
					        # Get the non-zero values
 | 
				
			||||||
 | 
					        non_zero_indices = torch.nonzero(weight_prop, as_tuple=True)
 | 
				
			||||||
 | 
					        non_zero_values = weight_prop[non_zero_indices]
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
 | 
					        # Sparse multiplication
 | 
				
			||||||
 | 
					        batch_indices, row_indices, col_indices = non_zero_indices
 | 
				
			||||||
 | 
					        sparse_matmul = alpha * non_zero_values[:, None] * x_prop[batch_indices, col_indices, :]
 | 
				
			||||||
 | 
					        reduce_indices = batch_indices * x_kept.shape[1] + row_indices
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
 | 
					        x_kept = x_kept.reshape(-1, C).scatter_reduce(dim=0, 
 | 
				
			||||||
 | 
					                                                      index=reduce_indices[:, None], 
 | 
				
			||||||
 | 
					                                                      src=sparse_matmul, 
 | 
				
			||||||
 | 
					                                                      reduce="sum",
 | 
				
			||||||
 | 
					                                                      include_self=True)
 | 
				
			||||||
 | 
					        x_kept = x_kept.reshape(B, -1, C)
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
 | 
					        # Step 3.3: calculate the scale of each token if token_scales is not None
 | 
				
			||||||
 | 
					        if token_scales is not None:
 | 
				
			||||||
 | 
					            if cls_token:
 | 
				
			||||||
 | 
					                token_scales_cls = token_scales[:, 0:1] # B, 1
 | 
				
			||||||
 | 
					                token_scales = token_scales[:, 1:]
 | 
				
			||||||
 | 
					            token_scales_kept = token_scales.gather(dim=1, index=index_kept) # B, N-1-num_prop
 | 
				
			||||||
 | 
					            token_scales_prop = token_scales.gather(dim=1, index=index_prop) # B, num_prop
 | 
				
			||||||
 | 
					            token_scales_prop = weight_prop @ token_scales_prop.unsqueeze(-1) # B, N-1-num_prop, 1
 | 
				
			||||||
 | 
					            token_scales = token_scales_kept + alpha * token_scales_prop.squeeze(-1) # B, N-1-num_prop
 | 
				
			||||||
 | 
					            if cls_token:
 | 
				
			||||||
 | 
					                token_scales = torch.cat((token_scales_cls, token_scales), dim=1) # B, N-num_prop
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        assert False, "Propagation method \'%f\' has not been supported yet." % standard
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					    if cls_token:
 | 
				
			||||||
 | 
					        # Step 4: concatenate the [CLS] token and generate returned value
 | 
				
			||||||
 | 
					        x = torch.cat((x_cls, x_kept), dim=1) # B, N-num_prop, C
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        x = x_kept
 | 
				
			||||||
 | 
					    return x, weight, token_scales
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def select(weight: torch.Tensor, standard: str = "None", num_prop: int = 0, cls_token = True):
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    Select image tokens to be propagated. The [CLS] token will be ignored. 
 | 
				
			||||||
 | 
					    ======================================================================
 | 
				
			||||||
 | 
					    Args:
 | 
				
			||||||
 | 
					        - weight: Tensor([B, H, N, N]): used for selecting the kept tokens. Only support the
 | 
				
			||||||
 | 
					                                        attention map of tokens at the moment.
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
 | 
					        - standard: str: the method applied to select the tokens
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
 | 
					        - num_prop: int: the number of tokens to be propagated
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
 | 
					    Return:
 | 
				
			||||||
 | 
					        - index_kept: Tensor([B, N-1-num_prop]): the index of kept tokens 
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
 | 
					        - index_prop: Tensor([B, num_prop]): the index of propagated tokens
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					    assert len(weight.shape) == 4, "Selection methods on tensors other than the attention map haven't been supported yet."
 | 
				
			||||||
 | 
					    B, H, N1, N2 = weight.shape
 | 
				
			||||||
 | 
					    assert N1 == N2, "Selection methods on tensors other than the attention map haven't been supported yet."
 | 
				
			||||||
 | 
					    N = N1
 | 
				
			||||||
 | 
					    assert num_prop >= 0, "The number of propagated/pruned tokens must be non-negative."
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					    if cls_token:
 | 
				
			||||||
 | 
					        if standard == "CLSAttnMean":
 | 
				
			||||||
 | 
					            token_rank = weight[:,:,0,1:].mean(1)
 | 
				
			||||||
 | 
					            
 | 
				
			||||||
 | 
					        elif standard == "CLSAttnMax":
 | 
				
			||||||
 | 
					            token_rank = weight[:,:,0,1:].max(1)[0]
 | 
				
			||||||
 | 
					                
 | 
				
			||||||
 | 
					        elif standard == "IMGAttnMean":
 | 
				
			||||||
 | 
					            token_rank = weight[:,:,:,1:].sum(-2).mean(1)
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
 | 
					        elif standard == "IMGAttnMax":
 | 
				
			||||||
 | 
					            token_rank = weight[:,:,:,1:].sum(-2).max(1)[0]
 | 
				
			||||||
 | 
					                
 | 
				
			||||||
 | 
					        elif standard == "DiagAttnMean":
 | 
				
			||||||
 | 
					            token_rank = torch.diagonal(weight, dim1=-2, dim2=-1)[:,:,1:].mean(1)
 | 
				
			||||||
 | 
					            
 | 
				
			||||||
 | 
					        elif standard == "DiagAttnMax":
 | 
				
			||||||
 | 
					            token_rank = torch.diagonal(weight, dim1=-2, dim2=-1)[:,:,1:].max(1)[0]
 | 
				
			||||||
 | 
					            
 | 
				
			||||||
 | 
					        elif standard == "MixedAttnMean":
 | 
				
			||||||
 | 
					            token_rank_1 = torch.diagonal(weight, dim1=-2, dim2=-1)[:,:,1:].mean(1)
 | 
				
			||||||
 | 
					            token_rank_2 = weight[:,:,:,1:].sum(-2).mean(1)
 | 
				
			||||||
 | 
					            token_rank = token_rank_1 * token_rank_2
 | 
				
			||||||
 | 
					            
 | 
				
			||||||
 | 
					        elif standard == "MixedAttnMax":
 | 
				
			||||||
 | 
					            token_rank_1 = torch.diagonal(weight, dim1=-2, dim2=-1)[:,:,1:].max(1)[0]
 | 
				
			||||||
 | 
					            token_rank_2 = weight[:,:,:,1:].sum(-2).max(1)[0]
 | 
				
			||||||
 | 
					            token_rank = token_rank_1 * token_rank_2
 | 
				
			||||||
 | 
					            
 | 
				
			||||||
 | 
					        elif standard == "SumAttnMax":
 | 
				
			||||||
 | 
					            token_rank_1 = torch.diagonal(weight, dim1=-2, dim2=-1)[:,:,1:].max(1)[0]
 | 
				
			||||||
 | 
					            token_rank_2 = weight[:,:,:,1:].sum(-2).max(1)[0]
 | 
				
			||||||
 | 
					            token_rank = token_rank_1 + token_rank_2
 | 
				
			||||||
 | 
					            
 | 
				
			||||||
 | 
					        elif standard == "CosSimMean":
 | 
				
			||||||
 | 
					            weight = weight[:,:,1:,:].mean(1)
 | 
				
			||||||
 | 
					            weight = weight / weight.norm(dim=-1, keepdim=True)
 | 
				
			||||||
 | 
					            token_rank = -(weight @ weight.transpose(-1, -2)).sum(-1)
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
 | 
					        elif standard == "CosSimMax":
 | 
				
			||||||
 | 
					            weight = weight[:,:,1:,:].max(1)[0]
 | 
				
			||||||
 | 
					            weight = weight / weight.norm(dim=-1, keepdim=True)
 | 
				
			||||||
 | 
					            token_rank = -(weight @ weight.transpose(-1, -2)).sum(-1)
 | 
				
			||||||
 | 
					            
 | 
				
			||||||
 | 
					        elif standard == "Random":
 | 
				
			||||||
 | 
					            token_rank = torch.randn((B, N-1), device=weight.device)
 | 
				
			||||||
 | 
					                
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            print("Type\'", standard, "\' selection not supported.")
 | 
				
			||||||
 | 
					            assert False
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
 | 
					        token_rank = torch.argsort(token_rank, dim=1, descending=True) # B, N-1
 | 
				
			||||||
 | 
					        index_kept = token_rank[:, :-num_prop]+1 # B, N-1-num_prop
 | 
				
			||||||
 | 
					        index_prop = token_rank[:, -num_prop:]+1 # B, num_prop
 | 
				
			||||||
 | 
					            
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        if standard == "IMGAttnMean":
 | 
				
			||||||
 | 
					            token_rank = weight.sum(-2).mean(1)
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
 | 
					        elif standard == "IMGAttnMax":
 | 
				
			||||||
 | 
					            token_rank = weight.sum(-2).max(1)[0]
 | 
				
			||||||
 | 
					                
 | 
				
			||||||
 | 
					        elif standard == "DiagAttnMean":
 | 
				
			||||||
 | 
					            token_rank = torch.diagonal(weight, dim1=-2, dim2=-1).mean(1)
 | 
				
			||||||
 | 
					            
 | 
				
			||||||
 | 
					        elif standard == "DiagAttnMax":
 | 
				
			||||||
 | 
					            token_rank = torch.diagonal(weight, dim1=-2, dim2=-1).max(1)[0]
 | 
				
			||||||
 | 
					            
 | 
				
			||||||
 | 
					        elif standard == "MixedAttnMean":
 | 
				
			||||||
 | 
					            token_rank_1 = torch.diagonal(weight, dim1=-2, dim2=-1).mean(1)
 | 
				
			||||||
 | 
					            token_rank_2 = weight.sum(-2).mean(1)
 | 
				
			||||||
 | 
					            token_rank = token_rank_1 * token_rank_2
 | 
				
			||||||
 | 
					            
 | 
				
			||||||
 | 
					        elif standard == "MixedAttnMax":
 | 
				
			||||||
 | 
					            token_rank_1 = torch.diagonal(weight, dim1=-2, dim2=-1).max(1)[0]
 | 
				
			||||||
 | 
					            token_rank_2 = weight.sum(-2).max(1)[0]
 | 
				
			||||||
 | 
					            token_rank = token_rank_1 * token_rank_2
 | 
				
			||||||
 | 
					            
 | 
				
			||||||
 | 
					        elif standard == "SumAttnMax":
 | 
				
			||||||
 | 
					            token_rank_1 = torch.diagonal(weight, dim1=-2, dim2=-1).max(1)[0]
 | 
				
			||||||
 | 
					            token_rank_2 = weight.sum(-2).max(1)[0]
 | 
				
			||||||
 | 
					            token_rank = token_rank_1 + token_rank_2
 | 
				
			||||||
 | 
					            
 | 
				
			||||||
 | 
					        elif standard == "CosSimMean":
 | 
				
			||||||
 | 
					            weight = weight.mean(1)
 | 
				
			||||||
 | 
					            weight = weight / weight.norm(dim=-1, keepdim=True)
 | 
				
			||||||
 | 
					            token_rank = -(weight @ weight.transpose(-1, -2)).sum(-1)
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
 | 
					        elif standard == "CosSimMax":
 | 
				
			||||||
 | 
					            weight = weight.max(1)[0]
 | 
				
			||||||
 | 
					            weight = weight / weight.norm(dim=-1, keepdim=True)
 | 
				
			||||||
 | 
					            token_rank = -(weight @ weight.transpose(-1, -2)).sum(-1)
 | 
				
			||||||
 | 
					            
 | 
				
			||||||
 | 
					        elif standard == "Random":
 | 
				
			||||||
 | 
					            token_rank = torch.randn((B, N-1), device=weight.device)
 | 
				
			||||||
 | 
					                
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            print("Type\'", standard, "\' selection not supported.")
 | 
				
			||||||
 | 
					            assert False
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
 | 
					        token_rank = torch.argsort(token_rank, dim=1, descending=True) # B, N-1
 | 
				
			||||||
 | 
					        index_kept = token_rank[:, :-num_prop] # B, N-1-num_prop
 | 
				
			||||||
 | 
					        index_prop = token_rank[:, -num_prop:] # B, num_prop
 | 
				
			||||||
 | 
					    return index_kept, index_prop
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class PoolFormerBlock(nn.Module):
 | 
					class PoolFormerBlock(nn.Module):
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
    Implementation of one PoolFormer block.
 | 
					    Implementation of one PoolFormer block.
 | 
				
			||||||
@@ -910,13 +1212,17 @@ class PoolFormerBlock(nn.Module):
 | 
				
			|||||||
            self.layer_scale_2 = nn.Parameter(
 | 
					            self.layer_scale_2 = nn.Parameter(
 | 
				
			||||||
                layer_scale_init_value * torch.ones((dim)), requires_grad=True)
 | 
					                layer_scale_init_value * torch.ones((dim)), requires_grad=True)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def forward(self, x):
 | 
					    def forward(self, x, graph):
 | 
				
			||||||
        B, C, H, W = x.shape
 | 
					        B, C, H, W = x.shape
 | 
				
			||||||
        x_windows = window_partition(x, self.window_size)
 | 
					        x_windows = window_partition(x, self.window_size)
 | 
				
			||||||
        x_windows = x_windows.view(-1, self.window_size * self.window_size, C)
 | 
					        x_windows = x_windows.view(-1, self.window_size * self.window_size, C)
 | 
				
			||||||
        attn_windows = self.token_mixer(x_windows, mask=self.attn_mask)
 | 
					        attn_windows = self.token_mixer(x_windows, mask=self.attn_mask)
 | 
				
			||||||
        attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
 | 
					        attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
 | 
				
			||||||
        x_attn = window_reverse(attn_windows, self.window_size, H, W)
 | 
					        x_attn = window_reverse(attn_windows, self.window_size, H, W)
 | 
				
			||||||
 | 
					        index_kept, index_prop = select(x_attn, standard="MixedAttnMax", num_prop=0,
 | 
				
			||||||
 | 
					                                            cls_token=False)
 | 
				
			||||||
 | 
					        x, weight, token_scales = propagate(x, weight, index_kept, index_prop, standard="GraphProp",
 | 
				
			||||||
 | 
					                                        alpha=0.1, token_scales=token_scales, cls_token=False)
 | 
				
			||||||
        if self.use_layer_scale:
 | 
					        if self.use_layer_scale:
 | 
				
			||||||
            x = x + self.drop_path(
 | 
					            x = x + self.drop_path(
 | 
				
			||||||
                self.layer_scale_1.unsqueeze(-1).unsqueeze(-1)
 | 
					                self.layer_scale_1.unsqueeze(-1).unsqueeze(-1)
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user