diff --git a/Project.toml b/Project.toml index 845fcb1ca..725d9a1f7 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "Zygote" uuid = "e88e6eb3-aa80-5325-afca-941959d7151f" -version = "0.7.0" +version = "0.7.1" [deps] AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c" diff --git a/docs/src/limitations.md b/docs/src/limitations.md index 4e9012ced..14a8a8a8e 100644 --- a/docs/src/limitations.md +++ b/docs/src/limitations.md @@ -20,7 +20,6 @@ Let's explore this with a more concrete example. Here we define a simple mutatin ```julia function f!(x) x .= 2 .* x - return x end ``` @@ -42,43 +41,36 @@ Stacktrace: ... ``` We got an error message and a long stacktrace. The error informs us that our code performs array mutation by calling `copyto!` (we might not have directly called this function, but it is being invoked somewhere in the call stack). We see that our code includes `x .= ...` which is given as an example of array mutation. Other examples of mutating operations include: -- setting values (`x .= ...`) -- appending/popping values (`push!(x, v)` / `pop!(x)`) -- calling mutating functions (`mul!(C, A, B)`) +- setting values (`x[i] = val` or `x .= values`) +- appending/popping values (`push!(x, v)` or `pop!(x)`) +- calling mutating functions (such as `LinearAlgebra.mul!(C, A, B)`) !!! warning Non-mutating functions might also use mutation under the hood. This can be done for performance reasons or code re-use. ```julia -function g!(x, y) - x .= 2 .* y - +function g_inner!(x, y) + for i in eachindex(x, y) + x[i] = 2 * y[i] + end return x end -g(y) = g!(similar(y), y) -``` -Here `g` is a "non-mutating function," and it indeed does not mutate `y`, its only argument. But it still allocates a new array and calls `g!` on this array which will result in a mutating operation. You may encounter such functions when working with another package. - -Specifically for array mutation, we can use [`Zygote.Buffer`](@ref) to re-write our function. For example, let's fix the function `g!` above. -```julia -function g!(x, y) - x .= 2 .* y - return x +function g_outer(y) + z = similar(y) + g_inner!(z, y) + return z end +``` +Here `g_outer` does not mutate `y`, its only argument. But it still allocates a new array `z` and calls `g_inner!` on this array, which will result in a mutating operation. You may encounter such functions when working with another package. -function g(y) - x = Zygote.Buffer(y) # Buffer supports syntax like similar - g!(x, y) - return copy(x) # this step makes the Buffer immutable (w/o actually copying) -end +How can you solve this problem? +* Re-write the code not to use mutation. Here we can obviously write `g_better(y) = 2 .* y` using broadcasting. Many other cases may be solved by writing comprehensions `[f(x, y) for x in xs, y in ys]` or using `map(f, xs, ys)`, instead of explicitly allocating an output array and then writing into it. +* Write a custom rule, defining `rrule(::typeof(g), y)` using what you know about `g` to derive the right expression. +* Use another AD package instead of Zygote for part of the calculation. Replacing `g(y)` with `Zygote.forwarddiff(g, y)` will compute the same value, but when it is time to find the gradient, this job is outsourced to [ForwardDiff.jl](https://github.com/JuliaDiff/ForwardDiff.jl). ForwardDiff has its own limitations but mutation isn't one of them. -julia> gradient(rand(3)) do y - sum(g(y)) - end -([2.0, 2.0, 2.0],) -``` +Finally, there is also [`Zygote.Buffer`](@ref) which aims to handle the pattern of allocating space and then mutating it. But it has many bugs and is not really recommended. ## Try-catch statements @@ -136,7 +128,8 @@ For all of the errors above, the suggested solutions are similar. You have the f 2. define a [custom `ChainRulesCore.rrule`](https://juliadiff.org/ChainRulesCore.jl/stable/rule_author/example.html) 3. open an [issue on Zygote](https://github.com/FluxML/Zygote.jl/issues) -Avoiding the operation is simple, just don't do it! If you are using a mutating function, try to use a non-mutating variant. If you are using `try`/`catch` statements, try to use more graceful error handling such as returning `nothing` or another sentinel value. Recall that array mutation can also be avoided by using [`Zygote.Buffer`](@ref) as discussed above. +Avoiding the operation is simple, just don't do it! If you are using a mutating function, try to use a non-mutating variant. Instead of allocating an array and writing into it, try to make the output directly using broadcasting, `map`, or a comprehension. +If you are using `try`/`catch` statements, try to use more graceful error handling such as returning `nothing` or another sentinel value. Sometimes, we cannot avoid expressions that Zygote cannot differentiate, but we may be able to manually derive a gradient. In these cases, you can write [a custom `rrule`](https://juliadiff.org/ChainRulesCore.jl/stable/rule_author/example.html) using ChainRules.jl. Please refer to the linked ChainRules documentation for how to do this. _This solution is the only solution available for foreign call expressions._ Below, we provide a custom `rrule` for `jclock`. ```julia