diff --git a/examples/next/src/components/providers/StarknetProvider.tsx b/examples/next/src/components/providers/StarknetProvider.tsx index 4ab05792e..6c9e8a623 100644 --- a/examples/next/src/components/providers/StarknetProvider.tsx +++ b/examples/next/src/components/providers/StarknetProvider.tsx @@ -16,6 +16,35 @@ export const ETH_CONTRACT_ADDRESS = export const STRK_CONTRACT_ADDRESS = "0x04718f5a0Fc34cC1AF16A1cdee98fFB20C31f5cD61D6Ab07201858f4287c938D"; +const messageForChain = (chainId: constants.StarknetChainId) => { + return { + types: { + StarknetDomain: [ + { name: "name", type: "shortstring" }, + { name: "version", type: "shortstring" }, + { name: "chainId", type: "shortstring" }, + { name: "revision", type: "shortstring" }, + ], + Person: [ + { name: "name", type: "felt" }, + { name: "wallet", type: "felt" }, + ], + Mail: [ + { name: "from", type: "Person" }, + { name: "to", type: "Person" }, + { name: "contents", type: "felt" }, + ], + }, + primaryType: "Mail", + domain: { + name: "StarkNet Mail", + version: "1", + revision: "1", + chainId: chainId, + }, + }; +}; + const policies: SessionPolicies = { contracts: { [ETH_CONTRACT_ADDRESS]: { @@ -51,32 +80,8 @@ const policies: SessionPolicies = { }, }, messages: [ - // { - // types: { - // StarknetDomain: [ - // { name: "name", type: "shortstring" }, - // { name: "version", type: "shortstring" }, - // { name: "chainId", type: "shortstring" }, - // { name: "revision", type: "shortstring" }, - // ], - // Person: [ - // { name: "name", type: "felt" }, - // { name: "wallet", type: "felt" }, - // ], - // Mail: [ - // { name: "from", type: "Person" }, - // { name: "to", type: "Person" }, - // { name: "contents", type: "felt" }, - // ], - // }, - // primaryType: "Mail", - // domain: { - // name: "StarkNet Mail", - // version: "1", - // revision: "1", - // chainId: "SN_SEPOLIA", - // }, - // }, + messageForChain(constants.StarknetChainId.SN_MAIN), + messageForChain(constants.StarknetChainId.SN_SEPOLIA), ], }; diff --git a/packages/controller/src/controller.ts b/packages/controller/src/controller.ts index d6f8b0679..e162aea54 100644 --- a/packages/controller/src/controller.ts +++ b/packages/controller/src/controller.ts @@ -37,7 +37,6 @@ export default class ControllerProvider extends BaseProvider { let chainId: ChainId | undefined; const url = new URL(chain.rpcUrl); const parts = url.pathname.split("/"); - if (parts.includes("starknet")) { if (parts.includes("mainnet")) { chainId = constants.StarknetChainId.SN_MAIN; @@ -60,6 +59,17 @@ export default class ControllerProvider extends BaseProvider { chains.set(chainId, chain); } + if ( + options.policies?.messages?.length && + options.policies.messages.length !== chains.size + ) { + console.warn( + "Each message policy is associated with a specific chain. " + + "The number of message policies does not match the number of chains specified - " + + "session message signing may not work on some chains.", + ); + } + this.chains = chains; this.selectedChain = options.defaultChainId; diff --git a/packages/keychain/src/components/connect/CreateSession.tsx b/packages/keychain/src/components/connect/CreateSession.tsx index 865dd7101..8165cafb9 100644 --- a/packages/keychain/src/components/connect/CreateSession.tsx +++ b/packages/keychain/src/components/connect/CreateSession.tsx @@ -1,13 +1,11 @@ import { Container, Content, Footer } from "@/components/layout"; import { BigNumberish, shortString } from "starknet"; import { ControllerError } from "@/utils/connection"; -import { useCallback, useEffect, useMemo, useState } from "react"; +import { useCallback, useMemo, useState } from "react"; import { useConnection } from "@/hooks/connection"; import { ControllerErrorAlert } from "@/components/ErrorAlert"; import { SessionConsent } from "@/components/connect"; import { Upgrade } from "./Upgrade"; -import { ErrorCode } from "@cartridge/account-wasm"; -import { TypedDataPolicy } from "@cartridge/presets"; import { ParsedSessionPolicies } from "@/hooks/session"; import { UnverifiedSessionSummary } from "@/components/session/UnverifiedSessionSummary"; import { VerifiedSessionSummary } from "@/components/session/VerifiedSessionSummary"; @@ -33,7 +31,6 @@ export function CreateSession({ }) { const { controller, upgrade, chainId, theme, logout } = useConnection(); const [isConnecting, setIsConnecting] = useState(false); - const [isDisabled, setIsDisabled] = useState(false); const [isConsent, setIsConsent] = useState(false); const [duration, setDuration] = useState(DEFAULT_SESSION_DURATION); const expiresAt = useMemo( @@ -43,32 +40,17 @@ export function CreateSession({ const [maxFee] = useState(); const [error, setError] = useState(); - useEffect(() => { - if (!chainId) return; - const normalizedChainId = normalizeChainId(chainId); - - const violatingPolicy = policies.messages?.find( - (policy: TypedDataPolicy) => - "domain" in policy && - (!policy.domain.chainId || - normalizeChainId(policy.domain.chainId) !== normalizedChainId), - ); - - if (violatingPolicy) { - setError({ - code: ErrorCode.PolicyChainIdMismatch, - message: `Policy for ${ - (violatingPolicy as TypedDataPolicy).domain.name - }.${ - (violatingPolicy as TypedDataPolicy).primaryType - } has mismatched chain ID.`, - }); - setIsDisabled(true); - } else { - setError(undefined); - setIsDisabled(false); - } - }, [chainId, policies]); + const chainSpecificMessages = useMemo(() => { + if (!policies.messages || !chainId) return []; + return policies.messages.filter((message) => { + return ( + !("domain" in message) || + (message.domain.chainId && + normalizeChainId(message.domain.chainId) === + normalizeChainId(chainId)) + ); + }); + }, [policies.messages, chainId]); const onCreateSession = useCallback(async () => { if (!controller || !policies) return; @@ -139,9 +121,16 @@ export function CreateSession({ {policies?.verified ? ( - + ) : ( - + )}