try gtp vit
This commit is contained in:
parent
b01e504874
commit
68a94bd1e2
14
models.py
14
models.py
@ -11,6 +11,7 @@ from timm.models.layers import DropPath, trunc_normal_
|
|||||||
from timm.models.registry import register_model
|
from timm.models.registry import register_model
|
||||||
from timm.layers.helpers import to_2tuple
|
from timm.layers.helpers import to_2tuple
|
||||||
from typing import *
|
from typing import *
|
||||||
|
import math
|
||||||
|
|
||||||
|
|
||||||
class ConvE(torch.nn.Module):
|
class ConvE(torch.nn.Module):
|
||||||
@ -528,6 +529,19 @@ class FouriER(torch.nn.Module):
|
|||||||
self.network = nn.ModuleList(network)
|
self.network = nn.ModuleList(network)
|
||||||
self.norm = norm_layer(embed_dims[-1])
|
self.norm = norm_layer(embed_dims[-1])
|
||||||
self.graph_type = 'Spatial'
|
self.graph_type = 'Spatial'
|
||||||
|
N = (image_h // patch_size)**2
|
||||||
|
if self.graph_type in ["Spatial", "Mixed"]:
|
||||||
|
# Create a range tensor of node indices
|
||||||
|
indices = torch.arange(N)
|
||||||
|
# Reshape the indices tensor to create a grid of row and column indices
|
||||||
|
row_indices = indices.view(-1, 1).expand(-1, N)
|
||||||
|
col_indices = indices.view(1, -1).expand(N, -1)
|
||||||
|
# Compute the adjacency matrix
|
||||||
|
row1, col1 = row_indices // int(math.sqrt(N)), row_indices % int(math.sqrt(N))
|
||||||
|
row2, col2 = col_indices // int(math.sqrt(N)), col_indices % int(math.sqrt(N))
|
||||||
|
graph = ((abs(row1 - row2) <= 1).float() * (abs(col1 - col2) <= 1).float())
|
||||||
|
graph = graph - torch.eye(N)
|
||||||
|
self.spatial_graph = graph.to("cuda") # comment .to("cuda") if the environment is cpu
|
||||||
self.class_token = False
|
self.class_token = False
|
||||||
self.token_scale = False
|
self.token_scale = False
|
||||||
self.head = nn.Linear(
|
self.head = nn.Linear(
|
||||||
|
Loading…
x
Reference in New Issue
Block a user