diff --git a/src/onehot.jl b/src/onehot.jl index 986fb9d02c..32a2a19bca 100644 --- a/src/onehot.jl +++ b/src/onehot.jl @@ -190,14 +190,14 @@ end for wrapper in [:Adjoint, :Transpose] @eval begin function Base.:*(A::$wrapper{<:Any, <:AbstractMatrix{T}}, b::OneHotVector{<:Any, L}) where {L, T} - size(A, 2) != L || + size(A, 2) == L || throw(DimensionMismatch("Matrix column must correspond with OneHot size: $(size(A, 2)) != $L")) return A[:, b.ix] end function Base.:*(A::$wrapper{<:Number, <:AbstractVector{T}}, b::OneHotVector{<:Any, L}) where {L, T} - size(A, 2) != L || + size(A, 2) == L || throw(DimensionMismatch("Matrix column must correspond with OneHot size: $(size(A, 2)) != $L")) return A[b.ix]