Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support vllm in controller #635

Merged
merged 4 commits into from
Dec 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -126,8 +126,8 @@ GINKGO_FOCUS ?=
GINKGO_SKIP ?=
GINKGO_NODES ?= 2
GINKGO_NO_COLOR ?= false
GINKGO_TIMEOUT ?= 180m
GINKGO_ARGS ?= -focus="$(GINKGO_FOCUS)" -skip="$(GINKGO_SKIP)" -nodes=$(GINKGO_NODES) -no-color=$(GINKGO_NO_COLOR) -timeout=$(GINKGO_TIMEOUT) --fail-fast
GINKGO_TIMEOUT ?= 120m
GINKGO_ARGS ?= -focus="$(GINKGO_FOCUS)" -skip="$(GINKGO_SKIP)" -nodes=$(GINKGO_NODES) -no-color=$(GINKGO_NO_COLOR) --output-interceptor-mode=none -timeout=$(GINKGO_TIMEOUT) --fail-fast

$(E2E_TEST):
(cd test/e2e && go test -c . -o $(E2E_TEST))
Expand Down
31 changes: 31 additions & 0 deletions api/v1alpha1/labels.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,12 @@

package v1alpha1

import (
"github.com/kaito-project/kaito/pkg/featuregates"
"github.com/kaito-project/kaito/pkg/model"
"github.com/kaito-project/kaito/pkg/utils/consts"
)

