Skip to content

Commit

Permalink
feat(torii-graphql): filter on nested values (#2905)
Browse files Browse the repository at this point in the history
* feat(torii-graphql): filter on nested values

* fmt

* support truly nested types in where input

* remove prints

* correctly handle enum option filtering

* fmt
  • Loading branch information
Larkooo authored Jan 14, 2025
1 parent 1ca1257 commit c54c427
Show file tree
Hide file tree
Showing 3 changed files with 154 additions and 73 deletions.
216 changes: 147 additions & 69 deletions crates/torii/graphql/src/object/inputs/where_input.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,46 +16,84 @@ use crate::types::TypeData;
pub struct WhereInputObject {
pub type_name: String,
pub type_mapping: TypeMapping,
pub nested_inputs: Vec<WhereInputObject>,
}

impl WhereInputObject {
// Iterate through an object's type mapping and create a new mapping for whereInput. For each of
// the object type (model member), we add 6 additional types for comparators (great than,
// not equal, etc)
pub fn new(type_name: &str, object_types: &TypeMapping) -> Self {
let where_mapping = object_types
.iter()
.filter(|(_, type_data)| !type_data.is_nested() && !type_data.is_list())
.flat_map(|(type_name, type_data)| {
// TODO: filter on nested and enum objects
if type_data.type_ref() == TypeRef::named("Enum")
|| type_data.type_ref() == TypeRef::named("bool")
{
return vec![(Name::new(type_name), type_data.clone())];
fn build_field_mapping(type_name: &str, type_data: &TypeData) -> Vec<(Name, TypeData)> {
if type_data.type_ref() == TypeRef::named("Enum")
|| type_data.type_ref() == TypeRef::named("bool")
{
return vec![(Name::new(type_name), type_data.clone())];
}

Comparator::iter().fold(
vec![(Name::new(type_name), type_data.clone())],
|mut acc, comparator| {
let name = format!("{}{}", type_name, comparator.as_ref());
match comparator {
Comparator::In | Comparator::NotIn => {
acc.push((Name::new(name), TypeData::List(Box::new(type_data.clone()))))
}
_ => {
acc.push((Name::new(name), type_data.clone()));
}
}
acc
},
)
}

Comparator::iter().fold(
vec![(Name::new(type_name), type_data.clone())],
|mut acc, comparator| {
let name = format!("{}{}", type_name, comparator.as_ref());

match comparator {
Comparator::In | Comparator::NotIn => acc.push((
Name::new(name),
TypeData::List(Box::new(type_data.clone())),
pub fn new(type_name: &str, object_types: &TypeMapping) -> Self {
let mut nested_inputs = Vec::new();
let mut where_mapping = TypeMapping::new();

for (field_name, type_data) in object_types {
if !type_data.is_list() {
match type_data {
TypeData::Nested((_, nested_types)) => {
// Create nested input object
let nested_input = WhereInputObject::new(
&format!("{}_{}", type_name, field_name),
nested_types,
);

// Add field for the nested input using TypeData::Nested
where_mapping.insert(
Name::new(field_name),
TypeData::Nested((
TypeRef::named(&nested_input.type_name),
nested_types.clone(),
)),
_ => {
acc.push((Name::new(name), type_data.clone()));
}
);
nested_inputs.push(nested_input);
}
_ => {
// Add regular field with comparators
for (name, mapped_type) in Self::build_field_mapping(field_name, type_data)
{
where_mapping.insert(name, mapped_type);
}
}
}
}
}

Self {
type_name: format!("{}WhereInput", type_name),
type_mapping: where_mapping,
nested_inputs,
}
}
}

acc
},
)
})
.collect();

Self { type_name: format!("{}WhereInput", type_name), type_mapping: where_mapping }
impl WhereInputObject {
pub fn input_objects(&self) -> Vec<InputObject> {
let mut objects = vec![self.input_object()];
for nested in &self.nested_inputs {
objects.extend(nested.input_objects());
}
objects
}
}

Expand All @@ -79,6 +117,77 @@ pub fn where_argument(field: Field, type_name: &str) -> Field {
field.argument(InputValue::new("where", TypeRef::named(format!("{}WhereInput", type_name))))
}

fn parse_nested_where(
input_object: &ValueAccessor<'_>,
type_name: &str,
type_data: &TypeData,
) -> Result<Vec<Filter>> {
match type_data {
TypeData::Nested((_, nested_mapping)) => {
let nested_input = input_object.object()?;
nested_mapping
.iter()
.filter_map(|(field_name, field_type)| {
nested_input.get(field_name).map(|input| {
let nested_filters = parse_where_value(
input,
&format!("{}.{}", type_name, field_name),
field_type,
)?;
Ok(nested_filters)
})
})
.collect::<Result<Vec<_>>>()
.map(|filters| filters.into_iter().flatten().collect())
}
_ => Ok(vec![]),
}
}

fn parse_where_value(
input: ValueAccessor<'_>,
field_path: &str,
type_data: &TypeData,
) -> Result<Vec<Filter>> {
match type_data {
TypeData::Simple(_) => {
if type_data.type_ref() == TypeRef::named("Enum") {
let value = input.string()?;
let mut filter =
parse_filter(&Name::new(field_path), FilterValue::String(value.to_string()));
// complex enums have a nested option field for their variant name.
// we trim the .option suffix to get the actual db field name
filter.field = filter.field.trim_end_matches(".option").to_string();
return Ok(vec![filter]);
}

let primitive = Primitive::from_str(&type_data.type_ref().to_string())?;
let filter_value = match primitive.to_sql_type() {
SqlType::Integer => parse_integer(input, field_path, primitive)?,
SqlType::Text => parse_string(input, field_path, primitive)?,
};

Ok(vec![parse_filter(&Name::new(field_path), filter_value)])
}
TypeData::List(inner) => {
let list = input.list()?;
let values = list
.iter()
.map(|value| {
let primitive = Primitive::from_str(&inner.type_ref().to_string())?;
match primitive.to_sql_type() {
SqlType::Integer => parse_integer(value, field_path, primitive),
SqlType::Text => parse_string(value, field_path, primitive),
}
})
.collect::<Result<Vec<_>>>()?;

Ok(vec![parse_filter(&Name::new(field_path), FilterValue::List(values))])
}
TypeData::Nested(_) => parse_nested_where(&input, field_path, type_data),
}
}

pub fn parse_where_argument(
ctx: &ResolverContext<'_>,
where_mapping: &TypeMapping,
Expand All @@ -87,44 +196,13 @@ pub fn parse_where_argument(
let input_object = where_input.object()?;
where_mapping
.iter()
.filter_map(|(type_name, type_data)| {
input_object.get(type_name).map(|input| match type_data {
TypeData::Simple(_) => {
if type_data.type_ref() == TypeRef::named("Enum") {
let value = input.string().unwrap();
return Ok(Some(parse_filter(
type_name,
FilterValue::String(value.to_string()),
)));
}

let primitive = Primitive::from_str(&type_data.type_ref().to_string())?;
let filter_value = match primitive.to_sql_type() {
SqlType::Integer => parse_integer(input, type_name, primitive)?,
SqlType::Text => parse_string(input, type_name, primitive)?,
};

Ok(Some(parse_filter(type_name, filter_value)))
}
TypeData::List(inner) => {
let list = input.list()?;
let values = list
.iter()
.map(|value| {
let primitive = Primitive::from_str(&inner.type_ref().to_string())?;
match primitive.to_sql_type() {
SqlType::Integer => parse_integer(value, type_name, primitive),
SqlType::Text => parse_string(value, type_name, primitive),
}
})
.collect::<Result<Vec<_>>>()?;

Ok(Some(parse_filter(type_name, FilterValue::List(values))))
}
_ => Err(GqlError::new("Nested types are not supported")),
})
.filter_map(|(field_name, type_data)| {
input_object
.get(field_name)
.map(|input| parse_where_value(input, field_name, type_data))
})
.collect::<Result<Option<Vec<_>>>>()
.collect::<Result<Vec<_>>>()
.map(|filters| Some(filters.into_iter().flatten().collect()))
})
}

Expand Down
5 changes: 4 additions & 1 deletion crates/torii/graphql/src/object/model_data.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,10 @@ impl BasicObject for ModelDataObject {

impl ResolvableObject for ModelDataObject {
fn input_objects(&self) -> Option<Vec<InputObject>> {
Some(vec![self.where_input.input_object(), self.order_input.input_object()])
let mut objects = vec![];
objects.extend(self.where_input.input_objects());
objects.push(self.order_input.input_object());
Some(objects)
}

fn enum_objects(&self) -> Option<Vec<Enum>> {
Expand Down
6 changes: 3 additions & 3 deletions crates/torii/graphql/src/query/data.rs
Original file line number Diff line number Diff line change
Expand Up @@ -222,8 +222,8 @@ fn build_conditions(keys: &Option<Vec<String>>, filters: &Option<Vec<Filter>>) -

if let Some(filters) = filters {
conditions.extend(filters.iter().map(|filter| match &filter.value {
FilterValue::Int(i) => format!("{} {} {}", filter.field, filter.comparator, i),
FilterValue::String(s) => format!("{} {} '{}'", filter.field, filter.comparator, s),
FilterValue::Int(i) => format!("[{}] {} {}", filter.field, filter.comparator, i),
FilterValue::String(s) => format!("[{}] {} '{}'", filter.field, filter.comparator, s),
FilterValue::List(list) => {
let values = list
.iter()
Expand All @@ -234,7 +234,7 @@ fn build_conditions(keys: &Option<Vec<String>>, filters: &Option<Vec<Filter>>) -
})
.collect::<Vec<_>>()
.join(", ");
format!("{} {} ({})", filter.field, filter.comparator, values)
format!("[{}] {} ({})", filter.field, filter.comparator, values)
}
}));
}
Expand Down

0 comments on commit c54c427

Please sign in to comment.