diff --git a/Cargo.lock b/Cargo.lock index ec9b4ae466b2..d156fc51229f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3800,7 +3800,6 @@ dependencies = [ "serde", "serde_json", "sql-query-builder", - "sql-query-connector", "telemetry", "thiserror", "tokio", @@ -3836,6 +3835,7 @@ dependencies = [ "serde", "serde_json", "serial_test", + "sql-query-builder", "sql-query-connector", "structopt", "telemetry", @@ -5133,6 +5133,7 @@ dependencies = [ name = "sql-query-builder" version = "0.1.0" dependencies = [ + "bigdecimal", "chrono", "itertools 0.12.0", "prisma-value", @@ -5160,6 +5161,7 @@ dependencies = [ "prisma-value", "psl", "quaint", + "query-builder", "query-connector", "query-structure", "rand 0.8.5", diff --git a/Cargo.toml b/Cargo.toml index 4658ac4617c9..3d3948494403 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -38,6 +38,7 @@ members = [ [workspace.dependencies] async-trait = { version = "0.1.77" } +bigdecimal = "0.3" enumflags2 = { version = "0.7", features = ["serde"] } futures = "0.3" psl = { path = "./psl/psl" } diff --git a/query-engine/connectors/sql-query-connector/Cargo.toml b/query-engine/connectors/sql-query-connector/Cargo.toml index d53013d15832..0c04ef95f340 100644 --- a/query-engine/connectors/sql-query-connector/Cargo.toml +++ b/query-engine/connectors/sql-query-connector/Cargo.toml @@ -62,6 +62,9 @@ path = "../query-connector" [dependencies.query-structure] path = "../../query-structure" +[dependencies.query-builder] +path = "../../query-builders/query-builder" + [dependencies.sql-query-builder] path = "../../query-builders/sql-query-builder" diff --git a/query-engine/connectors/sql-query-connector/src/database/operations/read.rs b/query-engine/connectors/sql-query-connector/src/database/operations/read.rs index 9a9ac4469ce0..7e3c881f802c 100644 --- a/query-engine/connectors/sql-query-connector/src/database/operations/read.rs +++ b/query-engine/connectors/sql-query-connector/src/database/operations/read.rs @@ -8,8 +8,9 @@ use crate::{QueryExt, Queryable, SqlError}; use connector_interface::*; use futures::stream::{FuturesUnordered, StreamExt}; use quaint::ast::*; +use query_builder::QueryArgumentsExt; use query_structure::*; -use sql_query_builder::{column_metadata, read, AsColumns, AsTable, Context, QueryArgumentsExt, RelationFieldExt}; +use sql_query_builder::{column_metadata, read, AsColumns, AsTable, Context, RelationFieldExt}; pub(crate) async fn get_single_record( conn: &dyn Queryable, diff --git a/query-engine/connectors/sql-query-connector/src/database/operations/read/process.rs b/query-engine/connectors/sql-query-connector/src/database/operations/read/process.rs index 042dc2815b6a..56332089f975 100644 --- a/query-engine/connectors/sql-query-connector/src/database/operations/read/process.rs +++ b/query-engine/connectors/sql-query-connector/src/database/operations/read/process.rs @@ -1,8 +1,8 @@ use std::borrow::Cow; use itertools::{Either, Itertools}; +use query_builder::QueryArgumentsExt; use query_structure::{QueryArguments, Record}; -use sql_query_builder::QueryArgumentsExt; macro_rules! processor_state { ($name:ident $(-> $transition:ident($bound:ident))?) => { diff --git a/query-engine/connectors/sql-query-connector/src/database/operations/write.rs b/query-engine/connectors/sql-query-connector/src/database/operations/write.rs index 07a385bab3ce..8fe2bcaac411 100644 --- a/query-engine/connectors/sql-query-connector/src/database/operations/write.rs +++ b/query-engine/connectors/sql-query-connector/src/database/operations/write.rs @@ -2,8 +2,7 @@ use super::update::*; use crate::row::ToSqlRow; use crate::value::to_prisma_value; use crate::{error::SqlError, QueryExt, Queryable}; -use itertools::Itertools; -use quaint::ast::{Insert, Query}; +use quaint::ast::Query; use quaint::prelude::ResultSet; use quaint::{ error::ErrorKind, @@ -12,32 +11,9 @@ use quaint::{ use query_structure::*; use sql_query_builder::{column_metadata, write, Context, FilterBuilder, SelectionResultExt, SqlTraceComment}; use std::borrow::Cow; -use std::{ - collections::{HashMap, HashSet}, - ops::Deref, -}; +use std::collections::HashMap; use user_facing_errors::query_engine::DatabaseConstraint; -#[cfg(target_arch = "wasm32")] -macro_rules! trace { - (target: $target:expr, $($arg:tt)+) => {{ - // No-op in WebAssembly - }}; - ($($arg:tt)+) => {{ - // No-op in WebAssembly - }}; -} - -#[cfg(not(target_arch = "wasm32"))] -macro_rules! trace { - (target: $target:expr, $($arg:tt)+) => { - tracing::log::trace!(target: $target, $($arg)+); - }; - ($($arg:tt)+) => { - tracing::log::trace!($($arg)+); - }; -} - async fn generate_id( conn: &dyn Queryable, id_field: &FieldSelection, @@ -191,49 +167,6 @@ pub(crate) async fn create_record( } } -/// Returns a set of fields that are used in the arguments for the create operation. -fn collect_affected_fields(args: &[WriteArgs], model: &Model) -> HashSet { - let mut fields = HashSet::new(); - args.iter().for_each(|arg| fields.extend(arg.keys())); - - fields - .into_iter() - .map(|dsfn| model.fields().scalar().find(|sf| sf.db_name() == dsfn.deref()).unwrap()) - .collect() -} - -/// Generates a list of insert statements to execute. If `selected_fields` is set, insert statements -/// will return the specified columns of inserted rows. -pub fn generate_insert_statements( - model: &Model, - args: Vec, - skip_duplicates: bool, - selected_fields: Option<&ModelProjection>, - ctx: &Context<'_>, -) -> Vec> { - let affected_fields = collect_affected_fields(&args, model); - - if affected_fields.is_empty() { - args.into_iter() - .map(|_| write::create_records_empty(model, skip_duplicates, selected_fields, ctx)) - .collect() - } else { - let partitioned_batches = partition_into_batches(args, ctx); - trace!("Total of {} batches to be executed.", partitioned_batches.len()); - trace!( - "Batch sizes: {:?}", - partitioned_batches.iter().map(|b| b.len()).collect_vec() - ); - - partitioned_batches - .into_iter() - .map(|batch| { - write::create_records_nonempty(model, batch, skip_duplicates, &affected_fields, selected_fields, ctx) - }) - .collect() - } -} - /// Inserts records specified as a list of `WriteArgs`. Returns number of inserted records. pub(crate) async fn create_records_count( conn: &dyn Queryable, @@ -242,7 +175,7 @@ pub(crate) async fn create_records_count( skip_duplicates: bool, ctx: &Context<'_>, ) -> crate::Result { - let inserts = generate_insert_statements(model, args, skip_duplicates, None, ctx); + let inserts = write::generate_insert_statements(model, args, skip_duplicates, None, ctx); let mut count = 0; for insert in inserts { count += conn.execute(insert.into()).await?; @@ -265,7 +198,7 @@ pub(crate) async fn create_records_returning( let idents = selected_fields.type_identifiers_with_arities(); let meta = column_metadata::create(&field_names, &idents); let mut records = ManyRecords::new(field_names.clone()); - let inserts = generate_insert_statements(model, args, skip_duplicates, Some(&selected_fields.into()), ctx); + let inserts = write::generate_insert_statements(model, args, skip_duplicates, Some(&selected_fields.into()), ctx); for insert in inserts { let result_set = conn.query(insert.into()).await?; @@ -281,74 +214,6 @@ pub(crate) async fn create_records_returning( Ok(records) } -/// Partitions data into batches, respecting `max_bind_values` and `max_insert_rows` settings from -/// the `Context`. -fn partition_into_batches(args: Vec, ctx: &Context<'_>) -> Vec> { - let batches = if let Some(max_params) = ctx.max_bind_values() { - // We need to split inserts if they are above a parameter threshold, as well as split based on number of rows. - // -> Horizontal partitioning by row number, vertical by number of args. - args.into_iter() - .peekable() - .batching(|iter| { - let mut param_count: usize = 0; - let mut batch = vec![]; - - while param_count < max_params { - // If the param count _including_ the next item doens't exceed the limit, - // we continue filling up the current batch. - let proceed = match iter.peek() { - Some(next) => (param_count + next.len()) <= max_params, - None => break, - }; - - if proceed { - match iter.next() { - Some(next) => { - param_count += next.len(); - batch.push(next) - } - None => break, - } - } else { - break; - } - } - - if batch.is_empty() { - None - } else { - Some(batch) - } - }) - .collect_vec() - } else { - vec![args] - }; - - if let Some(max_rows) = ctx.max_insert_rows() { - let capacity = batches.len(); - batches - .into_iter() - .fold(Vec::with_capacity(capacity), |mut batches, next_batch| { - if next_batch.len() > max_rows { - batches.extend( - next_batch - .into_iter() - .chunks(max_rows) - .into_iter() - .map(|chunk| chunk.into_iter().collect_vec()), - ); - } else { - batches.push(next_batch); - } - - batches - }) - } else { - batches - } -} - /// Update one record in a database defined in `conn` and the records /// defined in `args`, resulting the identifiers that were modified in the /// operation. diff --git a/query-engine/connectors/sql-query-connector/src/lib.rs b/query-engine/connectors/sql-query-connector/src/lib.rs index 28ec5862e227..b29085a918d0 100644 --- a/query-engine/connectors/sql-query-connector/src/lib.rs +++ b/query-engine/connectors/sql-query-connector/src/lib.rs @@ -11,8 +11,6 @@ mod value; use self::{query_ext::QueryExt, row::*}; use quaint::prelude::Queryable; -pub use database::operations::write::generate_insert_statements; - pub use database::FromSource; #[cfg(feature = "driver-adapters")] pub use database::Js; diff --git a/query-engine/core/Cargo.toml b/query-engine/core/Cargo.toml index cd41c4ccf840..a1d976e71416 100644 --- a/query-engine/core/Cargo.toml +++ b/query-engine/core/Cargo.toml @@ -9,7 +9,7 @@ graphql-protocol = [] [dependencies] async-trait.workspace = true -bigdecimal = "0.3" +bigdecimal.workspace = true chrono.workspace = true connection-string.workspace = true connector = { path = "../connectors/query-connector", package = "query-connector" } @@ -45,10 +45,5 @@ lru = "0.7.7" enumflags2.workspace = true derive_more.workspace = true -# HACK: query builders need to be a separate crate, and maybe the compiler too -# HACK: we hardcode PostgreSQL as the dialect for now -sql-query-connector = { path = "../connectors/sql-query-connector", features = [ - "postgresql", -] } # HACK: this should not be in core either -quaint.workspace = true +quaint = { workspace = true, features = ["postgresql"] } diff --git a/query-engine/core/src/compiler/mod.rs b/query-engine/core/src/compiler/mod.rs index 878daa2f0e35..f1d0e8c75e37 100644 --- a/query-engine/core/src/compiler/mod.rs +++ b/query-engine/core/src/compiler/mod.rs @@ -4,8 +4,12 @@ pub mod translate; use std::sync::Arc; pub use expression::Expression; -use quaint::connector::ConnectionInfo; +use quaint::{ + prelude::{ConnectionInfo, SqlFamily}, + visitor, +}; use schema::QuerySchema; +use sql_query_builder::{Context, SqlQueryBuilder}; use thiserror::Error; pub use translate::{translate, TranslateError}; @@ -29,6 +33,16 @@ pub fn compile( return Err(CompileError::UnsupportedRequest.into()); }; + let ctx = Context::new(connection_info, None); let (graph, _serializer) = QueryGraphBuilder::new(query_schema).build(query)?; - Ok(translate(graph, connection_info).map_err(CompileError::from)?) + let res = match connection_info.sql_family() { + SqlFamily::Postgres => translate(graph, &SqlQueryBuilder::>::new(ctx)), + // feature flags are disabled for now + // SqlFamily::Mysql => translate(graph, &SqlQueryBuilder::>::new(ctx)), + // SqlFamily::Sqlite => translate(graph, &SqlQueryBuilder::>::new(ctx)), + // SqlFamily::Mssql => translate(graph, &SqlQueryBuilder::>::new(ctx)), + _ => unimplemented!(), + }; + + Ok(res.map_err(CompileError::TranslateError)?) } diff --git a/query-engine/core/src/compiler/translate.rs b/query-engine/core/src/compiler/translate.rs index 81742b50c304..a5fa6b30f0f1 100644 --- a/query-engine/core/src/compiler/translate.rs +++ b/query-engine/core/src/compiler/translate.rs @@ -1,8 +1,8 @@ mod query; use crate::{EdgeRef, Node, NodeRef, Query, QueryGraph}; -use quaint::connector::ConnectionInfo; use query::translate_query; +use query_builder::QueryBuilder; use thiserror::Error; use super::expression::{Binding, Expression}; @@ -12,41 +12,41 @@ pub enum TranslateError { #[error("node {0} has no content")] NodeContentEmpty(String), - #[error("{0}")] - QuaintError(#[from] quaint::error::Error), + #[error("query builder error: {0}")] + QueryBuildFailure(#[source] Box), } pub type TranslateResult = Result; -pub fn translate(mut graph: QueryGraph, connection_info: &ConnectionInfo) -> TranslateResult { +pub fn translate(mut graph: QueryGraph, builder: &dyn QueryBuilder) -> TranslateResult { graph .root_nodes() .into_iter() - .map(|node| NodeTranslator::new(&mut graph, node, &[], connection_info).translate()) + .map(|node| NodeTranslator::new(&mut graph, node, &[], builder).translate()) .collect::>>() .map(Expression::Seq) } -struct NodeTranslator<'a, 'b, 'c> { +struct NodeTranslator<'a, 'b> { graph: &'a mut QueryGraph, node: NodeRef, #[allow(dead_code)] parent_edges: &'b [EdgeRef], - connection_info: &'c ConnectionInfo, + query_builder: &'b dyn QueryBuilder, } -impl<'a, 'b, 'c> NodeTranslator<'a, 'b, 'c> { +impl<'a, 'b> NodeTranslator<'a, 'b> { fn new( graph: &'a mut QueryGraph, node: NodeRef, parent_edges: &'b [EdgeRef], - connection_info: &'c ConnectionInfo, + query_builder: &'b dyn QueryBuilder, ) -> Self { Self { graph, node, parent_edges, - connection_info, + query_builder, } } @@ -71,7 +71,7 @@ impl<'a, 'b, 'c> NodeTranslator<'a, 'b, 'c> { .try_into() .expect("current node must be query"); - translate_query(query, self.connection_info) + translate_query(query, self.query_builder) } #[allow(dead_code)] @@ -106,7 +106,7 @@ impl<'a, 'b, 'c> NodeTranslator<'a, 'b, 'c> { .into_iter() .map(|(_, node)| { let edges = self.graph.incoming_edges(&node); - NodeTranslator::new(self.graph, node, &edges, self.connection_info).translate() + NodeTranslator::new(self.graph, node, &edges, self.query_builder).translate() }) .collect::, _>>()?; @@ -128,7 +128,7 @@ impl<'a, 'b, 'c> NodeTranslator<'a, 'b, 'c> { .map(|(_, node)| { let name = node.id(); let edges = self.graph.incoming_edges(&node); - let expr = NodeTranslator::new(self.graph, node, &edges, self.connection_info).translate()?; + let expr = NodeTranslator::new(self.graph, node, &edges, self.query_builder).translate()?; Ok(Binding { name, expr }) }) .collect::>>()?; diff --git a/query-engine/core/src/compiler/translate/query.rs b/query-engine/core/src/compiler/translate/query.rs index 23c9e7321972..1fe1acf243ad 100644 --- a/query-engine/core/src/compiler/translate/query.rs +++ b/query-engine/core/src/compiler/translate/query.rs @@ -1,42 +1,17 @@ -mod convert; mod read; mod write; -use quaint::{ - prelude::{ConnectionInfo, SqlFamily}, - visitor::Visitor, -}; -use query_builder::DbQuery; +use query_builder::QueryBuilder; use read::translate_read_query; -use sql_query_builder::Context; use write::translate_write_query; use crate::{compiler::expression::Expression, Query}; use super::TranslateResult; -pub(crate) fn translate_query(query: Query, connection_info: &ConnectionInfo) -> TranslateResult { - let ctx = Context::new(connection_info, None); - +pub(crate) fn translate_query(query: Query, builder: &dyn QueryBuilder) -> TranslateResult { match query { - Query::Read(rq) => translate_read_query(rq, &ctx), - Query::Write(wq) => translate_write_query(wq, &ctx), + Query::Read(rq) => translate_read_query(rq, builder), + Query::Write(wq) => translate_write_query(wq, builder), } } - -fn build_db_query<'a>(query: impl Into>, ctx: &Context<'_>) -> TranslateResult { - let (sql, params) = match ctx.connection_info.sql_family() { - SqlFamily::Postgres => quaint::visitor::Postgres::build(query)?, - // TODO: implement proper switch for other databases once proper feature flags are supported/logic is extracted - _ => unimplemented!(), - // SqlFamily::Mysql => quaint::visitor::Mysql::build(query)?, - // SqlFamily::Sqlite => quaint::visitor::Sqlite::build(query)?, - // SqlFamily::Mssql => quaint::visitor::Mssql::build(query)?, - }; - - let params = params - .into_iter() - .map(convert::quaint_value_to_prisma_value) - .collect::>(); - Ok(DbQuery::new(sql, params)) -} diff --git a/query-engine/core/src/compiler/translate/query/read.rs b/query-engine/core/src/compiler/translate/query/read.rs index 0839416ce7c4..f736331ec8ac 100644 --- a/query-engine/core/src/compiler/translate/query/read.rs +++ b/query-engine/core/src/compiler/translate/query/read.rs @@ -4,40 +4,37 @@ use crate::{ compiler::{ expression::{Binding, Expression, JoinExpression}, translate::TranslateResult, + TranslateError, }, FilteredQuery, ReadQuery, RelatedRecordsQuery, }; use itertools::Itertools; +use query_builder::{QueryArgumentsExt, QueryBuilder}; use query_structure::{ - ConditionValue, Filter, ModelProjection, PrismaValue, QueryMode, ScalarCondition, ScalarFilter, ScalarProjection, + ConditionValue, Filter, PrismaValue, QueryArguments, QueryMode, ScalarCondition, ScalarFilter, ScalarProjection, }; -use sql_query_builder::{read, AsColumns, Context, QueryArgumentsExt}; -use super::build_db_query; - -pub(crate) fn translate_read_query(query: ReadQuery, ctx: &Context<'_>) -> TranslateResult { +pub(crate) fn translate_read_query(query: ReadQuery, builder: &dyn QueryBuilder) -> TranslateResult { Ok(match query { ReadQuery::RecordQuery(rq) => { let selected_fields = rq.selected_fields.without_relations().into_virtuals_last(); - let query = read::get_records( - &rq.model, - ModelProjection::from(&selected_fields) - .as_columns(ctx) - .mark_all_selected(), - selected_fields.virtuals(), + let args = QueryArguments::from(( + rq.model.clone(), rq.filter.expect("ReadOne query should always have filter set"), - ctx, - ) - .limit(1); + )) + .with_take(Some(1)); + let query = builder + .build_get_records(&rq.model, args, &selected_fields) + .map_err(TranslateError::QueryBuildFailure)?; - let expr = Expression::Query(build_db_query(query, ctx)?); + let expr = Expression::Query(query); let expr = Expression::Unique(Box::new(expr)); if rq.nested.is_empty() { expr } else { - add_inmemory_join(expr, rq.nested, ctx)? + add_inmemory_join(expr, rq.nested, builder)? } } @@ -46,17 +43,11 @@ pub(crate) fn translate_read_query(query: ReadQuery, ctx: &Context<'_>) -> Trans let needs_reversed_order = mrq.args.needs_reversed_order(); // TODO: we ignore chunking for now - let query = read::get_records( - &mrq.model, - ModelProjection::from(&selected_fields) - .as_columns(ctx) - .mark_all_selected(), - selected_fields.virtuals(), - mrq.args, - ctx, - ); - - let expr = Expression::Query(build_db_query(query, ctx)?); + let query = builder + .build_get_records(&mrq.model, mrq.args, &selected_fields) + .map_err(TranslateError::QueryBuildFailure)?; + + let expr = Expression::Query(query); let expr = if needs_reversed_order { Expression::Reverse(Box::new(expr)) @@ -67,15 +58,15 @@ pub(crate) fn translate_read_query(query: ReadQuery, ctx: &Context<'_>) -> Trans if mrq.nested.is_empty() { expr } else { - add_inmemory_join(expr, mrq.nested, ctx)? + add_inmemory_join(expr, mrq.nested, builder)? } } ReadQuery::RelatedRecordsQuery(rrq) => { if rrq.parent_field.relation().is_many_to_many() { - build_read_m2m_query(rrq, ctx)? + build_read_m2m_query(rrq, builder)? } else { - build_read_one2m_query(rrq, ctx)? + build_read_one2m_query(rrq, builder)? } } @@ -83,7 +74,11 @@ pub(crate) fn translate_read_query(query: ReadQuery, ctx: &Context<'_>) -> Trans }) } -fn add_inmemory_join(parent: Expression, nested: Vec, ctx: &Context<'_>) -> TranslateResult { +fn add_inmemory_join( + parent: Expression, + nested: Vec, + builder: &dyn QueryBuilder, +) -> TranslateResult { let all_linking_fields = nested .iter() .flat_map(|nested| match nested { @@ -139,7 +134,7 @@ fn add_inmemory_join(parent: Expression, nested: Vec, ctx: &Context<' })); } - let child_query = translate_read_query(ReadQuery::RelatedRecordsQuery(rrq), ctx)?; + let child_query = translate_read_query(ReadQuery::RelatedRecordsQuery(rrq), builder)?; Ok(JoinExpression { child: child_query, @@ -164,29 +159,27 @@ fn add_inmemory_join(parent: Expression, nested: Vec, ctx: &Context<' }) } -fn build_read_m2m_query(_query: RelatedRecordsQuery, _ctx: &Context<'_>) -> TranslateResult { +fn build_read_m2m_query(_query: RelatedRecordsQuery, _builder: &dyn QueryBuilder) -> TranslateResult { todo!() } -fn build_read_one2m_query(rrq: RelatedRecordsQuery, ctx: &Context<'_>) -> TranslateResult { +fn build_read_one2m_query(rrq: RelatedRecordsQuery, builder: &dyn QueryBuilder) -> TranslateResult { let selected_fields = rrq.selected_fields.without_relations().into_virtuals_last(); let needs_reversed_order = rrq.args.needs_reversed_order(); let to_one_relation = !rrq.parent_field.arity().is_list(); // TODO: we ignore chunking for now - let query = read::get_records( - &rrq.parent_field.related_model(), - ModelProjection::from(&selected_fields) - .as_columns(ctx) - .mark_all_selected(), - selected_fields.virtuals(), - rrq.args, - ctx, - ); - let query = if to_one_relation { query.limit(1) } else { query }; + let args = if to_one_relation { + rrq.args.with_take(Some(1)) + } else { + rrq.args + }; + let query = builder + .build_get_records(&rrq.parent_field.related_model(), args, &selected_fields) + .map_err(TranslateError::QueryBuildFailure)?; - let mut expr = Expression::Query(build_db_query(query, ctx)?); + let mut expr = Expression::Query(query); if to_one_relation { expr = Expression::Unique(Box::new(expr)); @@ -199,6 +192,6 @@ fn build_read_one2m_query(rrq: RelatedRecordsQuery, ctx: &Context<'_>) -> Transl if rrq.nested.is_empty() { Ok(expr) } else { - add_inmemory_join(expr, rrq.nested, ctx) + add_inmemory_join(expr, rrq.nested, builder) } } diff --git a/query-engine/core/src/compiler/translate/query/write.rs b/query-engine/core/src/compiler/translate/query/write.rs index 4e361a6f746e..286fd11e9147 100644 --- a/query-engine/core/src/compiler/translate/query/write.rs +++ b/query-engine/core/src/compiler/translate/query/write.rs @@ -1,49 +1,43 @@ -use query_structure::ModelProjection; -use sql_query_builder::{write, Context}; -use sql_query_connector::generate_insert_statements; +use query_builder::QueryBuilder; use crate::{ - compiler::{expression::Expression, translate::TranslateResult}, + compiler::{expression::Expression, translate::TranslateResult, TranslateError}, WriteQuery, }; -use super::build_db_query; - -pub(crate) fn translate_write_query(query: WriteQuery, ctx: &Context<'_>) -> TranslateResult { +pub(crate) fn translate_write_query(query: WriteQuery, builder: &dyn QueryBuilder) -> TranslateResult { Ok(match query { WriteQuery::CreateRecord(cr) => { // TODO: MySQL needs additional logic to generate IDs on our side. // See sql_query_connector::database::operations::write::create_record - let query = write::create_record(&cr.model, cr.args, &ModelProjection::from(&cr.selected_fields), ctx); + let query = builder + .build_create_record(&cr.model, cr.args, &cr.selected_fields) + .map_err(TranslateError::QueryBuildFailure)?; // TODO: we probably need some additional node type or extra info in the WriteQuery node // to help the client executor figure out the returned ID in the case when it's inferred // from the query arguments. - Expression::Query(build_db_query(query, ctx)?) + Expression::Query(query) } WriteQuery::CreateManyRecords(cmr) => { if let Some(selected_fields) = cmr.selected_fields { Expression::Concat( - generate_insert_statements( - &cmr.model, - cmr.args, - cmr.skip_duplicates, - Some(&selected_fields.fields.into()), - ctx, - ) - .into_iter() - .map(|query| build_db_query(query, ctx)) - .map(|maybe_db_query| maybe_db_query.map(Expression::Execute)) - .collect::>>()?, + builder + .build_inserts(&cmr.model, cmr.args, cmr.skip_duplicates, Some(&selected_fields.fields)) + .map_err(TranslateError::QueryBuildFailure)? + .into_iter() + .map(Expression::Execute) + .collect::>(), ) } else { Expression::Sum( - generate_insert_statements(&cmr.model, cmr.args, cmr.skip_duplicates, None, ctx) + builder + .build_inserts(&cmr.model, cmr.args, cmr.skip_duplicates, None) + .map_err(TranslateError::QueryBuildFailure)? .into_iter() - .map(|query| build_db_query(query, ctx)) - .map(|maybe_db_query| maybe_db_query.map(Expression::Execute)) - .collect::>>()?, + .map(Expression::Execute) + .collect::>(), ) } } diff --git a/query-engine/query-builders/query-builder/Cargo.toml b/query-engine/query-builders/query-builder/Cargo.toml index 4c35b489e828..3bc481aee186 100644 --- a/query-engine/query-builders/query-builder/Cargo.toml +++ b/query-engine/query-builders/query-builder/Cargo.toml @@ -7,3 +7,6 @@ version = "0.1.0" serde.workspace = true query-structure = { path = "../../query-structure" } + +[features] +relation_joins = [] diff --git a/query-engine/query-builders/query-builder/src/lib.rs b/query-engine/query-builders/query-builder/src/lib.rs index 240ca848580c..682ec5e21c2b 100644 --- a/query-engine/query-builders/query-builder/src/lib.rs +++ b/query-engine/query-builders/query-builder/src/lib.rs @@ -1,5 +1,32 @@ -use query_structure::PrismaValue; +use query_structure::{FieldSelection, Model, PrismaValue, QueryArguments, WriteArgs}; use serde::Serialize; +mod query_arguments_ext; + +pub use query_arguments_ext::QueryArgumentsExt; + +pub trait QueryBuilder { + fn build_get_records( + &self, + model: &Model, + query_arguments: QueryArguments, + selected_fields: &FieldSelection, + ) -> Result>; + + fn build_create_record( + &self, + model: &Model, + args: WriteArgs, + selected_fields: &FieldSelection, + ) -> Result>; + + fn build_inserts( + &self, + model: &Model, + args: Vec, + skip_duplicates: bool, + selected_fields: Option<&FieldSelection>, + ) -> Result, Box>; +} #[derive(Debug, Serialize)] pub struct DbQuery { diff --git a/query-engine/query-builders/sql-query-builder/src/query_arguments_ext.rs b/query-engine/query-builders/query-builder/src/query_arguments_ext.rs similarity index 100% rename from query-engine/query-builders/sql-query-builder/src/query_arguments_ext.rs rename to query-engine/query-builders/query-builder/src/query_arguments_ext.rs diff --git a/query-engine/query-builders/sql-query-builder/Cargo.toml b/query-engine/query-builders/sql-query-builder/Cargo.toml index 80cccff5f961..c24274d2749e 100644 --- a/query-engine/query-builders/sql-query-builder/Cargo.toml +++ b/query-engine/query-builders/sql-query-builder/Cargo.toml @@ -13,7 +13,8 @@ psl = { path = "../../../psl/psl" } itertools.workspace = true chrono.workspace = true +bigdecimal.workspace = true serde_json.workspace = true [features] -relation_joins = [] +relation_joins = ["query-builder/relation_joins"] diff --git a/query-engine/core/src/compiler/translate/query/convert.rs b/query-engine/query-builders/sql-query-builder/src/convert.rs similarity index 98% rename from query-engine/core/src/compiler/translate/query/convert.rs rename to query-engine/query-builders/sql-query-builder/src/convert.rs index 2ea8463f93c0..e4a4c0cbc3d6 100644 --- a/query-engine/core/src/compiler/translate/query/convert.rs +++ b/query-engine/query-builders/sql-query-builder/src/convert.rs @@ -1,7 +1,7 @@ use bigdecimal::{BigDecimal, FromPrimitive}; use chrono::{DateTime, NaiveDate, Utc}; +use prisma_value::{PlaceholderType, PrismaValue}; use quaint::ast::VarType; -use query_structure::{PlaceholderType, PrismaValue}; pub(crate) fn quaint_value_to_prisma_value(value: quaint::Value<'_>) -> PrismaValue { match value.typed { diff --git a/query-engine/query-builders/sql-query-builder/src/cursor_condition.rs b/query-engine/query-builders/sql-query-builder/src/cursor_condition.rs index 1f91eff8a299..542885bcd8fc 100644 --- a/query-engine/query-builders/sql-query-builder/src/cursor_condition.rs +++ b/query-engine/query-builders/sql-query-builder/src/cursor_condition.rs @@ -2,11 +2,11 @@ use crate::{ join_utils::AliasedJoin, model_extensions::{AsColumn, AsColumns, AsTable, SelectionResultExt}, ordering::OrderByDefinition, - query_arguments_ext::QueryArgumentsExt, Context, }; use itertools::Itertools; use quaint::ast::*; +use query_builder::QueryArgumentsExt; use query_structure::*; #[derive(Debug)] diff --git a/query-engine/query-builders/sql-query-builder/src/lib.rs b/query-engine/query-builders/sql-query-builder/src/lib.rs index dafabf1f3772..f4f8cee9749b 100644 --- a/query-engine/query-builders/sql-query-builder/src/lib.rs +++ b/query-engine/query-builders/sql-query-builder/src/lib.rs @@ -1,5 +1,6 @@ pub mod column_metadata; mod context; +mod convert; mod cursor_condition; mod filter; mod join_utils; @@ -7,25 +8,97 @@ pub mod limit; mod model_extensions; mod nested_aggregations; mod ordering; -mod query_arguments_ext; pub mod read; #[cfg(feature = "relation_joins")] pub mod select; mod sql_trace; pub mod write; -use quaint::ast::{Column, Comparable, ConditionTree, Query, Row, Values}; -use query_structure::SelectionResult; +use std::marker::PhantomData; + +use quaint::{ + ast::{Column, Comparable, ConditionTree, Query, Row, Values}, + visitor::Visitor, +}; +use query_builder::{DbQuery, QueryBuilder}; +use query_structure::{FieldSelection, Model, ModelProjection, QueryArguments, SelectionResult, WriteArgs}; pub use column_metadata::ColumnMetadata; pub use context::Context; pub use filter::FilterBuilder; pub use model_extensions::{AsColumn, AsColumns, AsTable, RelationFieldExt, SelectionResultExt}; -pub use query_arguments_ext::QueryArgumentsExt; pub use sql_trace::SqlTraceComment; const PARAMETER_LIMIT: usize = 2000; +pub struct SqlQueryBuilder<'a, Visitor> { + context: Context<'a>, + phantom: PhantomData, +} + +impl<'a, V> SqlQueryBuilder<'a, V> { + pub fn new(context: Context<'a>) -> Self { + Self { + context, + phantom: PhantomData, + } + } + + fn convert_query(&self, query: impl Into>) -> Result> + where + V: Visitor<'a>, + { + let (sql, params) = V::build(query)?; + let params = params + .into_iter() + .map(convert::quaint_value_to_prisma_value) + .collect::>(); + Ok(DbQuery::new(sql, params)) + } +} + +impl<'a, V: Visitor<'a>> QueryBuilder for SqlQueryBuilder<'a, V> { + fn build_get_records( + &self, + model: &Model, + query_arguments: QueryArguments, + selected_fields: &FieldSelection, + ) -> Result> { + let query = read::get_records( + model, + ModelProjection::from(selected_fields) + .as_columns(&self.context) + .mark_all_selected(), + selected_fields.virtuals(), + query_arguments, + &self.context, + ); + self.convert_query(query) + } + + fn build_create_record( + &self, + model: &Model, + args: WriteArgs, + selected_fields: &FieldSelection, + ) -> Result> { + let query = write::create_record(model, args, &selected_fields.into(), &self.context); + self.convert_query(query) + } + + fn build_inserts( + &self, + model: &Model, + args: Vec, + skip_duplicates: bool, + selected_fields: Option<&FieldSelection>, + ) -> Result, Box> { + let projection = selected_fields.map(ModelProjection::from); + let query = write::generate_insert_statements(model, args, skip_duplicates, projection.as_ref(), &self.context); + query.into_iter().map(|q| self.convert_query(q)).collect() + } +} + pub fn chunked_conditions( columns: &[Column<'static>], records: &[&SelectionResult], diff --git a/query-engine/query-builders/sql-query-builder/src/ordering.rs b/query-engine/query-builders/sql-query-builder/src/ordering.rs index 3906a3ca0aa9..dfddd19a8ec5 100644 --- a/query-engine/query-builders/sql-query-builder/src/ordering.rs +++ b/query-engine/query-builders/sql-query-builder/src/ordering.rs @@ -1,7 +1,8 @@ -use crate::{join_utils::*, model_extensions::*, query_arguments_ext::QueryArgumentsExt, Context}; +use crate::{join_utils::*, model_extensions::*, Context}; use itertools::Itertools; use psl::{datamodel_connector::ConnectorCapability, reachable_only_with_capability}; use quaint::ast::*; +use query_builder::QueryArgumentsExt; use query_structure::*; static ORDER_JOIN_PREFIX: &str = "orderby_"; diff --git a/query-engine/query-builders/sql-query-builder/src/write.rs b/query-engine/query-builders/sql-query-builder/src/write.rs index 1059cb6069f8..d9307b01e569 100644 --- a/query-engine/query-builders/sql-query-builder/src/write.rs +++ b/query-engine/query-builders/sql-query-builder/src/write.rs @@ -1,5 +1,6 @@ use crate::limit::wrap_with_limit_subquery_if_needed; use crate::{model_extensions::*, sql_trace::SqlTraceComment, Context}; +use itertools::Itertools; use quaint::ast::*; use query_structure::*; use std::{collections::HashSet, convert::TryInto}; @@ -300,3 +301,107 @@ pub fn delete_relation_table_records( .so_that(parent_id_criteria.and(child_id_criteria)) .add_traceparent(ctx.traceparent) } + +/// Generates a list of insert statements to execute. If `selected_fields` is set, insert statements +/// will return the specified columns of inserted rows. +pub fn generate_insert_statements( + model: &Model, + args: Vec, + skip_duplicates: bool, + selected_fields: Option<&ModelProjection>, + ctx: &Context<'_>, +) -> Vec> { + let affected_fields = collect_affected_fields(&args, model); + + if affected_fields.is_empty() { + args.into_iter() + .map(|_| create_records_empty(model, skip_duplicates, selected_fields, ctx)) + .collect() + } else { + let partitioned_batches = partition_into_batches(args, ctx); + + partitioned_batches + .into_iter() + .map(|batch| create_records_nonempty(model, batch, skip_duplicates, &affected_fields, selected_fields, ctx)) + .collect() + } +} + +/// Returns a set of fields that are used in the arguments for the create operation. +fn collect_affected_fields(args: &[WriteArgs], model: &Model) -> HashSet { + let mut fields = HashSet::new(); + args.iter().for_each(|arg| fields.extend(arg.keys())); + + fields + .into_iter() + .map(|dsfn| model.fields().scalar().find(|sf| sf.db_name() == &**dsfn).unwrap()) + .collect() +} + +/// Partitions data into batches, respecting `max_bind_values` and `max_insert_rows` settings from +/// the `Context`. +fn partition_into_batches(args: Vec, ctx: &Context<'_>) -> Vec> { + let batches = if let Some(max_params) = ctx.max_bind_values() { + // We need to split inserts if they are above a parameter threshold, as well as split based on number of rows. + // -> Horizontal partitioning by row number, vertical by number of args. + args.into_iter() + .peekable() + .batching(|iter| { + let mut param_count: usize = 0; + let mut batch = vec![]; + + while param_count < max_params { + // If the param count _including_ the next item doens't exceed the limit, + // we continue filling up the current batch. + let proceed = match iter.peek() { + Some(next) => (param_count + next.len()) <= max_params, + None => break, + }; + + if proceed { + match iter.next() { + Some(next) => { + param_count += next.len(); + batch.push(next) + } + None => break, + } + } else { + break; + } + } + + if batch.is_empty() { + None + } else { + Some(batch) + } + }) + .collect_vec() + } else { + vec![args] + }; + + if let Some(max_rows) = ctx.max_insert_rows() { + let capacity = batches.len(); + batches + .into_iter() + .fold(Vec::with_capacity(capacity), |mut batches, next_batch| { + if next_batch.len() > max_rows { + batches.extend( + next_batch + .into_iter() + .chunks(max_rows) + .into_iter() + .map(|chunk| chunk.into_iter().collect_vec()), + ); + } else { + batches.push(next_batch); + } + + batches + }) + } else { + batches + } +} diff --git a/query-engine/query-engine/Cargo.toml b/query-engine/query-engine/Cargo.toml index db011f9238d7..5d51d16237c6 100644 --- a/query-engine/query-engine/Cargo.toml +++ b/query-engine/query-engine/Cargo.toml @@ -39,6 +39,7 @@ serial_test = "*" quaint.workspace = true indoc.workspace = true indexmap.workspace = true +sql-query-builder = { path = "../query-builders/sql-query-builder" } [build-dependencies] build-utils.path = "../../libs/build-utils" diff --git a/query-engine/query-engine/examples/compiler.rs b/query-engine/query-engine/examples/compiler.rs index 099ea2a80359..e72a1acdf987 100644 --- a/query-engine/query-engine/examples/compiler.rs +++ b/query-engine/query-engine/examples/compiler.rs @@ -1,9 +1,13 @@ use std::sync::Arc; -use quaint::connector::{ConnectionInfo, ExternalConnectionInfo, SqlFamily}; +use quaint::{ + prelude::{ConnectionInfo, ExternalConnectionInfo, SqlFamily}, + visitor::Postgres, +}; use query_core::{query_graph_builder::QueryGraphBuilder, QueryDocument}; use request_handlers::{JsonBody, JsonSingleQuery, RequestBody}; use serde_json::json; +use sql_query_builder::{Context, SqlQueryBuilder}; pub fn main() -> anyhow::Result<()> { let schema_string = include_str!("./schema.prisma"); @@ -73,7 +77,10 @@ pub fn main() -> anyhow::Result<()> { println!("{graph}"); - let expr = query_core::compiler::translate(graph, &connection_info)?; + let ctx = Context::new(&connection_info, None); + let builder = SqlQueryBuilder::>::new(ctx); + + let expr = query_core::compiler::translate(graph, &builder)?; println!("{}", expr.pretty_print(true, 80)?); diff --git a/query-engine/query-structure/src/query_arguments.rs b/query-engine/query-structure/src/query_arguments.rs index 6739f3eeb0ae..6afc7b9e41c4 100644 --- a/query-engine/query-structure/src/query_arguments.rs +++ b/query-engine/query-structure/src/query_arguments.rs @@ -89,6 +89,11 @@ impl QueryArguments { } } + pub fn with_take(mut self, take: Option) -> Self { + self.take = take; + self + } + pub fn do_nothing(&self) -> bool { self.cursor.is_none() && self.take.is_none()