Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Message policy support chain switching #1265

Merged
merged 14 commits into from
Jan 10, 2025
57 changes: 31 additions & 26 deletions examples/next/src/components/providers/StarknetProvider.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,35 @@ export const ETH_CONTRACT_ADDRESS =
export const STRK_CONTRACT_ADDRESS =
"0x04718f5a0Fc34cC1AF16A1cdee98fFB20C31f5cD61D6Ab07201858f4287c938D";

const messageForChain = (chainId: constants.StarknetChainId) => {
return {
types: {
StarknetDomain: [
{ name: "name", type: "shortstring" },
{ name: "version", type: "shortstring" },
{ name: "chainId", type: "shortstring" },
{ name: "revision", type: "shortstring" },
],
Person: [
{ name: "name", type: "felt" },
{ name: "wallet", type: "felt" },
],
Mail: [
{ name: "from", type: "Person" },
{ name: "to", type: "Person" },
{ name: "contents", type: "felt" },
],
},
primaryType: "Mail",
domain: {
name: "StarkNet Mail",
version: "1",
revision: "1",
chainId: chainId,
},
};
};

const policies: SessionPolicies = {
contracts: {
[ETH_CONTRACT_ADDRESS]: {
Expand Down Expand Up @@ -51,32 +80,8 @@ const policies: SessionPolicies = {
},
},
messages: [
// {
// types: {
// StarknetDomain: [
// { name: "name", type: "shortstring" },
// { name: "version", type: "shortstring" },
// { name: "chainId", type: "shortstring" },
// { name: "revision", type: "shortstring" },
// ],
// Person: [
// { name: "name", type: "felt" },
// { name: "wallet", type: "felt" },
// ],
// Mail: [
// { name: "from", type: "Person" },
// { name: "to", type: "Person" },
// { name: "contents", type: "felt" },
// ],
// },
// primaryType: "Mail",
// domain: {
// name: "StarkNet Mail",
// version: "1",
// revision: "1",
// chainId: "SN_SEPOLIA",
// },
// },
messageForChain(constants.StarknetChainId.SN_MAIN),
messageForChain(constants.StarknetChainId.SN_SEPOLIA),
],
};

Expand Down
12 changes: 11 additions & 1 deletion packages/controller/src/controller.ts
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@ export default class ControllerProvider extends BaseProvider {
let chainId: ChainId | undefined;
const url = new URL(chain.rpcUrl);
const parts = url.pathname.split("/");

if (parts.includes("starknet")) {
if (parts.includes("mainnet")) {
chainId = constants.StarknetChainId.SN_MAIN;
Expand All @@ -60,6 +59,17 @@ export default class ControllerProvider extends BaseProvider {
chains.set(chainId, chain);
}

if (
options.policies?.messages?.length &&
options.policies.messages.length !== chains.size
) {
console.warn(
"Each message policy is associated with a specific chain. " +
"The number of message policies does not match the number of chains specified - " +
"session message signing may not work on some chains.",
);
}

this.chains = chains;
this.selectedChain = options.defaultChainId;

Expand Down
57 changes: 22 additions & 35 deletions packages/keychain/src/components/connect/CreateSession.tsx
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
import { Container, Content, Footer } from "@/components/layout";
import { BigNumberish, shortString } from "starknet";
import { ControllerError } from "@/utils/connection";
import { useCallback, useEffect, useMemo, useState } from "react";
import { useCallback, useMemo, useState } from "react";
import { useConnection } from "@/hooks/connection";
import { ControllerErrorAlert } from "@/components/ErrorAlert";
import { SessionConsent } from "@/components/connect";
import { Upgrade } from "./Upgrade";
import { ErrorCode } from "@cartridge/account-wasm";
import { TypedDataPolicy } from "@cartridge/presets";
import { ParsedSessionPolicies } from "@/hooks/session";
import { UnverifiedSessionSummary } from "@/components/session/UnverifiedSessionSummary";
import { VerifiedSessionSummary } from "@/components/session/VerifiedSessionSummary";
Expand All @@ -33,7 +31,6 @@ export function CreateSession({
}) {
const { controller, upgrade, chainId, theme, logout } = useConnection();
const [isConnecting, setIsConnecting] = useState(false);
const [isDisabled, setIsDisabled] = useState(false);
const [isConsent, setIsConsent] = useState(false);
const [duration, setDuration] = useState<bigint>(DEFAULT_SESSION_DURATION);
const expiresAt = useMemo(
Expand All @@ -43,32 +40,17 @@ export function CreateSession({
const [maxFee] = useState<BigNumberish>();
const [error, setError] = useState<ControllerError | Error>();

useEffect(() => {
if (!chainId) return;
const normalizedChainId = normalizeChainId(chainId);

const violatingPolicy = policies.messages?.find(
(policy: TypedDataPolicy) =>
"domain" in policy &&
(!policy.domain.chainId ||
normalizeChainId(policy.domain.chainId) !== normalizedChainId),
);

if (violatingPolicy) {
setError({
code: ErrorCode.PolicyChainIdMismatch,
message: `Policy for ${
(violatingPolicy as TypedDataPolicy).domain.name
}.${
(violatingPolicy as TypedDataPolicy).primaryType
} has mismatched chain ID.`,
});
setIsDisabled(true);
} else {
setError(undefined);
setIsDisabled(false);
}
}, [chainId, policies]);
const chainSpecificMessages = useMemo(() => {
if (!policies.messages || !chainId) return [];
return policies.messages.filter((message) => {
return (
!("domain" in message) ||
(message.domain.chainId &&
normalizeChainId(message.domain.chainId) ===
normalizeChainId(chainId))
);
});
}, [policies.messages, chainId]);

const onCreateSession = useCallback(async () => {
if (!controller || !policies) return;
Expand Down Expand Up @@ -139,9 +121,16 @@ export function CreateSession({
<Content gap={6}>
<SessionConsent isVerified={policies?.verified} />
{policies?.verified ? (
<VerifiedSessionSummary game={theme.name} policies={policies} />
<VerifiedSessionSummary
game={theme.name}
contracts={policies.contracts}
messages={chainSpecificMessages}
/>
) : (
<UnverifiedSessionSummary policies={policies} />
<UnverifiedSessionSummary
contracts={policies.contracts}
messages={chainSpecificMessages}
/>
)}
</Content>
<Footer>
Expand Down Expand Up @@ -199,9 +188,7 @@ export function CreateSession({
</Button>
<Button
className="flex-1"
disabled={
isDisabled || isConnecting || (!policies?.verified && !isConsent)
}
disabled={isConnecting || (!policies?.verified && !isConsent)}
isLoading={isConnecting}
onClick={onCreateSession}
>
Expand Down
11 changes: 9 additions & 2 deletions packages/keychain/src/components/connect/RegisterSession.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -86,9 +86,16 @@ export function RegisterSession({
<Content>
<SessionConsent isVerified={policies?.verified} />
{policies?.verified ? (
<VerifiedSessionSummary game={theme.name} policies={policies} />
<VerifiedSessionSummary
game={theme.name}
contracts={policies.contracts}
messages={policies.messages}
/>
) : (
<UnverifiedSessionSummary policies={policies} />
<UnverifiedSessionSummary
contracts={policies.contracts}
messages={policies.messages}
/>
)}
</Content>
</ExecutionContainer>
Expand Down
96 changes: 50 additions & 46 deletions packages/keychain/src/components/session/AggregateCard.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -4,29 +4,35 @@ import { useExplorer } from "@starknet-react/core";
import { constants } from "starknet";
import { Method } from "@cartridge/presets";
import { useChainId } from "@/hooks/connection";
import { ParsedSessionPolicies } from "@/hooks/session";
import { SessionContracts, SessionMessages } from "@/hooks/session";
import { Link } from "react-router-dom";
import { AccordionCard } from "./AccordionCard";
import { MessageContent } from "./MessageCard";

interface AggregateCardProps {
title: string;
icon: React.ReactNode;
policies: ParsedSessionPolicies;
contracts?: SessionContracts;
messages?: SessionMessages;
}

export function AggregateCard({ title, icon, policies }: AggregateCardProps) {
export function AggregateCard({
title,
icon,
contracts,
messages,
}: AggregateCardProps) {
const chainId = useChainId();
const explorer = useExplorer();

const totalMethods = Object.values(policies.contracts || {}).reduce(
const totalMethods = Object.values(contracts || {}).reduce(
(acc, contract) => {
return acc + (contract.methods?.length || 0);
},
0,
);

const totalMessages = policies.messages?.length ?? 0;
const totalMessages = messages?.length ?? 0;
const count = totalMethods + totalMessages;

return (
Expand All @@ -43,52 +49,50 @@ export function AggregateCard({ title, icon, policies }: AggregateCardProps) {
}
className="gap-2"
>
{Object.entries(policies.contracts || {}).map(
([address, { name, methods }]) => (
<div key={address} className="flex flex-col gap-2">
<div className="flex items-center justify-between bg-secondary text-xs">
<div className="py-2 font-bold">{name}</div>
<Link
to={
chainId === constants.StarknetChainId.SN_MAIN ||
chainId === constants.StarknetChainId.SN_SEPOLIA
? explorer.contract(address)
: `#` // TODO: Add explorer for worlds.dev
}
target="_blank"
className="text-muted-foreground hover:underline"
>
{formatAddress(address, { first: 5, last: 5 })}
</Link>
</div>
{Object.entries(contracts || {}).map(([address, { name, methods }]) => (
<div key={address} className="flex flex-col gap-2">
<div className="flex items-center justify-between bg-secondary text-xs">
<div className="py-2 font-bold">{name}</div>
<Link
to={
chainId === constants.StarknetChainId.SN_MAIN ||
chainId === constants.StarknetChainId.SN_SEPOLIA
? explorer.contract(address)
: `#` // TODO: Add explorer for worlds.dev
}
target="_blank"
className="text-muted-foreground hover:underline"
>
{formatAddress(address, { first: 5, last: 5 })}
</Link>
</div>

<div className="flex flex-col gap-px rounded overflow-auto border border-background">
{methods.map((method: Method) => (
<div
key={method.name}
className="flex flex-col p-3 gap-3 text-xs"
>
<div className="flex items-center justify-between">
<div className="font-bold text-accent-foreground">
{method.name}
</div>
<div className="text-muted-foreground">
{method.entrypoint}
</div>
<div className="flex flex-col gap-px rounded overflow-auto border border-background">
{methods.map((method: Method) => (
<div
key={method.name}
className="flex flex-col p-3 gap-3 text-xs"
>
<div className="flex items-center justify-between">
<div className="font-bold text-accent-foreground">
{method.name}
</div>
<div className="text-muted-foreground">
{method.entrypoint}
</div>
{method.description && (
<div className="text-muted-foreground">
{method.description}
</div>
)}
</div>
))}
</div>
{method.description && (
<div className="text-muted-foreground">
{method.description}
</div>
)}
</div>
))}
</div>
),
)}
</div>
))}

{policies.messages && <MessageContent messages={policies.messages} />}
{messages && <MessageContent messages={messages} />}
</AccordionCard>
);
}
Original file line number Diff line number Diff line change
@@ -1,17 +1,19 @@
import { toArray } from "@cartridge/controller";
import { ParsedSessionPolicies } from "@/hooks/session";
import { SessionContracts, SessionMessages } from "@/hooks/session";

import { MessageCard } from "./MessageCard";
import { ContractCard } from "./ContractCard";

export function UnverifiedSessionSummary({
policies,
contracts,
messages,
}: {
policies: ParsedSessionPolicies;
contracts?: SessionContracts;
messages?: SessionMessages;
}) {
return (
<div className="flex flex-col gap-4">
{Object.entries(policies.contracts ?? {}).map(([address, contract]) => {
{Object.entries(contracts ?? {}).map(([address, contract]) => {
const methods = toArray(contract.methods);
const title = !contract.meta?.name ? "Contract" : contract.meta.name;
const icon = contract.meta?.icon;
Expand All @@ -28,8 +30,8 @@ export function UnverifiedSessionSummary({
);
})}

{policies.messages && policies.messages.length > 0 && (
<MessageCard messages={policies.messages} isExpanded />
{messages && messages.length > 0 && (
<MessageCard messages={messages} isExpanded />
)}
</div>
);
Expand Down
Loading
Loading