Skip to content

Commit

Permalink
Fix TensorFlow numpy_call
Browse files Browse the repository at this point in the history
  • Loading branch information
holl- committed Oct 27, 2023
1 parent 1d39fe7 commit b27234f
Showing 1 changed file with 13 additions and 1 deletion.
14 changes: 13 additions & 1 deletion phiml/backend/tensorflow/_tf_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,9 @@ def aux_f(*args):
output_dtypes = tf.nest.map_structure(lambda x: x.dtype, output0)
else:
output_dtypes = tf.nest.map_structure(lambda dtype: to_numpy_dtype(dtype), output_dtypes)
return tf.py_function(aux_f, args, output_dtypes)
result = tf.py_function(aux_f, args, output_dtypes)
self.set_shapes_tree(result, output_shapes)
return result

def jit_compile(self, f: Callable) -> Callable:
compiled = tf.function(f)
Expand Down Expand Up @@ -345,6 +347,16 @@ def while_loop(self, loop: Callable, values: tuple, max_iter: Union[int, Tuple[i
def stop_gradient_tree(self, tree):
return tf.nest.map_structure(tf.stop_gradient, tree)

def set_shapes_tree(self, values, shapes):
if isinstance(values, (tuple, list)):
for e, s in zip(values, shapes):
self.set_shapes_tree(e, s)
elif self.is_tensor(values, only_native=False):
if self.is_tensor(values, only_native=True):
values.set_shape(shapes)
else:
raise NotImplementedError(type(values))

def abs(self, x):
with tf.device(x.device):
return tf.abs(x)
Expand Down

0 comments on commit b27234f

Please sign in to comment.