Model Architectures and Loss Functions


GAN Architectures

ganslate provides implementations of several popular image translation GANs, which you can use out-of-the-box in your projects. Here is the list of currently supported GAN architectures:

  1. Pix2Pix

    • Class: ganslate.nn.gans.paired.pix2pix.Pix2PixConditionalGAN
    • Data requirements: Paired pixel-wise aligned domain A and domain B images
    • Original paper: Isola et. al - Image-to-Image Translation with Conditional Adversarial Networks (arXiv)
  2. CycleGAN

    • Class: ganslate.nn.gans.unpaired.cyclegan.CycleGAN
    • Data requirements: Unpaired domain A and domain B images
    • Original paper: Zhu et. al - Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks (arXiv)
  3. RevGAN

    • Class: ganslate.nn.gans.unpaired.revgan.RevGAN
    • Data requirements: Unpaired domain A and domain B images
    • Original paper: Ouderaa et. al - Reversible GANs for Memory-efficient Image-to-Image Translation (arXiv)
  4. CUT

    • Class: ganslate.nn.gans.unpaired.cut.CUT
    • Data requirements: Unpaired domain A and domain B images
    • Original paper: Park et. al - Contrastive Learning for Unpaired Image-to-Image Translation (arXiv)

ganslate defines an abstract base class ganslate.nn.gans.base.BaseGAN (source) that implements some of the basic functionalty common to all the aforementioned GAN architectures, such as methods related to model setup, saving, loading, learning rate update, etc. Additionally, it also declares certain abstract methods whose implementation might differ across various GAN architectures, such as the forward pass and backpropagation logic. Each of the aforementioned GAN architectures inherits from BaseGAN and implements the necessary abstract methods.

The BaseGAN class has an associated dataclass at ganslate.configs.base.BaseGANConfig that defines all its basic settings including the settings for optimizer, generator, and discriminator. Since the different GAN architectures have their own specific settings, each of them also has an associated configuration dataclass that inherits from ganslate.configs.base.BaseGANConfig and defines additional architecture-specific settings.

As a result to its extensible design, ganslate additionally enables users to modify the existing GANs by overriding certain functionalities or to define their own custom image translation GAN from scratch. The former is discussed in the context of loss functions as part of the basic tutorial Your First Project. Whereas, the latter is part of the advanced tutorial Writing Your Own GAN Class from Scratch.


Generator and Discriminator Architectures

Generators and discriminators are defined in ganslate as regular PyTorch modules derived from torch.nn.Module.

Following is the list of the available generator architectures:

  1. ResNet variants (Original ResNet paper - arXiv):

    • 2D ResNet: ganslate.nn.generators.resent.resnet2d.Resnet2D
    • 3D ResNet: ganslate.nn.generators.resent.resnet3d.Resnet3D
    • Partially-invertible ResNet generator: ganslate.nn.generators.resent.piresnet3d.Piresnet3D
  2. U-Net variants (Original U-Net paper - arXiv):

    • 2D U-Net: ganslate.nn.generators.unet.unet2d.Unet2D
    • 3D U-Net: ganslate.nn.generators.unet.unet3d.Unet#D
  3. V-Net variants (Original V-Net paper - arXiv)

    • 2D V-Net: ganslate.nn.generators.vnet.vnet2d.Vnet2D
    • 3D V-Net: ganslate.nn.generators.vnet.vnet3d.Vnet3D
    • Partially-invertible 3D V-Net generator with Self-Attention: ganslate.nn.generators.vnet.sa_vnet3d.SAVnet3D

And here is the list of the available discriminator architectures:

  1. PatchGAN discriminator variants (PatchGAN originally described in the Pix2Pix paper - arXiv)
    • 2D PatchGAN: ganslate.nn.discriminators.patchgan.patchgan2d.PatchGAN2D
    • 3D PatchGAN: ganslate.nn.discriminators.patchgan.patchgan3d.PatchGAN3D
    • Multiscale 3D PatchGAN: ganslate.nn.discriminators.patchgan.ms_patchgan3d.MSPatchGAN3D
    • 3D PatchGAN with Self-Attention: ganslate.nn.discriminators.patchgan.sa_patchgan3d.SAPatchGAN3D

Loss Functions

Several different loss function classes are provided in the ganslate package. These include different flavors of the adversarial loss as well as various GAN architecture-specific losses.

  1. Adversarial loss

    • Class: ganslate.nn.losses.adversarial_loss.AdversarialLoss
    • Variants: 'vanilla' (original adversarial loss based on cross-entropy), 'lsgan' (least-squares loss), 'wgangp' (Wasserstein-1 distance with gradient penalty), and 'nonsaturating'
  2. Pix2Pix loss

    • Class: ganslate.nn.losses.pix2pix_losses.Pix2PixLoss
    • Components:
      • Pixel-to-pixel L1 loss between synthetic image and ground truth (weighted by the scalar lambda_pix2pix)
  3. CycleGAN losses

    • Class: ganslate.nn.losses.cyclegan_losses.CycleGANLosses
    • Components:
      • Cycle-consistency loss based on L1 distance (A-B-A and B-A-B components separated weighted by lambda_AB and lambda_BA, respectively). Option to compute cycle-consistency as using a weighted sum of L1 and SSIM losses (weights defined by the hyperparameter proportion_ssim).
      • Identity loss implemented with L1 distance
  4. CUT losses

    • Class: ganslate.nn.losses.cut_losses.PatchNCELoss
    • Components:
      • PatchNCE loss