From 2cb0bf58a6216bd9c4c67dbe96a6a0d4057062ec Mon Sep 17 00:00:00 2001 From: broody Date: Mon, 20 Nov 2023 14:03:26 -0800 Subject: [PATCH] map rows to proto model --- crates/dojo-types/src/primitive.rs | 28 +++--- crates/torii/core/src/cache.rs | 10 +-- crates/torii/core/src/model.rs | 60 ++++++------- crates/torii/grpc/proto/types.proto | 12 ++- crates/torii/grpc/src/server/mod.rs | 129 +++++++++++++++++++++++----- scripts/rust_fmt.sh | 2 +- 6 files changed, 166 insertions(+), 75 deletions(-) diff --git a/crates/dojo-types/src/primitive.rs b/crates/dojo-types/src/primitive.rs index dbb266df86..b146cb2521 100644 --- a/crates/dojo-types/src/primitive.rs +++ b/crates/dojo-types/src/primitive.rs @@ -63,7 +63,7 @@ macro_rules! set_primitive { } /// Macro to generate getter methods for Primitive enum variants. -macro_rules! get_primitive { +macro_rules! as_primitive { ($method_name:ident, $variant:ident, $type:ty) => { /// If the `Primitive` is type T, returns the associated [`T`]. Returns `None` otherwise. pub fn $method_name(&self) -> Option<$type> { @@ -76,17 +76,17 @@ macro_rules! get_primitive { } impl Primitive { - get_primitive!(as_u8, U8, u8); - get_primitive!(as_u16, U16, u16); - get_primitive!(as_u32, U32, u32); - get_primitive!(as_u64, U64, u64); - get_primitive!(as_u128, U128, u128); - get_primitive!(as_u256, U256, U256); - get_primitive!(as_bool, Bool, bool); - get_primitive!(as_usize, USize, u32); - get_primitive!(as_felt252, Felt252, FieldElement); - get_primitive!(as_class_hash, ClassHash, FieldElement); - get_primitive!(as_contract_address, ContractAddress, FieldElement); + as_primitive!(as_u8, U8, u8); + as_primitive!(as_u16, U16, u16); + as_primitive!(as_u32, U32, u32); + as_primitive!(as_u64, U64, u64); + as_primitive!(as_u128, U128, u128); + as_primitive!(as_u256, U256, U256); + as_primitive!(as_bool, Bool, bool); + as_primitive!(as_usize, USize, u32); + as_primitive!(as_felt252, Felt252, FieldElement); + as_primitive!(as_class_hash, ClassHash, FieldElement); + as_primitive!(as_contract_address, ContractAddress, FieldElement); set_primitive!(set_u8, U8, u8); set_primitive!(set_u16, U16, u16); @@ -285,8 +285,8 @@ mod tests { let mut deserialized = primitive; deserialized.deserialize(&mut serialized.clone()).unwrap(); - assert_eq!(sql_value, "0xaaaaaaaaaaaaaaaabbbbbbbbbbbbbbbbccccccccccccccccdddddddddddddddd"); - assert_eq!( + assert_eq!(sql_value, + "0xaaaaaaaaaaaaaaaabbbbbbbbbbbbbbbbccccccccccccccccdddddddddddddddd"); assert_eq!( serialized, vec![ FieldElement::from_str("0xccccccccccccccccdddddddddddddddd").unwrap(), diff --git a/crates/torii/core/src/cache.rs b/crates/torii/core/src/cache.rs index a67ba5befc..0296b46d25 100644 --- a/crates/torii/core/src/cache.rs +++ b/crates/torii/core/src/cache.rs @@ -26,8 +26,8 @@ impl ModelCache { pub async fn schema(&self, model: &str) -> Result, Error> { { let schemas = self.schemas.read().await; - if let Some(schema_data) = schemas.get(model) { - return Ok(Arc::clone(schema_data)); + if let Some(schema) = schemas.get(model) { + return Ok(Arc::clone(schema)); } } @@ -49,12 +49,12 @@ impl ModelCache { let ty = parse_sql_model_members(model, &model_members); let sql = build_sql_model_query(ty.as_struct().unwrap()); - let schema_data = Arc::new(SchemaData { ty, sql }); + let schema = Arc::new(SchemaData { ty, sql }); let mut schemas = self.schemas.write().await; - schemas.insert(model.into(), Arc::clone(&schema_data)); + schemas.insert(model.into(), Arc::clone(&schema)); - Ok(schema_data) + Ok(schema) } pub async fn clear(&self) { diff --git a/crates/torii/core/src/model.rs b/crates/torii/core/src/model.rs index 41cf9b4af7..433cd29cb8 100644 --- a/crates/torii/core/src/model.rs +++ b/crates/torii/core/src/model.rs @@ -150,7 +150,7 @@ pub fn parse_sql_model_members(model: &str, model_members_all: &[SqlModelMember] parse_sql_model_members_impl(model, model_members_all) } -/// A helper function to build a model query including all nested structs +/// A helper function to build a model query including all nested structs and its the entity id pub fn build_sql_model_query(schema: &Struct) -> String { fn build_sql_model_query_impl( path: &str, @@ -198,45 +198,45 @@ pub fn build_sql_model_query(schema: &Struct) -> String { } /// Converts SQLite rows into a vector of `Ty` based on a specified schema. -pub fn map_rows_to_tys(schema: &Struct, rows: &[SqliteRow]) -> Result, Error> { +pub fn map_rows_to_tys(schema: &Ty, rows: &[SqliteRow]) -> Result, Error> { fn populate_struct_from_row( path: &str, struct_ty: &mut Struct, row: &SqliteRow, ) -> Result<(), Error> { - for child in struct_ty.children.iter_mut() { - let column_name = format!("{}.{}", path, child.name); - match &mut child.ty { - Ty::Primitive(p) => { - match &p { + for member in struct_ty.children.iter_mut() { + let column_name = format!("{}.{}", path, member.name); + match &mut member.ty { + Ty::Primitive(primitive) => { + match &primitive { Primitive::Bool(_) => { - let value = row.try_get::(&column_name)?; - p.set_bool(Some(value == 1))?; + let value = row.try_get::(&column_name)?; + primitive.set_bool(Some(value))?; } Primitive::USize(_) => { - let value = row.try_get::(&column_name)?; - p.set_usize(Some(value as u32))?; + let value = row.try_get::(&column_name)?; + primitive.set_usize(Some(value))?; } Primitive::U8(_) => { - let value = row.try_get::(&column_name)?; - p.set_u8(Some(value as u8))?; + let value = row.try_get::(&column_name)?; + primitive.set_u8(Some(value))?; } Primitive::U16(_) => { - let value = row.try_get::(&column_name)?; - p.set_u16(Some(value as u16))?; + let value = row.try_get::(&column_name)?; + primitive.set_u16(Some(value))?; } Primitive::U32(_) => { - let value = row.try_get::(&column_name)?; - p.set_u32(Some(value as u32))?; + let value = row.try_get::(&column_name)?; + primitive.set_u32(Some(value))?; } Primitive::U64(_) => { let value = row.try_get::(&column_name)?; - p.set_u64(Some(value as u64))?; + primitive.set_u64(Some(value as u64))?; } Primitive::U128(_) => { let value = row.try_get::(&column_name)?; let hex_str = value.trim_start_matches("0x"); - p.set_u128(Some( + primitive.set_u128(Some( u128::from_str_radix(hex_str, 16) .map_err(ParseError::ParseIntError)?, ))?; @@ -244,35 +244,35 @@ pub fn map_rows_to_tys(schema: &Struct, rows: &[SqliteRow]) -> Result, E Primitive::U256(_) => { let value = row.try_get::(&column_name)?; let hex_str = value.trim_start_matches("0x"); - p.set_u256(Some(U256::from_be_hex(hex_str)))?; + primitive.set_u256(Some(U256::from_be_hex(hex_str)))?; } Primitive::Felt252(_) => { let value = row.try_get::(&column_name)?; - p.set_felt252(Some( + primitive.set_felt252(Some( FieldElement::from_str(&value).map_err(ParseError::FromStr)?, ))?; } Primitive::ClassHash(_) => { let value = row.try_get::(&column_name)?; - p.set_class_hash(Some( + primitive.set_class_hash(Some( FieldElement::from_str(&value).map_err(ParseError::FromStr)?, ))?; } Primitive::ContractAddress(_) => { let value = row.try_get::(&column_name)?; - p.set_contract_address(Some( + primitive.set_contract_address(Some( FieldElement::from_str(&value).map_err(ParseError::FromStr)?, ))?; } }; } - Ty::Enum(e) => { + Ty::Enum(enum_ty) => { let value = row.try_get::(&column_name)?; - e.set_option(&value)?; + enum_ty.set_option(&value)?; } - Ty::Struct(nested) => { - let path = [path, &nested.name].join("$"); - populate_struct_from_row(&path, nested, row)?; + Ty::Struct(struct_ty) => { + let path = [path, &struct_ty.name].join("$"); + populate_struct_from_row(&path, struct_ty, row)?; } ty => { unimplemented!("unimplemented type_enum: {ty}"); @@ -285,8 +285,8 @@ pub fn map_rows_to_tys(schema: &Struct, rows: &[SqliteRow]) -> Result, E rows.iter() .map(|row| { - let mut struct_ty = schema.clone(); - populate_struct_from_row(&schema.name, &mut struct_ty, row)?; + let mut struct_ty = schema.as_struct().expect("schema should be struct ty").clone(); + populate_struct_from_row(&schema.name(), &mut struct_ty, row)?; Ok(Ty::Struct(struct_ty)) }) diff --git a/crates/torii/grpc/proto/types.proto b/crates/torii/grpc/proto/types.proto index d9dab2f1f3..1dbb7b86e9 100644 --- a/crates/torii/grpc/proto/types.proto +++ b/crates/torii/grpc/proto/types.proto @@ -29,13 +29,21 @@ message ModelMetadata { bytes schema = 6; } +message Enum { + // Variant + uint32 option = 1; + // Variants of the enum + repeated string options = 2; +} + message Member { // Name of the member string name = 1; // Type of value - oneof value_type { + oneof member_type { Value value = 2; - Model struct = 3; + Enum enum = 3; + Model struct = 4; } } diff --git a/crates/torii/grpc/src/server/mod.rs b/crates/torii/grpc/src/server/mod.rs index 3d0674f546..05279c820c 100644 --- a/crates/torii/grpc/src/server/mod.rs +++ b/crates/torii/grpc/src/server/mod.rs @@ -7,12 +7,15 @@ use std::net::SocketAddr; use std::pin::Pin; use std::sync::Arc; +use dojo_types::primitive::Primitive; +use dojo_types::schema::{Struct, Ty}; use futures::Stream; use proto::world::{ MetadataRequest, MetadataResponse, RetrieveEntitiesRequest, RetrieveEntitiesResponse, SubscribeEntitiesRequest, SubscribeEntitiesResponse, }; -use sqlx::{Pool, Sqlite}; +use sqlx::sqlite::SqliteRow; +use sqlx::{Pool, Row, Sqlite}; use starknet::core::utils::cairo_short_string_to_felt; use starknet::providers::jsonrpc::HttpTransport; use starknet::providers::JsonRpcClient; @@ -24,7 +27,6 @@ use tonic::transport::Server; use tonic::{Request, Response, Status}; use torii_core::cache::ModelCache; use torii_core::error::{Error, ParseError, QueryError}; -use torii_core::model::map_rows_to_tys; use self::subscription::SubscribeRequest; use crate::proto::types::clause::ClauseType; @@ -84,14 +86,14 @@ impl DojoWorld { let mut models_metadata = Vec::with_capacity(models.len()); for model in models { - let schema_data = self.model_cache.schema(&model.0).await?; + let schema = self.model_cache.schema(&model.0).await?; models_metadata.push(proto::types::ModelMetadata { name: model.0, class_hash: model.1, packed_size: model.2, unpacked_size: model.3, layout: hex::decode(&model.4).unwrap(), - schema: serde_json::to_vec(&schema_data.ty).unwrap(), + schema: serde_json::to_vec(&schema.ty).unwrap(), }); } @@ -115,25 +117,11 @@ impl DojoWorld { &self, attribute: proto::types::AttributeClause, ) -> Result, Error> { - let schema_data = self.model_cache.schema(&attribute.model).await?; - let results = sqlx::query(&schema_data.sql).fetch_all(&self.pool).await?; - let tys = map_rows_to_tys(schema_data.ty.as_struct().unwrap(), &results)?; - - let mut entities = Vec::with_capacity(tys.len()); - - for ty in tys { - entities.push(proto::types::Entity { - key: "".to_string(), - models: vec![ - proto::types::Model { - name: ty.name(), - data: serde_json::to_vec(&ty).unwrap() - } - ] - }) - } - - Ok(vec![]) + let schema = self.model_cache.schema(&attribute.model).await?; + let rows = sqlx::query(&schema.sql).fetch_all(&self.pool).await?; + let models = self.map_rows_to_models(&schema.ty, &rows).await?; + + Ok(vec![proto::types::Entity { key: "".to_string(), models }]) } async fn entities_by_composite( @@ -213,6 +201,101 @@ impl DojoWorld { Ok(RetrieveEntitiesResponse { entities }) } + + async fn map_rows_to_models( + &self, + schema: &Ty, + rows: &[SqliteRow], + ) -> Result, Error> { + fn row_to_model( + path: &str, + struct_ty: &Struct, + row: &SqliteRow, + ) -> Result { + let members = struct_ty + .children + .iter() + .map(|member| { + let column_name = format!("{}.{}", path, member.name); + let name = member.name.clone(); + let member = match &member.ty { + Ty::Primitive(primitive) => { + let value_type = match primitive { + Primitive::Bool(_) => proto::types::value::ValueType::BoolValue( + row.try_get::(&column_name)?, + ), + Primitive::U8(_) + | Primitive::U16(_) + | Primitive::U32(_) + | Primitive::U64(_) + | Primitive::USize(_) => { + let value = row.try_get::(&column_name)?; + proto::types::value::ValueType::UintValue(value as u64) + } + Primitive::U128(_) + | Primitive::U256(_) + | Primitive::Felt252(_) + | Primitive::ClassHash(_) + | Primitive::ContractAddress(_) => { + let value = row.try_get::(&column_name)?; + proto::types::value::ValueType::StringValue(value) + } + }; + + proto::types::Member { + name, + member_type: Some(proto::types::member::MemberType::Value( + proto::types::Value { value_type: Some(value_type) }, + )), + } + } + + Ty::Enum(enum_ty) => { + let value = row.try_get::(&column_name)?; + let options = enum_ty + .options + .iter() + .map(|e| e.name.to_string()) + .collect::>(); + let option = + options.iter().position(|o| o == &value).expect("wrong enum value") + as u32; + proto::types::Member { + name: enum_ty.name.clone(), + member_type: Some(proto::types::member::MemberType::Enum( + proto::types::Enum { option, options }, + )), + } + } + Ty::Struct(struct_ty) => { + let path = [path, &struct_ty.name].join("$"); + proto::types::Member { + name, + member_type: Some(proto::types::member::MemberType::Struct( + row_to_model(&path, struct_ty, row)?, + )), + } + } + ty => { + unimplemented!("unimplemented type_enum: {ty}"); + } + }; + + Ok(member) + }) + .collect::, Error>>()?; + + Ok(proto::types::Model { name: struct_ty.name.clone(), members }) + } + + rows.iter() + .map(|row| { + let struct_ty = schema.as_struct().expect("schema should be struct ty").clone(); + + row_to_model(&schema.name(), &struct_ty, row) + }) + .collect::, Error>>() + } } type ServiceResult = Result, Status>; diff --git a/scripts/rust_fmt.sh b/scripts/rust_fmt.sh index 095c105201..62a418693a 100755 --- a/scripts/rust_fmt.sh +++ b/scripts/rust_fmt.sh @@ -1,3 +1,3 @@ #!/bin/bash -cargo +nightly fmt --all -- "$@" +cargo +nightly fmt --check --all -- "$@"