I am implementing DCGANs using PyTorch.
It works well in that I can get reasonable quality generated images, however now I want to evaluate the health of the GAN models by using metrics, mainly the ones introduced by this guide https://machinelearningmastery.com/practical-guide-to-gan-failure-modes/
Their implementation uses Keras which SDK lets you define what metrics you want when you compile the model, see https://keras.io/api/models/model/. In this case the accuracy of the discriminator, i.e. percentage of when it successfully identifies an image as real or generated.
With the PyTorch SDK, I can't seem to find a similar feature that would help me easily acquire this metric from my model.
Does Pytorch provide the functionality to be able to define and extract common metrics from a model?
Pure PyTorch does not provide metrics out of the box, but it is very easy to define those yourself.
Also there is no such thing as "extracting metrics from model". Metrics are metrics, they measure (in this case accuracy of discriminator), they are not inherent to the model.
In your case, you are looking for binary accuracy metric. Below code works with either logits
(unnormalized probability outputed by discriminator
, probably last nn.Linear
layer without activation) or probabilities
(last nn.Linear
followed by sigmoid
activation):
import typing
import torch
class BinaryAccuracy:
def __init__(
self,
logits: bool = True,
reduction: typing.Callable[
[
torch.Tensor,
],
torch.Tensor,
] = torch.mean,
):
self.logits = logits
if logits:
self.threshold = 0
else:
self.threshold = 0.5
self.reduction = reduction
def __call__(self, y_pred, y_true):
return self.reduction(((y_pred > self.threshold) == y_true.bool()).float())
Usage:
metric = BinaryAccuracy()
target = torch.randint(2, size=(64,))
outputs = torch.randn(size=(64, 1))
print(metric(outputs, target))
You can also use PyTorch Lightning or other framework on top of PyTorch which defines metrics like accuracy