What are the differences between torch.flatten()
and torch.nn.Flatten()
?
2 Answers
Flattening is available in three forms in PyTorch
As a tensor method (oop style)
torch.Tensor.flatten
applied directly on a tensor:x.flatten()
.As a function (functional form)
torch.flatten
applied as:torch.flatten(x)
.As a module (layer
nn.Module
)nn.Flatten()
. Generally used in a model definition.
All three are identical and share the same implementation, the only difference being nn.Flatten
has start_dim
set to 1
by default to avoid flattening the first axis (usually the batch axis). While the other two flatten from axis=0
to axis=-1
- i.e. the entire tensor - if no arguments are given.

- 34,531
- 8
- 55
- 100
-
is it still true that flatten uses reshape beneath in C code? – prosti Feb 02 '21 at 17:11
-
1Found this: [`torch/csrc/utils/tensor_flatten.h`](https://github.com/pytorch/pytorch/blob/c9cae1446f9a9509406238a946b1735fe36a73e3/torch/csrc/utils/tensor_flatten.h#L10). Seems like it uses `view`, which is a reshape! – Ivan Feb 02 '21 at 18:22
You can think of the job of torch.flatten()
as to simply doing a flattening operation of the tensor, without any strings attached. You give a tensor, it flattens, and returns it. That's all there to it.
On the contrary, nn.Flatten()
is much more sophisticated (i.e., it's a neural net layer). Being object oriented, it inherits from nn.Module
, although it internally uses the plain tensor.flatten() OP in the forward()
method for flattening the tensor. You can think of it more like a syntactic sugar over torch.flatten()
.
Important difference: A notable distinction is that torch.flatten()
always returns an 1D tensor as result, provided that the input is at least 1D or greater, whereas nn.Flatten()
always returns a 2D tensor, provided that the input is at least 2D or greater (With 1D tensor as input, it will throw an IndexError).
Comparisons:
torch.flatten()
is an API whereasnn.Flatten()
is a neural net layer.torch.flatten()
is a python function whereasnn.Flatten()
is a python class.because of the above point,
nn.Flatten()
comes with lot of methods and attributestorch.flatten()
can be used in the wild (e.g., for simple tensor OPs) whereasnn.Flatten()
is expected to be used in ann.Sequential()
block as one of the layers.torch.flatten()
has no information about the computation graph unless it is stuck into other graph-aware block (withtensor.requires_grad
flag set toTrue
) whereasnn.Flatten()
is always being tracked by autograd.torch.flatten()
cannot accept and process (e.g., linear/conv1D) layers as inputs whereasnn.Flatten()
is mostly used for processing these neural net layers.both
torch.flatten()
andnn.Flatten()
return views to input tensor. Thus, any modification to the result also affects the input tensor. (See the code below)
Code demo:
# input tensors to work with
In [109]: t1 = torch.arange(12).reshape(3, -1)
In [110]: t2 = torch.arange(12, 24).reshape(3, -1)
In [111]: t3 = torch.arange(12, 36).reshape(3, 2, -1) # 3D tensor
Flattening with torch.flatten()
:
In [113]: t1flat = torch.flatten(t1)
In [114]: t1flat
Out[114]: tensor([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11])
# modification to the flattened tensor
In [115]: t1flat[-1] = -1
# input tensor is also modified; thus flattening is a view.
In [116]: t1
Out[116]:
tensor([[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, -1]])
Flattening with nn.Flatten()
:
In [123]: nnfl = nn.Flatten()
In [124]: t3flat = nnfl(t3)
# note that the result is 2D, as opposed to 1D with torch.flatten
In [125]: t3flat
Out[125]:
tensor([[12, 13, 14, 15, 16, 17, 18, 19],
[20, 21, 22, 23, 24, 25, 26, 27],
[28, 29, 30, 31, 32, 33, 34, 35]])
# modification to the result
In [126]: t3flat[-1, -1] = -1
# input tensor also modified. Thus, flattened result is a view.
In [127]: t3
Out[127]:
tensor([[[12, 13, 14, 15],
[16, 17, 18, 19]],
[[20, 21, 22, 23],
[24, 25, 26, 27]],
[[28, 29, 30, 31],
[32, 33, 34, -1]]])
tidbit: torch.flatten()
is the precursor to nn.Flatten()
and its brethren nn.Unflatten()
since it existed from the very beginning. Then, there was a legitimate use-case for nn.Flatten()
, since this is a common requirement for almost all ConvNets (just before the softmax or elsewhere). So it was added later on in the PR #22245.
There are also recent proposals to use nn.Flatten()
in ResNets for model surgery.

- 57,311
- 13
- 161
- 150