-
Notifications
You must be signed in to change notification settings - Fork 104
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
When I activate 4 gpus at the same time, I get the following error #31
Comments
Hi, can you try to make the ctcloss parallel optimizer = nn.DataParallel(optimizer, device_ids=device_ids) |
In the code there's already: gpu number (params.ngpu): crnn = torch.nn.DataParallel(crnn, device_ids=range(params.ngpu)) Do I change it to something like cuda:0, cuda:1,...? like that? |
device_ids = [0,1,2,3]
model = nn.DataParallel(model, device_ids=device_ids)
optimizer = nn.DataParallel(optimizer, device_ids=device_ids) I forget to make the |
Alright, I will check it, but can you check out this link about gpus parallalesim? https://discuss.pytorch.org/t/multi-gpu-training-pipeline-in-0-4-1/32199 To recap what's written in this link: when we have an optimizer, the gpus will have one leader gpu to organize the data first, and then pass it to the other 3, then after all the work, one of the 4 wil get the whole data, and it must be the 1st one (the leader) who get these data (or else probably, we will get an error). Do you think this assumption/what's written in the link, is right? Thanks. |
Alright, I got the following error: AttributeError: 'RMSprop' object has no attribute 'cuda' |
I only have one p40, so I can’t test the code. I can only give you some suggestions. |
I haven't tested it yet too, but the error appears before the execution: the optimizer doesn't have cuda in it. I thought about deleting the optimizers (lol). |
Now I erased the optimizer (just to see the result) and the error is about: RNN cannot be divided and that I should add the flatten_parameters function to it. |
Problem here is that DataParallel parallelizes data on a specific dimension (which is by default 0). However, RNNs unless you use batch_first=True, chooses the 2nd dimension to be batch. Unfortunately convolutional part uses the 1st dimension for batch. DataParallel does good job with CNN part, but cannot work with RNN part. Only thing that can be done to solve this problem is to change RNN part to work with batch_first=True mode. |
Thanks for the answer, I will try it. So the problem is about the dimensions and not about the optimizer? In models/crnn.py file, do we change: self.rnn = nn.LSTM(nIn, nHidden, bidirectional=True, batch_first=True)? |
Yes, that's the first step. But then you need to adjust internal calculations such that the first dimension is batch. |
Okay, thanks. By internal calculations, you mean that the RNN input should be the 1st dimension of the batch? I kind of don't know where to put that? |
@furkankirac Thank you so much . I will fix it later. |
@mariembenslama It's not very straightforward. I don't know the details of the code. Maybe @Holmeyoung may help you. Best. |
Dear mariembenslama : |
@JasonBoy1 Hello, Actually no, I have been waiting for @Holmeyoung since he said he will be editing the code. |
Hi, i am so sorry i hav't fixed the bug yet. I have been doing recommendation these days. Can you try these code~~~ class BidirectionalLSTM(nn.Module):
def __init__(self, nIn, nHidden, nOut):
super(BidirectionalLSTM, self).__init__()
self.rnn = nn.LSTM(nIn, nHidden, bidirectional=True, batch_first=True)
self.embedding = nn.Linear(nHidden * 2, nOut)
def forward(self, input):
recurrent, _ = self.rnn(input)
b, T, h = recurrent.size()
t_rec = recurrent.view(T * b, h)
output = self.embedding(t_rec) # [T * b, nOut]
output = output.view(T, b, -1)
return output |
I have already solved this problem. This problem occurs due to the
and So the batch_size would be at dim=0 and DataParallel would not cat your output wrong. |
Thank you! I will try your solution and tell you :) |
Could you tell me what do we change in the self.rnn exactly? |
That's the same as the first one and not so simple as that.
I don't think this will make any difference. |
I cloned the project, directly changed as you told me and I got the following error: Traceback (most recent call last): |
I did that in www.github.com/meijieru/crnn.pytorch's code, and used warp_ctc loss. torch.nn.CTC_loss is different from warp_ctc. And I think you should use IPython.embed to debug and watch the size of the tensors. |
@mineshmathew |
@mineshmathew |
I have solved the problem. It seems that we cannot modify the tensor's shape easily. So, I just modify the input for CTC loss. The prediction matrix for my model is preds = model(inp).cpu(). Then, I use "preds.permute(1, 0, 2)" as the first input of CTC loss function. The problem is to keep the batch_size in the first dimension while your GPUs will combine them. |
@JasonBoy1 s fix doesn't work for me (I use pytorch CTC, not warp CTC). It is even complicated when you have more than one rnn layer. |
Here is a solution if you use pytorch CTC, and my pytorch version is 1.10.2. class BidirectionalLSTM(nn.Module):
def __init__(self, nIn, nHidden, nOut):
super(BidirectionalLSTM, self).__init__()
self.rnn = nn.LSTM(nIn, nHidden, bidirectional=True, batch_first=True)
self.embedding = nn.Linear(nHidden * 2, nOut)
def forward(self, input):
recurrent, _ = self.rnn(input)
b, T, h = recurrent.size()
#t_rec = recurrent.view(T * b, h)
t_rec = recurrent.reshape(T * b, h)
output = self.embedding(t_rec) # [T * b, nOut]
output = output.view(b, T, -1)
return output nn.LSTM behaves differently when batch_first is specified, so the output should be modified too. class CRNN(nn.Module):
def __init__(self, imgH, nc, nclass, nh, n_rnn=2, leakyRelu=False):
....
def forward(self, input):
# conv features
conv = self.cnn(input)
b, c, h, w = conv.size()
assert h == 1, "the height of conv must be 1"
conv = conv.squeeze(2)
#conv = conv.permute(2, 0, 1) # [w, b, c]
conv = conv.permute(0, 2, 1) # [b, w, c]
# rnn features
output = self.rnn(conv)
# add log_softmax to converge output
output = F.log_softmax(output, dim=2)
return output And before the loss is calculated, according to the format of CTCLoss: preds = crnn(image).permute(1, 0, 2) |
Okay, I referred this issue and modified according with @JasonBoy1 and @Rabbit19731, it works fine. In conclusion, it needs Then Plus, I added line for using multiple gpus under Thanks all. I can stop wandering about using multi-gpus with this issue. |
Hello, I'm using a aws instance with 4 gpus and when activated (in the params.py file - True multigpu and 4 for the number) I get the following error: (P.S: For 4, 3, 2 and even 1 which is incomprehensible even for 1):
CRNN(
(cnn): Sequential(
(conv0): Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(relu0): ReLU(inplace=True)
(pooling0): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(relu1): ReLU(inplace=True)
(pooling1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(conv2): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(batchnorm2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu2): ReLU(inplace=True)
(conv3): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(relu3): ReLU(inplace=True)
(pooling2): MaxPool2d(kernel_size=(2, 2), stride=(2, 1), padding=(0, 1), dilation=1, ceil_mode=False)
(conv4): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(batchnorm4): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu4): ReLU(inplace=True)
(conv5): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(relu5): ReLU(inplace=True)
(pooling3): MaxPool2d(kernel_size=(2, 2), stride=(2, 1), padding=(0, 1), dilation=1, ceil_mode=False)
(conv6): Conv2d(512, 512, kernel_size=(2, 2), stride=(1, 1))
(batchnorm6): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu6): ReLU(inplace=True)
)
(rnn): Sequential(
(0): BidirectionalLSTM(
(rnn): LSTM(512, 256, bidirectional=True)
(embedding): Linear(in_features=512, out_features=256, bias=True)
)
(1): BidirectionalLSTM(
(rnn): LSTM(256, 256, bidirectional=True)
(embedding): Linear(in_features=512, out_features=7116, bias=True)
)
)
)
/opt/conda/conda-bld/pytorch_1565272279342/work/aten/src/ATen/native/cudnn/RNN.cpp:1266: UserWarning: RNN module weights are not part of single contiguous chunk of memory. This means they need to be compacted at every call, possibly greatly increasing memory usage. To compact weights again call flatten_parameters().
/opt/conda/conda-bld/pytorch_1565272279342/work/aten/src/ATen/native/cudnn/RNN.cpp:1266: UserWarning: RNN module weights are not part of single contiguous chunk of memory. This means they need to be compacted at every call, possibly greatly increasing memory usage. To compact weights again call flatten_parameters().
/opt/conda/conda-bld/pytorch_1565272279342/work/aten/src/ATen/native/cudnn/RNN.cpp:1266: UserWarning: RNN module weights are not part of single contiguous chunk of memory. This means they need to be compacted at every call, possibly greatly increasing memory usage. To compact weights again call flatten_parameters().
/opt/conda/conda-bld/pytorch_1565272279342/work/aten/src/ATen/native/cudnn/RNN.cpp:1266: UserWarning: RNN module weights are not part of single contiguous chunk of memory. This means they need to be compacted at every call, possibly greatly increasing memory usage. To compact weights again call flatten_parameters().
/opt/conda/conda-bld/pytorch_1565272279342/work/aten/src/ATen/native/cudnn/RNN.cpp:1266: UserWarning: RNN module weights are not part of single contiguous chunk of memory. This means they need to be compacted at every call, possibly greatly increasing memory usage. To compact weights again call flatten_parameters().
/opt/conda/conda-bld/pytorch_1565272279342/work/aten/src/ATen/native/cudnn/RNN.cpp:1266: UserWarning: RNN module weights are not part of single contiguous chunk of memory. This means they need to be compacted at every call, possibly greatly increasing memory usage. To compact weights again call flatten_parameters().
Traceback (most recent call last):
File "train.py", line 253, in
cost = train(crnn, criterion, optimizer, train_iter)
File "train.py", line 241, in train
cost = criterion(preds, text, preds_size, length) / batch_size
File "/home/ubuntu/anaconda3/lib/python3.6/site-packages/torch/nn/modules/module.py", line 547, in call
result = self.forward(*input, **kwargs)
File "/home/ubuntu/anaconda3/lib/python3.6/site-packages/torch/nn/modules/loss.py", line 1295, in forward
self.zero_infinity)
File "/home/ubuntu/anaconda3/lib/python3.6/site-packages/torch/nn/functional.py", line 1767, in ctc_loss
zero_infinity)
RuntimeError: input_lengths must be of size batch_size
The text was updated successfully, but these errors were encountered: