From 4c04f0f6dfc4f424af5f253c0c9938a13840b5ab Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mateusz=20Sok=C3=B3=C5=82?= Date: Wed, 12 Jun 2024 10:55:48 +0200 Subject: [PATCH 1/2] Enable Finch dep from url --- src/finch/julia.py | 6 ++++++ src/finch/tensor.py | 6 +++++- 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/src/finch/julia.py b/src/finch/julia.py index 72e074d..12fddef 100644 --- a/src/finch/julia.py +++ b/src/finch/julia.py @@ -6,9 +6,15 @@ _FINCH_VERSION = "0.6.31" _FINCH_HASH = "9177782c-1635-4eb9-9bfb-d9dfa25e6bce" _FINCH_REPO_PATH = os.environ.get("FINCH_REPO_PATH", default=None) +_FINCH_REPO_URL = os.environ.get("FINCH_URL_PATH", default=None) + +if _FINCH_REPO_PATH and _FINCH_REPO_URL: + raise ValueError("FINCH_REPO_PATH and FINCH_URL_PATH can't be set at the same time.") if _FINCH_REPO_PATH: # Also account for empty string juliapkg.add(_FINCH_NAME, _FINCH_HASH, path=_FINCH_REPO_PATH, dev=True) +elif _FINCH_REPO_URL: + juliapkg.add(_FINCH_NAME, _FINCH_HASH, url=_FINCH_REPO_URL, dev=True) else: deps = juliapkg.deps.load_cur_deps() if ( diff --git a/src/finch/tensor.py b/src/finch/tensor.py index 7b45b01..51b6c13 100644 --- a/src/finch/tensor.py +++ b/src/finch/tensor.py @@ -85,7 +85,7 @@ class Tensor(_Display, SparseArray): def __init__( self, - obj: np.ndarray | spmatrix | Storage | JuliaObj, + obj: np.ndarray | spmatrix | Storage | JuliaObj | "Tensor", /, *, fill_value: np.number | None = None, @@ -985,6 +985,10 @@ def eye( def tensordot(x1: Tensor, x2: Tensor, /, *, axes=2) -> Tensor: + if not isinstance(x1, Tensor): + x1 = Tensor(x1) + if not isinstance(x2, Tensor): + x2 = Tensor(x2) if isinstance(axes, Iterable): self_axes = normalize_axis_tuple(axes[0], x1.ndim) other_axes = normalize_axis_tuple(axes[1], x2.ndim) From 3deafad7a88327b8503468c60186b30454fa6f24 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mateusz=20Sok=C3=B3=C5=82?= Date: Wed, 12 Jun 2024 11:48:05 +0200 Subject: [PATCH 2/2] Fix type hints --- src/finch/tensor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/finch/tensor.py b/src/finch/tensor.py index 51b6c13..667aa9f 100644 --- a/src/finch/tensor.py +++ b/src/finch/tensor.py @@ -85,7 +85,7 @@ class Tensor(_Display, SparseArray): def __init__( self, - obj: np.ndarray | spmatrix | Storage | JuliaObj | "Tensor", + obj: np.ndarray | spmatrix | Storage | JuliaObj, /, *, fill_value: np.number | None = None,