From b27234fe2a003f109b96af26960468bc1856cc86 Mon Sep 17 00:00:00 2001 From: holl- Date: Fri, 27 Oct 2023 13:23:44 +0200 Subject: [PATCH] Fix TensorFlow numpy_call --- phiml/backend/tensorflow/_tf_backend.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/phiml/backend/tensorflow/_tf_backend.py b/phiml/backend/tensorflow/_tf_backend.py index 2b83f485..f8de966e 100644 --- a/phiml/backend/tensorflow/_tf_backend.py +++ b/phiml/backend/tensorflow/_tf_backend.py @@ -137,7 +137,9 @@ def aux_f(*args): output_dtypes = tf.nest.map_structure(lambda x: x.dtype, output0) else: output_dtypes = tf.nest.map_structure(lambda dtype: to_numpy_dtype(dtype), output_dtypes) - return tf.py_function(aux_f, args, output_dtypes) + result = tf.py_function(aux_f, args, output_dtypes) + self.set_shapes_tree(result, output_shapes) + return result def jit_compile(self, f: Callable) -> Callable: compiled = tf.function(f) @@ -345,6 +347,16 @@ def while_loop(self, loop: Callable, values: tuple, max_iter: Union[int, Tuple[i def stop_gradient_tree(self, tree): return tf.nest.map_structure(tf.stop_gradient, tree) + def set_shapes_tree(self, values, shapes): + if isinstance(values, (tuple, list)): + for e, s in zip(values, shapes): + self.set_shapes_tree(e, s) + elif self.is_tensor(values, only_native=False): + if self.is_tensor(values, only_native=True): + values.set_shape(shapes) + else: + raise NotImplementedError(type(values)) + def abs(self, x): with tf.device(x.device): return tf.abs(x)