Model Architectures and Loss Functions
GAN Architectures
ganslate
provides implementations of several popular image translation GANs, which you can use out-of-the-box in your projects. Here is the list of currently supported GAN architectures:
-
Pix2Pix
- Class:
ganslate.nn.gans.paired.pix2pix.Pix2PixConditionalGAN
- Data requirements: Paired pixel-wise aligned domain A and domain B images
- Original paper: Isola et. al - Image-to-Image Translation with Conditional Adversarial Networks (arXiv)
- Class:
-
CycleGAN
- Class:
ganslate.nn.gans.unpaired.cyclegan.CycleGAN
- Data requirements: Unpaired domain A and domain B images
- Original paper: Zhu et. al - Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks (arXiv)
- Class:
-
RevGAN
- Class:
ganslate.nn.gans.unpaired.revgan.RevGAN
- Data requirements: Unpaired domain A and domain B images
- Original paper: Ouderaa et. al - Reversible GANs for Memory-efficient Image-to-Image Translation (arXiv)
- Class:
-
CUT
- Class:
ganslate.nn.gans.unpaired.cut.CUT
- Data requirements: Unpaired domain A and domain B images
- Original paper: Park et. al - Contrastive Learning for Unpaired Image-to-Image Translation (arXiv)
- Class:
ganslate
defines an abstract base class ganslate.nn.gans.base.BaseGAN
(source) that implements some of the basic functionalty common to all the aforementioned GAN architectures, such as methods related to model setup, saving, loading, learning rate update, etc. Additionally, it also declares certain abstract methods whose implementation might differ across various GAN architectures, such as the forward pass and backpropagation logic. Each of the aforementioned GAN architectures inherits from BaseGAN
and implements the necessary abstract methods.
The BaseGAN
class has an associated dataclass
at ganslate.configs.base.BaseGANConfig
that defines all its basic settings including the settings for optimizer, generator, and discriminator. Since the different GAN architectures have their own specific settings, each of them also has an associated configuration dataclass
that inherits from ganslate.configs.base.BaseGANConfig
and defines additional architecture-specific settings.
As a result to its extensible design, ganslate
additionally enables users to modify the existing GANs by overriding certain functionalities or to define their own custom image translation GAN from scratch. The former is discussed in the context of loss functions as part of the basic tutorial Your First Project. Whereas, the latter is part of the advanced tutorial Writing Your Own GAN Class from Scratch.
Generator and Discriminator Architectures
Generators and discriminators are defined in ganslate
as regular PyTorch modules derived from torch.nn.Module
.
Following is the list of the available generator architectures:
-
ResNet variants (Original ResNet paper - arXiv):
- 2D ResNet:
ganslate.nn.generators.resent.resnet2d.Resnet2D
- 3D ResNet:
ganslate.nn.generators.resent.resnet3d.Resnet3D
- Partially-invertible ResNet generator:
ganslate.nn.generators.resent.piresnet3d.Piresnet3D
- 2D ResNet:
-
U-Net variants (Original U-Net paper - arXiv):
- 2D U-Net:
ganslate.nn.generators.unet.unet2d.Unet2D
- 3D U-Net:
ganslate.nn.generators.unet.unet3d.Unet#D
- 2D U-Net:
-
V-Net variants (Original V-Net paper - arXiv)
- 2D V-Net:
ganslate.nn.generators.vnet.vnet2d.Vnet2D
- 3D V-Net:
ganslate.nn.generators.vnet.vnet3d.Vnet3D
- Partially-invertible 3D V-Net generator with Self-Attention:
ganslate.nn.generators.vnet.sa_vnet3d.SAVnet3D
- 2D V-Net:
And here is the list of the available discriminator architectures:
- PatchGAN discriminator variants (PatchGAN originally described in the Pix2Pix paper - arXiv)
- 2D PatchGAN:
ganslate.nn.discriminators.patchgan.patchgan2d.PatchGAN2D
- 3D PatchGAN:
ganslate.nn.discriminators.patchgan.patchgan3d.PatchGAN3D
- Multiscale 3D PatchGAN:
ganslate.nn.discriminators.patchgan.ms_patchgan3d.MSPatchGAN3D
- 3D PatchGAN with Self-Attention:
ganslate.nn.discriminators.patchgan.sa_patchgan3d.SAPatchGAN3D
- 2D PatchGAN:
Loss Functions
Several different loss function classes are provided in the ganslate
package. These include different flavors of the adversarial loss as well as various GAN architecture-specific losses.
-
Adversarial loss
- Class:
ganslate.nn.losses.adversarial_loss.AdversarialLoss
- Variants:
'vanilla'
(original adversarial loss based on cross-entropy),'lsgan'
(least-squares loss),'wgangp'
(Wasserstein-1 distance with gradient penalty), and'nonsaturating'
- Class:
-
Pix2Pix loss
- Class:
ganslate.nn.losses.pix2pix_losses.Pix2PixLoss
- Components:
- Pixel-to-pixel L1 loss between synthetic image and ground truth (weighted by the scalar
lambda_pix2pix
)
- Pixel-to-pixel L1 loss between synthetic image and ground truth (weighted by the scalar
- Class:
-
CycleGAN losses
- Class:
ganslate.nn.losses.cyclegan_losses.CycleGANLosses
- Components:
- Cycle-consistency loss based on L1 distance (A-B-A and B-A-B components separated weighted by
lambda_AB
andlambda_BA
, respectively). Option to compute cycle-consistency as using a weighted sum of L1 and SSIM losses (weights defined by the hyperparameterproportion_ssim
). - Identity loss implemented with L1 distance
- Cycle-consistency loss based on L1 distance (A-B-A and B-A-B components separated weighted by
- Class:
-
CUT losses
- Class:
ganslate.nn.losses.cut_losses.PatchNCELoss
- Components:
- PatchNCE loss
- Class: