Skip to content

Commit

Permalink
typo
Browse files Browse the repository at this point in the history
  • Loading branch information
mtsokol committed Dec 6, 2024
1 parent 79b6f0d commit 3a15ffb
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 17 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "finch-tensor"
version = "0.2.0"
version = "0.2.1"
description = ""
authors = ["Willow Ahrens <[email protected]>"]
readme = "README.md"
Expand Down
3 changes: 3 additions & 0 deletions src/finch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,8 @@
"SparseHash",
"Storage",
"DenseStorage",
"DefaultScheduler",
"GalleyScheduler",
"asarray",
"astype",
"random",
Expand Down Expand Up @@ -269,6 +271,7 @@
"empty_like",
"arange",
"linspace",
"set_optimizer",
]

__array_api_version__: str = "2023.12"
32 changes: 16 additions & 16 deletions src/finch/compiled.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,23 +4,23 @@
from .tensor import Tensor


def compiled(opt=""):
def inner(func):
@wraps(func)
def wrapper_func(*args, **kwargs):
new_args = []
for arg in args:
if isinstance(arg, Tensor) and not jl.isa(arg._obj, jl.Finch.LazyTensor):
new_args.append(Tensor(jl.Finch.LazyTensor(arg._obj)))
def compiled(opt=None):
def inner(func):
@wraps(func)
def wrapper_func(*args, **kwargs):
new_args = []
for arg in args:
if isinstance(arg, Tensor) and not jl.isa(arg._obj, jl.Finch.LazyTensor):
new_args.append(Tensor(jl.Finch.LazyTensor(arg._obj)))
else:
new_args.append(arg)
result = func(*new_args, **kwargs)
kwargs = {"ctx": get_scheduler(name=opt)} if opt != "" else {}
result_tensor = Tensor(jl.Finch.compute(result._obj, **kwargs))
return result_tensor
return wrapper_func

return inner
new_args.append(arg)
result = func(*new_args, **kwargs)
kwargs = {"ctx": opt.get_julia_scheduler()} if opt is not None else {}
result_tensor = Tensor(jl.Finch.compute(result._obj, **kwargs))
return result_tensor
return wrapper_func

return inner

def lazy(tensor: Tensor):
if tensor.is_computed():
Expand Down

0 comments on commit 3a15ffb

Please sign in to comment.