try gtp vit
This commit is contained in:
		
							
								
								
									
										16
									
								
								models.py
									
									
									
									
									
								
							
							
						
						
									
										16
									
								
								models.py
									
									
									
									
									
								
							@@ -11,6 +11,7 @@ from timm.models.layers import DropPath, trunc_normal_
 | 
			
		||||
from timm.models.registry import register_model
 | 
			
		||||
from timm.layers.helpers import to_2tuple
 | 
			
		||||
from typing import *
 | 
			
		||||
import math
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ConvE(torch.nn.Module):
 | 
			
		||||
@@ -528,6 +529,19 @@ class FouriER(torch.nn.Module):
 | 
			
		||||
        self.network = nn.ModuleList(network)
 | 
			
		||||
        self.norm = norm_layer(embed_dims[-1])
 | 
			
		||||
        self.graph_type = 'Spatial'
 | 
			
		||||
        N = (image_h // patch_size)**2
 | 
			
		||||
        if self.graph_type in ["Spatial", "Mixed"]:
 | 
			
		||||
            # Create a range tensor of node indices
 | 
			
		||||
            indices = torch.arange(N)
 | 
			
		||||
            # Reshape the indices tensor to create a grid of row and column indices
 | 
			
		||||
            row_indices = indices.view(-1, 1).expand(-1, N)
 | 
			
		||||
            col_indices = indices.view(1, -1).expand(N, -1)
 | 
			
		||||
            # Compute the adjacency matrix
 | 
			
		||||
            row1, col1 = row_indices // int(math.sqrt(N)), row_indices % int(math.sqrt(N))
 | 
			
		||||
            row2, col2 = col_indices // int(math.sqrt(N)), col_indices % int(math.sqrt(N))
 | 
			
		||||
            graph = ((abs(row1 - row2) <= 1).float() * (abs(col1 - col2) <= 1).float())
 | 
			
		||||
            graph = graph - torch.eye(N)
 | 
			
		||||
            self.spatial_graph = graph.to("cuda") # comment .to("cuda") if the environment is cpu
 | 
			
		||||
        self.class_token = False
 | 
			
		||||
        self.token_scale = False
 | 
			
		||||
        self.head = nn.Linear(
 | 
			
		||||
@@ -580,7 +594,7 @@ class FouriER(torch.nn.Module):
 | 
			
		||||
            degree = graph.sum(-1) # B, N
 | 
			
		||||
            degree = torch.diag_embed(degree**(-1/2))
 | 
			
		||||
            graph = degree @ graph @ degree
 | 
			
		||||
            
 | 
			
		||||
 | 
			
		||||
        for idx, block in enumerate(self.network):
 | 
			
		||||
            x = block(x)
 | 
			
		||||
        # output only the features of last layer for image classification
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user