From 68a94bd1e215bf64c92d6a7198c925a9279eef03 Mon Sep 17 00:00:00 2001 From: thanhvc3 Date: Sun, 28 Apr 2024 12:05:04 +0700 Subject: [PATCH] try gtp vit --- models.py | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/models.py b/models.py index d8347a7..e286a6e 100644 --- a/models.py +++ b/models.py @@ -11,6 +11,7 @@ from timm.models.layers import DropPath, trunc_normal_ from timm.models.registry import register_model from timm.layers.helpers import to_2tuple from typing import * +import math class ConvE(torch.nn.Module): @@ -528,6 +529,19 @@ class FouriER(torch.nn.Module): self.network = nn.ModuleList(network) self.norm = norm_layer(embed_dims[-1]) self.graph_type = 'Spatial' + N = (image_h // patch_size)**2 + if self.graph_type in ["Spatial", "Mixed"]: + # Create a range tensor of node indices + indices = torch.arange(N) + # Reshape the indices tensor to create a grid of row and column indices + row_indices = indices.view(-1, 1).expand(-1, N) + col_indices = indices.view(1, -1).expand(N, -1) + # Compute the adjacency matrix + row1, col1 = row_indices // int(math.sqrt(N)), row_indices % int(math.sqrt(N)) + row2, col2 = col_indices // int(math.sqrt(N)), col_indices % int(math.sqrt(N)) + graph = ((abs(row1 - row2) <= 1).float() * (abs(col1 - col2) <= 1).float()) + graph = graph - torch.eye(N) + self.spatial_graph = graph.to("cuda") # comment .to("cuda") if the environment is cpu self.class_token = False self.token_scale = False self.head = nn.Linear( @@ -580,7 +594,7 @@ class FouriER(torch.nn.Module): degree = graph.sum(-1) # B, N degree = torch.diag_embed(degree**(-1/2)) graph = degree @ graph @ degree - + for idx, block in enumerate(self.network): x = block(x) # output only the features of last layer for image classification