Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature Request] Allow overloading a function by just changing an input argument from read to owned #3925

Open
1 task done
gabrieldemarmiesse opened this issue Jan 5, 2025 · 13 comments
Labels
enhancement New feature or request mojo-repo Tag all issues with this label

Comments

@gabrieldemarmiesse
Copy link
Contributor

gabrieldemarmiesse commented Jan 5, 2025

Review Mojo's priorities

What is your request?

I would like the Mojo language to allow overloading a function where the only change is the ownership keyword, for simplicity, on this proposal, only read and owned will be discussed.

Some concrete example:

Let's give a real-world example where it allows a library user to performs an optimization, without any intervention from the user:

# Not optimized
fn concat(x: List[Int], y: List[Int]) -> List[Int]:
    # Here x is read-only, so we have to create 
    # a new buffer (List) to allocate memory
    return x + y

# Optimized
fn concat(owned x: List[Int], y: List[Int]) -> List[Int]:
    # We can potentially avoid a malloc if the capacity
    # of x is high enough.
    x += y
    return x

fn main_not_optimized():
    x = List[Int](1, 2, 3)
    y = List[Int](4, 5, 6)
    z = concat(x, y)
    print(x) # x is reused here, so it can't be passed as owned to concat()


fn main_optimized():
    x = List[Int](1, 2, 3)
    y = List[Int](4, 5, 6)
    z = concat(x, y)
    # x could be passed as owned to concat(), so the optimization kicks in.

As you can see, it can allow library authors to "grab" values in the user code that are not used anymore and re-use memory and buffers. And this, without the user having to do anything.

Advantages

This kind of optimization can be added progressively in a library and it will kick in without the user being aware of it because it's backward compatible to add an new overload that owns the input argument. It also simplifies the API and the user can always trigger the optimization explicitly by using the transfer operator ^.

Without allowing this kind of overload, I would have, as a library author to create two distinct functions concat() and concat_consume() and the user would have to understand which function to use in which case. We can't even give a hint automatically to the user that an optimization here is possible.

Other real-world examples:

To really show the usefulness of this pattern in library code, I can give a few other real-word examples where automatically stealing the input buffer can avoid new allocations:

  • When adding two tensors elementwise, if one of the two tensors is not used after the function call, we can pass it as owned and put the result there. Note that this work for any elementwise operation. Numpy as an out= argument for this, but not everyone understand what it does.
  • When calling my_string.replace(a, b), if b is smaller than a and my_string is not used after the call, the result could be put in the buffer of my_string and then returned.
  • When writing a function parsing a String in Yaml or Json, if the input string is not used after the call of json.loads(), it can be possible to store data in there, if the library author is careful about overwriting only the bytes that were already read. For example, it can be possible to store transformed strings.
  • Functions like map() and filter() could benefit from re-using the input buffer without user intervention.

Unexplored subjects

