From 541c4fa2b3c57d806935fe054b0497dc82f9037b Mon Sep 17 00:00:00 2001 From: thanhvc3 Date: Sun, 28 Apr 2024 12:05:57 +0700 Subject: [PATCH] try gtp vit --- models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/models.py b/models.py index e286a6e..a2fb3f8 100644 --- a/models.py +++ b/models.py @@ -541,7 +541,7 @@ class FouriER(torch.nn.Module): row2, col2 = col_indices // int(math.sqrt(N)), col_indices % int(math.sqrt(N)) graph = ((abs(row1 - row2) <= 1).float() * (abs(col1 - col2) <= 1).float()) graph = graph - torch.eye(N) - self.spatial_graph = graph.to("cuda") # comment .to("cuda") if the environment is cpu + self.spatial_graph = graph.cuda() # comment .to("cuda") if the environment is cpu self.class_token = False self.token_scale = False self.head = nn.Linear(