diff --git a/python/ray/data/tests/test_streaming_executor.py b/python/ray/data/tests/test_streaming_executor.py index 81f04873673bd..1a4b21d409000 100644 --- a/python/ray/data/tests/test_streaming_executor.py +++ b/python/ray/data/tests/test_streaming_executor.py @@ -665,12 +665,13 @@ def after_execution_fails(self, error: Exception): self._execution_error = error # Test the success case. - ctx = DataContext.get_current() + ds = ray.data.range(10) + ctx = ds.context callback = CustomExecutionCallback() add_execution_callback(callback, ctx) assert get_execution_callbacks(ctx) == [callback] - ray.data.range(10).take_all() + ds.take_all() assert callback._before_execution_starts_called assert callback._after_execution_succeeds_called @@ -680,18 +681,22 @@ def after_execution_fails(self, error: Exception): assert get_execution_callbacks(ctx) == [] # Test the failure case. + ds = ray.data.range(10) + ctx = ds.context + ctx.raise_original_map_exception = True callback = CustomExecutionCallback() add_execution_callback(callback, ctx) - def map(_): + def map_fn(_): raise ValueError("") - ray.data.range(10).map(map).take_all() + with pytest.raises(ValueError): + ds.map(map_fn).take_all() assert callback._before_execution_starts_called assert not callback._after_execution_succeeds_called error = callback._execution_error - assert isinstance(error, ValueError) + assert isinstance(error, ValueError), error if __name__ == "__main__":