diff --git a/main.py b/main.py index f6dea4c..b9349b9 100644 --- a/main.py +++ b/main.py @@ -716,17 +716,19 @@ if __name__ == "__main__": model.load_model(save_path) model.evaluate('test') else: - while True: - try: - model = Main(args, logger) - model.fit() - except Exception as e: - print(e) - traceback.print_exc() - try: - del model - except Exception: - pass - time.sleep(30) - continue - break + model = Main(args, logger) + model.fit() + # while True: + # try: + # model = Main(args, logger) + # model.fit() + # except Exception as e: + # print(e) + # traceback.print_exc() + # try: + # del model + # except Exception: + # pass + # time.sleep(30) + # continue + # break diff --git a/models.py b/models.py index d13bc8a..e603f85 100644 --- a/models.py +++ b/models.py @@ -862,7 +862,7 @@ class PoolFormerBlock(nn.Module): self.norm1 = norm_layer(dim) #self.token_mixer = Pooling(pool_size=pool_size) # self.token_mixer = FNetBlock() - self.token_mixer = WindowAttention(dim=dim, window_size=to_2tuple(7), num_heads=3, pretrained_window_size=[5,5]) + self.token_mixer = WindowAttention(dim=dim, window_size=to_2tuple(7), num_heads=1, pretrained_window_size=[5,5]) self.norm2 = norm_layer(dim) mlp_hidden_dim = int(dim * mlp_ratio) self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim,