Skip to content

Commit

Permalink
add scheduler classes
Browse files Browse the repository at this point in the history
  • Loading branch information
kylebd99 committed Dec 5, 2024
1 parent 44760fa commit 1219c0e
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 16 deletions.
2 changes: 2 additions & 0 deletions src/finch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,8 @@
compiled,
compute,
set_optimizer,
DefaultScheduler,
GalleyScheduler,
)
from .dtypes import (
int_,
Expand Down
36 changes: 22 additions & 14 deletions src/finch/compiled.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,6 @@
from .tensor import Tensor


def get_scheduler(name, verbose=False):
if name == "default":
return jl.Finch.default_scheduler()
elif name == "galley":
return jl.Finch.galley_scheduler(verbose=verbose)

def compiled(opt=""):
def inner(func):
@wraps(func)
Expand All @@ -33,17 +27,31 @@ def lazy(tensor: Tensor):
return Tensor(jl.Finch.LazyTensor(tensor._obj))
return tensor

def set_optimizer(opt="default"):
if opt == "default":
jl.Finch.set_scheduler_b(jl.Finch.default_scheduler())
elif opt == "galley":
jl.Finch.set_scheduler_b(jl.Finch.galley_scheduler())
class AbstractScheduler():
pass

class GalleyScheduler(AbstractScheduler):
def __init__(self, verbose=False):
self.verbose=verbose

class DefaultScheduler(AbstractScheduler):
def __init__(self, verbose=False):
self.verbose=verbose

def get_julia_scheduler(opt):
if isinstance(opt, DefaultScheduler):
return jl.Finch.default_scheduler(verbose=opt.verbose)
elif isinstance(opt, GalleyScheduler):
return jl.Finch.galley_scheduler(verbose=opt.verbose)

def set_optimizer(opt):
jl.Finch.set_scheduler_b(get_julia_scheduler(opt))
return

def compute(tensor: Tensor, *, verbose: bool = False, opt="", tag=-1):
def compute(tensor: Tensor, *, verbose: bool = False, opt=None, tag=-1):
if not tensor.is_computed():
if opt == "":
if opt == None:
return Tensor(jl.Finch.compute(tensor._obj, verbose=verbose, tag=tag))
else:
return Tensor(jl.Finch.compute(tensor._obj, verbose=verbose, tag=tag, ctx=get_scheduler(opt, verbose=verbose)))
return Tensor(jl.Finch.compute(tensor._obj, verbose=verbose, tag=tag, ctx=get_julia_scheduler(opt)))
return tensor
2 changes: 1 addition & 1 deletion tests/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

arr1d = np.array([1, 1, 2, 3])

parametrize_optimizer = pytest.mark.parametrize("opt", ["default", "galley"])
parametrize_optimizer = pytest.mark.parametrize("opt", [finch.DefaultScheduler(), finch.GalleyScheduler()])

@parametrize_optimizer
def test_eager(arr3d, opt):
Expand Down
2 changes: 1 addition & 1 deletion tests/test_sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import finch

parametrize_optimizer = pytest.mark.parametrize("opt", ["default", "galley"])
parametrize_optimizer = pytest.mark.parametrize("opt", [finch.DefaultScheduler(), finch.GalleyScheduler()])

@pytest.mark.parametrize(
"dtype,jl_dtype",
Expand Down

0 comments on commit 1219c0e

Please sign in to comment.