-
-
Notifications
You must be signed in to change notification settings - Fork 608
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add option to throw error on passing wrong precision floats to layers #2454
Comments
This seems like a good idea. Maybe it should literally be the same switch as That function belongs to GPUArraysCore, which Flux doesn't directly load right now, but NNlib does. |
When you say "same switch" do you mean defining something like |
I also think it makes sense to have a separate function for each of these checks. But having a |
This was my suggestion! Float precision is a big deal on GPU and not otherwise. There are basically two modes of using it:
Of course we could invent some new switch that we own to control this. But then it's one more mysterious function you have to know about. One more kind of mutable state. |
Note that CUDA.jl has switched the default to be disallowing scalar access. Maybe that means using the same switch is a worse idea. So if we own a switch, what's a good name for it? julia> using CUDA
julia> first(CUDA.randn(32))
ERROR: Scalar indexing is disallowed.
Invocation of getindex resulted in scalar indexing of a GPU array.
...
julia> CUDA.allowscalar(true)
┌ Warning: It's not recommended to use allowscalar([true]) to allow scalar indexing.
│ Instead, use `allowscalar() do end` or `@allowscalar` to denote exactly which operations can use scalar operations.
└ @ GPUArraysCore ~/.julia/packages/GPUArraysCore/GMsgk/src/GPUArraysCore.jl:188
julia> first(CUDA.randn(32))
0.37104183f0
(@v1.11) pkg> st CUDA
Status `~/.julia/environments/v1.11/Project.toml`
[052768ef] CUDA v5.5.2 |
Motivation and description
The warning about wrong precision is very helpful to point at potential performance issues
Flux.jl/src/layers/stateless.jl
Line 60 in 2f19e68
I think that this is the correct default behavior. However, in order to find out where the problem is coming from throwing an error to produce a stacktrace would be very helpful.
Possible Implementation
There could be a Preference or global flag that allows switching errors instead of warning for wrong precision inputs. This would also be made consistent with
CUDA.allowscalar
.The text was updated successfully, but these errors were encountered: