diff --git a/src/executor/infer_graph_attr_pass.cc b/src/executor/infer_graph_attr_pass.cc index 80e4084c478e..4b6ee2e1dc0d 100644 --- a/src/executor/infer_graph_attr_pass.cc +++ b/src/executor/infer_graph_attr_pass.cc @@ -67,6 +67,7 @@ template inline void GetAttrFromForwardNode(const uint32_t nid, const nnvm::IndexedGraph &idx, std::vector* rshape_ptr, + std::vector* inference_finished, IsNone fis_none) { std::vector& rshape = *rshape_ptr; const nnvm::IndexedGraph::Node& inode = idx[nid]; @@ -83,18 +84,23 @@ inline void GetAttrFromForwardNode(const uint32_t nid, // input gradient list const std::vector& igrad = fgrad[fwd_ptr->op()](fwd_ptr, ograd); const nnvm::Node* igrad_node = nullptr; + bool all_attrs_known = true; // Input gradient assignement for (size_t i = 0; i < igrad.size(); ++i) { if (igrad[i].node->op() == inode.source->op()) { uint32_t eid = idx.entry_id(nid, igrad[i].index); - if (fis_none(rshape[eid])) { - rshape[eid] = rshape[idx.entry_id(fnode.inputs[i])]; - } else if (!fis_none(rshape[idx.entry_id(fnode.inputs[i])])) { + if (fis_none(rshape[idx.entry_id(fnode.inputs[i])])) { // Need to skip empty forward shape, because it may not be // available now and it is possible to infer the forward // shape in one of the next a few passes - CHECK_EQ(rshape[eid], rshape[idx.entry_id(fnode.inputs[i])]) - << "Backward shape inconsistent with the forward shape"; + all_attrs_known = false; + } else { + if (fis_none(rshape[eid])) { + rshape[eid] = rshape[idx.entry_id(fnode.inputs[i])]; + } else { + CHECK_EQ(rshape[eid], rshape[idx.entry_id(fnode.inputs[i])]) + << "Backward shape inconsistent with the forward shape"; + } } if (igrad_node == nullptr) { igrad_node = igrad[i].node.get(); @@ -113,14 +119,20 @@ inline void GetAttrFromForwardNode(const uint32_t nid, if (fis_none(rshape[eid])) { rshape[eid] = rshape[idx.entry_id(inode.control_deps[0], e.index)]; } + if (fis_none(rshape[eid])) { + // If the attr is still unknown + all_attrs_known = false; + } } } + (*inference_finished)[nid] = all_attrs_known; } template void GetAttrFromFusedNode(uint32_t nid, const nnvm::IndexedGraph& idx, std::vector* rshape_ptr, + std::vector* inference_finished, IsNone fis_none, const std::string& infer_fusion_name) { std::vector& rshape = *rshape_ptr; @@ -147,19 +159,24 @@ void GetAttrFromFusedNode(uint32_t nid, // input gradient list const std::vector& igrad = fgrad[fwd_ptr->op()](fwd_ptr, ograd); const nnvm::Node* igrad_node = nullptr; + bool all_attrs_known = true; // Set the attributes of output gradients // using attributes of forward node inputs for (size_t i = 0; i < igrad.size(); ++i) { if (igrad[i].node->op() == inode.source->op()) { uint32_t eid = idx.entry_id(nid, igrad[i].index); - if (fis_none(rshape[eid])) { - rshape[eid] = input_attrs[i]; - } else if (!fis_none(input_attrs[i])) { + if (fis_none(input_attrs[i])) { // Need to skip empty forward shape, because it may not be // available now and it is possible to infer the forward // shape in one of the next a few passes - CHECK_EQ(rshape[eid], input_attrs[i]) - << "Backward shape inconsistent with the forward shape"; + all_attrs_known = false; + } else { + if (fis_none(rshape[eid])) { + rshape[eid] = input_attrs[i]; + } else { + CHECK_EQ(rshape[eid], input_attrs[i]) + << "Backward shape inconsistent with the forward shape"; + } } if (igrad_node == nullptr) { igrad_node = igrad[i].node.get(); @@ -180,8 +197,13 @@ void GetAttrFromFusedNode(uint32_t nid, if (fis_none(rshape[eid])) { rshape[eid] = output_attrs[e.index]; } + if (fis_none(rshape[eid])) { + // If the attr is still unknown + all_attrs_known = false; + } } } + (*inference_finished)[nid] = all_attrs_known; } template @@ -270,6 +292,9 @@ nnvm::Graph InferAttr(nnvm::Graph &&ret, Op::GetAttr("TIsBackward"); // reshape shape vector AttrVector rshape; + // vector holding information which operators + // finished attribute inference + std::vector inference_finished(idx.num_nodes(), false); // dispatch mode vector DispatchModeVector dispatch_modes; if (ret.attrs.count(attr_name) != 0) { @@ -340,6 +365,7 @@ nnvm::Graph InferAttr(nnvm::Graph &&ret, // inference step function for nid auto infer_step = [&](uint32_t nid, bool last_iter) { + if (inference_finished[nid]) return; const auto& inode = idx[nid]; const uint32_t num_inputs = inode.inputs.size(); const uint32_t num_outputs = inode.source->num_outputs(); @@ -355,6 +381,9 @@ nnvm::Graph InferAttr(nnvm::Graph &&ret, CHECK(is >> rshape[out_ent_id]) << "Invalid attribute"; } } + if (!fis_none(rshape[out_ent_id])) { + inference_finished[nid] = true; + } // assign a default value to node attribute if (dispatch_mode_name != nullptr) { op::dispatch_mode_assign(&dispatch_modes[nid], default_mode_val); @@ -370,47 +399,66 @@ nnvm::Graph InferAttr(nnvm::Graph &&ret, static auto& is_fusion_helper = Op::GetAttr("TIsFusionHelper"); if (!is_fusion_helper.get(fwd_ptr->op(), false)) { - GetAttrFromForwardNode(nid, idx, &rshape, fis_none); + GetAttrFromForwardNode(nid, idx, &rshape, &inference_finished, fis_none); } else { - GetAttrFromFusedNode(nid, idx, &rshape, fis_none, infer_fusion_name); + GetAttrFromFusedNode(nid, idx, &rshape, &inference_finished, + fis_none, infer_fusion_name); } } else { DispatchMode* dispatch_mode = nullptr; - bool forward_known = true; // Forward operator inference. ishape.resize(num_inputs, empty_val); for (uint32_t i = 0; i < ishape.size(); ++i) { ishape[i] = rshape[idx.entry_id(inode.inputs[i])]; - if (fis_none(ishape[i])) forward_known = false; } oshape.resize(num_outputs, empty_val); for (uint32_t i = 0; i < oshape.size(); ++i) { oshape[i] = rshape[idx.entry_id(nid, i)]; - if (fis_none(oshape[i])) forward_known = false; } if (dispatch_mode_name != nullptr) { dispatch_mode = &dispatch_modes[nid]; - if (dispatch_modes[nid] == DispatchMode::kUndefined) forward_known = false; } auto finfer = finfer_shape.get(inode.source->op(), fdefault); - if (!forward_known) { - if (finfer != nullptr) { - // Call inference function of the operator. - try { - static auto& is_fusion = Op::GetAttr("TIsFusion"); - if (is_fusion.get(inode.source->op(), false)) { - ProvideAttrToFusion(nid, idx, rshape, provide_fusion_name); - } - forward_known = ApplyOpInferAttr(ret, finfer, inode.source->attrs, - nid, &ishape, &oshape, dispatch_mode); - } catch (const std::exception& e) { - throw dmlc::Error("Error in operator " + inode.source->attrs.name + ": " + e.what()); + if (finfer != nullptr) { + // Call inference function of the operator. + try { + static auto& is_fusion = Op::GetAttr("TIsFusion"); + if (is_fusion.get(inode.source->op(), false)) { + ProvideAttrToFusion(nid, idx, rshape, provide_fusion_name); } - } else { + ApplyOpInferAttr(ret, finfer, inode.source->attrs, + nid, &ishape, &oshape, dispatch_mode); + bool finished = true; + for (const auto& attr : ishape) { + if (fis_none(attr)) finished = false; + } + for (const auto& attr : oshape) { + if (fis_none(attr)) finished = false; + } + inference_finished[nid] = finished; + } catch (const std::exception& e) { + throw dmlc::Error("Error in operator " + inode.source->attrs.name + ": " + e.what()); + } + } else { + // Operator does not provide sttribute inference function, + // so we need to test if everything was inferred by other operators + bool all_attrs_known = true; + for (const auto& attr : ishape) { + if (fis_none(attr)) { + all_attrs_known = false; + } + } + for (const auto& attr : oshape) { + if (fis_none(attr)) { + all_attrs_known = false; + } + } + inference_finished[nid] = all_attrs_known; + if (!all_attrs_known) { CHECK(!last_iter) << "Attribute " << infer_name - << " is not registed by op " << inode.source->op()->name - << " we are not able to complete the inference because of this"; + << " is not registered by op " << inode.source->op()->name + << ". We are not able to complete the inference because of this"; } } // Save to the result map. @@ -427,16 +475,18 @@ nnvm::Graph InferAttr(nnvm::Graph &&ret, size_t num_unknown_dispatch_mode = dispatch_mode_name ? node_end - node_start : 0; size_t num_unknown_entry_attr = entry_end - entry_start; size_t num_unknown = num_unknown_entry_attr + num_unknown_dispatch_mode; + bool last_iter = false; + bool do_next_iteration = true; int i = 0; do { if (i % 2 == 0) { for (uint32_t nid = node_start; nid < node_end; ++nid) { - infer_step(nid, false); + infer_step(nid, last_iter); } } else { // backward inference for (uint32_t i = node_end; i != node_start; --i) { - infer_step(i - 1, false); + infer_step(i - 1, last_iter); } } last_num_unknown = num_unknown; @@ -451,8 +501,18 @@ nnvm::Graph InferAttr(nnvm::Graph &&ret, if (dispatch_modes[i] == DispatchMode::kUndefined) ++num_unknown; } } + do_next_iteration = num_unknown > 0 && last_num_unknown > num_unknown; + if (!do_next_iteration && !last_iter) { + // Check if every op agrees that it should be + // the end of attribute inference. If not, + // perform one final step + for (const bool done : inference_finished) { + do_next_iteration = do_next_iteration || !done; + } + last_iter = true; + } ++i; - } while (num_unknown > 0 && last_num_unknown > num_unknown); + } while (do_next_iteration); // set the shapes ret.attrs[attr_name] = std::make_shared(std::move(rshape)); // set the shapes @@ -517,6 +577,9 @@ nnvm::Graph InferShapeAttr(nnvm::Graph &&ret, Op::GetAttr("TIsBackward"); // reshape shape vector AttrVector rshape; + // vector holding information which operators + // finished attribute inference + std::vector inference_finished(idx.num_nodes(), false); // dispatch mode vector DispatchModeVector dispatch_modes; if (ret.attrs.count(attr_name) != 0) { @@ -594,6 +657,7 @@ nnvm::Graph InferShapeAttr(nnvm::Graph &&ret, // inference step function for nid auto infer_step = [&](uint32_t nid, bool last_iter) { + if (inference_finished[nid]) return; const auto& inode = idx[nid]; const std::string name = inode.source->attrs.name; const uint32_t num_inputs = inode.inputs.size(); @@ -613,6 +677,9 @@ nnvm::Graph InferShapeAttr(nnvm::Graph &&ret, } } } + if (!fis_none(rshape[out_ent_id])) { + inference_finished[nid] = true; + } // assign a default value to node attribute if (dispatch_mode_name != nullptr) { op::dispatch_mode_assign(&dispatch_modes[nid], default_mode_val); @@ -628,14 +695,15 @@ nnvm::Graph InferShapeAttr(nnvm::Graph &&ret, static auto& is_fusion_helper = Op::GetAttr("TIsFusionHelper"); if (!is_fusion_helper.get(fwd_ptr->op(), false)) { - GetAttrFromForwardNode(nid, idx, &rshape, fis_none); + GetAttrFromForwardNode(nid, idx, &rshape, &inference_finished, fis_none); } else { - GetAttrFromFusedNode(nid, idx, &rshape, fis_none, + GetAttrFromFusedNode(nid, idx, &rshape, + &inference_finished, + fis_none, "FAccessSubgraphShape"); } } else { DispatchMode* dispatch_mode = nullptr; - bool forward_known = true; // Forward operator inference. ishape.resize(num_inputs, empty_val); bool is_input_dynamic_shape = false; @@ -644,16 +712,13 @@ nnvm::Graph InferShapeAttr(nnvm::Graph &&ret, if (!mxnet::ndim_is_known(ishape[i]) && is_dynamic[idx.entry_id(inode.inputs[i])]) { is_input_dynamic_shape = true; } - if (fis_none(ishape[i])) forward_known = false; } oshape.resize(num_outputs, empty_val); for (uint32_t i = 0; i < oshape.size(); ++i) { oshape[i] = rshape[idx.entry_id(nid, i)]; - if (fis_none(oshape[i])) forward_known = false; } if (dispatch_mode_name != nullptr) { dispatch_mode = &dispatch_modes[nid]; - if (dispatch_modes[nid] == DispatchMode::kUndefined) forward_known = false; } auto finfer = finfer_shape.get(inode.source->op(), fdefault); if (finfer == nullptr || is_input_dynamic_shape) { @@ -662,25 +727,27 @@ nnvm::Graph InferShapeAttr(nnvm::Graph &&ret, is_dynamic[idx.entry_id(nid, i)] = 1; } } - } else if (!forward_known) { - if (finfer != nullptr) { - // Call inference function of the operator. - try { - static auto& is_fusion = Op::GetAttr("TIsFusion"); - if (is_fusion.get(inode.source->op(), false)) { - ProvideAttrToFusion(nid, idx, rshape, - "FProvideSubgraphShape"); - } - forward_known = ApplyOpInferAttr(ret, finfer, inode.source->attrs, - nid, &ishape, &oshape, dispatch_mode); - } catch (const std::exception& e) { - throw dmlc::Error("Error in operator " + inode.source->attrs.name + ": " + e.what()); + inference_finished[nid] = true; + } else { + // Call inference function of the operator. + try { + static auto& is_fusion = Op::GetAttr("TIsFusion"); + if (is_fusion.get(inode.source->op(), false)) { + ProvideAttrToFusion(nid, idx, rshape, + "FProvideSubgraphShape"); } - } else { - CHECK(!last_iter) - << "Attribute " << infer_name - << " is not registed by op " << inode.source->op()->name - << " we are not able to complete the inference because of this"; + ApplyOpInferAttr(ret, finfer, inode.source->attrs, + nid, &ishape, &oshape, dispatch_mode); + bool finished = true; + for (const auto& attr : ishape) { + if (fis_none(attr)) finished = false; + } + for (const auto& attr : oshape) { + if (fis_none(attr)) finished = false; + } + inference_finished[nid] = finished; + } catch (const std::exception& e) { + throw dmlc::Error("Error in operator " + inode.source->attrs.name + ": " + e.what()); } } // Save to the result map. @@ -695,18 +762,20 @@ nnvm::Graph InferShapeAttr(nnvm::Graph &&ret, size_t last_num_unknown; size_t num_unknown = static_cast(-1); // Infinity + bool last_iter = false; + bool do_next_iteration = true; int i = 0; do { if (i % 2 == 0) { // forward inference for (uint32_t nid = node_start; nid < node_end; ++nid) { - infer_step(nid, false); + infer_step(nid, last_iter); } } else { // backward inference for (uint32_t i = node_end; i != node_start; --i) { - infer_step(i - 1, false); + infer_step(i - 1, last_iter); } } last_num_unknown = num_unknown; @@ -723,8 +792,18 @@ nnvm::Graph InferShapeAttr(nnvm::Graph &&ret, } } } + do_next_iteration = num_unknown > 0 && last_num_unknown > num_unknown; + if (!do_next_iteration && !last_iter) { + // Check if every op agrees that it should be + // the end of attribute inference. If not, + // perform one final step + for (const bool done : inference_finished) { + do_next_iteration = do_next_iteration || !done; + } + last_iter = true; + } ++i; - } while (num_unknown > 0 && last_num_unknown > num_unknown); + } while (do_next_iteration); // set the shapes ret.attrs[attr_name] = std::make_shared(std::move(rshape)); // set the shapes diff --git a/src/operator/control_flow.cc b/src/operator/control_flow.cc index fd087ef39679..a9e9038e6c51 100644 --- a/src/operator/control_flow.cc +++ b/src/operator/control_flow.cc @@ -761,7 +761,7 @@ static bool WhileLoopType(const nnvm::NodeAttrs& attrs, std::vector func_in_type; extract_by_loc(*in_type, params.cond_input_locs, &cond_in_type); extract_by_loc(*in_type, params.func_input_locs, &func_in_type); - std::vector cond_out_type = {0}; + std::vector cond_out_type = {-1}; CHECK(params.sync_in_out(in_type, out_type, is_udf)); bool succ_0 = InferSubgraphDataType(*attrs.subgraphs[0], &cond_in_type, &cond_out_type); CHECK(params.sync_in_out(in_type, out_type, is_udf)); diff --git a/src/operator/tensor/matrix_op-inl.h b/src/operator/tensor/matrix_op-inl.h index 801e4e7126b4..0fee2a26c0ed 100644 --- a/src/operator/tensor/matrix_op-inl.h +++ b/src/operator/tensor/matrix_op-inl.h @@ -420,11 +420,9 @@ inline bool TransposeShape(const nnvm::NodeAttrs& attrs, mxnet::TShape& shp = (*in_attrs)[0]; mxnet::TShape& out_shp = (*out_attrs)[0]; CHECK_LE(shp.ndim(), 6) << "Transpose support at most 6 dimensions"; - CHECK_NE(shp.ndim(), 0) << "Number of dimensions cannot be 0"; - CHECK_NE(out_shp.ndim(), 0) << "Number of dimensions cannot be 0"; if (shp.ndim() == -1 && out_shp.ndim() == -1) return false; // none of the shapes is known - if (out_shp.ndim() > 0 && shp.ndim() > 0) + if (out_shp.ndim() >= 0 && shp.ndim() >= 0) CHECK_EQ(out_shp.ndim(), shp.ndim()); mxnet::TShape get(std::max(shp.ndim(), out_shp.ndim()), -1); mxnet::TShape ret(std::max(shp.ndim(), out_shp.ndim()), -1); diff --git a/tests/python/gpu/test_fusion.py b/tests/python/gpu/test_fusion.py index 693336f22496..24e33019f617 100644 --- a/tests/python/gpu/test_fusion.py +++ b/tests/python/gpu/test_fusion.py @@ -172,7 +172,7 @@ def check_binary_ops(): check_fused_symbol(3-a, a=arr1) check_fused_symbol(a*b, a=arr1, b=arr2) check_fused_symbol(a*3, a=arr1) - check_fused_symbol(a/b, a=arr1, b=arr2) + check_fused_symbol(a/(b+1), a=arr1, b=arr2) check_fused_symbol(a/3, a=arr1) check_fused_symbol(3/a, a=arr1) check_fused_symbol(a**b, a=arr1, b=arr2) diff --git a/tests/python/unittest/test_symbol.py b/tests/python/unittest/test_symbol.py index a2aad2c079fc..8e4fe11905cf 100644 --- a/tests/python/unittest/test_symbol.py +++ b/tests/python/unittest/test_symbol.py @@ -413,7 +413,7 @@ def test_gen_atomic_symbol_multiple_outputs(): p = mx.sym.Variable('param') h0 = mx.sym.Variable('h0') h1 = mx.sym.Variable('h1') - s = mx.sym.RNN(data, p, h0, h1, state_size=10, num_layers=2, + s = mx.sym.RNN(data, p, h0, h1, state_size=10, num_layers=2, bidirectional=True, state_outputs=True, mode='lstm') atomic_sym = s._gen_atomic_symbol() @@ -542,6 +542,21 @@ def get_net(): assert out_shapes[0] == (batch_size, num_hdidden) # output assert len(aux_shapes) == 0 +def test_infershape_happens_for_all_ops_in_graph(): + v = mx.sym.Variable('V') + s = mx.sym.transpose(v) + x = mx.sym.Variable('x') + s2 = x + v + s3 = s + s2 + with discard_stderr(): + try: + # This should throw an exception as you cannot add arrays + # with shapes [2,3] and [3,2] + e = s3.simple_bind(ctx=mx.cpu(), x=(2,3), grad_req='null') + except: + return + + assert False if __name__ == '__main__': import nose