-
-
Notifications
You must be signed in to change notification settings - Fork 6
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add similar
/setindex!
, update argmax
/argmin
#10
base: main
Are you sure you want to change the base?
Conversation
…ng conversion methods
function Base.similar(::OneHotArray{T, L}, ::Type{Bool}, dims::Dims) where {T, L} | ||
if first(dims) == L | ||
indices = ones(T, Base.tail(dims)) | ||
return OneHotArray(indices, first(dims)) | ||
else | ||
return BitArray(undef, dims) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you say what this similar
method is for?
It seems surprising to me to return something which isn't writable. Elsewhere the pattern is that the method which takes a size returns a full matrix, but without the size, a structured one:
julia> similar(Diagonal(1:3), Float32)
3×3 Diagonal{Float32, Vector{Float32}}:
0.0 ⋅ ⋅
⋅ 0.0 ⋅
⋅ ⋅ 2.09389f-37
julia> similar(Diagonal(1:3), Float32, (3,3))
3×3 Matrix{Float32}:
0.0 0.0 0.0
0.0 0.0 0.0
0.0 0.0 0.0
function OneHotArray(x::AbstractVector) | ||
!_onehot_compatible(x) && error("Array is not onehot compatible") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What arrays do you want this to accept, which aren't AbstractVector{Bool}
?
And could "not onehot compatible" explain a bit more what it means by compatible?
cart_inds = CartesianIndex.(_indices(x), CartesianIndices(_indices(x))) | ||
return reshape(cart_inds, (1, size(_indices(x))...)) | ||
else | ||
return argmax(BitArray(x); dims=dims) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is copying here better than invoke
ing the more generic method?
@@ -69,6 +87,30 @@ Base.print_array(io::IO, X::LinearAlgebra.AdjOrTrans{Bool, <:OneHotLike{T, L, N, | |||
_onehot_bool_type(::OneHotLike{<:Any, <:Any, <:Any, N, <:Union{Integer, AbstractArray}}) where N = Array{Bool, N} | |||
_onehot_bool_type(::OneHotLike{<:Any, <:Any, <:Any, N, <:CuArray}) where N = CuArray{Bool, N} | |||
|
|||
_onehot_compatible(x::OneHotLike) = _isonehot(x) | |||
_onehot_compatible(x::AbstractVector{Bool}) = count(x) == 1 | |||
_onehot_compatible(x::AbstractArray{Bool}) = all(isone, reduce(+, x; dims=1)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Slightly tidier, maybe:
_onehot_compatible(x::AbstractArray{Bool}) = all(isone, reduce(+, x; dims=1)) | |
_onehot_compatible(x::AbstractArray{Bool}) = all(isone, count(x; dims=1)) |
This addresses #6. Also adds some conversion methods, and expands the arrays that the fast
hcat
should work with (as far as I can tell this should still be valid).