diff --git a/packages/core/src/StateNode.ts b/packages/core/src/StateNode.ts index 375551fca9..4f403900bb 100644 --- a/packages/core/src/StateNode.ts +++ b/packages/core/src/StateNode.ts @@ -144,6 +144,7 @@ export class StateNode< public tags: string[] = []; public transitions!: Map[]>; public always?: Array>; + public invariant?: ({ context }: { context: TContext }) => void; constructor( /** The raw config used to create the machine. */ @@ -216,6 +217,7 @@ export class StateNode< this.output = this.type === 'final' || !this.parent ? this.config.output : undefined; this.tags = toArray(config.tags).slice(); + this.invariant = config.invariant; } /** @internal */ diff --git a/packages/core/src/stateUtils.ts b/packages/core/src/stateUtils.ts index fce7765ed1..ac0f3026e8 100644 --- a/packages/core/src/stateUtils.ts +++ b/packages/core/src/stateUtils.ts @@ -22,7 +22,6 @@ import { AnyMachineSnapshot, AnyStateNode, AnyTransitionDefinition, - DelayExpr, DelayedTransitionDefinition, EventObject, HistoryValue, @@ -1690,6 +1689,8 @@ export function macrostep( ); addMicrostate(nextSnapshot, event, []); + // No need to check invariant since the state is the same + return { snapshot: nextSnapshot, microstates @@ -1763,6 +1764,13 @@ export function macrostep( addMicrostate(nextSnapshot, nextEvent, enabledTransitions); } + // Check invariants + for (const sn of nextSnapshot._nodes) { + if (sn.invariant) { + sn.invariant({ context: nextSnapshot.context }); + } + } + if (nextSnapshot.status !== 'active') { stopChildren(nextSnapshot, nextEvent, actorScope); } diff --git a/packages/core/src/types.ts b/packages/core/src/types.ts index 13fbbcf924..f26e8d7821 100644 --- a/packages/core/src/types.ts +++ b/packages/core/src/types.ts @@ -1027,6 +1027,8 @@ export interface StateNodeConfig< /** A default target for a history state */ target?: string; + + invariant?: ({ context }: { context: TContext }) => void; } export type AnyStateNodeConfig = StateNodeConfig< diff --git a/packages/core/test/invariant.test.ts b/packages/core/test/invariant.test.ts new file mode 100644 index 0000000000..52f9f4dcd2 --- /dev/null +++ b/packages/core/test/invariant.test.ts @@ -0,0 +1,221 @@ +import { assign, createActor, createMachine } from '../src'; + +describe('state invariants', () => { + it('throws an error and does not transition if the invariant throws', () => { + const machine = createMachine({ + initial: 'idle', + states: { + idle: { + on: { + loadUser: { + target: 'userLoaded' + } + } + }, + userLoaded: { + invariant: (x) => { + if (!x.context.user) { + throw new Error('User not loaded'); + } + } + } + } + }); + const spy = jest.fn(); + + const actor = createActor(machine); + actor.subscribe({ + error: spy + }); + actor.start(); + + actor.send({ type: 'loadUser' }); + + expect(spy).toHaveBeenCalledWith(new Error('User not loaded')); + + expect(actor.getSnapshot().value).toEqual('idle'); + }); + + it('transitions as normal if the invariant does not fail', () => { + const machine = createMachine({ + initial: 'idle', + states: { + idle: { + on: { + loadUser: { + target: 'userLoaded', + actions: assign({ user: () => ({ name: 'David' }) }) + } + } + }, + userLoaded: { + invariant: (x) => { + if (!x.context.user) { + throw new Error('User not loaded'); + } + } + } + } + }); + const spy = jest.fn(); + + const actor = createActor(machine); + actor.subscribe({ + error: spy + }); + actor.start(); + + actor.send({ type: 'loadUser' }); + + expect(spy).not.toHaveBeenCalled(); + + expect(actor.getSnapshot().value).toEqual('userLoaded'); + }); + + it('throws an error and does not transition if the invariant fails on a transition within the state', () => { + const machine = createMachine({ + initial: 'userLoaded', + states: { + userLoaded: { + initial: 'active', + states: { + active: { + on: { + deactivate: 'inactive' + } + }, + inactive: { + entry: assign({ user: null }) + } + }, + invariant: (x) => { + if (!x.context.user) { + throw new Error('User not loaded'); + } + }, + entry: assign({ user: { name: 'David' } }) + } + } + }); + const spy = jest.fn(); + + const actor = createActor(machine); + actor.subscribe({ + error: spy + }); + actor.start(); + + actor.send({ type: 'deactivate' }); + + expect(spy).toHaveBeenCalledWith(new Error('User not loaded')); + expect(actor.getSnapshot().value).toEqual({ userLoaded: 'active' }); + }); + + it('does not throw an error when exiting a state with an invariant if the exit action clears the context', () => { + const machine = createMachine({ + initial: 'userLoaded', + states: { + userLoaded: { + invariant: (x) => { + if (!x.context.user) { + throw new Error('User not loaded'); + } + }, + entry: assign({ user: { name: 'David' } }), + exit: assign({ user: null }), + on: { + logout: 'idle' + } + }, + idle: {} + } + }); + const spy = jest.fn(); + + const actor = createActor(machine); + actor.subscribe({ + error: spy + }); + actor.start(); + + actor.send({ type: 'logout' }); + + expect(spy).not.toHaveBeenCalled(); + expect(actor.getSnapshot().value).toEqual('idle'); + }); + + it('parallel regions check for state invariants', () => { + const spy = jest.fn(); + + const machine = createMachine({ + initial: 'p', + types: { + context: {} as { user: { name: string; age: number } | null } + }, + context: { + user: { + name: 'David', + age: 30 + } + }, + states: { + p: { + type: 'parallel', + states: { + a: { + invariant: (x) => { + if (!x.context.user) { + throw new Error('User not loaded'); + } + }, + on: { + updateAge: { + actions: assign({ + user: (x) => ({ ...x.context.user!, age: -3 }) + }) + } + } + }, + b: { + invariant: (x) => { + if (x.context.user!.age < 0) { + throw new Error('User age cannot be negative'); + } + }, + on: { + deleteUser: { + actions: assign({ + user: () => null + }) + } + } + } + } + } + } + }); + + const actor = createActor(machine); + + actor.subscribe({ + error: spy + }); + + actor.start(); + + expect(actor.getSnapshot().value).toEqual({ + p: { + a: {}, + b: {} + } + }); + + actor.send({ + type: 'updateAge' + }); + + expect(spy).toHaveBeenCalledWith(new Error('User age cannot be negative')); + + expect(actor.getSnapshot().status).toEqual('error'); + }); +});