diff --git a/Cargo.lock b/Cargo.lock index c7d80ecd4a..c86692ab43 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -8300,6 +8300,7 @@ dependencies = [ "base64 0.21.5", "camino", "chrono", + "crypto-bigint", "dojo-test-utils", "dojo-types", "dojo-world", diff --git a/crates/dojo-types/src/primitive.rs b/crates/dojo-types/src/primitive.rs index d5a159bad9..2e91d0fdbb 100644 --- a/crates/dojo-types/src/primitive.rs +++ b/crates/dojo-types/src/primitive.rs @@ -33,6 +33,8 @@ pub enum PrimitiveError { NotEnoughFieldElements, #[error("Unsupported CairoType for SQL formatting")] UnsupportedType, + #[error("Set value type mismatch")] + TypeMismatch, #[error(transparent)] ValueOutOfRange(#[from] ValueOutOfRangeError), } @@ -44,97 +46,60 @@ pub enum SqlType { Text, } -impl Primitive { - /// If the `Primitive` is a u8, returns the associated [`u8`]. Returns `None` otherwise. - pub fn as_u8(&self) -> Option { - match self { - Primitive::U8(value) => *value, - _ => None, - } - } - - /// If the `Primitive` is a u16, returns the associated [`u16`]. Returns `None` otherwise. - pub fn as_u16(&self) -> Option { - match self { - Primitive::U16(value) => *value, - _ => None, - } - } - - /// If the `Primitive` is a u32, returns the associated [`u32`]. Returns `None` otherwise. - pub fn as_u32(&self) -> Option { - match self { - Primitive::U32(value) => *value, - _ => None, - } - } - - /// If the `Primitive` is a u64, returns the associated [`u64`]. Returns `None` otherwise. - pub fn as_u64(&self) -> Option { - match self { - Primitive::U64(value) => *value, - _ => None, - } - } - - /// If the `Primitive` is a u128, returns the associated [`u128`]. Returns `None` otherwise. - pub fn as_u128(&self) -> Option { - match self { - Primitive::U128(value) => *value, - _ => None, - } - } - - /// If the `Primitive` is a u256, returns the associated [`U256`]. Returns `None` otherwise. - pub fn as_u256(&self) -> Option { - match self { - Primitive::U256(value) => *value, - _ => None, - } - } - - /// If the `Primitive` is a felt252, returns the associated [`FieldElement`]. Returns `None` - /// otherwise. - pub fn as_felt252(&self) -> Option { - match self { - Primitive::Felt252(value) => *value, - _ => None, - } - } - - /// If the `Primitive` is a ClassHash, returns the associated [`FieldElement`]. Returns `None` - /// otherwise. - pub fn as_class_hash(&self) -> Option { - match self { - Primitive::ClassHash(value) => *value, - _ => None, +/// Macro to generate setter methods for Primitive enum variants. +macro_rules! set_primitive { + ($method_name:ident, $variant:ident, $type:ty) => { + /// Sets the inner value of the `Primitive` enum if variant matches. + pub fn $method_name(&mut self, value: Option<$type>) -> Result<(), PrimitiveError> { + match self { + Primitive::$variant(_) => { + *self = Primitive::$variant(value); + Ok(()) + } + _ => Err(PrimitiveError::TypeMismatch), + } } - } + }; +} - /// If the `Primitive` is a ContractAddress, returns the associated [`FieldElement`]. Returns - /// `None` otherwise. - pub fn as_contract_address(&self) -> Option { - match self { - Primitive::ContractAddress(value) => *value, - _ => None, +/// Macro to generate getter methods for Primitive enum variants. +macro_rules! as_primitive { + ($method_name:ident, $variant:ident, $type:ty) => { + /// If the `Primitive` is variant type, returns the associated vartiant value. Returns + /// `None` otherwise. + pub fn $method_name(&self) -> Option<$type> { + match self { + Primitive::$variant(value) => *value, + _ => None, + } } - } + }; +} - /// If the `Primitive` is a usize, returns the associated [`u32`]. Returns `None` otherwise. - pub fn as_usize(&self) -> Option { - match self { - Primitive::USize(value) => *value, - _ => None, - } - } +impl Primitive { + as_primitive!(as_u8, U8, u8); + as_primitive!(as_u16, U16, u16); + as_primitive!(as_u32, U32, u32); + as_primitive!(as_u64, U64, u64); + as_primitive!(as_u128, U128, u128); + as_primitive!(as_u256, U256, U256); + as_primitive!(as_bool, Bool, bool); + as_primitive!(as_usize, USize, u32); + as_primitive!(as_felt252, Felt252, FieldElement); + as_primitive!(as_class_hash, ClassHash, FieldElement); + as_primitive!(as_contract_address, ContractAddress, FieldElement); - /// If the `Primitive` is a bool, returns the associated [`bool`]. Returns `None` otherwise. - pub fn as_bool(&self) -> Option { - match self { - Primitive::Bool(value) => *value, - _ => None, - } - } + set_primitive!(set_u8, U8, u8); + set_primitive!(set_u16, U16, u16); + set_primitive!(set_u32, U32, u32); + set_primitive!(set_u64, U64, u64); + set_primitive!(set_u128, U128, u128); + set_primitive!(set_u256, U256, U256); + set_primitive!(set_bool, Bool, bool); + set_primitive!(set_usize, USize, u32); + set_primitive!(set_felt252, Felt252, FieldElement); + set_primitive!(set_class_hash, ClassHash, FieldElement); + set_primitive!(set_contract_address, ContractAddress, FieldElement); pub fn to_sql_type(&self) -> SqlType { match self { @@ -333,28 +298,39 @@ mod tests { } #[test] - fn as_inner_value() { - let primitive = Primitive::U8(Some(1u8)); + fn inner_value_getter_setter() { + let mut primitive = Primitive::U8(None); + primitive.set_u8(Some(1u8)).unwrap(); assert_eq!(primitive.as_u8(), Some(1u8)); - let primitive = Primitive::U16(Some(1u16)); + let mut primitive = Primitive::U16(None); + primitive.set_u16(Some(1u16)).unwrap(); assert_eq!(primitive.as_u16(), Some(1u16)); - let primitive = Primitive::U32(Some(1u32)); + let mut primitive = Primitive::U32(None); + primitive.set_u32(Some(1u32)).unwrap(); assert_eq!(primitive.as_u32(), Some(1u32)); - let primitive = Primitive::U64(Some(1u64)); + let mut primitive = Primitive::U64(None); + primitive.set_u64(Some(1u64)).unwrap(); assert_eq!(primitive.as_u64(), Some(1u64)); - let primitive = Primitive::U128(Some(1u128)); + let mut primitive = Primitive::U128(None); + primitive.set_u128(Some(1u128)).unwrap(); assert_eq!(primitive.as_u128(), Some(1u128)); - let primitive = Primitive::U256(Some(U256::from(1u128))); + let mut primitive = Primitive::U256(None); + primitive.set_u256(Some(U256::from(1u128))).unwrap(); assert_eq!(primitive.as_u256(), Some(U256::from(1u128))); - let primitive = Primitive::USize(Some(1u32)); + let mut primitive = Primitive::USize(None); + primitive.set_usize(Some(1u32)).unwrap(); assert_eq!(primitive.as_usize(), Some(1u32)); - let primitive = Primitive::Bool(Some(true)); + let mut primitive = Primitive::Bool(None); + primitive.set_bool(Some(true)).unwrap(); assert_eq!(primitive.as_bool(), Some(true)); - let primitive = Primitive::Felt252(Some(FieldElement::from(1u128))); + let mut primitive = Primitive::Felt252(None); + primitive.set_felt252(Some(FieldElement::from(1u128))).unwrap(); assert_eq!(primitive.as_felt252(), Some(FieldElement::from(1u128))); - let primitive = Primitive::ClassHash(Some(FieldElement::from(1u128))); + let mut primitive = Primitive::ClassHash(None); + primitive.set_class_hash(Some(FieldElement::from(1u128))).unwrap(); assert_eq!(primitive.as_class_hash(), Some(FieldElement::from(1u128))); - let primitive = Primitive::ContractAddress(Some(FieldElement::from(1u128))); + let mut primitive = Primitive::ContractAddress(None); + primitive.set_contract_address(Some(FieldElement::from(1u128))).unwrap(); assert_eq!(primitive.as_contract_address(), Some(FieldElement::from(1u128))); } } diff --git a/crates/dojo-types/src/schema.rs b/crates/dojo-types/src/schema.rs index ec86efd39c..f791c7bb96 100644 --- a/crates/dojo-types/src/schema.rs +++ b/crates/dojo-types/src/schema.rs @@ -269,6 +269,16 @@ impl Enum { Ok(self.options[option].name.clone()) } + pub fn set_option(&mut self, name: &str) -> Result<(), EnumError> { + match self.options.iter().position(|option| option.name == name) { + Some(index) => { + self.option = Some(index as u8); + Ok(()) + } + None => Err(EnumError::OptionInvalid), + } + } + pub fn to_sql_value(&self) -> Result { self.option() } diff --git a/crates/torii/core/Cargo.toml b/crates/torii/core/Cargo.toml index d6020e930d..9ad2757856 100644 --- a/crates/torii/core/Cargo.toml +++ b/crates/torii/core/Cargo.toml @@ -13,6 +13,7 @@ anyhow.workspace = true async-trait.workspace = true base64.workspace = true chrono.workspace = true +crypto-bigint = { version = "0.5.3", features = [ "serde" ] } dojo-types = { path = "../../dojo-types" } dojo-world = { path = "../../dojo-world", features = [ "contracts", "manifest" ] } futures-channel = "0.3.0" diff --git a/crates/torii/core/src/cache.rs b/crates/torii/core/src/cache.rs new file mode 100644 index 0000000000..dc8d39bc63 --- /dev/null +++ b/crates/torii/core/src/cache.rs @@ -0,0 +1,56 @@ +use std::collections::HashMap; + +use dojo_types::schema::Ty; +use sqlx::SqlitePool; +use tokio::sync::RwLock; + +use crate::error::{Error, QueryError}; +use crate::model::{parse_sql_model_members, SqlModelMember}; + +type ModelName = String; + +pub struct ModelCache { + pool: SqlitePool, + schemas: RwLock>, +} + +impl ModelCache { + pub fn new(pool: SqlitePool) -> Self { + Self { pool, schemas: RwLock::new(HashMap::new()) } + } + + pub async fn schema(&self, model: &str) -> Result { + { + let schemas = self.schemas.read().await; + if let Some(schema) = schemas.get(model) { + return Ok(schema.clone()); + } + } + + self.update_schema(model).await + } + + async fn update_schema(&self, model: &str) -> Result { + let model_members: Vec = sqlx::query_as( + "SELECT id, model_idx, member_idx, name, type, type_enum, enum_options, key FROM \ + model_members WHERE model_id = ? ORDER BY model_idx ASC, member_idx ASC", + ) + .bind(model) + .fetch_all(&self.pool) + .await?; + + if model_members.is_empty() { + return Err(QueryError::ModelNotFound(model.into()).into()); + } + + let ty = parse_sql_model_members(model, &model_members); + let mut schemas = self.schemas.write().await; + schemas.insert(model.into(), ty.clone()); + + Ok(ty) + } + + pub async fn clear(&self) { + self.schemas.write().await.clear(); + } +} diff --git a/crates/torii/core/src/error.rs b/crates/torii/core/src/error.rs index cdcf6b9b95..0d73633076 100644 --- a/crates/torii/core/src/error.rs +++ b/crates/torii/core/src/error.rs @@ -1,3 +1,7 @@ +use std::num::ParseIntError; + +use dojo_types::primitive::PrimitiveError; +use dojo_types::schema::EnumError; use starknet::core::types::{FromByteSliceError, FromStrError}; use starknet::core::utils::CairoShortStringToFeltError; @@ -9,6 +13,10 @@ pub enum Error { Sql(#[from] sqlx::Error), #[error(transparent)] QueryError(#[from] QueryError), + #[error(transparent)] + PrimitiveError(#[from] PrimitiveError), + #[error(transparent)] + EnumError(#[from] EnumError), } #[derive(Debug, thiserror::Error)] @@ -19,10 +27,16 @@ pub enum ParseError { CairoShortStringToFelt(#[from] CairoShortStringToFeltError), #[error(transparent)] FromByteSliceError(#[from] FromByteSliceError), + #[error(transparent)] + ParseIntError(#[from] ParseIntError), } #[derive(Debug, thiserror::Error)] pub enum QueryError { #[error("unsupported query")] UnsupportedQuery, + #[error("model not found: {0}")] + ModelNotFound(String), + #[error("exceeds sqlite `JOIN` limit (64)")] + SqliteJoinLimit, } diff --git a/crates/torii/core/src/lib.rs b/crates/torii/core/src/lib.rs index 877aab65a9..e36c3f2e3b 100644 --- a/crates/torii/core/src/lib.rs +++ b/crates/torii/core/src/lib.rs @@ -3,6 +3,7 @@ use sqlx::FromRow; use crate::types::SQLFieldElement; +pub mod cache; pub mod engine; pub mod error; pub mod model; diff --git a/crates/torii/core/src/model.rs b/crates/torii/core/src/model.rs index e701cf28e0..c05831c664 100644 --- a/crates/torii/core/src/model.rs +++ b/crates/torii/core/src/model.rs @@ -1,10 +1,16 @@ +use std::str::FromStr; + use async_trait::async_trait; +use crypto_bigint::U256; +use dojo_types::primitive::Primitive; use dojo_types::schema::{Enum, EnumOption, Member, Struct, Ty}; use dojo_world::contracts::model::ModelReader; -use sqlx::{Pool, Sqlite}; +use sqlx::sqlite::SqliteRow; +use sqlx::{Pool, Row, Sqlite}; use starknet::core::types::FieldElement; use super::error::{self, Error}; +use crate::error::{ParseError, QueryError}; pub struct ModelSQLReader { /// The name of the model @@ -144,11 +150,154 @@ pub fn parse_sql_model_members(model: &str, model_members_all: &[SqlModelMember] parse_sql_model_members_impl(model, model_members_all) } +/// Creates a query that fetches all models and their nested data. +pub fn build_sql_query(model_schemas: &Vec) -> Result { + fn parse_struct( + path: &str, + schema: &Struct, + selections: &mut Vec, + tables: &mut Vec, + ) { + for child in &schema.children { + match &child.ty { + Ty::Struct(s) => { + let table_name = format!("{}${}", path, s.name); + parse_struct(&table_name, s, selections, tables); + + tables.push(table_name); + } + _ => { + // alias selected columns to avoid conflicts in `JOIN` + selections.push(format!( + "{}.external_{} AS \"{}.{}\"", + path, child.name, path, child.name + )); + } + } + } + } + + let primary_table = model_schemas[0].name(); + let mut global_selections = Vec::new(); + let mut global_tables = model_schemas + .iter() + .enumerate() + .filter(|(index, _)| *index != 0) // primary_table don't `JOIN` itself + .map(|(_, schema)| schema.name()) + .collect::>(); + + for ty in model_schemas { + let schema = ty.as_struct().expect("schema should be struct"); + let model_table = &schema.name; + let mut selections = Vec::new(); + let mut tables = Vec::new(); + + parse_struct(model_table, schema, &mut selections, &mut tables); + + global_selections.push(selections.join(", ")); + global_tables.extend(tables); + } + + // TODO: Fallback to subqueries, SQLite has a max limit of 64 on 'table 'JOIN' + if global_tables.len() > 64 { + return Err(QueryError::SqliteJoinLimit.into()); + } + + let selections_clause = global_selections.join(", "); + let join_clause = global_tables + .into_iter() + .map(|table| format!(" LEFT JOIN {table} ON {primary_table}.entity_id = {table}.entity_id")) + .collect::>() + .join(" "); + + Ok(format!("SELECT {selections_clause} FROM {primary_table}{join_clause}")) +} + +/// Populate the values of a Ty (schema) from SQLite row. +pub fn map_row_to_ty(path: &str, struct_ty: &mut Struct, row: &SqliteRow) -> Result<(), Error> { + for member in struct_ty.children.iter_mut() { + let column_name = format!("{}.{}", path, member.name); + match &mut member.ty { + Ty::Primitive(primitive) => { + match &primitive { + Primitive::Bool(_) => { + let value = row.try_get::(&column_name)?; + primitive.set_bool(Some(value))?; + } + Primitive::USize(_) => { + let value = row.try_get::(&column_name)?; + primitive.set_usize(Some(value))?; + } + Primitive::U8(_) => { + let value = row.try_get::(&column_name)?; + primitive.set_u8(Some(value))?; + } + Primitive::U16(_) => { + let value = row.try_get::(&column_name)?; + primitive.set_u16(Some(value))?; + } + Primitive::U32(_) => { + let value = row.try_get::(&column_name)?; + primitive.set_u32(Some(value))?; + } + Primitive::U64(_) => { + let value = row.try_get::(&column_name)?; + primitive.set_u64(Some(value as u64))?; + } + Primitive::U128(_) => { + let value = row.try_get::(&column_name)?; + let hex_str = value.trim_start_matches("0x"); + primitive.set_u128(Some( + u128::from_str_radix(hex_str, 16).map_err(ParseError::ParseIntError)?, + ))?; + } + Primitive::U256(_) => { + let value = row.try_get::(&column_name)?; + let hex_str = value.trim_start_matches("0x"); + primitive.set_u256(Some(U256::from_be_hex(hex_str)))?; + } + Primitive::Felt252(_) => { + let value = row.try_get::(&column_name)?; + primitive.set_felt252(Some( + FieldElement::from_str(&value).map_err(ParseError::FromStr)?, + ))?; + } + Primitive::ClassHash(_) => { + let value = row.try_get::(&column_name)?; + primitive.set_class_hash(Some( + FieldElement::from_str(&value).map_err(ParseError::FromStr)?, + ))?; + } + Primitive::ContractAddress(_) => { + let value = row.try_get::(&column_name)?; + primitive.set_contract_address(Some( + FieldElement::from_str(&value).map_err(ParseError::FromStr)?, + ))?; + } + }; + } + Ty::Enum(enum_ty) => { + let value = row.try_get::(&column_name)?; + enum_ty.set_option(&value)?; + } + Ty::Struct(struct_ty) => { + let path = [path, &struct_ty.name].join("$"); + map_row_to_ty(&path, struct_ty, row)?; + } + ty => { + unimplemented!("unimplemented type_enum: {ty}"); + } + }; + } + + Ok(()) +} + #[cfg(test)] mod tests { use dojo_types::schema::{Enum, EnumOption, Member, Struct, Ty}; - use super::SqlModelMember; + use super::{build_sql_query, SqlModelMember}; use crate::model::parse_sql_model_members; #[test] @@ -321,4 +470,51 @@ mod tests { assert_eq!(parse_sql_model_members("Moves", &model_members), expected_ty); } + + #[test] + fn struct_ty_to_query() { + let ty = Ty::Struct(Struct { + name: "Position".into(), + children: vec![ + dojo_types::schema::Member { + name: "name".into(), + key: false, + ty: Ty::Primitive("felt252".parse().unwrap()), + }, + dojo_types::schema::Member { + name: "age".into(), + key: false, + ty: Ty::Primitive("u8".parse().unwrap()), + }, + dojo_types::schema::Member { + name: "vec".into(), + key: false, + ty: Ty::Struct(Struct { + name: "Vec2".into(), + children: vec![ + Member { + name: "x".into(), + key: false, + ty: Ty::Primitive("u256".parse().unwrap()), + }, + Member { + name: "y".into(), + key: false, + ty: Ty::Primitive("u256".parse().unwrap()), + }, + ], + }), + }, + ], + }); + + let query = build_sql_query(&vec![ty]).unwrap(); + assert_eq!( + query, + "SELECT Position.external_name AS \"Position.name\", Position.external_age AS \ + \"Position.age\", Position$Vec2.external_x AS \"Position$Vec2.x\", \ + Position$Vec2.external_y AS \"Position$Vec2.y\" FROM Position LEFT JOIN \ + Position$Vec2 ON Position.entity_id = Position$Vec2.entity_id" + ); + } } diff --git a/crates/torii/graphql/src/tests/entities_test.rs b/crates/torii/graphql/src/tests/entities_test.rs index a2119ba986..3f7c807828 100644 --- a/crates/torii/graphql/src/tests/entities_test.rs +++ b/crates/torii/graphql/src/tests/entities_test.rs @@ -92,7 +92,7 @@ mod tests { assert_eq!(connection.edges.len(), 10); assert_eq!(connection.total_count, 20); assert_eq!(&first_entity.node.model_names, "Subrecord"); - assert_eq!(&last_entity.node.model_names, "Record"); + assert_eq!(&last_entity.node.model_names, "Record,RecordSibling"); // first key param - returns all entities with `0x0` as first key let entities = entities_query(&schema, "(keys: [\"0x0\"])").await; diff --git a/crates/torii/graphql/src/tests/models_test.rs b/crates/torii/graphql/src/tests/models_test.rs index 0c6e3ca181..de28a4ea49 100644 --- a/crates/torii/graphql/src/tests/models_test.rs +++ b/crates/torii/graphql/src/tests/models_test.rs @@ -86,7 +86,7 @@ mod tests { assert_eq!(connection.total_count, 10); assert_eq!(connection.edges.len(), 10); assert_eq!(&record.node.__typename, "Record"); - assert_eq!(&entity.model_names, "Record"); + assert_eq!(&entity.model_names, "Record,RecordSibling"); assert_eq!(entity.keys.clone().unwrap(), vec!["0x0"]); assert_eq!(record.node.depth, "Zero"); assert_eq!(nested.depth, "One"); diff --git a/crates/torii/graphql/src/tests/types-test/src/contracts.cairo b/crates/torii/graphql/src/tests/types-test/src/contracts.cairo index 1e5a0a3dca..638df904a6 100644 --- a/crates/torii/graphql/src/tests/types-test/src/contracts.cairo +++ b/crates/torii/graphql/src/tests/types-test/src/contracts.cairo @@ -8,7 +8,7 @@ trait IRecords { #[dojo::contract] mod records { use starknet::{ContractAddress, get_caller_address}; - use types_test::models::{Record, Subrecord, Nested, NestedMore, NestedMoreMore, Depth}; + use types_test::models::{Record, RecordSibling, Subrecord, Nested, NestedMore, NestedMoreMore, Depth}; use types_test::{seed, random}; use super::IRecords; @@ -90,6 +90,9 @@ mod records { random_u8, random_u128 }, + RecordSibling { + record_id, random_u8 + }, Subrecord { record_id, subrecord_id, type_u8: record_idx.into(), random_u8, } diff --git a/crates/torii/graphql/src/tests/types-test/src/models.cairo b/crates/torii/graphql/src/tests/types-test/src/models.cairo index af855773d4..a68b11ae8e 100644 --- a/crates/torii/graphql/src/tests/types-test/src/models.cairo +++ b/crates/torii/graphql/src/tests/types-test/src/models.cairo @@ -21,6 +21,13 @@ struct Record { random_u128: u128, } +#[derive(Model, Copy, Drop, Serde)] +struct RecordSibling { + #[key] + record_id: u32, + random_u8: u8 +} + #[derive(Copy, Drop, Serde, Introspect)] struct Nested { depth: Depth, diff --git a/crates/torii/grpc/proto/types.proto b/crates/torii/grpc/proto/types.proto index 1c853c005b..80d9ae8da6 100644 --- a/crates/torii/grpc/proto/types.proto +++ b/crates/torii/grpc/proto/types.proto @@ -29,6 +29,38 @@ message ModelMetadata { bytes schema = 6; } +message Enum { + // Variant + uint32 option = 1; + // Variants of the enum + repeated string options = 2; +} + +message Member { + // Name of the member + string name = 1; + // Type of member + oneof member_type { + Value value = 2; + Enum enum = 3; + Model struct = 4; + } +} + +message Model { + // Name of the model + string name = 1; + // Members of the model + repeated Member members = 2; +} + +message Entity { + // The entity key + bytes key = 1; + // Models of the entity + repeated Model models = 2; +} + message StorageEntry { // The key of the changed value string key = 1; @@ -55,6 +87,8 @@ message EntityUpdate { message EntityQuery { Clause clause = 1; + uint32 limit = 2; + uint32 offset = 3; } message Clause { @@ -72,7 +106,7 @@ message KeysClause { message AttributeClause { string model = 1; - string attribute = 2; + string member = 2; ComparisonOperator operator = 3; Value value = 4; } diff --git a/crates/torii/grpc/proto/world.proto b/crates/torii/grpc/proto/world.proto index 66f5fce09d..ec6dccc469 100644 --- a/crates/torii/grpc/proto/world.proto +++ b/crates/torii/grpc/proto/world.proto @@ -9,8 +9,11 @@ service World { rpc WorldMetadata (MetadataRequest) returns (MetadataResponse); - // Subscribes to entity updates. + // Subscribes to entities updates. rpc SubscribeEntities (SubscribeEntitiesRequest) returns (stream SubscribeEntitiesResponse); + + // Retrieve entity + rpc RetrieveEntities (RetrieveEntitiesRequest) returns (RetrieveEntitiesResponse); } @@ -33,3 +36,12 @@ message SubscribeEntitiesResponse { // List of entities that have been updated. types.EntityUpdate entity_update = 1; } + +message RetrieveEntitiesRequest { + // The entities to retrieve + types.EntityQuery query = 1; +} + +message RetrieveEntitiesResponse { + repeated types.Entity entities = 1; +} diff --git a/crates/torii/grpc/src/server/mod.rs b/crates/torii/grpc/src/server/mod.rs index d5ccac1680..8b7d484817 100644 --- a/crates/torii/grpc/src/server/mod.rs +++ b/crates/torii/grpc/src/server/mod.rs @@ -5,13 +5,18 @@ pub mod subscription; use std::future::Future; use std::net::SocketAddr; use std::pin::Pin; +use std::str::FromStr; use std::sync::Arc; +use dojo_types::primitive::Primitive; +use dojo_types::schema::{Struct, Ty}; use futures::Stream; use proto::world::{ - MetadataRequest, MetadataResponse, SubscribeEntitiesRequest, SubscribeEntitiesResponse, + MetadataRequest, MetadataResponse, RetrieveEntitiesRequest, RetrieveEntitiesResponse, + SubscribeEntitiesRequest, SubscribeEntitiesResponse, }; -use sqlx::{Pool, Sqlite}; +use sqlx::sqlite::SqliteRow; +use sqlx::{Pool, Row, Sqlite}; use starknet::core::utils::cairo_short_string_to_felt; use starknet::providers::jsonrpc::HttpTransport; use starknet::providers::JsonRpcClient; @@ -21,10 +26,12 @@ use tokio::sync::mpsc::Receiver; use tokio_stream::wrappers::{ReceiverStream, TcpListenerStream}; use tonic::transport::Server; use tonic::{Request, Response, Status}; -use torii_core::error::{Error, ParseError}; -use torii_core::model::{parse_sql_model_members, SqlModelMember}; +use torii_core::cache::ModelCache; +use torii_core::error::{Error, ParseError, QueryError}; +use torii_core::model::build_sql_query; use self::subscription::SubscribeRequest; +use crate::proto::types::clause::ClauseType; use crate::proto::world::world_server::WorldServer; use crate::proto::{self}; @@ -33,6 +40,7 @@ pub struct DojoWorld { world_address: FieldElement, pool: Pool, subscriber_manager: Arc, + model_cache: Arc, } impl DojoWorld { @@ -51,7 +59,9 @@ impl DojoWorld { Arc::clone(&subscriber_manager), )); - Self { pool, world_address, subscriber_manager } + let model_cache = Arc::new(ModelCache::new(pool.clone())); + + Self { pool, model_cache, world_address, subscriber_manager } } } @@ -78,7 +88,7 @@ impl DojoWorld { let mut models_metadata = Vec::with_capacity(models.len()); for model in models { - let schema = self.model_schema(&model.0).await?; + let schema = self.model_cache.schema(&model.0).await?; models_metadata.push(proto::types::ModelMetadata { name: model.0, class_hash: model.1, @@ -98,16 +108,79 @@ impl DojoWorld { }) } - async fn model_schema(&self, model: &str) -> Result { - let model_members: Vec = sqlx::query_as( - "SELECT id, model_idx, member_idx, name, type, type_enum, enum_options, key FROM \ - model_members WHERE model_id = ? ORDER BY model_idx ASC, member_idx ASC", + async fn entities_by_keys( + &self, + keys_clause: proto::types::KeysClause, + limit: u32, + offset: u32, + ) -> Result, Error> { + let keys = keys_clause + .keys + .iter() + .map(|bytes| { + if bytes.is_empty() { + return Ok("%".to_string()); + } + Ok(FieldElement::from_byte_slice_be(bytes) + .map(|felt| format!("{:#x}", felt)) + .map_err(ParseError::FromByteSliceError)?) + }) + .collect::, Error>>()?; + let keys_pattern = keys.join("/") + "/%"; + + let db_entities: Vec<(String, String)> = sqlx::query_as( + "SELECT id, model_names FROM entities WHERE keys LIKE ? ORDER BY event_id ASC LIMIT ? \ + OFFSET ?", ) - .bind(model) + .bind(&keys_pattern) + .bind(limit) + .bind(offset) .fetch_all(&self.pool) .await?; - Ok(parse_sql_model_members(model, &model_members)) + let mut entities = Vec::new(); + for (entity_id, models_str) in db_entities { + let model_names: Vec<&str> = models_str.split(',').collect(); + let mut schemas = Vec::new(); + for model in &model_names { + schemas.push(self.model_cache.schema(model).await?); + } + + let entity_query = + format!("{} WHERE {}.entity_id = ?", build_sql_query(&schemas)?, schemas[0].name()); + let row = sqlx::query(&entity_query).bind(&entity_id).fetch_one(&self.pool).await?; + + let mut models = Vec::new(); + for schema in schemas { + let struct_ty = schema.as_struct().expect("schema should be struct"); + models.push(Self::map_row_to_model(&schema.name(), struct_ty, &row)?); + } + + let key = FieldElement::from_str(&entity_id).map_err(ParseError::FromStr)?; + entities.push(proto::types::Entity { key: key.to_bytes_be().to_vec(), models }) + } + + Ok(entities) + } + + async fn entities_by_attribute( + &self, + _attribute: proto::types::AttributeClause, + _limit: u32, + _offset: u32, + ) -> Result, Error> { + // TODO: Implement + Err(QueryError::UnsupportedQuery.into()) + } + + async fn entities_by_composite( + &self, + _composite: proto::types::CompositeClause, + _limit: u32, + _offset: u32, + ) -> Result, Error> { + // TODO: Implement + Err(QueryError::UnsupportedQuery.into()) } pub async fn model_metadata(&self, model: &str) -> Result { @@ -124,7 +197,7 @@ impl DojoWorld { .fetch_one(&self.pool) .await?; - let schema = self.model_schema(model).await?; + let schema = self.model_cache.schema(model).await?; let layout = hex::decode(&layout).unwrap(); Ok(proto::types::ModelMetadata { @@ -161,6 +234,111 @@ impl DojoWorld { self.subscriber_manager.add_subscriber(subs).await } + + async fn retrieve_entities( + &self, + query: proto::types::EntityQuery, + ) -> Result { + let clause_type = query + .clause + .ok_or(QueryError::UnsupportedQuery)? + .clause_type + .ok_or(QueryError::UnsupportedQuery)?; + + let entities = match clause_type { + ClauseType::Keys(keys) => { + self.entities_by_keys(keys, query.limit, query.offset).await? + } + ClauseType::Attribute(attribute) => { + self.entities_by_attribute(attribute, query.limit, query.offset).await? + } + ClauseType::Composite(composite) => { + self.entities_by_composite(composite, query.limit, query.offset).await? + } + }; + + Ok(RetrieveEntitiesResponse { entities }) + } + + fn map_row_to_model( + path: &str, + struct_ty: &Struct, + row: &SqliteRow, + ) -> Result { + let members = struct_ty + .children + .iter() + .map(|member| { + let column_name = format!("{}.{}", path, member.name); + let name = member.name.clone(); + let member = match &member.ty { + Ty::Primitive(primitive) => { + let value_type = match primitive { + Primitive::Bool(_) => proto::types::value::ValueType::BoolValue( + row.try_get::(&column_name)?, + ), + Primitive::U8(_) + | Primitive::U16(_) + | Primitive::U32(_) + | Primitive::U64(_) + | Primitive::USize(_) => { + let value = row.try_get::(&column_name)?; + proto::types::value::ValueType::UintValue(value as u64) + } + Primitive::U128(_) + | Primitive::U256(_) + | Primitive::Felt252(_) + | Primitive::ClassHash(_) + | Primitive::ContractAddress(_) => { + let value = row.try_get::(&column_name)?; + proto::types::value::ValueType::StringValue(value) + } + }; + + proto::types::Member { + name, + member_type: Some(proto::types::member::MemberType::Value( + proto::types::Value { value_type: Some(value_type) }, + )), + } + } + Ty::Enum(enum_ty) => { + let value = row.try_get::(&column_name)?; + let options = enum_ty + .options + .iter() + .map(|e| e.name.to_string()) + .collect::>(); + let option = + options.iter().position(|o| o == &value).expect("wrong enum value") + as u32; + proto::types::Member { + name: enum_ty.name.clone(), + member_type: Some(proto::types::member::MemberType::Enum( + proto::types::Enum { option, options }, + )), + } + } + Ty::Struct(struct_ty) => { + let path = [path, &struct_ty.name].join("$"); + proto::types::Member { + name, + member_type: Some(proto::types::member::MemberType::Struct( + Self::map_row_to_model(&path, struct_ty, row)?, + )), + } + } + ty => { + unimplemented!("unimplemented type_enum: {ty}"); + } + }; + + Ok(member) + }) + .collect::, Error>>()?; + + Ok(proto::types::Model { name: struct_ty.name.clone(), members }) + } } type ServiceResult = Result, Status>; @@ -194,6 +372,21 @@ impl proto::world::world_server::World for DojoWorld { .map_err(|e| Status::internal(e.to_string()))?; Ok(Response::new(Box::pin(ReceiverStream::new(rx)) as Self::SubscribeEntitiesStream)) } + + async fn retrieve_entities( + &self, + request: Request, + ) -> Result, Status> { + let query = request + .into_inner() + .query + .ok_or_else(|| Status::invalid_argument("Missing query argument"))?; + + let entities = + self.retrieve_entities(query).await.map_err(|e| Status::internal(e.to_string()))?; + + Ok(Response::new(entities)) + } } pub async fn new( diff --git a/crates/torii/grpc/src/types.rs b/crates/torii/grpc/src/types.rs index e62dc3b8a8..00efe50358 100644 --- a/crates/torii/grpc/src/types.rs +++ b/crates/torii/grpc/src/types.rs @@ -13,6 +13,8 @@ use crate::proto; #[derive(Debug, Serialize, Deserialize, PartialEq, Hash, Eq, Clone)] pub struct Query { pub clause: Clause, + pub limit: u32, + pub offset: u32, } #[derive(Debug, Serialize, Deserialize, PartialEq, Hash, Eq, Clone)] @@ -31,7 +33,7 @@ pub struct KeysClause { #[derive(Debug, Serialize, Deserialize, PartialEq, Hash, Eq, Clone)] pub struct AttributeClause { pub model: String, - pub attribute: String, + pub member: String, pub operator: ComparisonOperator, pub value: Value, } @@ -105,7 +107,7 @@ impl TryFrom for dojo_types::WorldMetadata { impl From for proto::types::EntityQuery { fn from(value: Query) -> Self { - Self { clause: Some(value.clause.into()) } + Self { clause: Some(value.clause.into()), limit: value.limit, offset: value.offset } } } @@ -152,7 +154,7 @@ impl From for proto::types::AttributeClause { fn from(value: AttributeClause) -> Self { Self { model: value.model, - attribute: value.attribute, + member: value.member, operator: value.operator as i32, value: Some(value.value.into()), }