try swin
This commit is contained in:
		
							
								
								
									
										30
									
								
								main.py
									
									
									
									
									
								
							
							
						
						
									
										30
									
								
								main.py
									
									
									
									
									
								
							@@ -716,17 +716,19 @@ if __name__ == "__main__":
 | 
			
		||||
        model.load_model(save_path)
 | 
			
		||||
        model.evaluate('test')
 | 
			
		||||
    else:
 | 
			
		||||
        while True:
 | 
			
		||||
            try:
 | 
			
		||||
                model = Main(args, logger)
 | 
			
		||||
                model.fit()
 | 
			
		||||
            except Exception as e:
 | 
			
		||||
                print(e)
 | 
			
		||||
                traceback.print_exc()
 | 
			
		||||
                try:
 | 
			
		||||
                    del model
 | 
			
		||||
                except Exception:
 | 
			
		||||
                    pass
 | 
			
		||||
                time.sleep(30)
 | 
			
		||||
                continue
 | 
			
		||||
            break
 | 
			
		||||
        model = Main(args, logger)
 | 
			
		||||
        model.fit()
 | 
			
		||||
        # while True:
 | 
			
		||||
        #     try:
 | 
			
		||||
        #         model = Main(args, logger)
 | 
			
		||||
        #         model.fit()
 | 
			
		||||
        #     except Exception as e:
 | 
			
		||||
        #         print(e)
 | 
			
		||||
        #         traceback.print_exc()
 | 
			
		||||
        #         try:
 | 
			
		||||
        #             del model
 | 
			
		||||
        #         except Exception:
 | 
			
		||||
        #             pass
 | 
			
		||||
        #         time.sleep(30)
 | 
			
		||||
        #         continue
 | 
			
		||||
        #     break
 | 
			
		||||
 
 | 
			
		||||
@@ -862,7 +862,7 @@ class PoolFormerBlock(nn.Module):
 | 
			
		||||
        self.norm1 = norm_layer(dim)
 | 
			
		||||
        #self.token_mixer = Pooling(pool_size=pool_size)
 | 
			
		||||
        # self.token_mixer = FNetBlock()
 | 
			
		||||
        self.token_mixer = WindowAttention(dim=dim, window_size=to_2tuple(7), num_heads=3, pretrained_window_size=[5,5])
 | 
			
		||||
        self.token_mixer = WindowAttention(dim=dim, window_size=to_2tuple(7), num_heads=1, pretrained_window_size=[5,5])
 | 
			
		||||
        self.norm2 = norm_layer(dim)
 | 
			
		||||
        mlp_hidden_dim = int(dim * mlp_ratio)
 | 
			
		||||
        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, 
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user