Issue
I have this training loop
def train(dataloader, model, loss_fn, optimizer):
size = len(dataloader.dataset)
model.train()
for batch, (X, y) in enumerate(dataloader):
X, y = torch.stack(X).to(device), torch.stack(y).to(device)
# Compute prediction error
pred = model(X)
loss = loss_fn(pred, y)
# Backpropagation
optimizer.zero_grad()
loss.backward()
optimizer.step()
if batch % 100 == 0:
loss, current = loss.item(), batch * len(X)
print(f"loss: {loss:>7f} [{current:>5d}/{size:>5d}]")
and this lstm:
import torch
import torch.nn as nn
import pandas as pd
import numpy as np
class BELT_LSTM(nn.Module):
def __init__(self, input_size, hidden_size, num_layers):
super (BELT_LSTM, self).__init__()
self.hidden_size = hidden_size
self.num_layers = num_layers
self.input_size = input_size
self.BELT_LSTM = nn.LSTM(input_size, hidden_size, num_layers)
def forward(self, x):
# receive an input, create a new hidden state, return output?
# reset the hidden state?
hidden = (torch.zeros(self.num_layers, self.hidden_size), torch.zeros(self.num_layers, self.hidden_size))
x, _ = self.BELT_LSTM(x, hidden)
#since our observation has several sequences, we only want the output after the last sequence of the observation'''
x = x[:, -1]
return x
and here's the dataset class:
from __future__ import print_function, division
import os
import torch
import pandas as pd
import numpy as np
import math
from torch.utils.data import Dataset, DataLoader
class rcvLSTMDataSet(Dataset):
"""rcv dataset."""
TIMESTEPS = 10
def __init__(self, csv_data_file, annotations_file):
"""
Args:
csv_data_file (string): Path to the csv file with the training data
annotations_file (string): Path to the file with the annotations
"""
self.csv_data_file = csv_data_file
self.annotations_file = annotations_file
self.labels = pd.read_csv(annotations_file)
self.data = pd.read_csv(csv_data_file)
def __len__(self):
return math.floor(len(self.labels) / 10)
def __getitem__(self, idx):
"""
pytorch expects whatever data is returned is in the form of a tensor. Included, it expects the label for the data.
Together, they make a tuple.
"""
# convert every ten indexes and label into one observation
Observation = []
counter = 0
start_pos = self.TIMESTEPS *idx
avg_avg_1 = 0
avg_avg_2 = 0
avg_avg_3 = 0
while counter < self.TIMESTEPS:
Observation.append(self.data.iloc[idx + counter].values)
avg_avg_1 += self.labels.iloc[idx + counter][2]
avg_avg_2 += self.labels.iloc[idx + counter][1]
avg_avg_3 += self.labels.iloc[idx + counter][0]
counter += 1
#average the avg_1, avg_2, avg_3 for TIMESTEPS length
avg_avg_1 = avg_avg_1 / self.TIMESTEPS
avg_avg_2 = avg_avg_2 / self.TIMESTEPS
avg_avg_3 = avg_avg_3 / self.TIMESTEPS
current_labels = [avg_avg_1, avg_avg_2, avg_avg_3]
print(current_labels)
return Observation, current_labels
def main():
loader = rcvDataSet('foo','foo2.csv')
j = 0
while j < len(loader.data % loader.TIMESTEPS):
print(loader.__getitem__(j))
j += 1
if "__main__" == __name__:
main()
When running this, i get:
File "module.py", line 1110, in _call_impl
return forward_call(*input, **kwargs)
File "lstm.py", line 21, in forward
x, _ = self.BELT_LSTM(x, hidden)
File "module.py", line 1110, in _call_impl
return forward_call(*input, **kwargs)
File "rnn.py", line 747, in forward
raise RuntimeError(msg)
RuntimeError: For batched 3-D input, hx and cx should also be 3-D but got (2-D, 2-D) tensors
but as far as i can tell, i've followed the nn.LSTM instructions for both setting up the layers, and shaping the data properly. What am i doing wrong?
For reference, the incoming data is rows from a csv file, 12 columns wide, and i serve 10 rows per observation
Thanks
Solution
Your code:
hidden = (torch.zeros(self.num_layers, self.hidden_size),
torch.zeros(self.num_layers, self.hidden_size))
x, _ = self.BELT_LSTM(x, hidden)
Here hx and cx are both 2-D tensors. The correct way should be:
h_0 = torch.randn(self.num_directions*self.num_layers,
self.batch_size,
self.hidden_size)
c_0 = torch.randn(self.num_directions*self.num_layers,
self.batch_size,
self.hidden_size)
x, _ = self.BELT_LSTM(x, (h_0, c_0))
Answered By - ki-ljl
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.