diff --git a/docs/src/docs/array_api.md b/docs/src/docs/array_api.md index 06b2091a0..5f59b0ed2 100644 --- a/docs/src/docs/array_api.md +++ b/docs/src/docs/array_api.md @@ -144,4 +144,23 @@ julia> @btime compute(sum(A * B * C), ctx=galley_scheduler()); By taking advantage of the fact that C is highly sparse, Galley can better structure the computation. In the matrix chain multiplication, it always starts with the C,B matmul before multiplying with A. In the summation, it takes advantage of distributivity to pushing the reduction -down to the inputs. It first sums over A and C, then multiplies those vectors with B. \ No newline at end of file +down to the inputs. It first sums over A and C, then multiplies those vectors with B. + +Because Galley adapts to the sparsity patterns of the first input tensor, it can +be useful to distinguish between different uses of the same function using the +`tag` keyword argument to `compute` or `fuse`. For example, we may wish to +distinguish one spmv from another, as follows: + +```jldoctest example2; setup=:(using Finch) +julia> A = rand(1000, 1000); B = rand(1000, 1000); C = fsprand(1000, 1000, 0.0001); + +julia> fused((A, B, C) -> C .* (A * B), A, B, C, tag=:very_sparse_sddmm); + +julia> C = fsprand(1000, 1000, 0.9); + +julia> fused((A, B, C) -> C .* (A * B), A, B, C, tag=:very_dense_sddmm); + +``` + +By distinguishing between the two uses of the same function, Galley can make +better decisions about how to optimize each computation separately. \ No newline at end of file diff --git a/src/Finch.jl b/src/Finch.jl index 811cdb289..e42f13211 100644 --- a/src/Finch.jl +++ b/src/Finch.jl @@ -49,7 +49,7 @@ export diagmask, lotrimask, uptrimask, bandmask, chunkmask export scale, products, offset, permissive, protocolize, swizzle, toeplitz, window export PlusOneVector -export lazy, compute, tensordot, @einsum +export lazy, compute, fused, tensordot, @einsum export choose, minby, maxby, overwrite, initwrite, filterop, d diff --git a/src/FinchLogic/nodes.jl b/src/FinchLogic/nodes.jl index 88ed5246e..341f75b39 100644 --- a/src/FinchLogic/nodes.jl +++ b/src/FinchLogic/nodes.jl @@ -218,7 +218,7 @@ function LogicNode(kind::LogicNodeKind, args::Vector) (kind === subquery && length(args) == 2) || (kind === query && length(args) == 2) || (kind === produces) || - (kind === plan) + (kind === plan) return LogicNode(kind, nothing, Any, args) else error("wrong number of arguments to $kind(...)") diff --git a/src/interface/lazy.jl b/src/interface/lazy.jl index 578dc4a9e..97af5f223 100644 --- a/src/interface/lazy.jl +++ b/src/interface/lazy.jl @@ -479,12 +479,15 @@ default_scheduler(;verbose=false) = LogicExecutor(DefaultLogicOptimizer(LogicCom """ fused(f, args...; kwargs...) -This function decorator modifies `f` to fuse the contained array -operations and optimize the resulting program. The function must return a single -array or tuple of arrays. `kwargs` are passed to [`compute`](@ref) +This function decorator modifies `f` to fuse the contained array operations and +optimize the resulting program. The function must return a single array or tuple +of arrays. Some keyword arguments can be passed to control the execution of the +program: + - `verbose=false`: Print the generated code before execution + - `tag=:global`: A tag to distinguish between different classes of inputs for the same program. """ function fused(f, args...; kwargs...) - compute(f(map(LazyTensor, args...)), kwargs...) + compute(f(map(LazyTensor, args)...); kwargs...) end current_scheduler = Ref{Any}(default_scheduler()) @@ -520,10 +523,13 @@ function with_scheduler(f, scheduler) end """ - compute(args..., ctx=default_scheduler()) -> Any + compute(args...; ctx=default_scheduler(), kwargs...) -> Any Compute the value of a lazy tensor. The result is the argument itself, or a -tuple of arguments if multiple arguments are passed. +tuple of arguments if multiple arguments are passed. Some keyword arguments +can be passed to control the execution of the program: + - `verbose=false`: Print the generated code before execution + - `tag=:global`: A tag to distinguish between different classes of inputs for the same program. """ compute(args...; ctx=get_scheduler(), kwargs...) = compute_parse(set_options(ctx; kwargs...), map(lazy, args)) compute(arg; ctx=get_scheduler(), kwargs...) = compute_parse(set_options(ctx; kwargs...), (lazy(arg),))[1] diff --git a/src/scheduler/LogicExecutor.jl b/src/scheduler/LogicExecutor.jl index 0ecac5a20..d8d590e9d 100644 --- a/src/scheduler/LogicExecutor.jl +++ b/src/scheduler/LogicExecutor.jl @@ -50,28 +50,30 @@ function logic_executor_code(ctx, prgm) end """ - LogicExecutor(ctx, verbose=false) + LogicExecutor(ctx, tag=:global, verbose=false) Executes a logic program by compiling it with the given compiler `ctx`. Compiled codes are cached, and are only compiled once for each program with the same -structure. +structure. The `tag` argument is used to distinguish between different +use cases for the same program structure. """ @kwdef struct LogicExecutor ctx + tag verbose end Base.:(==)(a::LogicExecutor, b::LogicExecutor) = a.ctx == b.ctx && a.verbose == b.verbose Base.hash(a::LogicExecutor, h::UInt) = hash(LogicExecutor, hash(a.ctx, hash(a.verbose, h))) -LogicExecutor(ctx; verbose = false) = LogicExecutor(ctx, verbose) -function set_options(ctx::LogicExecutor; verbose = ctx.verbose, kwargs...) - LogicExecutor(set_options(ctx.ctx; kwargs...), verbose) +LogicExecutor(ctx; tag = :global, verbose = false) = LogicExecutor(ctx, tag, verbose) +function set_options(ctx::LogicExecutor; tag = ctx.tag, verbose = ctx.verbose, kwargs...) + LogicExecutor(set_options(ctx.ctx; kwargs...), tag, verbose) end codes = Dict() function (ctx::LogicExecutor)(prgm) - (f, code) = get!(codes, (ctx.ctx, get_structure(prgm))) do + (f, code) = get!(codes, (ctx.ctx, ctx.tag, get_structure(prgm))) do thunk = logic_executor_code(ctx.ctx, prgm) (eval(thunk), thunk) end