Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sorting out issue with calling predict on matrix #219

Merged
merged 13 commits into from
Apr 24, 2023
1 change: 1 addition & 0 deletions src/core.jl
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,7 @@ by `model.batch_size`.)

"""
function collate(model, X, y)
y = y isa Matrix ? Tables.table(y) : y
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm trying to understand why this fix was needed. The nrows function is supposed to work for matrices as well as tables:

using MLJBase
julia> y = rand(2, 6)
julia> nrows(y)
2

julia> nrows(y')
6

And line 230 below should already take care of the conversion of y to a matrix, no?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was giving me an error previously, because the nrows method defined in the same file expects a table (or else it throws an ArgumentError. I've adjusted that and tests are passing.

row_batches = Base.Iterators.partition(1:nrows(y), model.batch_size)
Xmatrix = reformat(X)
ymatrix = reformat(y)
Expand Down