try gtp vit

This commit is contained in:
thanhvc3 2024-04-28 12:05:04 +07:00
parent b01e504874
commit 68a94bd1e2

View File

@ -11,6 +11,7 @@ from timm.models.layers import DropPath, trunc_normal_
from timm.models.registry import register_model
from timm.layers.helpers import to_2tuple
from typing import *
import math
class ConvE(torch.nn.Module):
@ -528,6 +529,19 @@ class FouriER(torch.nn.Module):
self.network = nn.ModuleList(network)
self.norm = norm_layer(embed_dims[-1])
self.graph_type = 'Spatial'
N = (image_h // patch_size)**2
if self.graph_type in ["Spatial", "Mixed"]:
# Create a range tensor of node indices
indices = torch.arange(N)
# Reshape the indices tensor to create a grid of row and column indices
row_indices = indices.view(-1, 1).expand(-1, N)
col_indices = indices.view(1, -1).expand(N, -1)
# Compute the adjacency matrix
row1, col1 = row_indices // int(math.sqrt(N)), row_indices % int(math.sqrt(N))
row2, col2 = col_indices // int(math.sqrt(N)), col_indices % int(math.sqrt(N))
graph = ((abs(row1 - row2) <= 1).float() * (abs(col1 - col2) <= 1).float())
graph = graph - torch.eye(N)
self.spatial_graph = graph.to("cuda") # comment .to("cuda") if the environment is cpu
self.class_token = False
self.token_scale = False
self.head = nn.Linear(
@ -580,7 +594,7 @@ class FouriER(torch.nn.Module):
degree = graph.sum(-1) # B, N
degree = torch.diag_embed(degree**(-1/2))
graph = degree @ graph @ degree
for idx, block in enumerate(self.network):
x = block(x)
# output only the features of last layer for image classification