To extract (overlapping-) patches and to reconstruct the input shape we can use the torch.nn.functional.unfold
and the inverse operation torch.nn.functional.fold
. These methods only process 4D tensors or 2D images, however you can use these methods to process one dimension at a time.
Few notes:
This way requires fold/unfold methods from pytorch, unfortunately I have yet to find a similar method in the TF api.
We can extract patches in 2 ways, their output is the same. The methods are called extract_patches_3d
and extract_patches_3ds
where X is the number of dimensions. The latter uses torch.Tensor.unfold() and has less lines of code. (output is the same, except it cannot use dilation)
The methods extract_patches_Xd
and combine_patches_Xd
are inverse methods and the combiner reverses the steps from the extracter step by step.
The lines of code are followed by a comment stating the dimensionality such as (B, C, D, H, W). The following are used:
B
: Batch size
C
: Channels
D
: Depth Dimension
H
: Height Dimension
W
: Width Dimension
x_dim_in
: In the extraction method, this is the number input pixels in dimension x
. In the combining method, this is the number of number of sliding windows in dimension x
.
x_dim_out
: In the extraction method, this is the number of sliding windows in dimension x
. In the combining method, this is the number output pixels in dimension x
.
I have a public notebook to try out the code
The get_dim_blocks()
method is the function given on the pytorch docs website to compute the output shape of a convolutional layer.
Note that if you have overlapping patches and you combine them, the overlapping elements will be summed. If you would like to get the initial input again there is a way.
- Create similar sized tensor of ones as the patches with
torch.ones_like(patches_tensor)
.
- Combine the patches into full image with same output shape. (this creates a counter for overlapping elements).
- Divide the Combined image with the Combined ones, this should reverse any double summation of elements.
(3D):
We need to use 2
fold
and unfold
where we first apply the fold
to the D
dimension and leave the W
and H
untouched by setting kernel to 1, padding to 0, stride to 1 and dilation to 1. After we review the tensor and fold over the H
and W
dimensions. The unfolding happens in reverse, starting with H
and W
, then D
.
def extract_patches_3ds(x, kernel_size, padding=0, stride=1):
if isinstance(kernel_size, int):
kernel_size = (kernel_size, kernel_size, kernel_size)
if isinstance(padding, int):
padding = (padding, padding, padding, padding, padding, padding)
if isinstance(stride, int):
stride = (stride, stride, stride)
channels = x.shape[1]
x = torch.nn.functional.pad(x, padding)
# (B, C, D, H, W)
x = x.unfold(2, kernel_size[0], stride[0]).unfold(3, kernel_size[1], stride[1]).unfold(4, kernel_size[2], stride[2])
# (B, C, d_dim_out, h_dim_out, w_dim_out, kernel_size[0], kernel_size[1], kernel_size[2])
x = x.contiguous().view(-1, channels, kernel_size[0], kernel_size[1], kernel_size[2])
# (B * d_dim_out * h_dim_out * w_dim_out, C, kernel_size[0], kernel_size[1], kernel_size[2])
return x
def extract_patches_3d(x, kernel_size, padding=0, stride=1, dilation=1):
if isinstance(kernel_size, int):
kernel_size = (kernel_size, kernel_size, kernel_size)
if isinstance(padding, int):
padding = (padding, padding, padding)
if isinstance(stride, int):
stride = (stride, stride, stride)
if isinstance(dilation, int):
dilation = (dilation, dilation, dilation)
def get_dim_blocks(dim_in, dim_kernel_size, dim_padding = 0, dim_stride = 1, dim_dilation = 1):
dim_out = (dim_in + 2 * dim_padding - dim_dilation * (dim_kernel_size - 1) - 1) // dim_stride + 1
return dim_out
channels = x.shape[1]
d_dim_in = x.shape[2]
h_dim_in = x.shape[3]
w_dim_in = x.shape[4]
d_dim_out = get_dim_blocks(d_dim_in, kernel_size[0], padding[0], stride[0], dilation[0])
h_dim_out = get_dim_blocks(h_dim_in, kernel_size[1], padding[1], stride[1], dilation[1])
w_dim_out = get_dim_blocks(w_dim_in, kernel_size[2], padding[2], stride[2], dilation[2])
# print(d_dim_in, h_dim_in, w_dim_in, d_dim_out, h_dim_out, w_dim_out)
# (B, C, D, H, W)
x = x.view(-1, channels, d_dim_in, h_dim_in * w_dim_in)
# (B, C, D, H * W)
x = torch.nn.functional.unfold(x, kernel_size=(kernel_size[0], 1), padding=(padding[0], 0), stride=(stride[0], 1), dilation=(dilation[0], 1))
# (B, C * kernel_size[0], d_dim_out * H * W)
x = x.view(-1, channels * kernel_size[0] * d_dim_out, h_dim_in, w_dim_in)
# (B, C * kernel_size[0] * d_dim_out, H, W)
x = torch.nn.functional.unfold(x, kernel_size=(kernel_size[1], kernel_size[2]), padding=(padding[1], padding[2]), stride=(stride[1], stride[2]), dilation=(dilation[1], dilation[2]))
# (B, C * kernel_size[0] * d_dim_out * kernel_size[1] * kernel_size[2], h_dim_out, w_dim_out)
x = x.view(-1, channels, kernel_size[0], d_dim_out, kernel_size[1], kernel_size[2], h_dim_out, w_dim_out)
# (B, C, kernel_size[0], d_dim_out, kernel_size[1], kernel_size[2], h_dim_out, w_dim_out)
x = x.permute(0, 1, 3, 6, 7, 2, 4, 5)
# (B, C, d_dim_out, h_dim_out, w_dim_out, kernel_size[0], kernel_size[1], kernel_size[2])
x = x.contiguous().view(-1, channels, kernel_size[0], kernel_size[1], kernel_size[2])
# (B * d_dim_out * h_dim_out * w_dim_out, C, kernel_size[0], kernel_size[1], kernel_size[2])
return x
def combine_patches_3d(x, kernel_size, output_shape, padding=0, stride=1, dilation=1):
if isinstance(kernel_size, int):
kernel_size = (kernel_size, kernel_size, kernel_size)
if isinstance(padding, int):
padding = (padding, padding, padding)
if isinstance(stride, int):
stride = (stride, stride, stride)
if isinstance(dilation, int):
dilation = (dilation, dilation, dilation)
def get_dim_blocks(dim_in, dim_kernel_size, dim_padding = 0, dim_stride = 1, dim_dilation = 1):
dim_out = (dim_in + 2 * dim_padding - dim_dilation * (dim_kernel_size - 1) - 1) // dim_stride + 1
return dim_out
channels = x.shape[1]
d_dim_out, h_dim_out, w_dim_out = output_shape[2:]
d_dim_in = get_dim_blocks(d_dim_out, kernel_size[0], padding[0], stride[0], dilation[0])
h_dim_in = get_dim_blocks(h_dim_out, kernel_size[1], padding[1], stride[1], dilation[1])
w_dim_in = get_dim_blocks(w_dim_out, kernel_size[2], padding[2], stride[2], dilation[2])
# print(d_dim_in, h_dim_in, w_dim_in, d_dim_out, h_dim_out, w_dim_out)
x = x.view(-1, channels, d_dim_in, h_dim_in, w_dim_in, kernel_size[0], kernel_size[1], kernel_size[2])
# (B, C, d_dim_in, h_dim_in, w_dim_in, kernel_size[0], kernel_size[1], kernel_size[2])
x = x.permute(0, 1, 5, 2, 6, 7, 3, 4)
# (B, C, kernel_size[0], d_dim_in, kernel_size[1], kernel_size[2], h_dim_in, w_dim_in)
x = x.contiguous().view(-1, channels * kernel_size[0] * d_dim_in * kernel_size[1] * kernel_size[2], h_dim_in * w_dim_in)
# (B, C * kernel_size[0] * d_dim_in * kernel_size[1] * kernel_size[2], h_dim_in * w_dim_in)
x = torch.nn.functional.fold(x, output_size=(h_dim_out, w_dim_out), kernel_size=(kernel_size[1], kernel_size[2]), padding=(padding[1], padding[2]), stride=(stride[1], stride[2]), dilation=(dilation[1], dilation[2]))
# (B, C * kernel_size[0] * d_dim_in, H, W)
x = x.view(-1, channels * kernel_size[0], d_dim_in * h_dim_out * w_dim_out)
# (B, C * kernel_size[0], d_dim_in * H * W)
x = torch.nn.functional.fold(x, output_size=(d_dim_out, h_dim_out * w_dim_out), kernel_size=(kernel_size[0], 1), padding=(padding[0], 0), stride=(stride[0], 1), dilation=(dilation[0], 1))
# (B, C, D, H * W)
x = x.view(-1, channels, d_dim_out, h_dim_out, w_dim_out)
# (B, C, D, H, W)
return x
a = torch.arange(1, 129, dtype=torch.float).view(2,2,2,4,4)
print(a.shape)
print(a)
b = extract_patches_3d(a, 2, padding=1, stride=1)
bs = extract_patches_3ds(a, 2, padding=1, stride=1)
print(b.shape)
print(b)
c = combine_patches_3d(b, (2,2,2,4,4), kernel_size=2, padding=1, stride=1)
print(c.shape)
print(c)
ones = torch.ones_like(b)
ones = combine_patches_3d(ones, (2,2,2,4,4), kernel_size=2, padding=1, stride=1)
print(torch.all(a==c))
print(c.shape, ones.shape)
d = c / ones
print(d)
print(torch.all(a==d))
Output (3D)
torch.Size([2, 2, 2, 4, 4])
tensor([[[[[ 1., 2., 3., 4.],
[ 5., 6., 7., 8.],
[ 9., 10., 11., 12.],
[ 13., 14., 15., 16.]],
[[ 17., 18., 19., 20.],
[ 21., 22., 23., 24.],
[ 25., 26., 27., 28.],
[ 29., 30., 31., 32.]]],
[[[ 33., 34., 35., 36.],
[ 37., 38., 39., 40.],
[ 41., 42., 43., 44.],
[ 45., 46., 47., 48.]],
[[ 49., 50., 51., 52.],
[ 53., 54., 55., 56.],
[ 57., 58., 59., 60.],
[ 61., 62., 63., 64.]]]],
[[[[ 65., 66., 67., 68.],
[ 69., 70., 71., 72.],
[ 73., 74., 75., 76.],
[ 77., 78., 79., 80.]],
[[ 81., 82., 83., 84.],
[ 85., 86., 87., 88.],
[ 89., 90., 91., 92.],
[ 93., 94., 95., 96.]]],
[[[ 97., 98., 99., 100.],
[101., 102., 103., 104.],
[105., 106., 107., 108.],
[109., 110., 111., 112.]],
[[113., 114., 115., 116.],
[117., 118., 119., 120.],
[121., 122., 123., 124.],
[125., 126., 127., 128.]]]]])
torch.Size([150, 2, 2, 2, 2])
tensor([[[[[ 0., 0.],
[ 0., 0.]],
[[ 0., 0.],
[ 0., 1.]]],
[[[ 0., 0.],
[ 0., 0.]],
[[ 0., 0.],
[ 1., 2.]]]],
[[[[ 0., 0.],
[ 0., 0.]],
[[ 0., 0.],
[ 2., 3.]]],
[[[ 0., 0.],
[ 0., 0.]],
[[ 0., 0.],
[ 3., 4.]]]],
[[[[ 0., 0.],
[ 0., 0.]],
[[ 0., 0.],
[ 4., 0.]]],
[[[ 0., 0.],
[ 0., 0.]],
[[ 0., 1.],
[ 0., 5.]]]],
...,
[[[[124., 0.],
[128., 0.]],
[[ 0., 0.],
[ 0., 0.]]],
[[[ 0., 125.],
[ 0., 0.]],
[[ 0., 0.],
[ 0., 0.]]]],
[[[[125., 126.],
[ 0., 0.]],
[[ 0., 0.],
[ 0., 0.]]],
[[[126., 127.],
[ 0., 0.]],
[[ 0., 0.],
[ 0., 0.]]]],
[[[[127., 128.],
[ 0., 0.]],
[[ 0., 0.],
[ 0., 0.]]],
[[[128., 0.],
[ 0., 0.]],
[[ 0., 0.],
[ 0., 0.]]]]])
torch.Size([2, 2, 2, 4, 4])
tensor([[[[[ 8., 16., 24., 32.],
[ 40., 48., 56., 64.],
[ 72., 80., 88., 96.],
[ 104., 112., 120., 128.]],
[[ 136., 144., 152., 160.],
[ 168., 176., 184., 192.],
[ 200., 208., 216., 224.],
[ 232., 240., 248., 256.]]],
[[[ 264., 272., 280., 288.],
[ 296., 304., 312., 320.],
[ 328., 336., 344., 352.],
[ 360., 368., 376., 384.]],
[[ 392., 400., 408., 416.],
[ 424., 432., 440., 448.],
[ 456., 464., 472., 480.],
[ 488., 496., 504., 512.]]]],
[[[[ 520., 528., 536., 544.],
[ 552., 560., 568., 576.],
[ 584., 592., 600., 608.],
[ 616., 624., 632., 640.]],
[[ 648., 656., 664., 672.],
[ 680., 688., 696., 704.],
[ 712., 720., 728., 736.],
[ 744., 752., 760., 768.]]],
[[[ 776., 784., 792., 800.],
[ 808., 816., 824., 832.],
[ 840., 848., 856., 864.],
[ 872., 880., 888., 896.]],
[[ 904., 912., 920., 928.],
[ 936., 944., 952., 960.],
[ 968., 976., 984., 992.],
[1000., 1008., 1016., 1024.]]]]])
tensor(False)
torch.Size([2, 2, 2, 4, 4]) torch.Size([2, 2, 2, 4, 4])
tensor([[[[[ 1., 2., 3., 4.],
[ 5., 6., 7., 8.],
[ 9., 10., 11., 12.],
[ 13., 14., 15., 16.]],
[[ 17., 18., 19., 20.],
[ 21., 22., 23., 24.],
[ 25., 26., 27., 28.],
[ 29., 30., 31., 32.]]],
[[[ 33., 34., 35., 36.],
[ 37., 38., 39., 40.],
[ 41., 42., 43., 44.],
[ 45., 46., 47., 48.]],
[[ 49., 50., 51., 52.],
[ 53., 54., 55., 56.],
[ 57., 58., 59., 60.],
[ 61., 62., 63., 64.]]]],
[[[[ 65., 66., 67., 68.],
[ 69., 70., 71., 72.],
[ 73., 74., 75., 76.],
[ 77., 78., 79., 80.]],
[[ 81., 82., 83., 84.],
[ 85., 86., 87., 88.],
[ 89., 90., 91., 92.],
[ 93., 94., 95., 96.]]],
[[[ 97., 98., 99., 100.],
[101., 102., 103., 104.],
[105., 106., 107., 108.],
[109., 110., 111., 112.]],
[[113., 114., 115., 116.],
[117., 118., 119., 120.],
[121., 122., 123., 124.],
[125., 126., 127., 128.]]]]])
tensor(True)