Issue
I am trying to run Pix2Pix, however my training function suddenly stops during the first 1k steps with no errors. I have used PyTorch for creating the discriminator and the generator. Below is the code with 2 functions responsible for training, one for training each step and one for fitting the model.
Training Step Function:
def train_step(input_image, target, step):
generator.train()
discriminator.train()
# Forward pass
gen_output = generator(input_image)
disc_real_output = discriminator(input_image, target)
disc_generated_output = discriminator(input_image, gen_output)
# Compute losses
gen_total_loss, gen_gan_loss, gen_l1_loss = generator_loss(disc_generated_output,
gen_output, target)
disc_loss = discriminator_loss(disc_real_output, disc_generated_output)
# Backward pass
generator_optimizer.zero_grad()
discriminator_optimizer.zero_grad()
gen_total_loss.backward(retain_graph=True)
discriminator_optimizer.zero_grad() # Clear the generator gradients for the
discriminator backward pass
disc_loss.backward()
# Update weights
generator_optimizer.step()
discriminator_optimizer.step()
# Logging
with torch.no_grad():
writer.add_scalar('gen_total_loss', gen_total_loss.item(), global_step=step // 1000)
writer.add_scalar('gen_gan_loss', gen_gan_loss.item(), global_step=step // 1000)
writer.add_scalar('gen_l1_loss', gen_l1_loss.item(), global_step=step // 1000)
writer.add_scalar('disc_loss', disc_loss.item(), global_step=step // 1000)
Fitting Function:
def fit(train_loader, test_loader, steps):
example_target, example_input = next(iter(test_loader))
start = time.time()
for step, (target, input_image) in enumerate(train_loader):
if (step) % 1000 == 0:
display.clear_output(wait=True)
if step != 0:
print(f'Time taken for 1000 steps: {time.time()-start:.2f} sec\n')
start = time.time()
generate_images(generator, example_input, example_target)
print(f"Step: {step//1000}k")
train_step(input_image, target, step)
# Training step
if (step+1) % 10 == 0:
print('.', end='', flush=True)
# Save (checkpoint) the model every 5k steps
if (step + 1) % 5000 == 0:
torch.save({
'generator_state_dict': generator.state_dict(),
'discriminator_state_dict': discriminator.state_dict(),
'generator_optimizer_state_dict': generator_optimizer.state_dict(),
'discriminator_optimizer_state_dict': discriminator_optimizer.state_dict(),
}, f'checkpoint_{step + 1}.pt')
I am new to using GANs and I am not sure what the issue is here. I have tried to check if there is any exception that occurs during the training loop and print it but nothing is printed.
Solution
The problem is with your for
loop that iterates over training data:
for step, (target, input_image) in enumerate(train_loader):
The way it's written, it will iterate once over the data in train_loader
, and stop.
Instead, you want something like:
total_steps = 0
max_steps = 1000000 # some large value
while total_steps < max_steps:
for step, (target, input_image) in enumerate(train_loader):
# do something
total_steps += step
This will terminate after max_steps
plus some leftover of steps.
Answered By - Yakov Dan
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.