I'm afraid there is no easy way around it: Torchvision's random transforms utilities are built in such a way that the transform parameters will be sampled when called. They are unique random transforms, in the sense that (1) parameters used are not accessible by the user and (2) the same random transformation is not repeatable.
As of Torchvision 0.8.0, random transforms are generally built with two main functions:
get_params
: which will sample based on the transform's hyperparameters (what you have provided when you initialized the transform operator, namely the parameters' range of values)
forward
: the function that gets executed when applying the transform. The important part is it gets its parameters from get_params
then applies it to the input using the associated deterministic function. For RandomRotation
, F.rotate
will get called. Similarly, RandomAffine
will use F.affine
.
One solution to your problem is sampling the parameters from get_params
yourself and calling the functional - deterministic - API instead. So you wouldn't be using RandomRotation
, RandomAffine
, nor any other Random*
transformation for that matter.
For instance, let's look at T.RandomRotation
(I have removed the comments for conciseness).
class RandomRotation(torch.nn.Module):
def __init__(
self, degrees, interpolation=InterpolationMode.NEAREST, expand=False,
center=None, fill=None, resample=None):
# ...
@staticmethod
def get_params(degrees: List[float]) -> float:
angle = float(torch.empty(1).uniform_(float(degrees[0]), \
float(degrees[1])).item())
return angle
def forward(self, img):
fill = self.fill
if isinstance(img, Tensor):
if isinstance(fill, (int, float)):
fill = [float(fill)] * F._get_image_num_channels(img)
else:
fill = [float(f) for f in fill]
angle = self.get_params(self.degrees)
return F.rotate(img, angle, self.resample, self.expand, self.center, fill)
def __repr__(self):
# ...
With that in mind, here is a possible override to modify T.RandomRotation
:
class RandomRotation(T.RandomRotation):
def __init__(*args, **kwargs):
super(RandomRotation, self).__init__(*args, **kwargs) # let super do all the work
self.angle = self.get_params(self.degrees) # initialize your random parameters
def forward(self): # override T.RandomRotation's forward
fill = self.fill
if isinstance(img, Tensor):
if isinstance(fill, (int, float)):
fill = [float(fill)] * F._get_image_num_channels(img)
else:
fill = [float(f) for f in fill]
return F.rotate(img, self.angle, self.resample, self.expand, self.center, fill)
I've essentially copied T.RandomRotation
's forward
function, the only difference being that the parameters are sampled in __init__
(i.e. once) instead of inside the forward
(i.e. on every call). Torchvision's implementation covers all cases, you generally won't need to copy the full forward
. In some cases, you can just call the functional version pretty much straight away. For example, if you don't need to set the fill
parameters, you can just discard that part and only use:
class RandomRotation(T.RandomRotation):
def __init__(*args, **kwargs):
super(RandomRotation, self).__init__(*args, **kwargs) # let super do all the work
self.angle = self.get_params(self.degrees) # initialize your random parameters
def forward(self): # override T.RandomRotation's forward
return F.rotate(img, self.angle, self.resample, self.expand, self.center)
If you want to override other random transforms you can look at the source code. The API is fairly self-explanatory and you shouldn't have too many issues implementing an override for each transform.