diff --git a/crates/dojo-types/src/schema.rs b/crates/dojo-types/src/schema.rs index a6ff30ebc5..ec86efd39c 100644 --- a/crates/dojo-types/src/schema.rs +++ b/crates/dojo-types/src/schema.rs @@ -20,62 +20,6 @@ impl Member { } } -#[derive(Debug, Serialize, Deserialize, PartialEq, Hash, Eq, Clone)] -pub struct EntityQuery { - pub model: String, - pub clause: Clause, -} - -#[derive(Debug, Serialize, Deserialize, PartialEq, Hash, Eq, Clone)] -pub enum Clause { - Keys(KeysClause), - Attribute(AttributeClause), - Composite(CompositeClause), -} - -#[derive(Debug, Serialize, Deserialize, PartialEq, Hash, Eq, Clone)] -pub struct KeysClause { - pub keys: Vec, -} - -#[derive(Debug, Serialize, Deserialize, PartialEq, Hash, Eq, Clone)] -pub struct AttributeClause { - pub attribute: String, - pub operator: ComparisonOperator, - pub value: Value, -} - -#[derive(Debug, Serialize, Deserialize, PartialEq, Hash, Eq, Clone)] -pub struct CompositeClause { - pub operator: LogicalOperator, - pub clauses: Vec, -} - -#[derive(Debug, Serialize, Deserialize, PartialEq, Hash, Eq, Clone)] -pub enum LogicalOperator { - And, - Or, -} - -#[derive(Debug, Serialize, Deserialize, PartialEq, Hash, Eq, Clone)] -pub enum ComparisonOperator { - Eq, - Neq, - Gt, - Gte, - Lt, - Lte, -} - -#[derive(Debug, Serialize, Deserialize, PartialEq, Hash, Eq, Clone)] -pub enum Value { - String(String), - Int(i64), - UInt(u64), - Bool(bool), - Bytes(Vec), -} - #[derive(Debug, Clone, Serialize, Deserialize)] pub struct ModelMetadata { pub schema: Ty, diff --git a/crates/torii/client/src/client/mod.rs b/crates/torii/client/src/client/mod.rs index 6c58f2c559..ba713b5c5a 100644 --- a/crates/torii/client/src/client/mod.rs +++ b/crates/torii/client/src/client/mod.rs @@ -7,7 +7,7 @@ use std::collections::HashSet; use std::sync::Arc; use dojo_types::packing::unpack; -use dojo_types::schema::{Clause, EntityQuery, Ty}; +use dojo_types::schema::Ty; use dojo_types::WorldMetadata; use dojo_world::contracts::WorldContractReader; use parking_lot::{RwLock, RwLockReadGuard}; @@ -17,6 +17,7 @@ use starknet::providers::JsonRpcClient; use starknet_crypto::FieldElement; use tokio::sync::RwLock as AsyncRwLock; use torii_grpc::client::EntityUpdateStreaming; +use torii_grpc::types::KeysClause; use self::error::{Error, ParseError}; use self::storage::ModelStorage; @@ -46,7 +47,7 @@ impl Client { torii_url: String, rpc_url: String, world: FieldElement, - queries: Option>, + entities_keys: Option>, ) -> Result { let mut grpc_client = torii_grpc::client::WorldClient::new(torii_url, world).await?; @@ -61,23 +62,18 @@ impl Client { let provider = JsonRpcClient::new(HttpTransport::new(rpc_url)); let world_reader = WorldContractReader::new(world, provider); - if let Some(queries) = queries { - subbed_entities.add_entities(queries)?; + if let Some(keys) = entities_keys { + subbed_entities.add_entities(keys)?; // TODO: change this to querying the gRPC url instead - let subbed_entities = subbed_entities.entities.read().clone(); - for EntityQuery { model, clause } in subbed_entities { - let model_reader = world_reader.model(&model).await?; - let keys = if let Clause::Keys(clause) = clause { - clause.keys - } else { - return Err(Error::UnsupportedQuery); - }; - let values = model_reader.entity_storage(&keys).await?; + let subbed_entities = subbed_entities.entities_keys.read().clone(); + for keys in subbed_entities { + let model_reader = world_reader.model(&keys.model).await?; + let values = model_reader.entity_storage(&keys.keys).await?; client_storage.set_entity_storage( - cairo_short_string_to_felt(&model).unwrap(), - keys, + cairo_short_string_to_felt(&keys.model).unwrap(), + keys.keys, values, )?; } @@ -98,8 +94,8 @@ impl Client { self.metadata.read() } - pub fn subscribed_entities(&self) -> RwLockReadGuard<'_, HashSet> { - self.subscribed_entities.entities.read() + pub fn subscribed_entities(&self) -> RwLockReadGuard<'_, HashSet> { + self.subscribed_entities.entities_keys.read() } /// Returns the model value of an entity. @@ -109,27 +105,20 @@ impl Client { /// /// If the requested entity is not among the synced entities, it will attempt to fetch it from /// the RPC. - pub async fn entity(&self, entity: &EntityQuery) -> Result, Error> { - let Some(mut schema) = self.metadata.read().model(&entity.model).map(|m| m.schema.clone()) + pub async fn entity(&self, keys: &KeysClause) -> Result, Error> { + let Some(mut schema) = self.metadata.read().model(&keys.model).map(|m| m.schema.clone()) else { return Ok(None); }; - let keys = if let Clause::Keys(clause) = entity.clone().clause { - clause.keys - } else { - return Err(Error::UnsupportedQuery); - }; - - if !self.subscribed_entities.is_synced(entity) { - let model = self.world_reader.model(&entity.model).await?; - return Ok(Some(model.entity(&keys).await?)); + if !self.subscribed_entities.is_synced(keys) { + let model = self.world_reader.model(&keys.model).await?; + return Ok(Some(model.entity(&keys.keys).await?)); } let Ok(Some(raw_values)) = self.storage.get_entity_storage( - cairo_short_string_to_felt(&entity.model) - .map_err(ParseError::CairoShortStringToFelt)?, - &keys, + cairo_short_string_to_felt(&keys.model).map_err(ParseError::CairoShortStringToFelt)?, + &keys.keys, ) else { return Ok(Some(schema)); }; @@ -137,12 +126,12 @@ impl Client { let layout = self .metadata .read() - .model(&entity.model) + .model(&keys.model) .map(|m| m.layout.clone()) .expect("qed; layout should exist"); let unpacked = unpack(raw_values, layout).unwrap(); - let mut keys_and_unpacked = [keys.to_vec(), unpacked].concat(); + let mut keys_and_unpacked = [keys.keys.to_vec(), unpacked].concat(); schema.deserialize(&mut keys_and_unpacked).unwrap(); @@ -152,8 +141,9 @@ impl Client { /// Initiate the entity subscriptions and returns a [SubscriptionService] which when await'ed /// will execute the subscription service and starts the syncing process. pub async fn start_subscription(&self) -> Result { - let entities = self.subscribed_entities.entities.read().clone().into_iter().collect(); - let sub_res_stream = self.initiate_subscription(entities).await?; + let entities_keys = + self.subscribed_entities.entities_keys.read().clone().into_iter().collect(); + let sub_res_stream = self.initiate_subscription(entities_keys).await?; let (service, handle) = SubscriptionService::new( Arc::clone(&self.storage), @@ -169,21 +159,15 @@ impl Client { /// Adds entities to the list of entities to be synced. /// /// NOTE: This will establish a new subscription stream with the server. - pub async fn add_entities_to_sync(&self, entities: Vec) -> Result<(), Error> { - for entity in &entities { - let keys = if let Clause::Keys(clause) = entity.clone().clause { - clause.keys - } else { - return Err(Error::UnsupportedQuery); - }; - - self.initiate_entity(&entity.model, keys.clone()).await?; + pub async fn add_entities_to_sync(&self, entities_keys: Vec) -> Result<(), Error> { + for keys in &entities_keys { + self.initiate_entity(&keys.model, keys.keys.clone()).await?; } - self.subscribed_entities.add_entities(entities)?; + self.subscribed_entities.add_entities(entities_keys)?; let updated_entities = - self.subscribed_entities.entities.read().clone().into_iter().collect(); + self.subscribed_entities.entities_keys.read().clone().into_iter().collect(); let sub_res_stream = self.initiate_subscription(updated_entities).await?; match self.sub_client_handle.get() { @@ -196,11 +180,14 @@ impl Client { /// Removes entities from the list of entities to be synced. /// /// NOTE: This will establish a new subscription stream with the server. - pub async fn remove_entities_to_sync(&self, entities: Vec) -> Result<(), Error> { - self.subscribed_entities.remove_entities(entities)?; + pub async fn remove_entities_to_sync( + &self, + entities_keys: Vec, + ) -> Result<(), Error> { + self.subscribed_entities.remove_entities(entities_keys)?; let updated_entities = - self.subscribed_entities.entities.read().clone().into_iter().collect(); + self.subscribed_entities.entities_keys.read().clone().into_iter().collect(); let sub_res_stream = self.initiate_subscription(updated_entities).await?; match self.sub_client_handle.get() { @@ -216,10 +203,10 @@ impl Client { async fn initiate_subscription( &self, - entities: Vec, + keys: Vec, ) -> Result { let mut grpc_client = self.inner.write().await; - let stream = grpc_client.subscribe_entities(entities).await?; + let stream = grpc_client.subscribe_entities(keys).await?; Ok(stream) } diff --git a/crates/torii/client/src/client/storage.rs b/crates/torii/client/src/client/storage.rs index 2f486825b2..d56b7db506 100644 --- a/crates/torii/client/src/client/storage.rs +++ b/crates/torii/client/src/client/storage.rs @@ -168,7 +168,7 @@ mod tests { use std::collections::HashMap; use std::sync::Arc; - use dojo_types::schema::{KeysClause, Ty}; + use dojo_types::schema::Ty; use dojo_types::WorldMetadata; use parking_lot::RwLock; use starknet::core::utils::cairo_short_string_to_felt; @@ -202,13 +202,8 @@ mod tests { fn err_if_set_values_too_many() { let storage = create_dummy_storage(); let keys = vec![felt!("0x12345")]; - let entity = dojo_types::schema::EntityQuery { - model: "Position".into(), - clause: dojo_types::schema::Clause::Keys(KeysClause { keys: keys.clone() }), - }; - let values = vec![felt!("1"), felt!("2"), felt!("3"), felt!("4"), felt!("5")]; - let model = cairo_short_string_to_felt(&entity.model).unwrap(); + let model = cairo_short_string_to_felt("Position").unwrap(); let result = storage.set_entity_storage(model, keys, values); assert!(storage.storage.read().is_empty()); @@ -222,13 +217,8 @@ mod tests { fn err_if_set_values_too_few() { let storage = create_dummy_storage(); let keys = vec![felt!("0x12345")]; - let entity = dojo_types::schema::EntityQuery { - model: "Position".into(), - clause: dojo_types::schema::Clause::Keys(KeysClause { keys: keys.clone() }), - }; - let values = vec![felt!("1"), felt!("2")]; - let model = cairo_short_string_to_felt(&entity.model).unwrap(); + let model = cairo_short_string_to_felt("Position").unwrap(); let result = storage.set_entity_storage(model, keys, values); assert!(storage.storage.read().is_empty()); @@ -242,15 +232,10 @@ mod tests { fn set_and_get_entity_value() { let storage = create_dummy_storage(); let keys = vec![felt!("0x12345")]; - let entity = dojo_types::schema::EntityQuery { - model: "Position".into(), - clause: dojo_types::schema::Clause::Keys(KeysClause { keys: keys.clone() }), - }; assert!(storage.storage.read().is_empty(), "storage must be empty initially"); - let model = storage.metadata.read().model(&entity.model).cloned().unwrap(); - + let model = storage.metadata.read().model("Position").cloned().unwrap(); let expected_storage_addresses = compute_all_storage_addresses( cairo_short_string_to_felt(&model.name).unwrap(), &keys, @@ -258,7 +243,7 @@ mod tests { ); let expected_values = vec![felt!("1"), felt!("2"), felt!("3"), felt!("4")]; - let model_name_in_felt = cairo_short_string_to_felt(&entity.model).unwrap(); + let model_name_in_felt = cairo_short_string_to_felt("Position").unwrap(); storage .set_entity_storage(model_name_in_felt, keys.clone(), expected_values.clone()) diff --git a/crates/torii/client/src/client/subscription.rs b/crates/torii/client/src/client/subscription.rs index 8bd82346f1..898d008130 100644 --- a/crates/torii/client/src/client/subscription.rs +++ b/crates/torii/client/src/client/subscription.rs @@ -4,7 +4,6 @@ use std::future::Future; use std::sync::Arc; use std::task::Poll; -use dojo_types::schema::{Clause, EntityQuery}; use dojo_types::WorldMetadata; use futures::channel::mpsc::{self, Receiver, Sender}; use futures_util::StreamExt; @@ -13,6 +12,7 @@ use starknet::core::types::{StateDiff, StateUpdate}; use starknet::core::utils::cairo_short_string_to_felt; use starknet_crypto::FieldElement; use torii_grpc::client::EntityUpdateStreaming; +use torii_grpc::types::KeysClause; use super::error::{Error, ParseError}; use super::ModelStorage; @@ -24,61 +24,54 @@ pub enum SubscriptionEvent { pub struct SubscribedEntities { metadata: Arc>, - pub(super) entities: RwLock>, + pub(super) entities_keys: RwLock>, /// All the relevant storage addresses derived from the subscribed entities pub(super) subscribed_storage_addresses: RwLock>, } impl SubscribedEntities { - pub(super) fn is_synced(&self, entity: &EntityQuery) -> bool { - self.entities.read().contains(entity) + pub(super) fn is_synced(&self, keys: &KeysClause) -> bool { + self.entities_keys.read().contains(keys) } pub(super) fn new(metadata: Arc>) -> Self { Self { metadata, - entities: Default::default(), + entities_keys: Default::default(), subscribed_storage_addresses: Default::default(), } } - pub(super) fn add_entities(&self, entities: Vec) -> Result<(), Error> { - for entity in entities { - Self::add_entity(self, entity)?; + pub(super) fn add_entities(&self, entities_keys: Vec) -> Result<(), Error> { + for keys in entities_keys { + Self::add_entity(self, keys)?; } Ok(()) } - pub(super) fn remove_entities(&self, entities: Vec) -> Result<(), Error> { - for entity in entities { - Self::remove_entity(self, entity)?; + pub(super) fn remove_entities(&self, entities_keys: Vec) -> Result<(), Error> { + for keys in entities_keys { + Self::remove_entity(self, keys)?; } Ok(()) } - pub(super) fn add_entity(&self, entity: EntityQuery) -> Result<(), Error> { - if !self.entities.write().insert(entity.clone()) { + pub(super) fn add_entity(&self, keys: KeysClause) -> Result<(), Error> { + if !self.entities_keys.write().insert(keys.clone()) { return Ok(()); } - let keys = if let Clause::Keys(clause) = entity.clause { - clause.keys - } else { - return Err(Error::UnsupportedQuery); - }; - let model_packed_size = self .metadata .read() .models - .get(&entity.model) + .get(&keys.model) .map(|c| c.packed_size) - .ok_or(Error::UnknownModel(entity.model.clone()))?; + .ok_or(Error::UnknownModel(keys.model.clone()))?; let storage_addresses = compute_all_storage_addresses( - cairo_short_string_to_felt(&entity.model) - .map_err(ParseError::CairoShortStringToFelt)?, - &keys, + cairo_short_string_to_felt(&keys.model).map_err(ParseError::CairoShortStringToFelt)?, + &keys.keys, model_packed_size, ); @@ -90,29 +83,22 @@ impl SubscribedEntities { Ok(()) } - pub(super) fn remove_entity(&self, entity: EntityQuery) -> Result<(), Error> { - if !self.entities.write().remove(&entity) { + pub(super) fn remove_entity(&self, keys: KeysClause) -> Result<(), Error> { + if !self.entities_keys.write().remove(&keys) { return Ok(()); } - let keys = if let Clause::Keys(clause) = entity.clause { - clause.keys - } else { - return Err(Error::UnsupportedQuery); - }; - let model_packed_size = self .metadata .read() .models - .get(&entity.model) + .get(&keys.model) .map(|c| c.packed_size) - .ok_or(Error::UnknownModel(entity.model.clone()))?; + .ok_or(Error::UnknownModel(keys.model.clone()))?; let storage_addresses = compute_all_storage_addresses( - cairo_short_string_to_felt(&entity.model) - .map_err(ParseError::CairoShortStringToFelt)?, - &keys, + cairo_short_string_to_felt(&keys.model).map_err(ParseError::CairoShortStringToFelt)?, + &keys.keys, model_packed_size, ); @@ -256,11 +242,12 @@ mod tests { use std::collections::HashMap; use std::sync::Arc; - use dojo_types::schema::{KeysClause, Ty}; + use dojo_types::schema::Ty; use dojo_types::WorldMetadata; use parking_lot::RwLock; use starknet::core::utils::cairo_short_string_to_felt; use starknet::macros::felt; + use torii_grpc::types::KeysClause; use crate::utils::compute_all_storage_addresses; @@ -295,29 +282,26 @@ mod tests { let metadata = self::create_dummy_metadata(); - let entity = dojo_types::schema::EntityQuery { - model: model_name, - clause: dojo_types::schema::Clause::Keys(KeysClause { keys }), - }; + let keys = KeysClause { model: model_name, keys }; let subscribed_entities = super::SubscribedEntities::new(Arc::new(RwLock::new(metadata))); - subscribed_entities.add_entities(vec![entity.clone()]).expect("able to add entity"); + subscribed_entities.add_entities(vec![keys.clone()]).expect("able to add entity"); let actual_storage_addresses_count = subscribed_entities.subscribed_storage_addresses.read().len(); let actual_storage_addresses = subscribed_entities.subscribed_storage_addresses.read().clone(); - assert!(subscribed_entities.entities.read().contains(&entity)); + assert!(subscribed_entities.entities_keys.read().contains(&keys)); assert_eq!(actual_storage_addresses_count, expected_storage_addresses.len()); assert!(expected_storage_addresses.all(|addr| actual_storage_addresses.contains(&addr))); - subscribed_entities.remove_entities(vec![entity.clone()]).expect("able to remove entities"); + subscribed_entities.remove_entities(vec![keys.clone()]).expect("able to remove entities"); let actual_storage_addresses_count_after = subscribed_entities.subscribed_storage_addresses.read().len(); assert_eq!(actual_storage_addresses_count_after, 0); - assert!(!subscribed_entities.entities.read().contains(&entity)); + assert!(!subscribed_entities.entities_keys.read().contains(&keys)); } } diff --git a/crates/torii/core/src/error.rs b/crates/torii/core/src/error.rs index 8ccccdc601..cdcf6b9b95 100644 --- a/crates/torii/core/src/error.rs +++ b/crates/torii/core/src/error.rs @@ -7,8 +7,8 @@ pub enum Error { Parse(#[from] ParseError), #[error(transparent)] Sql(#[from] sqlx::Error), - #[error("unsupported query clause")] - UnsupportedQuery, + #[error(transparent)] + QueryError(#[from] QueryError), } #[derive(Debug, thiserror::Error)] @@ -20,3 +20,9 @@ pub enum ParseError { #[error(transparent)] FromByteSliceError(#[from] FromByteSliceError), } + +#[derive(Debug, thiserror::Error)] +pub enum QueryError { + #[error("unsupported query")] + UnsupportedQuery, +} diff --git a/crates/torii/graphql/src/tests/types-test/Scarb.lock b/crates/torii/graphql/src/tests/types-test/Scarb.lock index 2ca7018569..6ad1bf2c03 100644 --- a/crates/torii/graphql/src/tests/types-test/Scarb.lock +++ b/crates/torii/graphql/src/tests/types-test/Scarb.lock @@ -3,14 +3,14 @@ version = 1 [[package]] name = "dojo" -version = "0.3.6" +version = "0.3.10" dependencies = [ "dojo_plugin", ] [[package]] name = "dojo_plugin" -version = "0.3.6" +version = "0.3.10" [[package]] name = "types_test" diff --git a/crates/torii/grpc/proto/types.proto b/crates/torii/grpc/proto/types.proto index eb2e6e9fad..1c853c005b 100644 --- a/crates/torii/grpc/proto/types.proto +++ b/crates/torii/grpc/proto/types.proto @@ -54,8 +54,7 @@ message EntityUpdate { } message EntityQuery { - string model = 1; - Clause clause = 2; + Clause clause = 1; } message Clause { @@ -67,18 +66,21 @@ message Clause { } message KeysClause { - repeated bytes keys = 1; + string model = 1; + repeated bytes keys = 2; } message AttributeClause { - string attribute = 1; - ComparisonOperator operator = 2; - Value value = 3; + string model = 1; + string attribute = 2; + ComparisonOperator operator = 3; + Value value = 4; } message CompositeClause { - LogicalOperator operator = 1; - repeated Clause clauses = 2; + string model = 1; + LogicalOperator operator = 2; + repeated Clause clauses = 3; } enum LogicalOperator { diff --git a/crates/torii/grpc/proto/world.proto b/crates/torii/grpc/proto/world.proto index 10734e7e8e..66f5fce09d 100644 --- a/crates/torii/grpc/proto/world.proto +++ b/crates/torii/grpc/proto/world.proto @@ -25,8 +25,8 @@ message MetadataResponse { } message SubscribeEntitiesRequest { - // The list of entity queries to subscribe to. - repeated types.EntityQuery queries = 1; + // The list of entity keys to subscribe to. + repeated types.KeysClause entities_keys = 1; } message SubscribeEntitiesResponse { diff --git a/crates/torii/grpc/src/client.rs b/crates/torii/grpc/src/client.rs index c97b250b16..e4097849c3 100644 --- a/crates/torii/grpc/src/client.rs +++ b/crates/torii/grpc/src/client.rs @@ -1,5 +1,4 @@ //! Client implementation for the gRPC service. - use futures_util::stream::MapOk; use futures_util::{Stream, StreamExt, TryStreamExt}; use proto::world::{world_client, SubscribeEntitiesRequest}; @@ -8,6 +7,7 @@ use starknet_crypto::FieldElement; use crate::proto::world::{MetadataRequest, SubscribeEntitiesResponse}; use crate::proto::{self}; +use crate::types::KeysClause; #[derive(Debug, thiserror::Error)] pub enum Error { @@ -67,12 +67,12 @@ impl WorldClient { /// Subscribe to the state diff for a set of entities of a World. pub async fn subscribe_entities( &mut self, - queries: Vec, + entities_keys: Vec, ) -> Result { let stream = self .inner .subscribe_entities(SubscribeEntitiesRequest { - queries: queries.into_iter().map(|e| e.into()).collect(), + entities_keys: entities_keys.into_iter().map(|e| e.into()).collect(), }) .await .map_err(Error::Grpc) diff --git a/crates/torii/grpc/src/lib.rs b/crates/torii/grpc/src/lib.rs index 4b19ce10cb..1fb5da7eb7 100644 --- a/crates/torii/grpc/src/lib.rs +++ b/crates/torii/grpc/src/lib.rs @@ -3,7 +3,7 @@ extern crate wasm_prost as prost; #[cfg(target_arch = "wasm32")] extern crate wasm_tonic as tonic; -pub mod conversion; +pub mod types; #[cfg(feature = "client")] pub mod client; diff --git a/crates/torii/grpc/src/server/mod.rs b/crates/torii/grpc/src/server/mod.rs index c4e2088e70..d5ccac1680 100644 --- a/crates/torii/grpc/src/server/mod.rs +++ b/crates/torii/grpc/src/server/mod.rs @@ -7,7 +7,6 @@ use std::net::SocketAddr; use std::pin::Pin; use std::sync::Arc; -use dojo_types::schema::KeysClause; use futures::Stream; use proto::world::{ MetadataRequest, MetadataResponse, SubscribeEntitiesRequest, SubscribeEntitiesResponse, @@ -26,7 +25,6 @@ use torii_core::error::{Error, ParseError}; use torii_core::model::{parse_sql_model_members, SqlModelMember}; use self::subscription::SubscribeRequest; -use crate::proto::types::clause::ClauseType; use crate::proto::world::world_server::WorldServer; use crate::proto::{self}; @@ -141,30 +139,19 @@ impl DojoWorld { async fn subscribe_entities( &self, - queries: Vec, + entities_keys: Vec, ) -> Result>, Error> { - let mut subs = Vec::with_capacity(queries.len()); - for query in queries { - let clause: KeysClause = query - .clause - .ok_or(Error::UnsupportedQuery) - .and_then(|clause| clause.clause_type.ok_or(Error::UnsupportedQuery)) - .and_then(|clause_type| match clause_type { - ClauseType::Keys(clause) => Ok(clause), - _ => Err(Error::UnsupportedQuery), - })? - .try_into() - .map_err(ParseError::FromByteSliceError)?; - - let model = cairo_short_string_to_felt(&query.model) + let mut subs = Vec::with_capacity(entities_keys.len()); + for keys in entities_keys { + let model = cairo_short_string_to_felt(&keys.model) .map_err(ParseError::CairoShortStringToFelt)?; let proto::types::ModelMetadata { packed_size, .. } = - self.model_metadata(&query.model).await?; + self.model_metadata(&keys.model).await?; subs.push(SubscribeRequest { - keys: clause.keys, + keys, model: subscription::ModelMetadata { name: model, packed_size: packed_size as usize, @@ -172,9 +159,7 @@ impl DojoWorld { }); } - let res = self.subscriber_manager.add_subscriber(subs).await; - - Ok(res) + self.subscriber_manager.add_subscriber(subs).await } } @@ -202,9 +187,11 @@ impl proto::world::world_server::World for DojoWorld { &self, request: Request, ) -> ServiceResult { - let SubscribeEntitiesRequest { queries } = request.into_inner(); - let rx = - self.subscribe_entities(queries).await.map_err(|e| Status::internal(e.to_string()))?; + let SubscribeEntitiesRequest { entities_keys } = request.into_inner(); + let rx = self + .subscribe_entities(entities_keys) + .await + .map_err(|e| Status::internal(e.to_string()))?; Ok(Response::new(Box::pin(ReceiverStream::new(rx)) as Self::SubscribeEntitiesStream)) } } diff --git a/crates/torii/grpc/src/server/subscription.rs b/crates/torii/grpc/src/server/subscription.rs index 4750d69cdd..bc20140df7 100644 --- a/crates/torii/grpc/src/server/subscription.rs +++ b/crates/torii/grpc/src/server/subscription.rs @@ -6,7 +6,7 @@ use std::task::Poll; use futures_util::future::BoxFuture; use futures_util::FutureExt; use rand::Rng; -use rayon::prelude::{IntoParallelIterator, IntoParallelRefIterator, ParallelIterator}; +use rayon::prelude::{IntoParallelIterator, ParallelIterator}; use starknet::core::types::{ BlockId, ContractStorageDiffItem, MaybePendingStateUpdate, StateUpdate, StorageEntry, }; @@ -15,10 +15,12 @@ use starknet::providers::Provider; use starknet_crypto::{poseidon_hash_many, FieldElement}; use tokio::sync::mpsc::{channel, Receiver, Sender}; use tokio::sync::RwLock; +use torii_core::error::{Error, ParseError}; use tracing::{debug, error, trace}; -use super::error::SubscriptionError as Error; +use super::error::SubscriptionError; use crate::proto; +use crate::types::KeysClause; pub struct ModelMetadata { pub name: FieldElement, @@ -27,7 +29,22 @@ pub struct ModelMetadata { pub struct SubscribeRequest { pub model: ModelMetadata, - pub keys: Vec, + pub keys: proto::types::KeysClause, +} + +impl SubscribeRequest { + // pub fn slots(&self) -> Result, QueryError> { + // match self.query.clause { + // Clause::Keys(KeysClause { keys }) => { + // let base = poseidon_hash_many(&[ + // short_string!("dojo_storage"), + // req.model.name, + // poseidon_hash_many(&keys), + // ]); + // } + // _ => Err(QueryError::UnsupportedQuery), + // } + // } } pub struct Subscriber { @@ -45,33 +62,41 @@ pub struct SubscriberManager { impl SubscriberManager { pub(super) async fn add_subscriber( &self, - entities: Vec, - ) -> Receiver> { + reqs: Vec, + ) -> Result>, Error> + { let id = rand::thread_rng().gen::(); let (sender, receiver) = channel(1); // convert the list of entites into a list storage addresses - let storage_addresses = entities - .par_iter() - .map(|entity| { + let storage_addresses = reqs + .into_iter() + .map(|req| { + let keys: KeysClause = + req.keys.try_into().map_err(ParseError::FromByteSliceError)?; + let base = poseidon_hash_many(&[ short_string!("dojo_storage"), - entity.model.name, - poseidon_hash_many(&entity.keys), + req.model.name, + poseidon_hash_many(&keys.keys), ]); - (0..entity.model.packed_size) + let res = (0..req.model.packed_size) .into_par_iter() .map(|i| base + i.into()) - .collect::>() + .collect::>(); + + Ok(res) }) + .collect::, Error>>()? + .into_iter() .flatten() .collect::>(); self.subscribers.write().await.insert(id, Subscriber { storage_addresses, sender }); - receiver + Ok(receiver) } pub(super) async fn remove_subscriber(&self, id: usize) { @@ -79,8 +104,8 @@ impl SubscriberManager { } } -type PublishStateUpdateResult = Result<(), Error>; -type RequestStateUpdateResult = Result; +type PublishStateUpdateResult = Result<(), SubscriptionError>; +type RequestStateUpdateResult = Result; #[must_use = "Service does nothing unless polled"] pub struct Service { @@ -115,8 +140,10 @@ where } async fn fetch_state_update(provider: P, block_num: u64) -> (P, u64, RequestStateUpdateResult) { - let res = - provider.get_state_update(BlockId::Number(block_num)).await.map_err(Error::Provider); + let res = provider + .get_state_update(BlockId::Number(block_num)) + .await + .map_err(SubscriptionError::Provider); (provider, block_num, res) } diff --git a/crates/torii/grpc/src/conversion.rs b/crates/torii/grpc/src/types.rs similarity index 77% rename from crates/torii/grpc/src/conversion.rs rename to crates/torii/grpc/src/types.rs index 987059cfb2..e62dc3b8a8 100644 --- a/crates/torii/grpc/src/conversion.rs +++ b/crates/torii/grpc/src/types.rs @@ -1,9 +1,8 @@ use std::collections::HashMap; use std::str::FromStr; -use dojo_types::schema::{ - AttributeClause, Clause, CompositeClause, EntityQuery, KeysClause, Ty, Value, -}; +use dojo_types::schema::Ty; +use serde::{Deserialize, Serialize}; use starknet::core::types::{ ContractStorageDiffItem, FromByteSliceError, FromStrError, StateDiff, StateUpdate, StorageEntry, }; @@ -11,6 +10,64 @@ use starknet_crypto::FieldElement; use crate::proto; +#[derive(Debug, Serialize, Deserialize, PartialEq, Hash, Eq, Clone)] +pub struct Query { + pub clause: Clause, +} + +#[derive(Debug, Serialize, Deserialize, PartialEq, Hash, Eq, Clone)] +pub enum Clause { + Keys(KeysClause), + Attribute(AttributeClause), + Composite(CompositeClause), +} + +#[derive(Debug, Serialize, Deserialize, PartialEq, Hash, Eq, Clone)] +pub struct KeysClause { + pub model: String, + pub keys: Vec, +} + +#[derive(Debug, Serialize, Deserialize, PartialEq, Hash, Eq, Clone)] +pub struct AttributeClause { + pub model: String, + pub attribute: String, + pub operator: ComparisonOperator, + pub value: Value, +} + +#[derive(Debug, Serialize, Deserialize, PartialEq, Hash, Eq, Clone)] +pub struct CompositeClause { + pub model: String, + pub operator: LogicalOperator, + pub clauses: Vec, +} + +#[derive(Debug, Serialize, Deserialize, PartialEq, Hash, Eq, Clone)] +pub enum LogicalOperator { + And, + Or, +} + +#[derive(Debug, Serialize, Deserialize, PartialEq, Hash, Eq, Clone)] +pub enum ComparisonOperator { + Eq, + Neq, + Gt, + Gte, + Lt, + Lte, +} + +#[derive(Debug, Serialize, Deserialize, PartialEq, Hash, Eq, Clone)] +pub enum Value { + String(String), + Int(i64), + UInt(u64), + Bool(bool), + Bytes(Vec), +} + impl TryFrom for dojo_types::schema::ModelMetadata { type Error = FromStrError; fn try_from(value: proto::types::ModelMetadata) -> Result { @@ -46,9 +103,9 @@ impl TryFrom for dojo_types::WorldMetadata { } } -impl From for proto::types::EntityQuery { - fn from(value: EntityQuery) -> Self { - Self { model: value.model, clause: Some(value.clause.into()) } +impl From for proto::types::EntityQuery { + fn from(value: Query) -> Self { + Self { clause: Some(value.clause.into()) } } } @@ -70,7 +127,10 @@ impl From for proto::types::Clause { impl From for proto::types::KeysClause { fn from(value: KeysClause) -> Self { - Self { keys: value.keys.iter().map(|k| k.to_bytes_be().into()).collect() } + Self { + model: value.model, + keys: value.keys.iter().map(|k| k.to_bytes_be().into()).collect(), + } } } @@ -84,13 +144,14 @@ impl TryFrom for KeysClause { .map(|k| FieldElement::from_byte_slice_be(&k)) .collect::, _>>()?; - Ok(Self { keys }) + Ok(Self { model: value.model, keys }) } } impl From for proto::types::AttributeClause { fn from(value: AttributeClause) -> Self { Self { + model: value.model, attribute: value.attribute, operator: value.operator as i32, value: Some(value.value.into()), @@ -101,6 +162,7 @@ impl From for proto::types::AttributeClause { impl From for proto::types::CompositeClause { fn from(value: CompositeClause) -> Self { Self { + model: value.model, operator: value.operator as i32, clauses: value.clauses.into_iter().map(|clause| clause.into()).collect(), } diff --git a/examples/spawn-and-move/Scarb.lock b/examples/spawn-and-move/Scarb.lock index 9ffa906f1d..2370c6221a 100644 --- a/examples/spawn-and-move/Scarb.lock +++ b/examples/spawn-and-move/Scarb.lock @@ -3,18 +3,18 @@ version = 1 [[package]] name = "dojo" -version = "0.3.6" +version = "0.3.10" dependencies = [ "dojo_plugin", ] [[package]] name = "dojo_examples" -version = "0.3.6" +version = "0.3.10" dependencies = [ "dojo", ] [[package]] name = "dojo_plugin" -version = "0.3.6" +version = "0.3.10"