try gtp vit
This commit is contained in:
		@@ -527,6 +527,9 @@ class FouriER(torch.nn.Module):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
        self.network = nn.ModuleList(network)
 | 
					        self.network = nn.ModuleList(network)
 | 
				
			||||||
        self.norm = norm_layer(embed_dims[-1])
 | 
					        self.norm = norm_layer(embed_dims[-1])
 | 
				
			||||||
 | 
					        self.graph_type = 'Spatial'
 | 
				
			||||||
 | 
					        self.class_token = False
 | 
				
			||||||
 | 
					        self.token_scale = False
 | 
				
			||||||
        self.head = nn.Linear(
 | 
					        self.head = nn.Linear(
 | 
				
			||||||
                embed_dims[-1], num_classes) if num_classes > 0 \
 | 
					                embed_dims[-1], num_classes) if num_classes > 0 \
 | 
				
			||||||
                else nn.Identity()
 | 
					                else nn.Identity()
 | 
				
			||||||
@@ -578,10 +581,6 @@ class FouriER(torch.nn.Module):
 | 
				
			|||||||
            degree = torch.diag_embed(degree**(-1/2))
 | 
					            degree = torch.diag_embed(degree**(-1/2))
 | 
				
			||||||
            graph = degree @ graph @ degree
 | 
					            graph = degree @ graph @ degree
 | 
				
			||||||
            
 | 
					            
 | 
				
			||||||
        if self.token_scale:
 | 
					 | 
				
			||||||
            token_scales = self.token_scales.unsqueeze(0).expand(B,-1).to(x.device)
 | 
					 | 
				
			||||||
        else:
 | 
					 | 
				
			||||||
            token_scales = None
 | 
					 | 
				
			||||||
        for idx, block in enumerate(self.network):
 | 
					        for idx, block in enumerate(self.network):
 | 
				
			||||||
            x = block(x)
 | 
					            x = block(x)
 | 
				
			||||||
        # output only the features of last layer for image classification
 | 
					        # output only the features of last layer for image classification
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user