Issue
Unfortunately, I get the following RuntimeError:
This error hapens at epoch 1 during the last batch (so all other batches run through), and I don't know what causes the error in my code. Here is a code snippet of my function
def gradient_penalty(critic, real, fake, device):
BATCH_SIZE, C, H, W = real.shape
epsilon = torch.rand(size = (BATCH_SIZE, 1, 1, 1)).repeat(1, C, H, W).to(device)
# generate tensor filles only with ones
x = torch.ones(size = (BATCH_SIZE, C, H, W), dtype = int)
# interpolate images
interpolated_images = real * epsilon + fake * (x - epsilon)
The variable real
stands for images and has the shape (128, 3, 64, 64)
.
I need to admit that I don't find the error message concretely, i. e. where don't the shapes of the tensors coincide?
Any help would be appreciated!
Solution
You can discard incomplete batches when instantiating a DataLoader
with the drop_last
argument:
torch.utils.data.DataLoader(trainset, batch_size=128, discard_last=True)
However, this seems a bit drastic measure since 128 elements from your dataset will go to waste.
Answered By - Ivan
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.