From 294b897190573be38c9517e112054a24913b12ba Mon Sep 17 00:00:00 2001 From: anakinxc <103552181+anakinxc@users.noreply.github.com> Date: Tue, 22 Aug 2023 16:18:44 +0800 Subject: [PATCH] Fix compilation cache with weakref object (#324) # Pull Request ## What problem does this PR solve? Issue Number: #306 Fixed compilation cache error when Callable contains weakref ## Possible side effects? - Performance: N/A - Backward compatibility: N/A --- spu/utils/frontend.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/spu/utils/frontend.py b/spu/utils/frontend.py index 9556b6dd..b91ca562 100644 --- a/spu/utils/frontend.py +++ b/spu/utils/frontend.py @@ -16,7 +16,6 @@ from enum import Enum from typing import Callable, Dict, Iterable, List -import cloudpickle from cachetools import LRUCache, cached from .. import api as spu_api @@ -27,12 +26,13 @@ def _jax_compilation_key( fn: Callable, static_argnums, static_argnames, args: List, kwargs: Dict ): import jax + from jax._src.util import weakref_lru_cache + + wrapped_fn = weakref_lru_cache(fn) flat_args, _ = jax.tree_util.tree_flatten((args, kwargs)) types = [(a.dtype, a.shape) if hasattr(a, 'dtype') else type(a) for a in flat_args] - hash_str = ( - f'{hash(cloudpickle.dumps(fn))}-{static_argnums}-{static_argnames}-{types}' - ) + hash_str = f'{hash(wrapped_fn)}-{static_argnums}-{static_argnames}-{types}' return hash_str