A Deep Dive into the ESRGAN (Part 3: Training and Inference)
- Mohamed Benaicha
- Dec 5, 2023
- 4 min read
Training the ESRGAN is not a straight forward process and will invoke knowledge from various concepts about deep earning and generative adversarial networks to understand. Be sure to review Part 1 and Part 2 of this article series on the ESRGAN, as well as this article on convolutions, and this one on backpropagation.
Training the ESRGAN requires (refer to Figure 1 for a visual representation of these steps):
Having data to train on. This means images that have a minimal resolution (say 128 pixels and above) for generator to be able to learn some meaningful features (call these image set 1). A copy of these images is made and are then downscaled forming our input image set (call these image set 2). Ultimately, these low-resolution images are required as they are fed through the generator to produce high-resolution images. The resulting prediction (call these image set 3) are comapred to the original high resolution images before they were downscaled (i.e., image set 1).
Having neural network architectures that are able to take an image, upscale it and produce a higher resolution image. These comprise the discriminator and generator architectures, the former of which is discarded after training is complete (i.e., during deployment).
Having loss functions to calculate to be able to know how good our predictions are. A unique loss is produced by the generator while another unique loss is produced by the discriminator for reasons that will become clear. For the generator, the loss is the pixel-by-pixel difference between our predicted upscaled images (image set 3) and the actual high-resolution images (image set 1).
An interface that can update out neural network weights (the convolutional kernels and fully connected layer weights discussed in Part 1 and 2 of this article series)

Figure 1
The last part (4) requires further detail as it is not simply the difference between the original images and the upscaled images that comprises the loss but also the adversary's feedback - the discriminator. The discriminator's ability to tell whether or not the upscaled image (after step 2 in Figure 1). That works as follows (refer to Figure 2 for a visual representation of these steps):
The downscaling of the high-resolution images is already mentioned above
The generator generates high-resolution predictions
The original high-resolution images are fed into the discriminator. The discriminator attempts to predict output values (call them set 1). The high resolution images predicted by the generator are fed into the discriminator. The discriminator attempts to predict output values (call them set 2). The values in set 1 and set 2 are not images as in the generator but are instead values that could mean different things - this hardly concerns us as all that matters are that we have the values that we'll compare against each other).
Set 1 and set 2 are compared, the difference is the loss.
The loss is used to adjust both generator and discriminator weights through backpropagation.

Figure 2
The rationale for the discriminator is to be able to produce the values in set 1 and set 2, when it receives the original high-resolution images versus when it receives the predicted high resolution images by the generator. It is a "critic" - one of the discriminator's other names - of the generator. If the discriminator is able to easily tell the difference between the original high resolution images and the predicted high resolution images, it generates a larger loss penalizing the generator so that the generator does a better job next time when predicting a high resolution image - so much so that the discriminator would be able to tell the difference between the predicted high resolution image provided by the generator and the original high resolution images - the discriminator would, as a result, produce a small loss at step 4 in Figure 2.
The Training Steps
Discriminator
Pass the low-resolution images through the generator and predict high resolution images
Pass the predicted high-resolution images from step 1 to the discriminator and produce an output (output 1); pass the original high resolution images through the discriminator to produce an output (output 2).
Calculate the discriminator loss by comparing the output 1 and output 2 from step 2. This is done by taking mean of output 2 and subtracting from it the mean of output 1. The difference is also negated.
A gradient penalty value is added to the loss (see appendix) in step 3
Backpropagation is conducted to update discriminator weights
output 1 = discriminator( predicted high resolution images )
output 2 = discriminator( original high resolution images )
discriminator loss = -( mean( output 2 ) - mean( output 1 ) ) + gradient penalty
Generator
The generator comprises 3 losses, the l1 and vgg losses (Figure 1, step 3), and the adversarial loss (Figure 2, step 5).
The l1 loss is the absolute difference between the original high-resolution images and predicted high-resolution images.
The VGG loss is simple the MSE, where the difference between the predicted high resolution image and the actual high-resolution image is squared. Such a value is produced for each image in a batch of images that is processed together in the network. The mean across all images comprises the l1 loss for the current batch.
The adversarial loss for the generator is different from the one for discriminator that we've seen above. The adversarial loss is simply the mean of discriminator's outputs (where the input is the predicted output from the generator). This value is negated, and scaled by (i.e., multiplied by) a small value. This is logical as the discriminator's ability to predict values
Hence, the training steps for the generator are, assuming the predicted high resolution images from step 1 of the discriminator steps:
Calculate the l1 loss by using the predicted and original high-resolution images
Calculate the adversarial loss by passing in the predicted high-resolution images to the discriminator
Calculate the VGG loss by using the predicted and original high-resolution images
Add the losses from step 1, 2 and 3 above.
The combined loss is used to backpropagate and update the generator weights.
Comments