Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor(ui): persisted workflow collection generators #7559

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
118 changes: 118 additions & 0 deletions invokeai/app/invocations/batch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
from typing import Literal

from invokeai.app.invocations.baseinvocation import (
BaseInvocation,
Classification,
invocation,
)
from invokeai.app.invocations.fields import (
ImageField,
Input,
InputField,
)
from invokeai.app.invocations.primitives import FloatOutput, ImageOutput, IntegerOutput, StringOutput
from invokeai.app.services.shared.invocation_context import InvocationContext

BATCH_GROUP_IDS = Literal[
"None",
"Group 1",
"Group 2",
"Group 3",
"Group 4",
"Group 5",
]


class NotExecutableNodeError(Exception):
def __init__(self, message: str = "This class should never be executed or instantiated directly."):
super().__init__(message)

pass


class BaseBatchInvocation(BaseInvocation):
batch_group_id: BATCH_GROUP_IDS = InputField(
default="None",
description="The ID of this batch node's group. If provided, all batch nodes in with the same ID will be 'zipped' before execution, and all nodes' collections must be of the same size.",
input=Input.Direct,
title="Batch Group",
)

def __init__(self):
raise NotExecutableNodeError()


@invocation(
"image_batch",
title="Image Batch",
tags=["primitives", "image", "batch", "special"],
category="primitives",
version="1.0.0",
classification=Classification.Special,
)
class ImageBatchInvocation(BaseBatchInvocation):
"""Create a batched generation, where the workflow is executed once for each image in the batch."""

images: list[ImageField] = InputField(
default=[], min_length=1, description="The images to batch over", input=Input.Direct
)

def invoke(self, context: InvocationContext) -> ImageOutput:
raise NotExecutableNodeError()


@invocation(
"string_batch",
title="String Batch",
tags=["primitives", "string", "batch", "special"],
category="primitives",
version="1.0.0",
classification=Classification.Special,
)
class StringBatchInvocation(BaseBatchInvocation):
"""Create a batched generation, where the workflow is executed once for each string in the batch."""

strings: list[str] = InputField(
default=[], min_length=1, description="The strings to batch over", input=Input.Direct
)

def invoke(self, context: InvocationContext) -> StringOutput:
raise NotExecutableNodeError()


@invocation(
"integer_batch",
title="Integer Batch",
tags=["primitives", "integer", "number", "batch", "special"],
category="primitives",
version="1.0.0",
classification=Classification.Special,
)
class IntegerBatchInvocation(BaseBatchInvocation):
"""Create a batched generation, where the workflow is executed once for each integer in the batch."""

integers: list[int] = InputField(
default=[], min_length=1, description="The integers to batch over", input=Input.Direct
)

def invoke(self, context: InvocationContext) -> IntegerOutput:
raise NotExecutableNodeError()


@invocation(
"float_batch",
title="Float Batch",
tags=["primitives", "float", "number", "batch", "special"],
category="primitives",
version="1.0.0",
classification=Classification.Special,
)
class FloatBatchInvocation(BaseBatchInvocation):
"""Create a batched generation, where the workflow is executed once for each float in the batch."""

floats: list[float] = InputField(
default=[], min_length=1, description="The floats to batch over", input=Input.Direct
)

def invoke(self, context: InvocationContext) -> FloatOutput:
raise NotExecutableNodeError()
100 changes: 1 addition & 99 deletions invokeai/app/invocations/primitives.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)

from typing import Literal, Optional
from typing import Optional

import torch

from invokeai.app.invocations.baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
Classification,
invocation,
invocation_output,
)
Expand Down Expand Up @@ -539,100 +538,3 @@ def invoke(self, context: InvocationContext) -> BoundingBoxOutput:


# endregion

BATCH_GROUP_IDS = Literal[
"None",
"Group 1",
"Group 2",
"Group 3",
"Group 4",
"Group 5",
]


class BaseBatchInvocation(BaseInvocation):
batch_group_id: BATCH_GROUP_IDS = InputField(
default="None",
description="The ID of this batch node's group. If provided, all batch nodes in with the same ID will be 'zipped' before execution, and all nodes' collections must be of the same size.",
input=Input.Direct,
title="Batch Group",
)

