Skip to content

Commit

Permalink
refactor(query): introduce udf runtime pool (#17304)
Browse files Browse the repository at this point in the history
* udf runtime pool

* fix

* fix

* fix

* test
  • Loading branch information
forsaken628 authored Jan 17, 2025
1 parent 37d96b2 commit 355a23a
Show file tree
Hide file tree
Showing 12 changed files with 356 additions and 260 deletions.
9 changes: 1 addition & 8 deletions src/query/service/src/pipelines/builders/builder_udf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,9 @@
// See the License for the specific language governing permissions and
// limitations under the License.

use std::sync::atomic::AtomicUsize;
use std::sync::Arc;

use databend_common_exception::Result;
use databend_common_pipeline_transforms::processors::TransformPipelineHelper;
use databend_common_sql::executor::physical_plans::Udf;
use databend_common_storages_fuse::TableContext;

use crate::pipelines::processors::transforms::TransformUdfScript;
use crate::pipelines::processors::transforms::TransformUdfServer;
Expand All @@ -29,15 +25,12 @@ impl PipelineBuilder {
self.build_pipeline(&udf.input)?;

if udf.script_udf {
let index_seq = Arc::new(AtomicUsize::new(0));
let runtime_num = self.ctx.get_settings().get_max_threads()? as usize;
let runtimes = TransformUdfScript::init_runtime(&udf.udf_funcs, runtime_num)?;
let runtimes = TransformUdfScript::init_runtime(&udf.udf_funcs)?;
self.main_pipeline.try_add_transformer(|| {
Ok(TransformUdfScript::new(
self.func_ctx.clone(),
udf.udf_funcs.clone(),
runtimes.clone(),
index_seq.clone(),
))
})
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,3 +37,4 @@ pub use udaf_script::*;
pub use utils::*;

pub use self::serde::*;
use super::runtime_pool;
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ use std::fmt;
use std::io::BufRead;
use std::io::Cursor;
use std::sync::Arc;
use std::sync::Mutex;

use arrow_array::Array;
use arrow_array::RecordBatch;
Expand All @@ -40,8 +39,8 @@ use databend_common_functions::aggregates::AggregateFunction;
use databend_common_sql::plans::UDFLanguage;
use databend_common_sql::plans::UDFScriptCode;

#[cfg(feature = "python-udf")]
use super::super::python_udf::GLOBAL_PYTHON_RUNTIME;
use super::runtime_pool::Pool;
use super::runtime_pool::RuntimeBuilder;

pub struct AggregateUdfScript {
display_name: String,
Expand Down Expand Up @@ -138,6 +137,15 @@ impl AggregateFunction for AggregateUdfScript {
builder.append_column(&result);
Ok(())
}

fn need_manual_drop_state(&self) -> bool {
true
}

unsafe fn drop_state(&self, place: StateAddr) {
let state = place.get::<UdfAggState>();
std::ptr::drop_in_place(state);
}
}

impl fmt::Display for AggregateUdfScript {
Expand Down Expand Up @@ -244,19 +252,19 @@ pub fn create_udaf_script_function(
let UDFScriptCode { language, code, .. } = code;
let runtime = match language {
UDFLanguage::JavaScript => {
let pool = JsRuntimePool::new(
let builder = JsRuntimeBuilder {
name,
String::from_utf8(code.to_vec())?,
ArrowType::Struct(
code: String::from_utf8(code.to_vec())?,
state_type: ArrowType::Struct(
state_fields
.iter()
.map(|f| f.into())
.collect::<Vec<arrow_schema::Field>>()
.into(),
),
output_type,
);
UDAFRuntime::JavaScript(pool)
};
UDAFRuntime::JavaScript(JsRuntimePool::new(builder))
}
UDFLanguage::WebAssembly => unimplemented!(),
#[cfg(not(feature = "python-udf"))]
Expand All @@ -267,22 +275,19 @@ pub fn create_udaf_script_function(
}
#[cfg(feature = "python-udf")]
UDFLanguage::Python => {
let mut runtime = GLOBAL_PYTHON_RUNTIME.write();
let code = String::from_utf8(code.to_vec())?;
runtime.add_aggregate(
&name,
ArrowType::Struct(
let builder = python_pool::PyRuntimeBuilder {
name,
code: String::from_utf8(code.to_vec())?,
state_type: ArrowType::Struct(
state_fields
.iter()
.map(|f| f.into())
.collect::<Vec<arrow_schema::Field>>()
.into(),
),
ArrowType::from(&output_type),
arrow_udf_python::CallMode::CalledOnNullInput,
&code,
)?;
UDAFRuntime::Python(PythonInfo { name, output_type })
output_type,
};
UDAFRuntime::Python(Pool::new(builder))
}
};
let init_state = runtime
Expand All @@ -297,27 +302,17 @@ pub fn create_udaf_script_function(
}))
}

struct JsRuntimePool {
struct JsRuntimeBuilder {
name: String,
code: String,
state_type: ArrowType,
output_type: DataType,

runtimes: Mutex<Vec<arrow_udf_js::Runtime>>,
}

impl JsRuntimePool {
fn new(name: String, code: String, state_type: ArrowType, output_type: DataType) -> Self {
Self {
name,
code,
state_type,
output_type,
runtimes: Mutex::new(vec![]),
}
}
impl RuntimeBuilder<arrow_udf_js::Runtime> for JsRuntimeBuilder {
type Error = ErrorCode;

fn create(&self) -> Result<arrow_udf_js::Runtime> {
fn build(&self) -> std::result::Result<arrow_udf_js::Runtime, Self::Error> {
let mut runtime = match arrow_udf_js::Runtime::new() {
Ok(runtime) => runtime,
Err(e) => {
Expand All @@ -331,78 +326,97 @@ impl JsRuntimePool {
converter.set_arrow_extension_key(EXTENSION_KEY);
converter.set_json_extension_name(ARROW_EXT_TYPE_VARIANT);

let output_type: ArrowType = (&self.output_type).into();
runtime
.add_aggregate(
&self.name,
self.state_type.clone(),
output_type,
// we pass the field instead of the data type because arrow-udf-js
// now takes the field as an argument here so that it can get any
// metadata associated with the field
arrow_field_from_data_type(&self.name, self.output_type.clone()),
arrow_udf_js::CallMode::CalledOnNullInput,
&self.code,
)
.map_err(|e| ErrorCode::UDFDataError(format!("Cannot add aggregate: {e}")))?;

Ok(runtime)
}
}

fn call<T, F>(&self, op: F) -> anyhow::Result<T>
where F: FnOnce(&arrow_udf_js::Runtime) -> anyhow::Result<T> {
let mut runtimes = self.runtimes.lock().unwrap();
let runtime = match runtimes.pop() {
Some(runtime) => runtime,
None => self.create()?,
};
drop(runtimes);
fn arrow_field_from_data_type(name: &str, dt: DataType) -> arrow_schema::Field {
let field = DataField::new(name, dt);
(&field).into()
}

type JsRuntimePool = Pool<arrow_udf_js::Runtime, JsRuntimeBuilder>;

#[cfg(feature = "python-udf")]
mod python_pool {
use super::*;

let result = op(&runtime)?;
pub(super) struct PyRuntimeBuilder {
pub name: String,
pub code: String,
pub state_type: ArrowType,
pub output_type: DataType,
}

let mut runtimes = self.runtimes.lock().unwrap();
runtimes.push(runtime);
impl RuntimeBuilder<arrow_udf_python::Runtime> for PyRuntimeBuilder {
type Error = ErrorCode;

Ok(result)
fn build(&self) -> std::result::Result<arrow_udf_python::Runtime, Self::Error> {
let mut runtime = arrow_udf_python::Builder::default()
.sandboxed(true)
.build()?;
runtime.add_aggregate(
&self.name,
self.state_type.clone(),
arrow_field_from_data_type(&self.name, self.output_type.clone()),
arrow_udf_python::CallMode::CalledOnNullInput,
&self.code,
)?;
Ok(runtime)
}
}

pub type PyRuntimePool = Pool<arrow_udf_python::Runtime, PyRuntimeBuilder>;
}

enum UDAFRuntime {
JavaScript(JsRuntimePool),
#[expect(unused)]
WebAssembly,
#[cfg(feature = "python-udf")]
Python(PythonInfo),
}

#[cfg(feature = "python-udf")]
struct PythonInfo {
name: String,
output_type: DataType,
Python(python_pool::PyRuntimePool),
}

impl UDAFRuntime {
fn name(&self) -> &str {
match self {
UDAFRuntime::JavaScript(pool) => &pool.name,
UDAFRuntime::JavaScript(pool) => &pool.builder.name,
#[cfg(feature = "python-udf")]
UDAFRuntime::Python(info) => &info.name,
UDAFRuntime::Python(info) => &info.builder.name,
_ => unimplemented!(),
}
}

fn return_type(&self) -> DataType {
match self {
UDAFRuntime::JavaScript(pool) => pool.output_type.clone(),
UDAFRuntime::JavaScript(pool) => pool.builder.output_type.clone(),
#[cfg(feature = "python-udf")]
UDAFRuntime::Python(info) => info.output_type.clone(),
UDAFRuntime::Python(info) => info.builder.output_type.clone(),
_ => unimplemented!(),
}
}

fn create_state(&self) -> anyhow::Result<UdfAggState> {
let state = match self {
UDAFRuntime::JavaScript(pool) => pool.call(|runtime| runtime.create_state(&pool.name)),
UDAFRuntime::JavaScript(pool) => {
pool.call(|runtime| runtime.create_state(&pool.builder.name))
}
#[cfg(feature = "python-udf")]
UDAFRuntime::Python(info) => {
let runtime = GLOBAL_PYTHON_RUNTIME.read();
runtime.create_state(&info.name)
UDAFRuntime::Python(pool) => {
pool.call(|runtime| runtime.create_state(&pool.builder.name))
}
_ => unimplemented!(),
}?;
Expand All @@ -412,12 +426,11 @@ impl UDAFRuntime {
fn accumulate(&self, state: &UdfAggState, input: &RecordBatch) -> anyhow::Result<UdfAggState> {
let state = match self {
UDAFRuntime::JavaScript(pool) => {
pool.call(|runtime| runtime.accumulate(&pool.name, &state.0, input))
pool.call(|runtime| runtime.accumulate(&pool.builder.name, &state.0, input))
}
#[cfg(feature = "python-udf")]
UDAFRuntime::Python(info) => {
let runtime = GLOBAL_PYTHON_RUNTIME.read();
runtime.accumulate(&info.name, &state.0, input)
UDAFRuntime::Python(pool) => {
pool.call(|runtime| runtime.accumulate(&pool.builder.name, &state.0, input))
}
_ => unimplemented!(),
}?;
Expand All @@ -426,11 +439,12 @@ impl UDAFRuntime {

fn merge(&self, states: &Arc<dyn Array>) -> anyhow::Result<UdfAggState> {
let state = match self {
UDAFRuntime::JavaScript(pool) => pool.call(|runtime| runtime.merge(&pool.name, states)),
UDAFRuntime::JavaScript(pool) => {
pool.call(|runtime| runtime.merge(&pool.builder.name, states))
}
#[cfg(feature = "python-udf")]
UDAFRuntime::Python(info) => {
let runtime = GLOBAL_PYTHON_RUNTIME.read();
runtime.merge(&info.name, states)
UDAFRuntime::Python(pool) => {
pool.call(|runtime| runtime.merge(&pool.builder.name, states))
}
_ => unimplemented!(),
}?;
Expand All @@ -440,12 +454,11 @@ impl UDAFRuntime {
fn finish(&self, state: &UdfAggState) -> anyhow::Result<Arc<dyn Array>> {
match self {
UDAFRuntime::JavaScript(pool) => {
pool.call(|runtime| runtime.finish(&pool.name, &state.0))
pool.call(|runtime| runtime.finish(&pool.builder.name, &state.0))
}
#[cfg(feature = "python-udf")]
UDAFRuntime::Python(info) => {
let runtime = GLOBAL_PYTHON_RUNTIME.read();
runtime.finish(&info.name, &state.0)
UDAFRuntime::Python(pool) => {
pool.call(|runtime| runtime.finish(&pool.builder.name, &state.0))
}
_ => unimplemented!(),
}
Expand Down Expand Up @@ -495,9 +508,9 @@ mod tests {
Field::new("sum", ArrowType::Int64, false),
Field::new("weight", ArrowType::Int64, false),
];
let pool = JsRuntimePool::new(
agg_name.clone(),
r#"
let builder = JsRuntimeBuilder {
name: agg_name.clone(),
code: r#"
export function create_state() {
return {sum: 0, weight: 0};
}
Expand All @@ -521,9 +534,10 @@ export function finish(state) {
}
"#
.to_string(),
ArrowType::Struct(fields.clone().into()),
Float32Type::data_type(),
);
state_type: ArrowType::Struct(fields.clone().into()),
output_type: Float32Type::data_type(),
};
let pool = JsRuntimePool::new(builder);

let state = pool.call(|runtime| runtime.create_state(&agg_name))?;

Expand Down
14 changes: 1 addition & 13 deletions src/query/service/src/pipelines/processors/transforms/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
pub mod aggregator;
mod hash_join;
pub(crate) mod range_join;
mod runtime_pool;
mod transform_add_computed_columns;
mod transform_add_const_columns;
mod transform_add_internal_columns;
Expand Down Expand Up @@ -66,16 +67,3 @@ pub use transform_stream_sort_spill::*;
pub use transform_udf_script::TransformUdfScript;
pub use transform_udf_server::TransformUdfServer;
pub use window::*;

#[cfg(feature = "python-udf")]
mod python_udf {
use std::sync::Arc;
use std::sync::LazyLock;

use arrow_udf_python::Runtime;
use parking_lot::RwLock;

/// python runtime should be only initialized once by gil lock, see: https://github.com/python/cpython/blob/main/Python/pystate.c
pub static GLOBAL_PYTHON_RUNTIME: LazyLock<Arc<RwLock<Runtime>>> =
LazyLock::new(|| Arc::new(RwLock::new(Runtime::new().unwrap())));
}
Loading

0 comments on commit 355a23a

Please sign in to comment.