const (

// Non-prefixed labels/annotations are reserved for end-use.
Expand Down Expand Up @@ -30,4 +36,29 @@ const (

// RAGEngineRevisionAnnotation is the Annotations for revision number
RAGEngineRevisionAnnotation = "ragengine.kaito.io/revision"

// AnnotationWorkspaceRuntime is the annotation for runtime selection.
AnnotationWorkspaceRuntime = KAITOPrefix + "runtime"
)

// GetWorkspaceRuntimeName returns the runtime name of the workspace.
func GetWorkspaceRuntimeName(ws *Workspace) model.RuntimeName {
if ws == nil {
panic("workspace is nil")
}

if !featuregates.FeatureGates[consts.FeatureFlagVLLM] {
return model.RuntimeNameHuggingfaceTransformers
}

runtime := model.RuntimeNameVLLM
name := ws.Annotations[AnnotationWorkspaceRuntime]
switch name {
case string(model.RuntimeNameHuggingfaceTransformers):
runtime = model.RuntimeNameHuggingfaceTransformers
case string(model.RuntimeNameVLLM):
runtime = model.RuntimeNameVLLM
}

return runtime
}
4 changes: 2 additions & 2 deletions api/v1alpha1/workspace_validation.go
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ func (r *TuningSpec) validateCreate(ctx context.Context, workspaceNamespace stri
// Currently require a preset to specified, in future we can consider defining a template
if r.Preset == nil {
errs = errs.Also(apis.ErrMissingField("Preset"))
} else if presetName := string(r.Preset.Name); !utils.IsValidPreset(presetName) {
} else if presetName := string(r.Preset.Name); !plugin.IsValidPreset(presetName) {
errs = errs.Also(apis.ErrInvalidValue(fmt.Sprintf("Unsupported tuning preset name %s", presetName), "presetName"))
}
return errs
Expand Down Expand Up @@ -404,7 +404,7 @@ func (i *InferenceSpec) validateCreate() (errs *apis.FieldError) {
if i.Preset != nil {
presetName := string(i.Preset.Name)
// Validate preset name
if !utils.IsValidPreset(presetName) {
if !plugin.IsValidPreset(presetName) {
errs = errs.Also(apis.ErrInvalidValue(fmt.Sprintf("Unsupported inference preset name %s", presetName), "presetName"))
}
// Validate private preset has private image specified
Expand Down
1 change: 1 addition & 0 deletions pkg/featuregates/featuregates.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ var (
// FeatureGates is a map that holds the feature gates and their default values for Kaito.
FeatureGates = map[string]bool{
consts.FeatureFlagKarpenter: false,
consts.FeatureFlagVLLM: true,
// Add more feature gates here
}
)
Expand Down
145 changes: 132 additions & 13 deletions pkg/model/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ package model

import (
"time"

"github.com/kaito-project/kaito/pkg/utils"
)

type Model interface {
Expand All @@ -13,23 +15,140 @@ type Model interface {
SupportTuning() bool
}

// RuntimeName is LLM runtime name.
type RuntimeName string

const (
RuntimeNameHuggingfaceTransformers RuntimeName = "transformers"
RuntimeNameVLLM RuntimeName = "vllm"
)

// PresetParam defines the preset inference parameters for a model.
type PresetParam struct {
ModelFamilyName string // The name of the model family.
ImageAccessMode string // Defines where the Image is Public or Private.
DiskStorageRequirement string // Disk storage requirements for the model.
GPUCountRequirement string // Number of GPUs required for the Preset. Used for inference.
TotalGPUMemoryRequirement string // Total GPU memory required for the Preset. Used for inference.
PerGPUMemoryRequirement string // GPU memory required per GPU. Used for inference.
TuningPerGPUMemoryRequirement map[string]int // Min GPU memory per tuning method (batch size 1). Used for tuning.
TorchRunParams map[string]string // Parameters for configuring the torchrun command.
TorchRunRdzvParams map[string]string // Optional rendezvous parameters for distributed training/inference using torchrun (elastic).
BaseCommand string // The initial command (e.g., 'torchrun', 'accelerate launch') used in the command line.
ModelRunParams map[string]string // Parameters for running the model training/inference.
Tag string // The model image tag
zhuangqh marked this conversation as resolved.
Show resolved Hide resolved
ModelFamilyName string // The name of the model family.
ImageAccessMode string // Defines where the Image is Public or Private.

DiskStorageRequirement string // Disk storage requirements for the model.
GPUCountRequirement string // Number of GPUs required for the Preset. Used for inference.
TotalGPUMemoryRequirement string // Total GPU memory required for the Preset. Used for inference.
PerGPUMemoryRequirement string // GPU memory required per GPU. Used for inference.
TuningPerGPUMemoryRequirement map[string]int // Min GPU memory per tuning method (batch size 1). Used for tuning.
WorldSize int // Defines the number of processes required for distributed inference.

RuntimeParam

// ReadinessTimeout defines the maximum duration for creating the workload.
// This timeout accommodates the size of the image, ensuring pull completion
// even under slower network conditions or unforeseen delays.
ReadinessTimeout time.Duration
WorldSize int // Defines the number of processes required for distributed inference.
Tag string // The model image tag
}

// RuntimeParam defines the llm runtime parameters.
type RuntimeParam struct {
Transformers HuggingfaceTransformersParam
VLLM VLLMParam
}

type HuggingfaceTransformersParam struct {
BaseCommand string // The initial command (e.g., 'torchrun', 'accelerate launch') used in the command line.
TorchRunParams map[string]string // Parameters for configuring the torchrun command.
TorchRunRdzvParams map[string]string // Optional rendezvous parameters for distributed training/inference using torchrun (elastic).
InferenceMainFile string // The main file for inference.
ModelRunParams map[string]string // Parameters for running the model training/inference.
}

type VLLMParam struct {
BaseCommand string
// The model name used in the openai serving API.
// see https://platform.openai.com/docs/api-reference/chat/create#chat-create-model.
ModelName string
// Parameters for distributed inference.
DistributionParams map[string]string
// Parameters for running the model training/inference.
ModelRunParams map[string]string
}

func (p *PresetParam) DeepCopy() *PresetParam {
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we may update some params according to the node counts. Thus, we must deepcopy it at first.

if p == nil {
return nil
}
out := new(PresetParam)
*out = *p
out.RuntimeParam = p.RuntimeParam.DeepCopy()
out.TuningPerGPUMemoryRequirement = make(map[string]int, len(p.TuningPerGPUMemoryRequirement))
for k, v := range p.TuningPerGPUMemoryRequirement {
out.TuningPerGPUMemoryRequirement[k] = v
}
return out
}

func (rp *RuntimeParam) DeepCopy() RuntimeParam {
if rp == nil {
return RuntimeParam{}
}
out := RuntimeParam{}
out.Transformers = rp.Transformers.DeepCopy()
out.VLLM = rp.VLLM.DeepCopy()
return out
}

func (h *HuggingfaceTransformersParam) DeepCopy() HuggingfaceTransformersParam {
if h == nil {
return HuggingfaceTransformersParam{}
}
out := HuggingfaceTransformersParam{}
out.BaseCommand = h.BaseCommand
out.InferenceMainFile = h.InferenceMainFile
out.TorchRunParams = make(map[string]string, len(h.TorchRunParams))
for k, v := range h.TorchRunParams {
out.TorchRunParams[k] = v
}
out.TorchRunRdzvParams = make(map[string]string, len(h.TorchRunRdzvParams))
for k, v := range h.TorchRunRdzvParams {
out.TorchRunRdzvParams[k] = v
}
out.ModelRunParams = make(map[string]string, len(h.ModelRunParams))
for k, v := range h.ModelRunParams {
out.ModelRunParams[k] = v
}
return out
}

func (v *VLLMParam) DeepCopy() VLLMParam {
if v == nil {
return VLLMParam{}
}
out := VLLMParam{}
out.BaseCommand = v.BaseCommand
out.ModelName = v.ModelName
out.DistributionParams = make(map[string]string, len(v.DistributionParams))
for k, v := range v.DistributionParams {
out.DistributionParams[k] = v
}
out.ModelRunParams = make(map[string]string, len(v.ModelRunParams))
for k, v := range v.ModelRunParams {
out.ModelRunParams[k] = v
}
return out
}

// builds the container command:
// eg. torchrun <TORCH_PARAMS> <OPTIONAL_RDZV_PARAMS> baseCommand <MODEL_PARAMS>
func (p *PresetParam) GetInferenceCommand(runtime RuntimeName, skuNumGPUs string) []string {
switch runtime {
case RuntimeNameHuggingfaceTransformers:
torchCommand := utils.BuildCmdStr(p.Transformers.BaseCommand, p.Transformers.TorchRunParams, p.Transformers.TorchRunRdzvParams)
modelCommand := utils.BuildCmdStr(p.Transformers.InferenceMainFile, p.Transformers.ModelRunParams)
return utils.ShellCmd(torchCommand + " " + modelCommand)
case RuntimeNameVLLM:
if p.VLLM.ModelName != "" {
p.VLLM.ModelRunParams["served-model-name"] = p.VLLM.ModelName
}
p.VLLM.ModelRunParams["tensor-parallel-size"] = skuNumGPUs
modelCommand := utils.BuildCmdStr(p.VLLM.BaseCommand, p.VLLM.ModelRunParams)
return utils.ShellCmd(modelCommand)
default:
return nil
}
}
5 changes: 0 additions & 5 deletions pkg/utils/common-preset.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
package utils

import (
"github.com/kaito-project/kaito/pkg/utils/plugin"
corev1 "k8s.io/api/core/v1"
)

Expand Down Expand Up @@ -150,7 +149,3 @@ func ConfigAdapterVolume() (corev1.Volume, corev1.VolumeMount) {
}
return volume, volumeMount
}

func IsValidPreset(preset string) bool {
return plugin.KaitoModelRegister.Has(preset)
}
14 changes: 8 additions & 6 deletions pkg/utils/common.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,13 +68,15 @@ func MergeConfigMaps(baseMap, overrideMap map[string]string) map[string]string {
return merged
}

func BuildCmdStr(baseCommand string, runParams map[string]string) string {
func BuildCmdStr(baseCommand string, runParams ...map[string]string) string {
updatedBaseCommand := baseCommand
for key, value := range runParams {
if value == "" {
updatedBaseCommand = fmt.Sprintf("%s --%s", updatedBaseCommand, key)
} else {
updatedBaseCommand = fmt.Sprintf("%s --%s=%s", updatedBaseCommand, key, value)
for _, runParam := range runParams {
for key, value := range runParam {
if value == "" {
updatedBaseCommand = fmt.Sprintf("%s --%s", updatedBaseCommand, key)
} else {
updatedBaseCommand = fmt.Sprintf("%s --%s=%s", updatedBaseCommand, key, value)
}
}
}

Expand Down
2 changes: 1 addition & 1 deletion pkg/utils/common_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package utils

import (
"context"
"sigs.k8s.io/controller-runtime/pkg/client"
"testing"

"github.com/stretchr/testify/assert"
Expand All @@ -12,6 +11,7 @@ import (
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/runtime"
"k8s.io/client-go/kubernetes/scheme"
"sigs.k8s.io/controller-runtime/pkg/client"
"sigs.k8s.io/controller-runtime/pkg/client/fake"
)

Expand Down
5 changes: 4 additions & 1 deletion pkg/utils/consts/consts.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ const (
// RAGEngineFinalizer is used to make sure that ragengine controller handles garbage collection.
RAGEngineFinalizer = "ragengine.finalizer.kaito.sh"
DefaultReleaseNamespaceEnvVar = "RELEASE_NAMESPACE"
FeatureFlagKarpenter = "Karpenter"
AzureCloudName = "azure"
AWSCloudName = "aws"
GPUString = "gpu"
Expand All @@ -20,6 +19,10 @@ const (
GiBToBytes = 1024 * 1024 * 1024 // Conversion factor from GiB to bytes
NvidiaGPU = "nvidia.com/gpu"

// Feature flags
FeatureFlagKarpenter = "Karpenter"
FeatureFlagVLLM = "vLLM"

// Nodeclaim related consts
KaitoNodePoolName = "kaito"
LabelNodePool = "karpenter.sh/nodepool"
Expand Down
4 changes: 4 additions & 0 deletions pkg/utils/plugin/plugin.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,3 +60,7 @@ func (reg *ModelRegister) Has(name string) bool {
_, ok := reg.models[name]
return ok
}

func IsValidPreset(preset string) bool {
return KaitoModelRegister.Has(preset)
}
24 changes: 20 additions & 4 deletions pkg/utils/test/testModel.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,16 @@ type testModel struct{}
func (*testModel) GetInferenceParameters() *model.PresetParam {
return &model.PresetParam{
GPUCountRequirement: "1",
ReadinessTimeout: time.Duration(30) * time.Minute,
BaseCommand: "python3",
RuntimeParam: model.RuntimeParam{
VLLM: model.VLLMParam{
BaseCommand: "python3 /workspace/vllm/inference_api.py",
},
Transformers: model.HuggingfaceTransformersParam{
BaseCommand: "accelerate launch",
InferenceMainFile: "/workspace/tfs/inference_api.py",
},
},
ReadinessTimeout: time.Duration(30) * time.Minute,
}
}
func (*testModel) GetTuningParameters() *model.PresetParam {
Expand All @@ -37,8 +45,16 @@ type testDistributedModel struct{}
func (*testDistributedModel) GetInferenceParameters() *model.PresetParam {
return &model.PresetParam{
GPUCountRequirement: "1",
ReadinessTimeout: time.Duration(30) * time.Minute,
BaseCommand: "python3",
RuntimeParam: model.RuntimeParam{
VLLM: model.VLLMParam{
BaseCommand: "python3 /workspace/vllm/inference_api.py",
},
Transformers: model.HuggingfaceTransformersParam{
BaseCommand: "accelerate launch",
InferenceMainFile: "/workspace/tfs/inference_api.py",
},
},
ReadinessTimeout: time.Duration(30) * time.Minute,
}
}
func (*testDistributedModel) GetTuningParameters() *model.PresetParam {
Expand Down
Loading
Loading