From 41a5c7b05a43802ae647ea0896e6e7a3767d750d Mon Sep 17 00:00:00 2001 From: thanhvc3 Date: Sun, 28 Apr 2024 11:57:17 +0700 Subject: [PATCH] try gtp vit --- models.py | 308 +++++++++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 307 insertions(+), 1 deletion(-) diff --git a/models.py b/models.py index b262ffd..947ce5a 100644 --- a/models.py +++ b/models.py @@ -543,6 +543,44 @@ class FouriER(torch.nn.Module): def forward_tokens(self, x): outs = [] + B, C, H, W = x.shape + N = H*W + if self.graph_type in ["Semantic", "Mixed"]: + # Generate the semantic graph w.r.t. the cosine similarity between tokens + # Compute cosine similarity + if self.class_token: + x_normed = x[:, 1:] / x[:, 1:].norm(dim=-1, keepdim=True) + else: + x_normed = x / x.norm(dim=-1, keepdim=True) + x_cossim = x_normed @ x_normed.transpose(-1, -2) + threshold = torch.kthvalue(x_cossim, N-1-self.num_neighbours, dim=-1, keepdim=True)[0] # B,H,1,1 + semantic_graph = torch.where(x_cossim>=threshold, 1.0, 0.0) + if self.class_token: + semantic_graph = semantic_graph - torch.eye(N-1, device=semantic_graph.device).unsqueeze(0) + else: + semantic_graph = semantic_graph - torch.eye(N, device=semantic_graph.device).unsqueeze(0) + + if self.graph_type == "None": + graph = None + else: + if self.graph_type == "Spatial": + graph = self.spatial_graph.unsqueeze(0).expand(B,-1,-1)#.to(x.device) + elif self.graph_type == "Semantic": + graph = semantic_graph + elif self.graph_type == "Mixed": + # Integrate the spatial graph and semantic graph + spatial_graph = self.spatial_graph.unsqueeze(0).expand(B,-1,-1).to(x.device) + graph = torch.bitwise_or(semantic_graph.int(), spatial_graph.int()).float() + + # Symmetrically normalize the graph + degree = graph.sum(-1) # B, N + degree = torch.diag_embed(degree**(-1/2)) + graph = degree @ graph @ degree + + if self.token_scale: + token_scales = self.token_scales.unsqueeze(0).expand(B,-1).to(x.device) + else: + token_scales = None for idx, block in enumerate(self.network): x = block(x) # output only the features of last layer for image classification @@ -868,6 +906,270 @@ def window_reverse(windows, window_size, H, W): x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, -1, H, W) return x +def propagate(x: torch.Tensor, weight: torch.Tensor, + index_kept: torch.Tensor, index_prop: torch.Tensor, + standard: str = "None", alpha: Optional[float] = 0, + token_scales: Optional[torch.Tensor] = None, + cls_token=True): + """ + Propagate tokens based on the selection results. + ================================================ + Args: + - x: Tensor([B, N, C]): the feature map of N tokens, including the [CLS] token. + + - weight: Tensor([B, N-1, N-1]): the weight of each token propagated to the other tokens, + excluding the [CLS] token. weight could be a pre-defined + graph of the current feature map (by default) or the + attention map (need to manually modify the Block Module). + + - index_kept: Tensor([B, N-1-num_prop]): the index of kept image tokens in the feature map X + + - index_prop: Tensor([B, num_prop]): the index of propagated image tokens in the feature map X + + - standard: str: the method applied to propagate the tokens, including "None", "Mean" and + "GraphProp" + + - alpha: float: the coefficient of propagated features + + - token_scales: Tensor([B, N]): the scale of tokens, including the [CLS] token. token_scales + is None by default. If it is not None, then token_scales + represents the scales of each token and should sum up to N. + + Return: + - x: Tensor([B, N-1-num_prop, C]): the feature map after propagation + + - weight: Tensor([B, N-1-num_prop, N-1-num_prop]): the graph of feature map after propagation + + - token_scales: Tensor([B, N-1-num_prop]): the scale of tokens after propagation + """ + + B, C, N = x.shape + + # Step 1: divide tokens + if cls_token: + x_cls = x[:, 0:1] # B, 1, C + x_kept = x.gather(dim=1, index=index_kept.unsqueeze(-1).expand(-1,-1,C)) # B, N-1-num_prop, C + x_prop = x.gather(dim=1, index=index_prop.unsqueeze(-1).expand(-1,-1,C)) # B, num_prop, C + + # Step 2: divide token_scales if it is not None + if token_scales is not None: + if cls_token: + token_scales_cls = token_scales[:, 0:1] # B, 1 + token_scales_kept = token_scales.gather(dim=1, index=index_kept) # B, N-1-num_prop + token_scales_prop = token_scales.gather(dim=1, index=index_prop) # B, num_prop + + # Step 3: propagate tokens + if standard == "None": + """ + No further propagation + """ + pass + + elif standard == "Mean": + """ + Calculate the mean of all the propagated tokens, + and concatenate the result token back to kept tokens. + """ + # naive average + x_prop = x_prop.mean(1, keepdim=True) # B, 1, C + # Concatenate the average token + x_kept = torch.cat((x_kept, x_prop), dim=1) # B, N-num_prop, C + + elif standard == "GraphProp": + """ + Propagate all the propagated token to kept token + with respect to the weights and token scales. + """ + assert weight is not None, "The graph weight is needed for graph propagation" + + # Step 3.1: divide propagation weights. + if cls_token: + index_kept = index_kept - 1 # since weights do not include the [CLS] token + index_prop = index_prop - 1 # since weights do not include the [CLS] token + weight = weight.gather(dim=1, index=index_kept.unsqueeze(-1).expand(-1,-1,N-1)) # B, N-1-num_prop, N-1 + weight_prop = weight.gather(dim=2, index=index_prop.unsqueeze(1).expand(-1,weight.shape[1],-1)) # B, N-1-num_prop, num_prop + weight = weight.gather(dim=2, index=index_kept.unsqueeze(1).expand(-1,weight.shape[1],-1)) # B, N-1-num_prop, N-1-num_prop + else: + weight = weight.gather(dim=1, index=index_kept.unsqueeze(-1).expand(-1,-1,N)) # B, N-1-num_prop, N-1 + weight_prop = weight.gather(dim=2, index=index_prop.unsqueeze(1).expand(-1,weight.shape[1],-1)) # B, N-1-num_prop, num_prop + weight = weight.gather(dim=2, index=index_kept.unsqueeze(1).expand(-1,weight.shape[1],-1)) # B, N-1-num_prop, N-1-num_prop + + # Step 3.2: generate the broadcast message and propagate the message to corresponding kept tokens + # Simple implementation + x_prop = weight_prop @ x_prop # B, N-1-num_prop, C + x_kept = x_kept + alpha * x_prop # B, N-1-num_prop, C + + """ scatter_reduce implementation for batched inputs + # Get the non-zero values + non_zero_indices = torch.nonzero(weight_prop, as_tuple=True) + non_zero_values = weight_prop[non_zero_indices] + + # Sparse multiplication + batch_indices, row_indices, col_indices = non_zero_indices + sparse_matmul = alpha * non_zero_values[:, None] * x_prop[batch_indices, col_indices, :] + reduce_indices = batch_indices * x_kept.shape[1] + row_indices + + x_kept = x_kept.reshape(-1, C).scatter_reduce(dim=0, + index=reduce_indices[:, None], + src=sparse_matmul, + reduce="sum", + include_self=True) + x_kept = x_kept.reshape(B, -1, C) + """ + + # Step 3.3: calculate the scale of each token if token_scales is not None + if token_scales is not None: + if cls_token: + token_scales_cls = token_scales[:, 0:1] # B, 1 + token_scales = token_scales[:, 1:] + token_scales_kept = token_scales.gather(dim=1, index=index_kept) # B, N-1-num_prop + token_scales_prop = token_scales.gather(dim=1, index=index_prop) # B, num_prop + token_scales_prop = weight_prop @ token_scales_prop.unsqueeze(-1) # B, N-1-num_prop, 1 + token_scales = token_scales_kept + alpha * token_scales_prop.squeeze(-1) # B, N-1-num_prop + if cls_token: + token_scales = torch.cat((token_scales_cls, token_scales), dim=1) # B, N-num_prop + else: + assert False, "Propagation method \'%f\' has not been supported yet." % standard + + + if cls_token: + # Step 4: concatenate the [CLS] token and generate returned value + x = torch.cat((x_cls, x_kept), dim=1) # B, N-num_prop, C + else: + x = x_kept + return x, weight, token_scales + + + +def select(weight: torch.Tensor, standard: str = "None", num_prop: int = 0, cls_token = True): + """ + Select image tokens to be propagated. The [CLS] token will be ignored. + ====================================================================== + Args: + - weight: Tensor([B, H, N, N]): used for selecting the kept tokens. Only support the + attention map of tokens at the moment. + + - standard: str: the method applied to select the tokens + + - num_prop: int: the number of tokens to be propagated + + Return: + - index_kept: Tensor([B, N-1-num_prop]): the index of kept tokens + + - index_prop: Tensor([B, num_prop]): the index of propagated tokens + """ + + assert len(weight.shape) == 4, "Selection methods on tensors other than the attention map haven't been supported yet." + B, H, N1, N2 = weight.shape + assert N1 == N2, "Selection methods on tensors other than the attention map haven't been supported yet." + N = N1 + assert num_prop >= 0, "The number of propagated/pruned tokens must be non-negative." + + if cls_token: + if standard == "CLSAttnMean": + token_rank = weight[:,:,0,1:].mean(1) + + elif standard == "CLSAttnMax": + token_rank = weight[:,:,0,1:].max(1)[0] + + elif standard == "IMGAttnMean": + token_rank = weight[:,:,:,1:].sum(-2).mean(1) + + elif standard == "IMGAttnMax": + token_rank = weight[:,:,:,1:].sum(-2).max(1)[0] + + elif standard == "DiagAttnMean": + token_rank = torch.diagonal(weight, dim1=-2, dim2=-1)[:,:,1:].mean(1) + + elif standard == "DiagAttnMax": + token_rank = torch.diagonal(weight, dim1=-2, dim2=-1)[:,:,1:].max(1)[0] + + elif standard == "MixedAttnMean": + token_rank_1 = torch.diagonal(weight, dim1=-2, dim2=-1)[:,:,1:].mean(1) + token_rank_2 = weight[:,:,:,1:].sum(-2).mean(1) + token_rank = token_rank_1 * token_rank_2 + + elif standard == "MixedAttnMax": + token_rank_1 = torch.diagonal(weight, dim1=-2, dim2=-1)[:,:,1:].max(1)[0] + token_rank_2 = weight[:,:,:,1:].sum(-2).max(1)[0] + token_rank = token_rank_1 * token_rank_2 + + elif standard == "SumAttnMax": + token_rank_1 = torch.diagonal(weight, dim1=-2, dim2=-1)[:,:,1:].max(1)[0] + token_rank_2 = weight[:,:,:,1:].sum(-2).max(1)[0] + token_rank = token_rank_1 + token_rank_2 + + elif standard == "CosSimMean": + weight = weight[:,:,1:,:].mean(1) + weight = weight / weight.norm(dim=-1, keepdim=True) + token_rank = -(weight @ weight.transpose(-1, -2)).sum(-1) + + elif standard == "CosSimMax": + weight = weight[:,:,1:,:].max(1)[0] + weight = weight / weight.norm(dim=-1, keepdim=True) + token_rank = -(weight @ weight.transpose(-1, -2)).sum(-1) + + elif standard == "Random": + token_rank = torch.randn((B, N-1), device=weight.device) + + else: + print("Type\'", standard, "\' selection not supported.") + assert False + + token_rank = torch.argsort(token_rank, dim=1, descending=True) # B, N-1 + index_kept = token_rank[:, :-num_prop]+1 # B, N-1-num_prop + index_prop = token_rank[:, -num_prop:]+1 # B, num_prop + + else: + if standard == "IMGAttnMean": + token_rank = weight.sum(-2).mean(1) + + elif standard == "IMGAttnMax": + token_rank = weight.sum(-2).max(1)[0] + + elif standard == "DiagAttnMean": + token_rank = torch.diagonal(weight, dim1=-2, dim2=-1).mean(1) + + elif standard == "DiagAttnMax": + token_rank = torch.diagonal(weight, dim1=-2, dim2=-1).max(1)[0] + + elif standard == "MixedAttnMean": + token_rank_1 = torch.diagonal(weight, dim1=-2, dim2=-1).mean(1) + token_rank_2 = weight.sum(-2).mean(1) + token_rank = token_rank_1 * token_rank_2 + + elif standard == "MixedAttnMax": + token_rank_1 = torch.diagonal(weight, dim1=-2, dim2=-1).max(1)[0] + token_rank_2 = weight.sum(-2).max(1)[0] + token_rank = token_rank_1 * token_rank_2 + + elif standard == "SumAttnMax": + token_rank_1 = torch.diagonal(weight, dim1=-2, dim2=-1).max(1)[0] + token_rank_2 = weight.sum(-2).max(1)[0] + token_rank = token_rank_1 + token_rank_2 + + elif standard == "CosSimMean": + weight = weight.mean(1) + weight = weight / weight.norm(dim=-1, keepdim=True) + token_rank = -(weight @ weight.transpose(-1, -2)).sum(-1) + + elif standard == "CosSimMax": + weight = weight.max(1)[0] + weight = weight / weight.norm(dim=-1, keepdim=True) + token_rank = -(weight @ weight.transpose(-1, -2)).sum(-1) + + elif standard == "Random": + token_rank = torch.randn((B, N-1), device=weight.device) + + else: + print("Type\'", standard, "\' selection not supported.") + assert False + + token_rank = torch.argsort(token_rank, dim=1, descending=True) # B, N-1 + index_kept = token_rank[:, :-num_prop] # B, N-1-num_prop + index_prop = token_rank[:, -num_prop:] # B, num_prop + return index_kept, index_prop + class PoolFormerBlock(nn.Module): """ Implementation of one PoolFormer block. @@ -910,13 +1212,17 @@ class PoolFormerBlock(nn.Module): self.layer_scale_2 = nn.Parameter( layer_scale_init_value * torch.ones((dim)), requires_grad=True) - def forward(self, x): + def forward(self, x, graph): B, C, H, W = x.shape x_windows = window_partition(x, self.window_size) x_windows = x_windows.view(-1, self.window_size * self.window_size, C) attn_windows = self.token_mixer(x_windows, mask=self.attn_mask) attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) x_attn = window_reverse(attn_windows, self.window_size, H, W) + index_kept, index_prop = select(x_attn, standard="MixedAttnMax", num_prop=0, + cls_token=False) + x, weight, token_scales = propagate(x, weight, index_kept, index_prop, standard="GraphProp", + alpha=0.1, token_scales=token_scales, cls_token=False) if self.use_layer_scale: x = x + self.drop_path( self.layer_scale_1.unsqueeze(-1).unsqueeze(-1)