I can augment my data during training by applying a random transform (rotation/translation/rescaling) but I don't know the value that was selected.
I need to know what values were applied. I can manually set these values, but then I lose a lot of the benefits that torch vision transforms provide.
Is there an easy way to get these values are implement them in a sensible way to apply during training?
Here is an example. I would love to be able print out the rotation angle, translation/rescaling being applied at each image:
import numpy as np
import matplotlib.pyplot as plt
from torchvision import transforms
RandAffine = transforms.RandomAffine(degrees=0, translate=(0.1, 0.1), scale=(0.8, 1.2))
rotate = transforms.RandomRotation(degrees=45)
shift = RandAffine
composed = transforms.Compose([rotate,
shift])
# Apply each of the above transforms on sample.
fig = plt.figure()
sample = np.zeros((28,28))
sample[5:15,7:20] = 255
sample = transforms.ToPILImage()(sample.astype(np.uint8))
title = ['None', 'Rot','Aff','Comp']
for i, tsfrm in enumerate([None,rotate, shift, composed]):
if tsfrm:
t_sample = tsfrm(sample)
else:
t_sample = sample
ax = plt.subplot(1, 5, i + 2)
plt.tight_layout()
ax.set_title(title[i])
ax.imshow(np.reshape(np.array(list(t_sample.getdata())), (-1,28)), cmap='gray')
plt.show()
I'm afraid there is no easy way around it: Torchvision's random transforms utilities are built in such a way that the transform parameters will be sampled when called. They are unique random transforms, in the sense that (1) parameters used are not accessible by the user and (2) the same random transformation is not repeatable.
As of Torchvision 0.8.0, random transforms are generally built with two main functions:
get_params
: which will sample based on the transform's hyperparameters (what you have provided when you initialized the transform operator, namely the parameters' range of values)
forward
: the function that gets executed when applying the transform. The important part is it gets its parameters from get_params
then applies it to the input using the associated deterministic function. For RandomRotation
, F.rotate
will get called. Similarly, RandomAffine
will use F.affine
.
One solution to your problem is sampling the parameters from get_params
yourself and calling the functional - deterministic - API instead. So you wouldn't be using RandomRotation
, RandomAffine
, nor any other Random*
transformation for that matter.
For instance, let's look at T.RandomRotation
(I have removed the comments for conciseness).
class RandomRotation(torch.nn.Module):
def __init__(
self, degrees, interpolation=InterpolationMode.NEAREST, expand=False,
center=None, fill=None, resample=None):
# ...
@staticmethod
def get_params(degrees: List[float]) -> float:
angle = float(torch.empty(1).uniform_(float(degrees[0]), \
float(degrees[1])).item())
return angle
def forward(self, img):
fill = self.fill
if isinstance(img, Tensor):
if isinstance(fill, (int, float)):
fill = [float(fill)] * F._get_image_num_channels(img)
else:
fill = [float(f) for f in fill]
angle = self.get_params(self.degrees)
return F.rotate(img, angle, self.resample, self.expand, self.center, fill)
def __repr__(self):
# ...
With that in mind, here is a possible override to modify T.RandomRotation
:
class RandomRotation(T.RandomRotation):
def __init__(*args, **kwargs):
super(RandomRotation, self).__init__(*args, **kwargs) # let super do all the work
self.angle = self.get_params(self.degrees) # initialize your random parameters
def forward(self): # override T.RandomRotation's forward
fill = self.fill
if isinstance(img, Tensor):
if isinstance(fill, (int, float)):
fill = [float(fill)] * F._get_image_num_channels(img)
else:
fill = [float(f) for f in fill]
return F.rotate(img, self.angle, self.resample, self.expand, self.center, fill)
I've essentially copied T.RandomRotation
's forward
function, the only difference being that the parameters are sampled in __init__
(i.e. once) instead of inside the forward
(i.e. on every call). Torchvision's implementation covers all cases, you generally won't need to copy the full forward
. In some cases, you can just call the functional version pretty much straight away. For example, if you don't need to set the fill
parameters, you can just discard that part and only use:
class RandomRotation(T.RandomRotation):
def __init__(*args, **kwargs):
super(RandomRotation, self).__init__(*args, **kwargs) # let super do all the work
self.angle = self.get_params(self.degrees) # initialize your random parameters
def forward(self): # override T.RandomRotation's forward
return F.rotate(img, self.angle, self.resample, self.expand, self.center)
If you want to override other random transforms you can look at the source code. The API is fairly self-explanatory and you shouldn't have too many issues implementing an override for each transform.