try gtp vit
This commit is contained in:
parent
68a94bd1e2
commit
541c4fa2b3
@ -541,7 +541,7 @@ class FouriER(torch.nn.Module):
|
|||||||
row2, col2 = col_indices // int(math.sqrt(N)), col_indices % int(math.sqrt(N))
|
row2, col2 = col_indices // int(math.sqrt(N)), col_indices % int(math.sqrt(N))
|
||||||
graph = ((abs(row1 - row2) <= 1).float() * (abs(col1 - col2) <= 1).float())
|
graph = ((abs(row1 - row2) <= 1).float() * (abs(col1 - col2) <= 1).float())
|
||||||
graph = graph - torch.eye(N)
|
graph = graph - torch.eye(N)
|
||||||
self.spatial_graph = graph.to("cuda") # comment .to("cuda") if the environment is cpu
|
self.spatial_graph = graph.cuda() # comment .to("cuda") if the environment is cpu
|
||||||
self.class_token = False
|
self.class_token = False
|
||||||
self.token_scale = False
|
self.token_scale = False
|
||||||
self.head = nn.Linear(
|
self.head = nn.Linear(
|
||||||
|
Loading…
Reference in New Issue
Block a user