try modify swin
This commit is contained in:
		@@ -965,7 +965,7 @@ class PoolFormerBlock(nn.Module):
 | 
			
		||||
        if self.shift_size > 0:
 | 
			
		||||
            # calculate attention mask for SW-MSA
 | 
			
		||||
            H, W = self.input_resolution
 | 
			
		||||
            img_mask = torch.zeros((1, H, W, 1))  # 1 H W 1
 | 
			
		||||
            img_mask = torch.zeros((1, 1, H, W))  # 1 H W 1
 | 
			
		||||
            h_slices = (slice(0, -self.window_size),
 | 
			
		||||
                        slice(-self.window_size, -self.shift_size),
 | 
			
		||||
                        slice(-self.shift_size, None))
 | 
			
		||||
@@ -975,7 +975,7 @@ class PoolFormerBlock(nn.Module):
 | 
			
		||||
            cnt = 0
 | 
			
		||||
            for h in h_slices:
 | 
			
		||||
                for w in w_slices:
 | 
			
		||||
                    img_mask[:, h, w, :] = cnt
 | 
			
		||||
                    img_mask[:, :, h, w] = cnt
 | 
			
		||||
                    cnt += 1
 | 
			
		||||
 | 
			
		||||
            print(self.input_resolution)
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user