Problem
- existing models are both inefficient and ineffective in such multi-domain image translation tasks
- incapable of jointly training domains from different datasets
New method
- Stargan, a novel and scalable approch that can perform image-to-image translations for multiple domains using only a single model
- A mask vector to domain label enables joint training between domains of different datasets
Star Generative Adversarial Networks
Star Generative Adversarial Networks
1. Multi-Domain Image-to-Image Translation
notation | meaning |
---|---|
input image | |
output image | |
target domain label | |
original domain label | |
Dsrc(x) | a probability distribution over sources given by D |
Dcls(c'|x) | a probability distribution over domain labels computed by D |
λcls | hyper-parameters that control the relative importance of domain classification and reconstruction losses |
λrec | hyper-parameters control the relative importance of reconstruction losses |
m | a mask vector |
concatenation | |
a vector for the labels of the |
|
sampled uniformly along a straight line between a pair of a real and a generated images | |
hyper-parameters control the gradient penalty |
- Goals:To train a single generator G that learns mappings among multiple domains
- train G to translate an input image x into an output image y conditioned on the target domain label c, G(x, c) → y
- Discriminator produces probability distributions over both sources and domain labels, D : x → {Dsrc(x), Dcls(x)}, in order to allows a single discriminator to control multiple domains.
Adversarial Loss
Dsrc(x) as a probability distribution over sources given by D. The generator G tries to minimize this objective, while the discriminator D tries to maximize it
Domain Classification Loss
- add an auxiliary classifier on top of D and impose the domain classification loss when optimizing both D and G
- decompose the objective into two terms: a domain classification loss of
real images used to optimize D, and a domain classification loss of fake images used to optimize G
Reconstruction Loss
- problem: minimizing the losses(Eqs. (1) and (3)) does not guarantee that translated images preserve the content of its input images while changing only the domain-related part of the inputs
- method: apply a cycle consistency loss to the generator
G takes in the translated image G(x, c) and the original domain label c' as input and tries to reconstruct the original image x. We adopt the L1 norm as our reconstruction loss.
Note that we use a single generator twice, first to translate an original image into an image in the target domain and then to reconstruct the original image from the translated image.
Full Objective
We use = 1 and
= 10 in all of our experiments
2. Training with Multiple Datasets
- Problem:the complete information on the label vector
is required when reconstructing the input image
from the translated image
Mask Vector
- introduce a mask vector
that allows StarGAN to ignore unspecified
labels and focus on the explicitly known label provided by a particular dataset. - use an n-dimensional one-hot vector to represent
, with
being the number of datasets. In addition, we define a unified version of the label as a vector
For the remaining -1 unknown labels we simply assign zero values
Training Strategy
- use the domain label
as input to the generator
- the generator learns to ignore the unspecified labels, which are zero vectors, and focus on the explicitly given label
- extend the auxiliary classifier of the discriminator to generate probability distributions over labels for all datasets
- train the model in a multi-task learning setting, where the discriminator tries to minimize only the classification error associated to the known label
- Under these settings, by alternating between CelebA and RaFD the discriminator learns all of the discriminative features for both datasets, and the generator learns to control all the labels in both datasets.
Implementation
Improved GAN Training
- replace Eq. (1) with Wasserstein GAN objective with gradient penalty defined as
where is sampled uniformly along a straight line between a pair of a real and a generated images. We use
= 10 for all experiments
Network Architecture
- generator network composed of two convolutional layers with the stride size of two for downsampling, six residual blocks, and two transposed convolutional layers with the stride size of two for upsampling.
- use instance normalization for the generator but no normalization for
the discriminator. - leverage PatchGANs for the discriminator network, which classifies whether local image patches are real or fake.