It would be interesting to know how this affect aliasing rules (What about just using an output_buffer= argument? Will the compiler allow it since it's a reference to another input argument?) or the priority order of overloads. At first glance I don't see any issue there but feel free to tell me if any conflict with those features arises. This proposal is not so much about giving every detail but more about showing what could be possible and why it's so useful for library authors.

@gabrieldemarmiesse gabrieldemarmiesse added enhancement New feature or request mojo-repo Tag all issues with this label labels Jan 5, 2025
@lsh
Copy link
Contributor

lsh commented Jan 5, 2025

Here is some food for thought:

When a struct has a __copyinit__, it will implicitly do the copy if called later. While one can write this as two separate functions (such as above), it can be cleaner to just write a function that takes an owned value. Here's a small example:

struct MyList[T: CollectionElement]:
    fn __init__(out self, *values: T):
        pass

    fn __moveinit__(mut self, owned rhs: Self):
        print("move")

    fn __copyinit__(mut self, rhs: Self):
        print("copy")

    fn __iadd__(mut self, rhs: Self):
        pass

    fn do_nothing(self):
        pass


fn concat(owned x: MyList[Int], y: MyList[Int]) -> MyList[Int]:
    x += y
    return x


fn main_not_optimized():
    x = MyList[Int](1, 2, 3)
    y = MyList[Int](4, 5, 6)
    z = concat(x, y)
    x.do_nothing()  # x is reused here, so it can't be passed as owned to concat()


fn main_optimized():
    x = MyList[Int](1, 2, 3)
    y = MyList[Int](4, 5, 6)
    z = concat(x, y)
    # x could be passed as owned to concat(), so the optimization kicks in.


fn main():
    # prints:
    # move
    main_optimized()
    # prints:
    # copy
    # move
    main_not_optimized()

This does the same optimization as your example, but with a single method.

@gabrieldemarmiesse
Copy link
Contributor Author

The alternative proposed here is not as efficient as the original exemple given when the first list is not big enough to hold the full result. In this case, a copy is performed but it's going to reallocate later anyway. The library author should have control over this behaviour to avoid wasting allocations.

This is for this reason (and others) that the stdlib team is planning to remove __copyinit__ from List so the alternative given will soon not be even possible.

@lsh
Copy link
Contributor

lsh commented Jan 6, 2025

This is for this reason (and others) that the stdlib team is planning to remove copyinit from List so the alternative given will soon not be even possible

This is true, and also as someone on the stdlib team, this is another concern I have with this proposal. Part of the problem with __copyinit__ on List is that it's hard to keep track of all the copies being made in a program. If we look at your call sites:

fn main_not_optimized():
    x = List[Int](1, 2, 3)
    y = List[Int](4, 5, 6)
    z = concat(x, y) # hidden allocation
    print(x) # x is reused here, so it can't be passed as owned to concat()


fn main_optimized():
    x = List[Int](1, 2, 3)
    y = List[Int](4, 5, 6)
    z = concat(x, y)
    # x could be passed as owned to concat(), so the optimization kicks in.

Compared to a type with an explicit copy, which would require:

z = concat(x.copy(), y)
print(x)

In fact, I think that behavior dovetails nicely with implicit and explicit copy, where it would copy just fine for cheap types and require an explicit call for expensive ones.

@gabrieldemarmiesse
Copy link
Contributor Author

The alternative here to ask the user to call .copy() is not as efficient as the theoretical code given in the motivating example in some cases. Let's analyze what happens if x can't be owned by concat and the two input List have a capacity of 3:

fn main():
    x = List[Int](1, 2, 3)         # capacity=3
    y = List[Int](4, 5, 6)         # capacity=3
    z = concat(x.copy(), y)   # First alloc() call in .copy(), new vector with capacity=3
    print(x)

fn concat(owned x: List[Int], y: List[Int]) -> List[Int]:
    x += y         # x has a capacity of 3, so __iadd__ does another alloc() to store 6 elements 
    return x

We end up with 2 alloc calls. In the theoretical first example, let's do the same analysis:

fn main():
    x = List[Int](1, 2, 3)         # capacity=3
    y = List[Int](4, 5, 6)         # capacity=3
    z = concat(x, y)
    print(x)

fn concat(x: List[Int], y: List[Int]) -> List[Int]:
    return x + y  # alloc() call here to allocate the result buffer of size 6

# other overload of concat() omitted here because it's not called

In the theoretical example, at most one alloc() is performed for any combination of the input. In the proposed alternative with .copy(), at most two alloc() calls can be done.

@christoph-schlumpf
Copy link

I thought that there is an implicit copy done if the argument is owned and the provided list is reused after the function call. So the x.copy() just makes explicit what is done implicit in the proposal. There would be at most two alloc calls in the proposed alternative too. But I might be wrong…

@gabrieldemarmiesse
Copy link
Contributor Author

Just to be clear, the proposal would work with types that don't have implicit copy too. No implicit copy is performed in the example given in the original post.

@christoph-schlumpf
Copy link

I see, so the proposal is that a fn bla(owned x …) is preferred if x is not used after the function call and fn bla(read x…) is preferred if x is used after the function call.
And this would work with an arbitrary number of arguments and irrespective of their position in the function signature?

@lattner
Copy link
Collaborator

lattner commented Jan 7, 2025

This is a very interesting proposal. I think it "could" be done and half of it would fit into the existing design cleanly, but your example wouldn't work. Let me look at your example:

fn concat(x: List[Int], y: List[Int]) -> List[Int]: ...
fn concat(owned x: List[Int], y: List[Int]) -> List[Int]: ...
fn main_not_optimized():
    x = List[Int](1, 2, 3)
    y = List[Int](4, 5, 6)
    z = concat(x, y)
    print(x) # x is reused here, so it can't be passed as owned to concat()

fn main_optimized():
    x = List[Int](1, 2, 3)
    y = List[Int](4, 5, 6)
    z = concat(x, y)

It would be possible to have overload resolution (at parser time) detect when an operand is an RValue and use that to disambiguate an overload. Note that the parser is context insensitive, resolving things an expression at a time.

This would mean things like:

  concat(x, ..) # gets unoptimized overload
  concat(List[Int](1, 2), ...) # gets optimized overload.
  concat(x^, ...) # gets optimized overload.

This would have value, but your example itself wouldn't work. The problem is that the optimization to remove extraneous copies (which is required for the first example) happens in a latter dataflow pass that is aware of the uses and defs of values in a function. That pass isn't part of the parser and doesn't have overload sets. It does know about one specific set of transformations (notably copyinit->moveinit and copy elision), but it can't see the general overload set.

Solving this would be possible in theory, but we'd have to bundle together a bunch more metadata into the IR and it would and would require engineering work to implement it. I think it would also be useful to look at it in the category of optimizations long the lines of x = String.__add(y, z) -> String.__iadd__(y, z) which overload sets aren't powerful enough to do (though in this case, you could approximate it).

@gabrieldemarmiesse
Copy link
Contributor Author

@christoph-schlumpf Indeed you got that right, and the more arguments you can grab, the more options you have as a library author to re-use memory :)

