try gtp vit
This commit is contained in:
		
							
								
								
									
										13
									
								
								models.py
									
									
									
									
									
								
							
							
						
						
									
										13
									
								
								models.py
									
									
									
									
									
								
							@@ -597,7 +597,7 @@ class FouriER(torch.nn.Module):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
        for idx, block in enumerate(self.network):
 | 
					        for idx, block in enumerate(self.network):
 | 
				
			||||||
            try:
 | 
					            try:
 | 
				
			||||||
                x = block(x, graph)
 | 
					                x = block((x, graph))
 | 
				
			||||||
            except:
 | 
					            except:
 | 
				
			||||||
                x = block(x)
 | 
					                x = block(x)
 | 
				
			||||||
        # output only the features of last layer for image classification
 | 
					        # output only the features of last layer for image classification
 | 
				
			||||||
@@ -758,7 +758,7 @@ def basic_blocks(dim, index, layers,
 | 
				
			|||||||
            use_layer_scale=use_layer_scale, 
 | 
					            use_layer_scale=use_layer_scale, 
 | 
				
			||||||
            layer_scale_init_value=layer_scale_init_value, 
 | 
					            layer_scale_init_value=layer_scale_init_value, 
 | 
				
			||||||
            ))
 | 
					            ))
 | 
				
			||||||
    blocks = nn.Sequential(*blocks)
 | 
					    blocks = SeqModel(*blocks)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    return blocks
 | 
					    return blocks
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -923,6 +923,15 @@ def window_reverse(windows, window_size, H, W):
 | 
				
			|||||||
    x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, -1, H, W)
 | 
					    x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, -1, H, W)
 | 
				
			||||||
    return x
 | 
					    return x
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class SeqModel(nn.Sequential):
 | 
				
			||||||
 | 
						def forward(self, *inputs):
 | 
				
			||||||
 | 
							for module in self._modules.values():
 | 
				
			||||||
 | 
								if type(inputs) == tuple:
 | 
				
			||||||
 | 
									inputs = module(*inputs)
 | 
				
			||||||
 | 
								else:
 | 
				
			||||||
 | 
									inputs = module(inputs)
 | 
				
			||||||
 | 
							return inputs
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def propagate(x: torch.Tensor, weight: torch.Tensor, 
 | 
					def propagate(x: torch.Tensor, weight: torch.Tensor, 
 | 
				
			||||||
              index_kept: torch.Tensor, index_prop: torch.Tensor, 
 | 
					              index_kept: torch.Tensor, index_prop: torch.Tensor, 
 | 
				
			||||||
              standard: str = "None", alpha: Optional[float] = 0, 
 | 
					              standard: str = "None", alpha: Optional[float] = 0, 
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user