diff --git a/src/rule_definition_tools.jl b/src/rule_definition_tools.jl index 170798d14..ef2db6b42 100644 --- a/src/rule_definition_tools.jl +++ b/src/rule_definition_tools.jl @@ -403,13 +403,13 @@ function _nondiff_frule_expr(__source__, primal_sig_parts, primal_invoke) function (::Core.kwftype(typeof(ChainRulesCore.frule)))( @nospecialize($kwargs::Any), frule::typeof(ChainRulesCore.frule), - ::$RuleConfig, + ::Tuple, $(map(esc, primal_sig_parts)...), ) return ($(esc(_with_kwargs_expr(primal_invoke, kwargs))), NoTangent()) end function ChainRulesCore.frule( - ::$RuleConfig, $(map(esc, primal_sig_parts)...) + ::Tuple, $(map(esc, primal_sig_parts)...) ) $(__source__) # Julia functions always only have 1 output, so return a single NoTangent() diff --git a/test/rule_definition_tools.jl b/test/rule_definition_tools.jl index 05c4d8389..43863a915 100644 --- a/test/rule_definition_tools.jl +++ b/test/rule_definition_tools.jl @@ -220,14 +220,14 @@ end foo_ndc1(x) = string(x) @non_differentiable foo_ndc1(x) - @test frule(AllConfig(), foo_ndc1, 2.0) == (string(2.0), NoTangent()) + @test frule(AllConfig(), (NoTangent(), NoTangent()), foo_ndc1, 2.0) == (string(2.0), NoTangent()) r1, pb1 = rrule(AllConfig(), foo_ndc1, 2.0) @test r1 == string(2.0) @test pb1(NoTangent()) == (NoTangent(), NoTangent()) foo_ndc2(x; y=0) = string(x + y) @non_differentiable foo_ndc2(x) - @test frule(AllConfig(), foo_ndc2, 2.0; y=4.0) == (string(6.0), NoTangent()) + @test frule(AllConfig(), (NoTangent(), NoTangent()), foo_ndc2, 2.0; y=4.0) == (string(6.0), NoTangent()) r2, pb2 = rrule(AllConfig(), foo_ndc2, 2.0; y=4.0) @test r2 == string(6.0) @test pb2(NoTangent()) == (NoTangent(), NoTangent())