diff --git a/python/distributed-ucxx/distributed_ucxx/tests/test_ucxx.py b/python/distributed-ucxx/distributed_ucxx/tests/test_ucxx.py index 0ba84b2a..91af3dad 100644 --- a/python/distributed-ucxx/distributed_ucxx/tests/test_ucxx.py +++ b/python/distributed-ucxx/distributed_ucxx/tests/test_ucxx.py @@ -180,7 +180,7 @@ async def test_ucxx_deserialize(ucxx_loop): [ lambda cudf: cudf.Series([1, 2, 3]), lambda cudf: cudf.Series([], dtype=object), - lambda cudf: cudf.DataFrame([]), + lambda cudf: cudf.DataFrame([], dtype=object), lambda cudf: cudf.DataFrame([1]).head(0), lambda cudf: cudf.DataFrame([1.0]).head(0), lambda cudf: cudf.DataFrame({"a": []}), diff --git a/python/ucxx/benchmarks/backends/ucxx_async.py b/python/ucxx/benchmarks/backends/ucxx_async.py index a21c0246..0f3cfe5b 100644 --- a/python/ucxx/benchmarks/backends/ucxx_async.py +++ b/python/ucxx/benchmarks/backends/ucxx_async.py @@ -91,7 +91,11 @@ async def server_handler(ep): await ep.close() lf.close() - lf = ucxx.create_listener(server_handler, port=self.args.port) + lf = ucxx.create_listener( + server_handler, + port=self.args.port, + endpoint_error_handling=self.args.error_handling, + ) self.queue.put(lf.port) while not lf.closed(): @@ -126,7 +130,11 @@ async def run(self): register_am_allocators(self.args) - ep = await ucxx.create_endpoint(self.server_address, self.port) + ep = await ucxx.create_endpoint( + self.server_address, + self.port, + endpoint_error_handling=self.args.error_handling, + ) if self.args.enable_am: msg = xp.arange(self.args.n_bytes, dtype="u1") diff --git a/python/ucxx/benchmarks/backends/ucxx_core.py b/python/ucxx/benchmarks/backends/ucxx_core.py index 9be1becd..4cf7af1a 100644 --- a/python/ucxx/benchmarks/backends/ucxx_core.py +++ b/python/ucxx/benchmarks/backends/ucxx_core.py @@ -142,7 +142,9 @@ def run(self): def _listener_handler(conn_request): global ep - ep = listener.create_endpoint_from_conn_request(conn_request, True) + ep = listener.create_endpoint_from_conn_request( + conn_request, endpoint_error_handling=self.args.error_handling + ) listener = ucx_api.UCXListener.create( worker=worker, port=self.args.port or 0, cb_func=_listener_handler @@ -236,7 +238,7 @@ def run(self): worker, self.server_address, self.port, - endpoint_error_handling=True, + endpoint_error_handling=self.args.error_handling, ) # Wireup before starting to transfer data diff --git a/python/ucxx/benchmarks/send_recv.py b/python/ucxx/benchmarks/send_recv.py index fa10f644..bfe7ae4a 100644 --- a/python/ucxx/benchmarks/send_recv.py +++ b/python/ucxx/benchmarks/send_recv.py @@ -322,6 +322,12 @@ def parse_args(): help="Only applies to 'ucxx-core' backend: number of maximum outstanding " "operations, see --delay-progress. (Default: 32)", ) + parser.add_argument( + "--error-handling", + action=argparse.BooleanOptionalAction, + default=True, + help="Enable endpoint error handling.", + ) args = parser.parse_args()