def __init__(self):
raise NotImplementedError("This class should never be executed or instantiated directly.")


@invocation(
"image_batch",
title="Image Batch",
tags=["primitives", "image", "batch", "special"],
category="primitives",
version="1.0.0",
classification=Classification.Special,
)
class ImageBatchInvocation(BaseBatchInvocation):
"""Create a batched generation, where the workflow is executed once for each image in the batch."""

images: list[ImageField] = InputField(
default=[], min_length=1, description="The images to batch over", input=Input.Direct
)

def invoke(self, context: InvocationContext) -> ImageOutput:
raise NotImplementedError("This class should never be executed or instantiated directly.")


@invocation(
"string_batch",
title="String Batch",
tags=["primitives", "string", "batch", "special"],
category="primitives",
version="1.0.0",
classification=Classification.Special,
)
class StringBatchInvocation(BaseBatchInvocation):
"""Create a batched generation, where the workflow is executed once for each string in the batch."""

strings: list[str] = InputField(
default=[], min_length=1, description="The strings to batch over", input=Input.Direct
)

def invoke(self, context: InvocationContext) -> StringOutput:
raise NotImplementedError("This class should never be executed or instantiated directly.")


@invocation(
"integer_batch",
title="Integer Batch",
tags=["primitives", "integer", "number", "batch", "special"],
category="primitives",
version="1.0.0",
classification=Classification.Special,
)
class IntegerBatchInvocation(BaseBatchInvocation):
"""Create a batched generation, where the workflow is executed once for each integer in the batch."""

integers: list[int] = InputField(
default=[], min_length=1, description="The integers to batch over", input=Input.Direct
)

def invoke(self, context: InvocationContext) -> IntegerOutput:
raise NotImplementedError("This class should never be executed or instantiated directly.")


@invocation(
"float_batch",
title="Float Batch",
tags=["primitives", "float", "number", "batch", "special"],
category="primitives",
version="1.0.0",
classification=Classification.Special,
)
class FloatBatchInvocation(BaseBatchInvocation):
"""Create a batched generation, where the workflow is executed once for each float in the batch."""

floats: list[float] = InputField(
default=[], min_length=1, description="The floats to batch over", input=Input.Direct
)

