A slightly less elegant solution than that proposed by Gil:
I took inspiration from this post on the Pytorch forums, formatting my image tensor to be of standard shape B x C x H x W (1 x 1 x 256 x 256). Unfolding:
# CREATE THE UNFOLDED IMAGE SLICES
I = image # shape [256, 256]
kernel_size = bx #shape [16]
stride = int(bx/2) #shape [8]
I2 = I.unsqueeze(0).unsqueeze(0) #shape [1, 1, 256, 256]
patches2 = I2.unfold(2, kernel_size, stride).unfold(3, kernel_size, stride)
#shape [1, 1, 31, 31, 16, 16]
Following this, I do some transforms and filtering to my tensor stack. Before doing this I apply a cosine window and normalise:
# NORMALISE AND WINDOW
Pvv = torch.mean(torch.pow(win, 2))*torch.numel(win)*(noise_std**2)
Pvv = Pvv.double()
mean_patches = torch.mean(patches2, (4, 5), keepdim=True)
mean_patches = mean_patches.repeat(1, 1, 1, 1, 16, 16)
window_patches = win.unsqueeze(0).unsqueeze(0).unsqueeze(0).unsqueeze(0).repeat(1, 1, 31, 31, 1, 1)
zero_mean = patches2 - mean_patches
windowed_patches = zero_mean * window_patches
#SOME FILTERING ....
#ADD MEAN AND WINDOW BEFORE FOLDING BACK TOGETHER.
filt_data_block = (filt_data_block + mean_patches*window_patches) * window_patches
The above code works for me, but a mask would be more simple. Next, I prepare my tensor of form [1, 1, 31, 31, 16, 16] to be transformed back into the original [1, 1, 256, 256]:
# REASSEMBLE THE IMAGE USING FOLD
patches = filt_data_block.contiguous().view(1, 1, -1, kernel_size*kernel_size)
patches = patches.permute(0, 1, 3, 2)
patches = patches.contiguous().view(1, kernel_size*kernel_size, -1)
IR = F.fold(patches, output_size=(256, 256), kernel_size=kernel_size, stride=stride)
IR = IR.squeeze()
This allowed me to create an overlapping sliding window and seamlessly stitch the image back together. Cutting out the filtering makes for an identical image.