Skip to content

Commit

Permalink
Provisional fix of axis for softmax when importing from ONNX
Browse files Browse the repository at this point in the history
  • Loading branch information
jonandergomez committed Mar 31, 2023
1 parent 69fa02d commit 93e93ad
Showing 1 changed file with 9 additions and 0 deletions.
9 changes: 9 additions & 0 deletions src/serialization/onnx/net/layers/core/activation_onnx.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -202,11 +202,20 @@ Layer* build_softmax_layer(onnx::NodeProto *node,
}

int parent_dims = parent->output->getShape().size();

if (axis < 0) // Check if the target axis is a negative index
axis = parent_dims + axis; // Get the target axis index
if (axis < 0 || axis >= parent_dims) // Check for invalid axis index
msg("The target axis for Softmax is not valid: axis = " + to_string(axis), "ONNX::ImportNet");

if (axis == 0 && parent_dims == 2)
axis = 1; // let us correct the problem of axis = 0 when importing a model in ONNX format where input shape does not contain the batch_size

if (axis == 0) { // Second check for invalid axis index
std::cerr << __FILE__ << "(" << __LINE__ << ") axis = " << axis << " shape.size = " << parent_dims << endl;
msg("The target axis for Softmax is not valid: axis = " + to_string(axis), "ONNX::ImportNet");
}

return new LActivation(parent, "softmax", {static_cast<float>(axis)}, node->name(), dev, mem);
}

Expand Down

0 comments on commit 93e93ad

Please sign in to comment.