diff --git a/src/brevitas/backport/fx/immutable_collections.py b/src/brevitas/backport/fx/immutable_collections.py index 0144e4701..b848bf1cf 100644 --- a/src/brevitas/backport/fx/immutable_collections.py +++ b/src/brevitas/backport/fx/immutable_collections.py @@ -43,7 +43,12 @@ from typing import Any, Dict, List, Tuple -from torch.utils._pytree import _register_pytree_node +try: + from torch.utils._pytree import register_pytree_node +except: + # Deprecated as of 2.3, but keeping for backportability + from torch.utils._pytree import _register_pytree_node + register_pytree_node = _register_pytree_node from torch.utils._pytree import Context from ._compatibility import compatibility @@ -111,5 +116,5 @@ def _immutable_list_unflatten(values: List[Any], context: Context) -> List[Any]: return immutable_list(values) -_register_pytree_node(immutable_dict, _immutable_dict_flatten, _immutable_dict_unflatten) -_register_pytree_node(immutable_list, _immutable_list_flatten, _immutable_list_unflatten) +register_pytree_node(immutable_dict, _immutable_dict_flatten, _immutable_dict_unflatten) +register_pytree_node(immutable_list, _immutable_list_flatten, _immutable_list_unflatten)