diff --git a/src/serialization/onnx/net/layers/core/activation_onnx.cpp b/src/serialization/onnx/net/layers/core/activation_onnx.cpp index e80ca8457..1ced4e0ea 100644 --- a/src/serialization/onnx/net/layers/core/activation_onnx.cpp +++ b/src/serialization/onnx/net/layers/core/activation_onnx.cpp @@ -202,11 +202,20 @@ Layer* build_softmax_layer(onnx::NodeProto *node, } int parent_dims = parent->output->getShape().size(); + if (axis < 0) // Check if the target axis is a negative index axis = parent_dims + axis; // Get the target axis index if (axis < 0 || axis >= parent_dims) // Check for invalid axis index msg("The target axis for Softmax is not valid: axis = " + to_string(axis), "ONNX::ImportNet"); + if (axis == 0 && parent_dims == 2) + axis = 1; // let us correct the problem of axis = 0 when importing a model in ONNX format where input shape does not contain the batch_size + + if (axis == 0) { // Second check for invalid axis index + std::cerr << __FILE__ << "(" << __LINE__ << ") axis = " << axis << " shape.size = " << parent_dims << endl; + msg("The target axis for Softmax is not valid: axis = " + to_string(axis), "ONNX::ImportNet"); + } + return new LActivation(parent, "softmax", {static_cast(axis)}, node->name(), dev, mem); }