Skip to content

Commit

Permalink
Fix compilation cache with weakref object (#324)
Browse files Browse the repository at this point in the history
# Pull Request

## What problem does this PR solve?

Issue Number: #306 

Fixed compilation cache error when Callable contains weakref

## Possible side effects?

- Performance: N/A

- Backward compatibility: N/A
  • Loading branch information
anakinxc authored Aug 22, 2023
1 parent 72b8cd5 commit 294b897
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions spu/utils/frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
from enum import Enum
from typing import Callable, Dict, Iterable, List

import cloudpickle
from cachetools import LRUCache, cached

from .. import api as spu_api
Expand All @@ -27,12 +26,13 @@ def _jax_compilation_key(
fn: Callable, static_argnums, static_argnames, args: List, kwargs: Dict
):
import jax
from jax._src.util import weakref_lru_cache

wrapped_fn = weakref_lru_cache(fn)

flat_args, _ = jax.tree_util.tree_flatten((args, kwargs))
types = [(a.dtype, a.shape) if hasattr(a, 'dtype') else type(a) for a in flat_args]
hash_str = (
f'{hash(cloudpickle.dumps(fn))}-{static_argnums}-{static_argnames}-{types}'
)
hash_str = f'{hash(wrapped_fn)}-{static_argnums}-{static_argnames}-{types}'
return hash_str


Expand Down

0 comments on commit 294b897

Please sign in to comment.