Skip to content

Commit

Permalink
feat: generic query building (#5127)
Browse files Browse the repository at this point in the history
* feat: generic query building

* fix: compiler example

* fix: comment out non pg variants for now

* fix: unused imports

* chore: remove currently unnecessary impl

* fix: enable postgres feature for now
  • Loading branch information
jacek-prisma authored Jan 16, 2025
1 parent 9cd8301 commit 5f0c232
Show file tree
Hide file tree
Showing 25 changed files with 340 additions and 276 deletions.
4 changes: 3 additions & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ members = [

[workspace.dependencies]
async-trait = { version = "0.1.77" }
bigdecimal = "0.3"
enumflags2 = { version = "0.7", features = ["serde"] }
futures = "0.3"
psl = { path = "./psl/psl" }
Expand Down
3 changes: 3 additions & 0 deletions query-engine/connectors/sql-query-connector/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,9 @@ path = "../query-connector"
[dependencies.query-structure]
path = "../../query-structure"

[dependencies.query-builder]
path = "../../query-builders/query-builder"

[dependencies.sql-query-builder]
path = "../../query-builders/sql-query-builder"

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,9 @@ use crate::{QueryExt, Queryable, SqlError};
use connector_interface::*;
use futures::stream::{FuturesUnordered, StreamExt};
use quaint::ast::*;
use query_builder::QueryArgumentsExt;
use query_structure::*;
use sql_query_builder::{column_metadata, read, AsColumns, AsTable, Context, QueryArgumentsExt, RelationFieldExt};
use sql_query_builder::{column_metadata, read, AsColumns, AsTable, Context, RelationFieldExt};

pub(crate) async fn get_single_record(
conn: &dyn Queryable,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
use std::borrow::Cow;

use itertools::{Either, Itertools};
use query_builder::QueryArgumentsExt;
use query_structure::{QueryArguments, Record};
use sql_query_builder::QueryArgumentsExt;

macro_rules! processor_state {
($name:ident $(-> $transition:ident($bound:ident))?) => {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,7 @@ use super::update::*;
use crate::row::ToSqlRow;
use crate::value::to_prisma_value;
use crate::{error::SqlError, QueryExt, Queryable};
use itertools::Itertools;
use quaint::ast::{Insert, Query};
use quaint::ast::Query;
use quaint::prelude::ResultSet;
use quaint::{
error::ErrorKind,
Expand All @@ -12,32 +11,9 @@ use quaint::{
use query_structure::*;
use sql_query_builder::{column_metadata, write, Context, FilterBuilder, SelectionResultExt, SqlTraceComment};
use std::borrow::Cow;
use std::{
collections::{HashMap, HashSet},
ops::Deref,
};
use std::collections::HashMap;
use user_facing_errors::query_engine::DatabaseConstraint;

#[cfg(target_arch = "wasm32")]
macro_rules! trace {
(target: $target:expr, $($arg:tt)+) => {{
// No-op in WebAssembly
}};
($($arg:tt)+) => {{
// No-op in WebAssembly
}};
}

#[cfg(not(target_arch = "wasm32"))]
macro_rules! trace {
(target: $target:expr, $($arg:tt)+) => {
tracing::log::trace!(target: $target, $($arg)+);
};
($($arg:tt)+) => {
tracing::log::trace!($($arg)+);
};
}

async fn generate_id(
conn: &dyn Queryable,
id_field: &FieldSelection,
Expand Down Expand Up @@ -191,49 +167,6 @@ pub(crate) async fn create_record(
}
}

/// Returns a set of fields that are used in the arguments for the create operation.
fn collect_affected_fields(args: &[WriteArgs], model: &Model) -> HashSet<ScalarFieldRef> {
let mut fields = HashSet::new();
args.iter().for_each(|arg| fields.extend(arg.keys()));

fields
.into_iter()
.map(|dsfn| model.fields().scalar().find(|sf| sf.db_name() == dsfn.deref()).unwrap())
.collect()
}

/// Generates a list of insert statements to execute. If `selected_fields` is set, insert statements
/// will return the specified columns of inserted rows.
pub fn generate_insert_statements(
model: &Model,
args: Vec<WriteArgs>,
skip_duplicates: bool,
selected_fields: Option<&ModelProjection>,
ctx: &Context<'_>,
) -> Vec<Insert<'static>> {
let affected_fields = collect_affected_fields(&args, model);

if affected_fields.is_empty() {
args.into_iter()
.map(|_| write::create_records_empty(model, skip_duplicates, selected_fields, ctx))
.collect()
} else {
let partitioned_batches = partition_into_batches(args, ctx);
trace!("Total of {} batches to be executed.", partitioned_batches.len());
trace!(
"Batch sizes: {:?}",
partitioned_batches.iter().map(|b| b.len()).collect_vec()
);

partitioned_batches
.into_iter()
.map(|batch| {
write::create_records_nonempty(model, batch, skip_duplicates, &affected_fields, selected_fields, ctx)
})
.collect()
}
}

/// Inserts records specified as a list of `WriteArgs`. Returns number of inserted records.
pub(crate) async fn create_records_count(
conn: &dyn Queryable,
Expand All @@ -242,7 +175,7 @@ pub(crate) async fn create_records_count(
skip_duplicates: bool,
ctx: &Context<'_>,
) -> crate::Result<usize> {
let inserts = generate_insert_statements(model, args, skip_duplicates, None, ctx);
let inserts = write::generate_insert_statements(model, args, skip_duplicates, None, ctx);
let mut count = 0;
for insert in inserts {
count += conn.execute(insert.into()).await?;
Expand All @@ -265,7 +198,7 @@ pub(crate) async fn create_records_returning(
let idents = selected_fields.type_identifiers_with_arities();
let meta = column_metadata::create(&field_names, &idents);
let mut records = ManyRecords::new(field_names.clone());
let inserts = generate_insert_statements(model, args, skip_duplicates, Some(&selected_fields.into()), ctx);
let inserts = write::generate_insert_statements(model, args, skip_duplicates, Some(&selected_fields.into()), ctx);

for insert in inserts {
let result_set = conn.query(insert.into()).await?;
Expand All @@ -281,74 +214,6 @@ pub(crate) async fn create_records_returning(
Ok(records)
}

/// Partitions data into batches, respecting `max_bind_values` and `max_insert_rows` settings from
/// the `Context`.
fn partition_into_batches(args: Vec<WriteArgs>, ctx: &Context<'_>) -> Vec<Vec<WriteArgs>> {
let batches = if let Some(max_params) = ctx.max_bind_values() {
// We need to split inserts if they are above a parameter threshold, as well as split based on number of rows.
// -> Horizontal partitioning by row number, vertical by number of args.
args.into_iter()
.peekable()
.batching(|iter| {
let mut param_count: usize = 0;
let mut batch = vec![];

while param_count < max_params {
// If the param count _including_ the next item doens't exceed the limit,
// we continue filling up the current batch.
let proceed = match iter.peek() {
Some(next) => (param_count + next.len()) <= max_params,
None => break,
};

if proceed {
match iter.next() {
Some(next) => {
param_count += next.len();
batch.push(next)
}
None => break,
}
} else {
break;
}
}

if batch.is_empty() {
None
} else {
Some(batch)
}
})
.collect_vec()
} else {
vec![args]
};

if let Some(max_rows) = ctx.max_insert_rows() {
let capacity = batches.len();
batches
.into_iter()
.fold(Vec::with_capacity(capacity), |mut batches, next_batch| {
if next_batch.len() > max_rows {
batches.extend(
next_batch
.into_iter()
.chunks(max_rows)
.into_iter()
.map(|chunk| chunk.into_iter().collect_vec()),
);
} else {
batches.push(next_batch);
}

batches
})
} else {
batches
}
}

/// Update one record in a database defined in `conn` and the records
/// defined in `args`, resulting the identifiers that were modified in the
/// operation.
Expand Down
2 changes: 0 additions & 2 deletions query-engine/connectors/sql-query-connector/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,6 @@ mod value;
use self::{query_ext::QueryExt, row::*};
use quaint::prelude::Queryable;

pub use database::operations::write::generate_insert_statements;

pub use database::FromSource;
#[cfg(feature = "driver-adapters")]
pub use database::Js;
Expand Down
9 changes: 2 additions & 7 deletions query-engine/core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ graphql-protocol = []

[dependencies]
async-trait.workspace = true
bigdecimal = "0.3"
bigdecimal.workspace = true
chrono.workspace = true
connection-string.workspace = true
connector = { path = "../connectors/query-connector", package = "query-connector" }
Expand Down Expand Up @@ -45,10 +45,5 @@ lru = "0.7.7"
enumflags2.workspace = true
derive_more.workspace = true

# HACK: query builders need to be a separate crate, and maybe the compiler too
# HACK: we hardcode PostgreSQL as the dialect for now
sql-query-connector = { path = "../connectors/sql-query-connector", features = [
"postgresql",
] }
# HACK: this should not be in core either
quaint.workspace = true
quaint = { workspace = true, features = ["postgresql"] }
18 changes: 16 additions & 2 deletions query-engine/core/src/compiler/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,12 @@ pub mod translate;
use std::sync::Arc;

pub use expression::Expression;
use quaint::connector::ConnectionInfo;
use quaint::{
prelude::{ConnectionInfo, SqlFamily},
visitor,
};
use schema::QuerySchema;
use sql_query_builder::{Context, SqlQueryBuilder};
use thiserror::Error;
pub use translate::{translate, TranslateError};

Expand All @@ -29,6 +33,16 @@ pub fn compile(
return Err(CompileError::UnsupportedRequest.into());
};

let ctx = Context::new(connection_info, None);
let (graph, _serializer) = QueryGraphBuilder::new(query_schema).build(query)?;
Ok(translate(graph, connection_info).map_err(CompileError::from)?)
let res = match connection_info.sql_family() {
SqlFamily::Postgres => translate(graph, &SqlQueryBuilder::<visitor::Postgres<'_>>::new(ctx)),
// feature flags are disabled for now
// SqlFamily::Mysql => translate(graph, &SqlQueryBuilder::<visitor::Mysql<'_>>::new(ctx)),
// SqlFamily::Sqlite => translate(graph, &SqlQueryBuilder::<visitor::Sqlite<'_>>::new(ctx)),
// SqlFamily::Mssql => translate(graph, &SqlQueryBuilder::<visitor::Mssql<'_>>::new(ctx)),
_ => unimplemented!(),
};

Ok(res.map_err(CompileError::TranslateError)?)
}
Loading

0 comments on commit 5f0c232

Please sign in to comment.