From c54c4273cc759c8d3afb962a7c6f665b9c3e555b Mon Sep 17 00:00:00 2001 From: Larko <59736843+Larkooo@users.noreply.github.com> Date: Wed, 15 Jan 2025 00:27:59 +0700 Subject: [PATCH] feat(torii-graphql): filter on nested values (#2905) * feat(torii-graphql): filter on nested values * fmt * support truly nested types in where input * remove prints * correctly handle enum option filtering * fmt --- .../graphql/src/object/inputs/where_input.rs | 216 ++++++++++++------ crates/torii/graphql/src/object/model_data.rs | 5 +- crates/torii/graphql/src/query/data.rs | 6 +- 3 files changed, 154 insertions(+), 73 deletions(-) diff --git a/crates/torii/graphql/src/object/inputs/where_input.rs b/crates/torii/graphql/src/object/inputs/where_input.rs index 8cea27172e..d78dd454f7 100644 --- a/crates/torii/graphql/src/object/inputs/where_input.rs +++ b/crates/torii/graphql/src/object/inputs/where_input.rs @@ -16,46 +16,84 @@ use crate::types::TypeData; pub struct WhereInputObject { pub type_name: String, pub type_mapping: TypeMapping, + pub nested_inputs: Vec, } impl WhereInputObject { - // Iterate through an object's type mapping and create a new mapping for whereInput. For each of - // the object type (model member), we add 6 additional types for comparators (great than, - // not equal, etc) - pub fn new(type_name: &str, object_types: &TypeMapping) -> Self { - let where_mapping = object_types - .iter() - .filter(|(_, type_data)| !type_data.is_nested() && !type_data.is_list()) - .flat_map(|(type_name, type_data)| { - // TODO: filter on nested and enum objects - if type_data.type_ref() == TypeRef::named("Enum") - || type_data.type_ref() == TypeRef::named("bool") - { - return vec![(Name::new(type_name), type_data.clone())]; + fn build_field_mapping(type_name: &str, type_data: &TypeData) -> Vec<(Name, TypeData)> { + if type_data.type_ref() == TypeRef::named("Enum") + || type_data.type_ref() == TypeRef::named("bool") + { + return vec![(Name::new(type_name), type_data.clone())]; + } + + Comparator::iter().fold( + vec![(Name::new(type_name), type_data.clone())], + |mut acc, comparator| { + let name = format!("{}{}", type_name, comparator.as_ref()); + match comparator { + Comparator::In | Comparator::NotIn => { + acc.push((Name::new(name), TypeData::List(Box::new(type_data.clone())))) + } + _ => { + acc.push((Name::new(name), type_data.clone())); + } } + acc + }, + ) + } - Comparator::iter().fold( - vec![(Name::new(type_name), type_data.clone())], - |mut acc, comparator| { - let name = format!("{}{}", type_name, comparator.as_ref()); - - match comparator { - Comparator::In | Comparator::NotIn => acc.push(( - Name::new(name), - TypeData::List(Box::new(type_data.clone())), + pub fn new(type_name: &str, object_types: &TypeMapping) -> Self { + let mut nested_inputs = Vec::new(); + let mut where_mapping = TypeMapping::new(); + + for (field_name, type_data) in object_types { + if !type_data.is_list() { + match type_data { + TypeData::Nested((_, nested_types)) => { + // Create nested input object + let nested_input = WhereInputObject::new( + &format!("{}_{}", type_name, field_name), + nested_types, + ); + + // Add field for the nested input using TypeData::Nested + where_mapping.insert( + Name::new(field_name), + TypeData::Nested(( + TypeRef::named(&nested_input.type_name), + nested_types.clone(), )), - _ => { - acc.push((Name::new(name), type_data.clone())); - } + ); + nested_inputs.push(nested_input); + } + _ => { + // Add regular field with comparators + for (name, mapped_type) in Self::build_field_mapping(field_name, type_data) + { + where_mapping.insert(name, mapped_type); } + } + } + } + } + + Self { + type_name: format!("{}WhereInput", type_name), + type_mapping: where_mapping, + nested_inputs, + } + } +} - acc - }, - ) - }) - .collect(); - - Self { type_name: format!("{}WhereInput", type_name), type_mapping: where_mapping } +impl WhereInputObject { + pub fn input_objects(&self) -> Vec { + let mut objects = vec![self.input_object()]; + for nested in &self.nested_inputs { + objects.extend(nested.input_objects()); + } + objects } } @@ -79,6 +117,77 @@ pub fn where_argument(field: Field, type_name: &str) -> Field { field.argument(InputValue::new("where", TypeRef::named(format!("{}WhereInput", type_name)))) } +fn parse_nested_where( + input_object: &ValueAccessor<'_>, + type_name: &str, + type_data: &TypeData, +) -> Result> { + match type_data { + TypeData::Nested((_, nested_mapping)) => { + let nested_input = input_object.object()?; + nested_mapping + .iter() + .filter_map(|(field_name, field_type)| { + nested_input.get(field_name).map(|input| { + let nested_filters = parse_where_value( + input, + &format!("{}.{}", type_name, field_name), + field_type, + )?; + Ok(nested_filters) + }) + }) + .collect::>>() + .map(|filters| filters.into_iter().flatten().collect()) + } + _ => Ok(vec![]), + } +} + +fn parse_where_value( + input: ValueAccessor<'_>, + field_path: &str, + type_data: &TypeData, +) -> Result> { + match type_data { + TypeData::Simple(_) => { + if type_data.type_ref() == TypeRef::named("Enum") { + let value = input.string()?; + let mut filter = + parse_filter(&Name::new(field_path), FilterValue::String(value.to_string())); + // complex enums have a nested option field for their variant name. + // we trim the .option suffix to get the actual db field name + filter.field = filter.field.trim_end_matches(".option").to_string(); + return Ok(vec![filter]); + } + + let primitive = Primitive::from_str(&type_data.type_ref().to_string())?; + let filter_value = match primitive.to_sql_type() { + SqlType::Integer => parse_integer(input, field_path, primitive)?, + SqlType::Text => parse_string(input, field_path, primitive)?, + }; + + Ok(vec![parse_filter(&Name::new(field_path), filter_value)]) + } + TypeData::List(inner) => { + let list = input.list()?; + let values = list + .iter() + .map(|value| { + let primitive = Primitive::from_str(&inner.type_ref().to_string())?; + match primitive.to_sql_type() { + SqlType::Integer => parse_integer(value, field_path, primitive), + SqlType::Text => parse_string(value, field_path, primitive), + } + }) + .collect::>>()?; + + Ok(vec![parse_filter(&Name::new(field_path), FilterValue::List(values))]) + } + TypeData::Nested(_) => parse_nested_where(&input, field_path, type_data), + } +} + pub fn parse_where_argument( ctx: &ResolverContext<'_>, where_mapping: &TypeMapping, @@ -87,44 +196,13 @@ pub fn parse_where_argument( let input_object = where_input.object()?; where_mapping .iter() - .filter_map(|(type_name, type_data)| { - input_object.get(type_name).map(|input| match type_data { - TypeData::Simple(_) => { - if type_data.type_ref() == TypeRef::named("Enum") { - let value = input.string().unwrap(); - return Ok(Some(parse_filter( - type_name, - FilterValue::String(value.to_string()), - ))); - } - - let primitive = Primitive::from_str(&type_data.type_ref().to_string())?; - let filter_value = match primitive.to_sql_type() { - SqlType::Integer => parse_integer(input, type_name, primitive)?, - SqlType::Text => parse_string(input, type_name, primitive)?, - }; - - Ok(Some(parse_filter(type_name, filter_value))) - } - TypeData::List(inner) => { - let list = input.list()?; - let values = list - .iter() - .map(|value| { - let primitive = Primitive::from_str(&inner.type_ref().to_string())?; - match primitive.to_sql_type() { - SqlType::Integer => parse_integer(value, type_name, primitive), - SqlType::Text => parse_string(value, type_name, primitive), - } - }) - .collect::>>()?; - - Ok(Some(parse_filter(type_name, FilterValue::List(values)))) - } - _ => Err(GqlError::new("Nested types are not supported")), - }) + .filter_map(|(field_name, type_data)| { + input_object + .get(field_name) + .map(|input| parse_where_value(input, field_name, type_data)) }) - .collect::>>>() + .collect::>>() + .map(|filters| Some(filters.into_iter().flatten().collect())) }) } diff --git a/crates/torii/graphql/src/object/model_data.rs b/crates/torii/graphql/src/object/model_data.rs index 8f6f0d0bc7..c1d2a89150 100644 --- a/crates/torii/graphql/src/object/model_data.rs +++ b/crates/torii/graphql/src/object/model_data.rs @@ -74,7 +74,10 @@ impl BasicObject for ModelDataObject { impl ResolvableObject for ModelDataObject { fn input_objects(&self) -> Option> { - Some(vec![self.where_input.input_object(), self.order_input.input_object()]) + let mut objects = vec![]; + objects.extend(self.where_input.input_objects()); + objects.push(self.order_input.input_object()); + Some(objects) } fn enum_objects(&self) -> Option> { diff --git a/crates/torii/graphql/src/query/data.rs b/crates/torii/graphql/src/query/data.rs index 8a165e6cee..83a5617621 100644 --- a/crates/torii/graphql/src/query/data.rs +++ b/crates/torii/graphql/src/query/data.rs @@ -222,8 +222,8 @@ fn build_conditions(keys: &Option>, filters: &Option>) - if let Some(filters) = filters { conditions.extend(filters.iter().map(|filter| match &filter.value { - FilterValue::Int(i) => format!("{} {} {}", filter.field, filter.comparator, i), - FilterValue::String(s) => format!("{} {} '{}'", filter.field, filter.comparator, s), + FilterValue::Int(i) => format!("[{}] {} {}", filter.field, filter.comparator, i), + FilterValue::String(s) => format!("[{}] {} '{}'", filter.field, filter.comparator, s), FilterValue::List(list) => { let values = list .iter() @@ -234,7 +234,7 @@ fn build_conditions(keys: &Option>, filters: &Option>) - }) .collect::>() .join(", "); - format!("{} {} ({})", filter.field, filter.comparator, values) + format!("[{}] {} ({})", filter.field, filter.comparator, values) } })); }