Skip to content

Commit

Permalink
feat(ui): seeded random generators
Browse files Browse the repository at this point in the history
- Add JS Mersenne Twister implementation dependency to use as seeded PRNG. This is not a cryptographically secure algorithm.
- Add nullish seed field to float and integer random generators.
- Add UI to control the seed.
- When seed is not set, behaviour is unchanged - the values are randomized when you Invoke. When seed is set, the random distribution is deterministic depending on the seed. In this case, we can display the values to the user.
  • Loading branch information
psychedelicious committed Jan 17, 2025
1 parent c24eae1 commit bb46567
Show file tree
Hide file tree
Showing 8 changed files with 119 additions and 38 deletions.
1 change: 1 addition & 0 deletions invokeai/frontend/web/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@
"konva": "^9.3.15",
"lodash-es": "^4.17.21",
"lru-cache": "^11.0.1",
"mtwist": "^1.0.2",
"nanoid": "^5.0.7",
"nanostores": "^0.11.3",
"new-github-issue-url": "^1.0.0",
Expand Down
7 changes: 7 additions & 0 deletions invokeai/frontend/web/pnpm-lock.yaml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion invokeai/frontend/web/public/locales/en.json
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,8 @@
"min": "Min",
"max": "Max",
"values": "Values",
"resetToDefaults": "Reset to Defaults"
"resetToDefaults": "Reset to Defaults",
"seed": "Seed"
},
"hrf": {
"hrf": "High Resolution Fix",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ import {
getFloatGeneratorDefaults,
resolveFloatGeneratorField,
} from 'features/nodes/types/field';
import { round } from 'lodash-es';
import { isNil, round } from 'lodash-es';
import { OverlayScrollbarsComponent } from 'overlayscrollbars-react';
import type { ChangeEvent } from 'react';
import { memo, useCallback, useMemo } from 'react';
Expand Down Expand Up @@ -63,7 +63,10 @@ export const FloatGeneratorFieldInputComponent = memo(

const [debouncedField] = useDebounce(field, 300);
const resolvedValuesAsString = useMemo(() => {
if (debouncedField.value.type === FloatGeneratorUniformRandomDistributionType) {
if (
debouncedField.value.type === FloatGeneratorUniformRandomDistributionType &&
isNil(debouncedField.value.seed)
) {
const { count } = debouncedField.value;
return `<${t('nodes.generatorNRandomValues', { count })}>`;
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import { CompositeNumberInput, Flex, FormControl, FormLabel } from '@invoke-ai/ui-library';
import { Checkbox, CompositeNumberInput, Flex, FormControl, FormLabel } from '@invoke-ai/ui-library';
import type { FloatGeneratorUniformRandomDistribution } from 'features/nodes/types/field';
import { isNil } from 'lodash-es';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';

Expand Down Expand Up @@ -29,21 +30,47 @@ export const FloatGeneratorUniformRandomDistributionSettings = memo(
},
[onChange, state]
);
const onToggleSeed = useCallback(() => {
onChange({ ...state, seed: isNil(state.seed) ? 0 : null });
}, [onChange, state]);
const onChangeSeed = useCallback(
(seed?: number | null) => {
onChange({ ...state, seed });
},
[onChange, state]
);

return (
<Flex gap={2} alignItems="flex-end">
<FormControl orientation="vertical">
<FormLabel>{t('common.min')}</FormLabel>
<CompositeNumberInput value={state.min} onChange={onChangeMin} min={-Infinity} max={Infinity} step={0.01} />
</FormControl>
<FormControl orientation="vertical">
<FormLabel>{t('common.max')}</FormLabel>
<CompositeNumberInput value={state.max} onChange={onChangeMax} min={-Infinity} max={Infinity} step={0.01} />
</FormControl>
<FormControl orientation="vertical">
<FormLabel>{t('common.count')}</FormLabel>
<CompositeNumberInput value={state.count} onChange={onChangeCount} min={1} max={Infinity} />
</FormControl>
<Flex gap={2} flexDir="column">
<Flex gap={2} alignItems="flex-end">
<FormControl orientation="vertical">
<FormLabel>{t('common.min')}</FormLabel>
<CompositeNumberInput value={state.min} onChange={onChangeMin} min={-Infinity} max={Infinity} step={0.01} />
</FormControl>
<FormControl orientation="vertical">
<FormLabel>{t('common.max')}</FormLabel>
<CompositeNumberInput value={state.max} onChange={onChangeMax} min={-Infinity} max={Infinity} step={0.01} />
</FormControl>
<FormControl orientation="vertical">
<FormLabel>{t('common.count')}</FormLabel>
<CompositeNumberInput value={state.count} onChange={onChangeCount} min={1} max={Infinity} />
</FormControl>
<FormControl orientation="vertical">
<FormLabel alignItems="center" justifyContent="space-between" m={0} display="flex" w="full">
{t('common.seed')}
<Checkbox onChange={onToggleSeed} isChecked={!isNil(state.seed)} />
</FormLabel>
<CompositeNumberInput
isDisabled={isNil(state.seed)}
// This cast is save only because we disable the element when seed is not a number - the `...` is
// rendered in the input field in this case
value={state.seed ?? ('...' as unknown as number)}
onChange={onChangeSeed}
min={-Infinity}
max={Infinity}
/>
</FormControl>
</Flex>
</Flex>
);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ import {
IntegerGeneratorUniformRandomDistributionType,
resolveIntegerGeneratorField,
} from 'features/nodes/types/field';
import { round } from 'lodash-es';
import { isNil, round } from 'lodash-es';
import { OverlayScrollbarsComponent } from 'overlayscrollbars-react';
import type { ChangeEvent } from 'react';
import { memo, useCallback, useMemo } from 'react';
Expand Down Expand Up @@ -65,7 +65,10 @@ export const IntegerGeneratorFieldInputComponent = memo(

const [debouncedField] = useDebounce(field, 300);
const resolvedValuesAsString = useMemo(() => {
if (debouncedField.value.type === IntegerGeneratorUniformRandomDistributionType) {
if (
debouncedField.value.type === IntegerGeneratorUniformRandomDistributionType &&
isNil(debouncedField.value.seed)
) {
const { count } = debouncedField.value;
return `<${t('nodes.generatorNRandomValues', { count })}>`;
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import { CompositeNumberInput, Flex, FormControl, FormLabel } from '@invoke-ai/ui-library';
import { Checkbox, CompositeNumberInput, Flex, FormControl, FormLabel } from '@invoke-ai/ui-library';
import type { IntegerGeneratorUniformRandomDistribution } from 'features/nodes/types/field';
import { isNil } from 'lodash-es';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';

Expand Down Expand Up @@ -29,21 +30,47 @@ export const IntegerGeneratorUniformRandomDistributionSettings = memo(
},
[onChange, state]
);
const onToggleSeed = useCallback(() => {
onChange({ ...state, seed: isNil(state.seed) ? 0 : null });
}, [onChange, state]);
const onChangeSeed = useCallback(
(seed?: number | null) => {
onChange({ ...state, seed });
},
[onChange, state]
);

return (
<Flex gap={2} alignItems="flex-end">
<FormControl orientation="vertical">
<FormLabel>{t('common.min')}</FormLabel>
<CompositeNumberInput value={state.min} onChange={onChangeMin} min={-Infinity} max={Infinity} />
</FormControl>
<FormControl orientation="vertical">
<FormLabel>{t('common.max')}</FormLabel>
<CompositeNumberInput value={state.max} onChange={onChangeMax} min={-Infinity} max={Infinity} />
</FormControl>
<FormControl orientation="vertical">
<FormLabel>{t('common.count')}</FormLabel>
<CompositeNumberInput value={state.count} onChange={onChangeCount} min={1} max={Infinity} />
</FormControl>
<Flex gap={2} flexDir="column">
<Flex gap={2} alignItems="flex-end">
<FormControl orientation="vertical">
<FormLabel>{t('common.min')}</FormLabel>
<CompositeNumberInput value={state.min} onChange={onChangeMin} min={-Infinity} max={Infinity} />
</FormControl>
<FormControl orientation="vertical">
<FormLabel>{t('common.max')}</FormLabel>
<CompositeNumberInput value={state.max} onChange={onChangeMax} min={-Infinity} max={Infinity} />
</FormControl>
<FormControl orientation="vertical">
<FormLabel>{t('common.count')}</FormLabel>
<CompositeNumberInput value={state.count} onChange={onChangeCount} min={1} max={Infinity} />
</FormControl>
<FormControl orientation="vertical">
<FormLabel alignItems="center" justifyContent="space-between" m={0} display="flex" w="full">
{t('common.seed')}
<Checkbox onChange={onToggleSeed} isChecked={!isNil(state.seed)} />
</FormLabel>
<CompositeNumberInput
isDisabled={isNil(state.seed)}
// This cast is save only because we disable the element when seed is not a number - the `...` is
// rendered in the input field in this case
value={state.seed ?? ('...' as unknown as number)}
onChange={onChangeSeed}
min={-Infinity}
max={Infinity}
/>
</FormControl>
</Flex>
</Flex>
);
}
Expand Down
22 changes: 17 additions & 5 deletions invokeai/frontend/web/src/features/nodes/types/field.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import { buildTypeGuard } from 'features/parameters/types/parameterSchemas';
import { trim } from 'lodash-es';
import { isNil, trim } from 'lodash-es';
import MersenneTwister from 'mtwist';
import { assert } from 'tsafe';
import { z } from 'zod';

Expand Down Expand Up @@ -1057,13 +1058,22 @@ const zFloatGeneratorUniformRandomDistribution = z.object({
min: z.number().default(0),
max: z.number().default(1),
count: z.number().int().default(10),
seed: z.number().int().nullish(),
values: z.array(z.number()).nullish(),
});
export type FloatGeneratorUniformRandomDistribution = z.infer<typeof zFloatGeneratorUniformRandomDistribution>;
const getFloatGeneratorUniformRandomDistributionDefaults = () => zFloatGeneratorUniformRandomDistribution.parse({});
const getRng = (seed?: number | null) => {
if (isNil(seed)) {
return () => Math.random();
}
const m = new MersenneTwister(seed);
return () => m.random();
};
const getFloatGeneratorUniformRandomDistributionValues = (generator: FloatGeneratorUniformRandomDistribution) => {
const { min, max, count } = generator;
const values = Array.from({ length: count }, () => Math.random() * (max - min) + min);
const { min, max, count, seed } = generator;
const rng = getRng(seed);
const values = Array.from({ length: count }, (_) => rng() * (max - min) + min);
return values;
};

Expand Down Expand Up @@ -1191,13 +1201,15 @@ const zIntegerGeneratorUniformRandomDistribution = z.object({
min: z.number().int().default(0),
max: z.number().int().default(10),
count: z.number().int().default(10),
seed: z.number().int().nullish(),
values: z.array(z.number().int()).nullish(),
});
export type IntegerGeneratorUniformRandomDistribution = z.infer<typeof zIntegerGeneratorUniformRandomDistribution>;
const getIntegerGeneratorUniformRandomDistributionDefaults = () => zIntegerGeneratorUniformRandomDistribution.parse({});
const getIntegerGeneratorUniformRandomDistributionValues = (generator: IntegerGeneratorUniformRandomDistribution) => {
const { min, max, count } = generator;
const values = Array.from({ length: count }, () => Math.floor(Math.random() * (max - min + 1)) + min);
const { min, max, count, seed } = generator;
const rng = getRng(seed);
const values = Array.from({ length: count }, () => Math.floor(rng() * (max - min + 1)) + min);
return values;
};

Expand Down

0 comments on commit bb46567

Please sign in to comment.