try modify swin
This commit is contained in:
		@@ -965,7 +965,7 @@ class PoolFormerBlock(nn.Module):
 | 
				
			|||||||
        if self.shift_size > 0:
 | 
					        if self.shift_size > 0:
 | 
				
			||||||
            # calculate attention mask for SW-MSA
 | 
					            # calculate attention mask for SW-MSA
 | 
				
			||||||
            H, W = self.input_resolution
 | 
					            H, W = self.input_resolution
 | 
				
			||||||
            img_mask = torch.zeros((1, H, W, 1))  # 1 H W 1
 | 
					            img_mask = torch.zeros((1, 1, H, W))  # 1 H W 1
 | 
				
			||||||
            h_slices = (slice(0, -self.window_size),
 | 
					            h_slices = (slice(0, -self.window_size),
 | 
				
			||||||
                        slice(-self.window_size, -self.shift_size),
 | 
					                        slice(-self.window_size, -self.shift_size),
 | 
				
			||||||
                        slice(-self.shift_size, None))
 | 
					                        slice(-self.shift_size, None))
 | 
				
			||||||
@@ -975,7 +975,7 @@ class PoolFormerBlock(nn.Module):
 | 
				
			|||||||
            cnt = 0
 | 
					            cnt = 0
 | 
				
			||||||
            for h in h_slices:
 | 
					            for h in h_slices:
 | 
				
			||||||
                for w in w_slices:
 | 
					                for w in w_slices:
 | 
				
			||||||
                    img_mask[:, h, w, :] = cnt
 | 
					                    img_mask[:, :, h, w] = cnt
 | 
				
			||||||
                    cnt += 1
 | 
					                    cnt += 1
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            print(self.input_resolution)
 | 
					            print(self.input_resolution)
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user