Issue
I am trying to implement a GAN called the SimGAN proposed by Apple researchers. The SimGAN is used to refine labelled synthetic images so that they look more like the unlabelled real images.
The link to the paper can be found on arXiv here.
In the paper, the loss function of the combined model, which comprises the generator and the discriminator, has a self-regularization component in the form of an L1 loss
that penalizes too great a difference between the synthetic images and the images after refinement. In other words, the refinement should not be too drastic.
I would like to know how I can implement this self-regularization loss in Keras
. Here is what I tried:
def self_regularization_loss(refined_images, syn_images):
def l1loss(y_true, y_pred):
return keras.metrics.mean_absolute_error(refined_images, syn_images)
return l1loss
However, I do not think I can compile the model in the way below as the batches of refined and synthetic images change during training time.
model.compile(loss=[self_regularization_loss(current_batch_of_refined, current_batch_of_synthetic),
local_adversarial_loss],
optimizer=opt)
What is the way to implement this loss?
Solution
Trying using the tf.function
decorator and tf.GradientTape()
:
@tf.function
def train_step(model, batch):
with tf.GradientTape() as tape:
refined_images, syn_images = batch
loss = self_regularization_loss(model, refined_images, syn_images)
gradients = tape.gradient(loss, model.trainable_variables)
self.optimizer.apply_gradients(zip(gradients, model.trainable_variables))
your training loop can look something like:
for image_batch in dataset:
train_step(model, image_batch)
Here it is assumed that model
is of type tf.keras.Model
. More details to the model class can be found here. Note that model
is also passed to self_regularization_loss
. In this function your model
recieves both images as inputs and then gives you the respective output. Then you calculate your loss.
Answered By - AloneTogether
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.