@gabrieldemarmiesse
Copy link
Contributor Author

Thanks for your detailed answer @lattner . Even if the proposal can't be implemented with the optimization that removes extraneous copies, it still has value. Let's take the examples you give:

  concat(List[Int](1, 2), ...) # gets optimized overload.
  concat(x^, ...) # gets optimized overload.
  concat(x, ..) # gets unoptimized overload

Let's take that one at a time and see how it's useful for a library author:

concat(List[Int](1, 2), ...) # gets optimized overload.

That's very useful and doesn't require user input. Creating a value with a function and passing it directly to another function is something that happens frequently in real-world code.

  concat(x^, ...) # gets optimized overload.

This is useful too. While it requires user input to add the ^, the fact that the function concat keeps the same name (because it's an overload) communicates to the user that the behavior will be the same as concat(x, ...) . It's very nice from a documentation perspective, it's easy to teach.

  concat(x, ..) # gets unoptimized overload

While it would be great to make this work out of the box, we can imagine linter tools could give information about the possible optimization here, and it would be quite easy to teach. Something along the lines of "Hey, here is the last use of the variable x, and it also happens that concat has an overload with owned for x, if you add ^ you can get an easy perf boost". Though I'm not expert in linters so it may be possible that such a linter is impossible to create.

But even so, the overload still has value, notably from a documentation point of view. For a library author, it's easy to fuse the docs of two overloads and just say "if you pass the x argument with ^ you get a perf boost".

Many places in the standard library could use such a pattern :)

@christoph-schlumpf
Copy link

Another option might be to use a different argument convention for optimization instead of function overloading: read_in (or read_owned or read_move).

The semantic of read_in would be:

  • if the argument is used after the function call, it behaves like a read argument
  • if the argument is not used after the function call, it behaves like a owned argument by moving the value into the function
  • the implementer of the function has to provide code that works with both argument conventions

--> it garantees, that no alloc (copy) is done to access the argument

fn concat(read_in x: List[Int], y: List[Int]) -> List[Int]:
    if is_read_only(x):
        # Here x is read-only, so we have to create 
        # a new buffer (List) to allocate memory
        return x + y
    else:
        # We can potentially avoid a malloc if the capacity
        # of x is high enough.
        x += y
        return x

fn main_not_optimized():
    x = List[Int](1, 2, 3)
    y = List[Int](4, 5, 6)
    z = concat(x, y)
    print(x) # x is reused here, so it will be passed as a `read` reference to `concat()`


fn main_optimized():
    x = List[Int](1, 2, 3)
    y = List[Int](4, 5, 6)
    z = concat(x, y)
    # x is not reused, so it will be passed without a copy as `owned` to `concat()`, so the optimization kicks in.

Advantages

  • The same advantages as with the original proposal
  • No override logic needed. May be avoids the parser issue mentioned by @lattner
  • The read_in argument convention indicates to the caller that it will be optimized if he does not reuse the argument after the function call.

Challenges

  • I don't know if an is_read_only(x) function is possible to implement/provide.
  • I'm not sure if the parser/compiler would be able to enforce that the implementation contains code for both conventions

@christoph-schlumpf
Copy link

christoph-schlumpf commented Jan 7, 2025

May be one could implement something in the same direction as ref-arguments that work with read and mut argument convention. But ref-arguments would lead to always mutate x irrespective whether it is reused or not after the function call because the caller owns x.

Therefore, instead of ref [origin] a read_in [origin] would be needed that works with read and owned arguments. The compiler somehow would need to provide x as an owned value or read ref to the function based on the ownership and reuse of x on the call side. The origin would be mutable only if the caller owns x and does not use it after the function call.

def concat[
      is_mutable: Bool, //,
      origin: Origin[is_mutable]
    ](read_in [origin] x: List[Int], y: List[Int]) -> List[Int]:
    @parameter
    if is_mutable:
        # We can potentially avoid a malloc if the capacity
        # of x is high enough.
        x += y
        return x
    else:
        # Here x is read-only, so we have to create 
        # a new buffer (List) to allocate memory
        return x + y

fn main_not_optimized():
    x = List[Int](1, 2, 3)
    y = List[Int](4, 5, 6)
    z = concat(x, y)
    print(x) # x is reused here, so it will be passed as a `read` reference to `concat()`


fn main_optimized():
    x = List[Int](1, 2, 3)
    y = List[Int](4, 5, 6)
    z = concat(x, y)
    # x is not reused, so it will be passed without a copy as `owned` to `concat()`.
    # So the optimization kicks in.

Copy link
Collaborator

lattner commented Jan 8, 2025

To confirm, I was agreeing that we "could" do it, and it "would have value in some cases". That isn't enough for us to want to do it though, we'd want to make sure it is worth the complexity, that it would lead to a good way of writing libraries, and that there aren't better ways to solve the same problem.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request mojo-repo Tag all issues with this label
Projects
None yet
Development

No branches or pull requests

4 participants