Skip to content

Commit

Permalink
Fix TensorFlow numpy_call()
Browse files Browse the repository at this point in the history
  • Loading branch information
holl- committed Jan 4, 2025
1 parent 71f73dd commit 6ec9c8f
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion phiml/backend/tensorflow/_tf_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ def numpy_call(self, f, output_shapes, output_dtypes, *args, **aux_args):
def aux_f(*args):
args = [self.numpy(a) for a in args]
return f(*args, **aux_args)
with self.device_of(args[0]):
with self._device_for(*args):
if output_dtypes is None:
output0 = f(*[t[0] for t in args], **aux_args) # Call f to determine its output signature.
output_dtypes = tf.nest.map_structure(lambda x: x.dtype, output0)
Expand Down

0 comments on commit 6ec9c8f

Please sign in to comment.