diff --git a/pyproject.toml b/pyproject.toml index 9e1d84a..04cde8e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "finch-tensor" -version = "0.2.0" +version = "0.2.1" description = "" authors = ["Willow Ahrens "] readme = "README.md" diff --git a/src/finch/__init__.py b/src/finch/__init__.py index 1cff467..8ebf847 100644 --- a/src/finch/__init__.py +++ b/src/finch/__init__.py @@ -156,6 +156,8 @@ "SparseHash", "Storage", "DenseStorage", + "DefaultScheduler", + "GalleyScheduler", "asarray", "astype", "random", @@ -269,6 +271,7 @@ "empty_like", "arange", "linspace", + "set_optimizer", ] __array_api_version__: str = "2023.12" diff --git a/src/finch/compiled.py b/src/finch/compiled.py index 2b7eb6b..6c757dd 100644 --- a/src/finch/compiled.py +++ b/src/finch/compiled.py @@ -4,23 +4,23 @@ from .tensor import Tensor -def compiled(opt=""): - def inner(func): - @wraps(func) - def wrapper_func(*args, **kwargs): - new_args = [] - for arg in args: - if isinstance(arg, Tensor) and not jl.isa(arg._obj, jl.Finch.LazyTensor): - new_args.append(Tensor(jl.Finch.LazyTensor(arg._obj))) +def compiled(opt=None): + def inner(func): + @wraps(func) + def wrapper_func(*args, **kwargs): + new_args = [] + for arg in args: + if isinstance(arg, Tensor) and not jl.isa(arg._obj, jl.Finch.LazyTensor): + new_args.append(Tensor(jl.Finch.LazyTensor(arg._obj))) else: - new_args.append(arg) - result = func(*new_args, **kwargs) - kwargs = {"ctx": get_scheduler(name=opt)} if opt != "" else {} - result_tensor = Tensor(jl.Finch.compute(result._obj, **kwargs)) - return result_tensor - return wrapper_func - - return inner + new_args.append(arg) + result = func(*new_args, **kwargs) + kwargs = {"ctx": opt.get_julia_scheduler()} if opt is not None else {} + result_tensor = Tensor(jl.Finch.compute(result._obj, **kwargs)) + return result_tensor + return wrapper_func + + return inner def lazy(tensor: Tensor): if tensor.is_computed():