def invoke(self, context: InvocationContext) -> FloatOutput:
raise NotImplementedError("This class should never be executed or instantiated directly.")
12 changes: 7 additions & 5 deletions invokeai/frontend/web/public/locales/en.json
Original file line number Diff line number Diff line change
Expand Up @@ -855,7 +855,13 @@
},
"nodes": {
"noBatchGroup": "no group",
"generator": "Generator",
"generatedValues": "Generated Values",
"commitValues": "Commit Values",
"addValue": "Add Value",
"addNode": "Add Node",
"lockLinearView": "Lock Linear View",
"unlockLinearView": "Unlock Linear View",
"addNodeToolTip": "Add Node (Shift+A, Space)",
"addLinearView": "Add to Linear View",
"animatedEdges": "Animated Edges",
Expand Down Expand Up @@ -994,11 +1000,7 @@
"imageAccessError": "Unable to find image {{image_name}}, resetting to default",
"boardAccessError": "Unable to find board {{board_id}}, resetting to default",
"modelAccessError": "Unable to find model {{key}}, resetting to default",
"saveToGallery": "Save To Gallery",
"addItem": "Add Item",
"generateValues": "Generate Values",
"floatRangeGenerator": "Float Range Generator",
"integerRangeGenerator": "Integer Range Generator"
"saveToGallery": "Save To Gallery"
},
"parameters": {
"aspect": "Aspect",
Expand Down
4 changes: 0 additions & 4 deletions invokeai/frontend/web/src/app/components/App.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,6 @@ import { DynamicPromptsModal } from 'features/dynamicPrompts/components/DynamicP
import DeleteBoardModal from 'features/gallery/components/Boards/DeleteBoardModal';
import { ImageContextMenu } from 'features/gallery/components/ImageContextMenu/ImageContextMenu';
import { useStarterModelsToast } from 'features/modelManagerV2/hooks/useStarterModelsToast';
import { FloatRangeGeneratorModal } from 'features/nodes/components/FloatRangeGeneratorModal';
import { IntegerRangeGeneratorModal } from 'features/nodes/components/IntegerRangeGeneratorModal';
import { ShareWorkflowModal } from 'features/nodes/components/sidePanel/WorkflowListMenu/ShareWorkflowModal';
import { ClearQueueConfirmationsAlertDialog } from 'features/queue/components/ClearQueueConfirmationAlertDialog';
import { DeleteStylePresetDialog } from 'features/stylePresets/components/DeleteStylePresetDialog';
Expand Down Expand Up @@ -112,8 +110,6 @@ const App = ({ config = DEFAULT_CONFIG, studioInitAction }: Props) => {
<ImageContextMenu />
<FullscreenDropzone />
<VideosModal />
<FloatRangeGeneratorModal />
<IntegerRangeGeneratorModal />
</ErrorBoundary>
);
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import {
isIntegerFieldCollectionInputInstance,
isStringFieldCollectionInputInstance,
} from 'features/nodes/types/field';
import { resolveNumberFieldCollectionValue } from 'features/nodes/types/fieldValidators';
import type { InvocationNodeEdge } from 'features/nodes/types/invocation';
import { isBatchNode, isInvocationNode } from 'features/nodes/types/invocation';
import { buildNodesGraph } from 'features/nodes/util/graph/buildNodesGraph';
Expand Down Expand Up @@ -140,10 +141,11 @@ export const addEnqueueRequestedNodes = (startAppListening: AppStartListening) =

// Find outgoing edges from the batch node, we will remove these from the graph and create batch data collection items from them instead
const edgesFromStringBatch = nodes.edges.filter((e) => e.source === node.id && e.sourceHandle === 'value');
const resolvedValue = resolveNumberFieldCollectionValue(integers);
if (batchGroupId !== 'None') {
addZippedBatchDataCollectionItem(edgesFromStringBatch, integers.value);
addZippedBatchDataCollectionItem(edgesFromStringBatch, resolvedValue);
} else {
addProductBatchDataCollectionItem(edgesFromStringBatch, integers.value);
addProductBatchDataCollectionItem(edgesFromStringBatch, resolvedValue);
}
}

Expand All @@ -163,10 +165,11 @@ export const addEnqueueRequestedNodes = (startAppListening: AppStartListening) =

// Find outgoing edges from the batch node, we will remove these from the graph and create batch data collection items from them instead
const edgesFromStringBatch = nodes.edges.filter((e) => e.source === node.id && e.sourceHandle === 'value');
const resolvedValue = resolveNumberFieldCollectionValue(floats);
if (batchGroupId !== 'None') {
addZippedBatchDataCollectionItem(edgesFromStringBatch, floats.value);
addZippedBatchDataCollectionItem(edgesFromStringBatch, resolvedValue);
} else {
addProductBatchDataCollectionItem(edgesFromStringBatch, floats.value);
addProductBatchDataCollectionItem(edgesFromStringBatch, resolvedValue);
}
}

Expand Down
6 changes: 4 additions & 2 deletions invokeai/frontend/web/src/app/store/store.ts
Original file line number Diff line number Diff line change
Expand Up @@ -166,8 +166,10 @@ export const createStore = (uniqueStoreKey?: string, persist = true) =>
reducer: rememberedRootReducer,
middleware: (getDefaultMiddleware) =>
getDefaultMiddleware({
serializableCheck: import.meta.env.MODE === 'development',
immutableCheck: import.meta.env.MODE === 'development',
serializableCheck: false,
immutableCheck: false,
// serializableCheck: import.meta.env.MODE === 'development',
// immutableCheck: import.meta.env.MODE === 'development',
})
.concat(api.middleware)
.concat(dynamicMiddlewares)
Expand Down
Loading
Loading