From 7a03a917a2c89d8f1bad89acd4d436c776f634ff Mon Sep 17 00:00:00 2001 From: = <=> Date: Thu, 21 Nov 2024 10:16:12 -0800 Subject: [PATCH 1/7] add instance node --- src/FinchLogic/nodes.jl | 20 +++++++++++++++++++- src/scheduler/LogicExecutor.jl | 9 ++++++++- 2 files changed, 27 insertions(+), 2 deletions(-) diff --git a/src/FinchLogic/nodes.jl b/src/FinchLogic/nodes.jl index 88ed5246e..74c59312e 100644 --- a/src/FinchLogic/nodes.jl +++ b/src/FinchLogic/nodes.jl @@ -17,6 +17,7 @@ const ID = 4 query = 11ID | IS_TREE | IS_STATEFUL produces = 12ID | IS_TREE | IS_STATEFUL plan = 13ID | IS_TREE | IS_STATEFUL + instance = 14ID | IS_TREE | IS_STATEFUL end """ @@ -122,6 +123,15 @@ Logical AST statement that executes a sequence of statements `bodies...`. """ plan + +""" + instance(prgm, tag) + +Logical AST statement that executes `prgm` and caches the compilation for this tag. +Allows the user to keep different compiled programs for different classes of inputs. +""" +instance + """ LogicNode @@ -218,7 +228,8 @@ function LogicNode(kind::LogicNodeKind, args::Vector) (kind === subquery && length(args) == 2) || (kind === query && length(args) == 2) || (kind === produces) || - (kind === plan) + (kind === plan) || + (kind === instance && length(args) == 2) return LogicNode(kind, nothing, Any, args) else error("wrong number of arguments to $kind(...)") @@ -257,6 +268,8 @@ function Base.getproperty(node::LogicNode, sym::Symbol) elseif node.kind === query && sym === :rhs node.children[2] elseif node.kind === produces && sym === :args node.children elseif node.kind === plan && sym === :bodies node.children + elseif node.kind === instance && sym === :prgm node.prgm + elseif node.kind === instance && sym === :tag node.tag else error("type LogicNode($(node.kind), ...) has no property $sym") end @@ -309,6 +322,11 @@ function display_statement(io, mime, node, indent) display_expression(io, mime, node.args[end]) end print(io, ")") + elseif operation(node) == instance + println(io, "instance ($node.tag)") + print(io, " " ^ (indent + 2)) + display_statement(io, mime, node.prgm, indent + 2) + println(io) else throw(ArgumentError("Expected statement but got $(operation(node))")) end diff --git a/src/scheduler/LogicExecutor.jl b/src/scheduler/LogicExecutor.jl index 434c59aaf..8c5a8bb71 100644 --- a/src/scheduler/LogicExecutor.jl +++ b/src/scheduler/LogicExecutor.jl @@ -68,9 +68,16 @@ end codes = Dict() function (ctx::LogicExecutor)(prgm) - (f, code) = get!(codes, get_structure(prgm)) do + (f, code) = if prgm.kind === plan + # If no tag is used, default to no caching thunk = logic_executor_code(ctx.ctx, prgm) (eval(thunk), thunk) + else + get!(codes, get_structure(prgm)) do + prgm = prgm.prgm + thunk = logic_executor_code(ctx.ctx, prgm) + (eval(thunk), thunk) + end end if ctx.verbose println("Executing:") From 3db62b7bd507e224ec0c79ea0e917ebe94ff7b07 Mon Sep 17 00:00:00 2001 From: = <=> Date: Thu, 21 Nov 2024 18:42:33 -0800 Subject: [PATCH 2/7] remove instance and use instance_id argument --- src/FinchLogic/nodes.jl | 20 +------------------- src/scheduler/LogicExecutor.jl | 6 +++--- 2 files changed, 4 insertions(+), 22 deletions(-) diff --git a/src/FinchLogic/nodes.jl b/src/FinchLogic/nodes.jl index 74c59312e..341f75b39 100644 --- a/src/FinchLogic/nodes.jl +++ b/src/FinchLogic/nodes.jl @@ -17,7 +17,6 @@ const ID = 4 query = 11ID | IS_TREE | IS_STATEFUL produces = 12ID | IS_TREE | IS_STATEFUL plan = 13ID | IS_TREE | IS_STATEFUL - instance = 14ID | IS_TREE | IS_STATEFUL end """ @@ -123,15 +122,6 @@ Logical AST statement that executes a sequence of statements `bodies...`. """ plan - -""" - instance(prgm, tag) - -Logical AST statement that executes `prgm` and caches the compilation for this tag. -Allows the user to keep different compiled programs for different classes of inputs. -""" -instance - """ LogicNode @@ -228,8 +218,7 @@ function LogicNode(kind::LogicNodeKind, args::Vector) (kind === subquery && length(args) == 2) || (kind === query && length(args) == 2) || (kind === produces) || - (kind === plan) || - (kind === instance && length(args) == 2) + (kind === plan) return LogicNode(kind, nothing, Any, args) else error("wrong number of arguments to $kind(...)") @@ -268,8 +257,6 @@ function Base.getproperty(node::LogicNode, sym::Symbol) elseif node.kind === query && sym === :rhs node.children[2] elseif node.kind === produces && sym === :args node.children elseif node.kind === plan && sym === :bodies node.children - elseif node.kind === instance && sym === :prgm node.prgm - elseif node.kind === instance && sym === :tag node.tag else error("type LogicNode($(node.kind), ...) has no property $sym") end @@ -322,11 +309,6 @@ function display_statement(io, mime, node, indent) display_expression(io, mime, node.args[end]) end print(io, ")") - elseif operation(node) == instance - println(io, "instance ($node.tag)") - print(io, " " ^ (indent + 2)) - display_statement(io, mime, node.prgm, indent + 2) - println(io) else throw(ArgumentError("Expected statement but got $(operation(node))")) end diff --git a/src/scheduler/LogicExecutor.jl b/src/scheduler/LogicExecutor.jl index 8c5a8bb71..69eadf319 100644 --- a/src/scheduler/LogicExecutor.jl +++ b/src/scheduler/LogicExecutor.jl @@ -67,13 +67,13 @@ function set_options(ctx::LogicExecutor; verbose = ctx.verbose, kwargs...) end codes = Dict() -function (ctx::LogicExecutor)(prgm) - (f, code) = if prgm.kind === plan +function (ctx::LogicExecutor)(prgm, instance_id=-1) + (f, code) = if instance_id == -1 # If no tag is used, default to no caching thunk = logic_executor_code(ctx.ctx, prgm) (eval(thunk), thunk) else - get!(codes, get_structure(prgm)) do + get!(codes, (instance_id, get_structure(prgm))) do prgm = prgm.prgm thunk = logic_executor_code(ctx.ctx, prgm) (eval(thunk), thunk) From eb77313af619ad44938ee4bf60c9a9d65148c987 Mon Sep 17 00:00:00 2001 From: = <=> Date: Fri, 22 Nov 2024 08:33:17 -0800 Subject: [PATCH 3/7] rename instance_id to tag, make a param, use set_options --- src/scheduler/LogicExecutor.jl | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/src/scheduler/LogicExecutor.jl b/src/scheduler/LogicExecutor.jl index 69eadf319..d5f98bb08 100644 --- a/src/scheduler/LogicExecutor.jl +++ b/src/scheduler/LogicExecutor.jl @@ -58,22 +58,23 @@ structure. """ struct LogicExecutor ctx + tag verbose end -LogicExecutor(ctx; verbose = false) = LogicExecutor(ctx, verbose) -function set_options(ctx::LogicExecutor; verbose = ctx.verbose, kwargs...) - LogicExecutor(set_options(ctx.ctx; kwargs...), verbose) +LogicExecutor(ctx; tag = -1, verbose = false) = LogicExecutor(ctx, tag, verbose) +function set_options(ctx::LogicExecutor; tag = ctx.tag, verbose = ctx.verbose, kwargs...) + LogicExecutor(set_options(ctx.ctx; kwargs...), tag, verbose) end codes = Dict() -function (ctx::LogicExecutor)(prgm, instance_id=-1) - (f, code) = if instance_id == -1 +function (ctx::LogicExecutor)(prgm) + (f, code) = if tag == -1 # If no tag is used, default to no caching thunk = logic_executor_code(ctx.ctx, prgm) (eval(thunk), thunk) else - get!(codes, (instance_id, get_structure(prgm))) do + get!(codes, (tag, get_structure(prgm))) do prgm = prgm.prgm thunk = logic_executor_code(ctx.ctx, prgm) (eval(thunk), thunk) From ea130e3c0f6b7d27ca4650d46b6351f4cfb8858a Mon Sep 17 00:00:00 2001 From: = <=> Date: Fri, 22 Nov 2024 08:57:04 -0800 Subject: [PATCH 4/7] bug fix --- src/scheduler/LogicExecutor.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/src/scheduler/LogicExecutor.jl b/src/scheduler/LogicExecutor.jl index d5f98bb08..574ab1991 100644 --- a/src/scheduler/LogicExecutor.jl +++ b/src/scheduler/LogicExecutor.jl @@ -75,7 +75,6 @@ function (ctx::LogicExecutor)(prgm) (eval(thunk), thunk) else get!(codes, (tag, get_structure(prgm))) do - prgm = prgm.prgm thunk = logic_executor_code(ctx.ctx, prgm) (eval(thunk), thunk) end From 0be91a8567e632a23b4a6bb61aef3bc6042b3ae0 Mon Sep 17 00:00:00 2001 From: Willow Ahrens Date: Tue, 26 Nov 2024 20:45:10 -0500 Subject: [PATCH 5/7] fix --- docs/src/docs/array_api.md | 21 ++++++++++++++++++++- src/interface/lazy.jl | 16 +++++++++++----- src/scheduler/LogicExecutor.jl | 5 ++--- 3 files changed, 33 insertions(+), 9 deletions(-) diff --git a/docs/src/docs/array_api.md b/docs/src/docs/array_api.md index 06b2091a0..bcd66f9e2 100644 --- a/docs/src/docs/array_api.md +++ b/docs/src/docs/array_api.md @@ -144,4 +144,23 @@ julia> @btime compute(sum(A * B * C), ctx=galley_scheduler()); By taking advantage of the fact that C is highly sparse, Galley can better structure the computation. In the matrix chain multiplication, it always starts with the C,B matmul before multiplying with A. In the summation, it takes advantage of distributivity to pushing the reduction -down to the inputs. It first sums over A and C, then multiplies those vectors with B. \ No newline at end of file +down to the inputs. It first sums over A and C, then multiplies those vectors with B. + +Because Galley adapts to the sparsity patterns of the first input tensor, it can +be useful to distinguish between different uses of the same function using the +`tag` keyword argument to `compute` or `fuse`. For example, we may wish to +distinguish one spmv from another, as follows: + +```jldoctest example2 +julia> A = rand(1000, 1000); B = rand(1000, 1000); C = fsprand(1000, 1000, 0.0001); + +julia> fused((A, B, C) -> C .* (A * B), tag=:very_sparse_sddmm); + +julia> C = fsprand(1000, 1000, 0.9); + +julia> fused((A, B, C) -> C .* (A * B), tag=:very_dense_sddmm); + +``` + +By distinguishing between the two uses of the same function, Galley can make +better decisions about how to optimize each computation separately. \ No newline at end of file diff --git a/src/interface/lazy.jl b/src/interface/lazy.jl index 578dc4a9e..7b9f051bf 100644 --- a/src/interface/lazy.jl +++ b/src/interface/lazy.jl @@ -479,9 +479,12 @@ default_scheduler(;verbose=false) = LogicExecutor(DefaultLogicOptimizer(LogicCom """ fused(f, args...; kwargs...) -This function decorator modifies `f` to fuse the contained array -operations and optimize the resulting program. The function must return a single -array or tuple of arrays. `kwargs` are passed to [`compute`](@ref) +This function decorator modifies `f` to fuse the contained array operations and +optimize the resulting program. The function must return a single array or tuple +of arrays. Some keyword arguments can be passed to control the execution of the +program: + - `verbose=false`: Print the generated code before execution + - `tag=:global`: A tag to distinguish between different classes of inputs for the same program. """ function fused(f, args...; kwargs...) compute(f(map(LazyTensor, args...)), kwargs...) @@ -520,10 +523,13 @@ function with_scheduler(f, scheduler) end """ - compute(args..., ctx=default_scheduler()) -> Any + compute(args...; ctx=default_scheduler(), kwargs...) -> Any Compute the value of a lazy tensor. The result is the argument itself, or a -tuple of arguments if multiple arguments are passed. +tuple of arguments if multiple arguments are passed. Some keyword arguments +can be passed to control the execution of the program: + - `verbose=false`: Print the generated code before execution + - `tag=:global`: A tag to distinguish between different classes of inputs for the same program. """ compute(args...; ctx=get_scheduler(), kwargs...) = compute_parse(set_options(ctx; kwargs...), map(lazy, args)) compute(arg; ctx=get_scheduler(), kwargs...) = compute_parse(set_options(ctx; kwargs...), (lazy(arg),))[1] diff --git a/src/scheduler/LogicExecutor.jl b/src/scheduler/LogicExecutor.jl index 22e1a3b2e..89a1d26a9 100644 --- a/src/scheduler/LogicExecutor.jl +++ b/src/scheduler/LogicExecutor.jl @@ -74,9 +74,8 @@ end codes = Dict() function (ctx::LogicExecutor)(prgm) (f, code) = get!(codes, (ctx.ctx, ctx.tag, get_structure(prgm))) do - thunk = logic_executor_code(ctx.ctx, prgm) - (eval(thunk), thunk) - end + thunk = logic_executor_code(ctx.ctx, prgm) + (eval(thunk), thunk) end if ctx.verbose println("Executing:") From d06b4b11e79256d3295664a485d36769d9e04a9e Mon Sep 17 00:00:00 2001 From: Willow Ahrens Date: Tue, 26 Nov 2024 20:49:12 -0500 Subject: [PATCH 6/7] fix --- src/scheduler/LogicExecutor.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/scheduler/LogicExecutor.jl b/src/scheduler/LogicExecutor.jl index 89a1d26a9..d8d590e9d 100644 --- a/src/scheduler/LogicExecutor.jl +++ b/src/scheduler/LogicExecutor.jl @@ -67,8 +67,8 @@ Base.:(==)(a::LogicExecutor, b::LogicExecutor) = a.ctx == b.ctx && a.verbose == Base.hash(a::LogicExecutor, h::UInt) = hash(LogicExecutor, hash(a.ctx, hash(a.verbose, h))) LogicExecutor(ctx; tag = :global, verbose = false) = LogicExecutor(ctx, tag, verbose) -function set_options(ctx::LogicExecutor; verbose = ctx.verbose, kwargs...) - LogicExecutor(set_options(ctx.ctx; kwargs...), verbose) +function set_options(ctx::LogicExecutor; tag = ctx.tag, verbose = ctx.verbose, kwargs...) + LogicExecutor(set_options(ctx.ctx; kwargs...), tag, verbose) end codes = Dict() From 3d59180609d331583e63a5c974b3495f00bb8cdb Mon Sep 17 00:00:00 2001 From: Willow Ahrens Date: Tue, 26 Nov 2024 21:07:44 -0500 Subject: [PATCH 7/7] testing --- docs/src/docs/array_api.md | 6 +++--- src/Finch.jl | 2 +- src/interface/lazy.jl | 2 +- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/docs/src/docs/array_api.md b/docs/src/docs/array_api.md index bcd66f9e2..5f59b0ed2 100644 --- a/docs/src/docs/array_api.md +++ b/docs/src/docs/array_api.md @@ -151,14 +151,14 @@ be useful to distinguish between different uses of the same function using the `tag` keyword argument to `compute` or `fuse`. For example, we may wish to distinguish one spmv from another, as follows: -```jldoctest example2 +```jldoctest example2; setup=:(using Finch) julia> A = rand(1000, 1000); B = rand(1000, 1000); C = fsprand(1000, 1000, 0.0001); -julia> fused((A, B, C) -> C .* (A * B), tag=:very_sparse_sddmm); +julia> fused((A, B, C) -> C .* (A * B), A, B, C, tag=:very_sparse_sddmm); julia> C = fsprand(1000, 1000, 0.9); -julia> fused((A, B, C) -> C .* (A * B), tag=:very_dense_sddmm); +julia> fused((A, B, C) -> C .* (A * B), A, B, C, tag=:very_dense_sddmm); ``` diff --git a/src/Finch.jl b/src/Finch.jl index 811cdb289..e42f13211 100644 --- a/src/Finch.jl +++ b/src/Finch.jl @@ -49,7 +49,7 @@ export diagmask, lotrimask, uptrimask, bandmask, chunkmask export scale, products, offset, permissive, protocolize, swizzle, toeplitz, window export PlusOneVector -export lazy, compute, tensordot, @einsum +export lazy, compute, fused, tensordot, @einsum export choose, minby, maxby, overwrite, initwrite, filterop, d diff --git a/src/interface/lazy.jl b/src/interface/lazy.jl index 7b9f051bf..97af5f223 100644 --- a/src/interface/lazy.jl +++ b/src/interface/lazy.jl @@ -487,7 +487,7 @@ program: - `tag=:global`: A tag to distinguish between different classes of inputs for the same program. """ function fused(f, args...; kwargs...) - compute(f(map(LazyTensor, args...)), kwargs...) + compute(f(map(LazyTensor, args)...); kwargs...) end current_scheduler = Ref{Any}(default_scheduler())