Skip to content

Commit

Permalink
feat: support vllm in controller
Browse files Browse the repository at this point in the history
- set vllm as the default runtime

Signed-off-by: jerryzhuang <[email protected]>
  • Loading branch information
zhuangqh committed Nov 14, 2024
1 parent f3ef4c8 commit 8244ebc
Show file tree
Hide file tree
Showing 22 changed files with 675 additions and 249 deletions.
30 changes: 30 additions & 0 deletions api/v1alpha1/labels.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,12 @@

package v1alpha1

import (
"github.com/kaito-project/kaito/pkg/featuregates"
"github.com/kaito-project/kaito/pkg/model"
"github.com/kaito-project/kaito/pkg/utils/consts"
)

const (

// Non-prefixed labels/annotations are reserved for end-use.
Expand Down Expand Up @@ -30,4 +36,28 @@ const (

// RAGEngineRevisionAnnotation is the Annotations for revision number
RAGEngineRevisionAnnotation = "ragengine.kaito.io/revision"

// AnnotationWorkspaceRuntime is the annotation for runtime selection.
AnnotationWorkspaceRuntime = KAITOPrefix + "runtime"
)

// GetWorkspaceRuntimeName returns the runtime name of the workspace.
func GetWorkspaceRuntimeName(ws *Workspace) model.RuntimeName {
if ws == nil {
panic("workspace is nil")
}
runtime := model.RuntimeNameHuggingfaceTransformers
if featuregates.FeatureGates[consts.FeatureFlagVLLM] {
runtime = model.RuntimeNameVLLM
}

name := ws.Annotations[AnnotationWorkspaceRuntime]
switch name {
case string(model.RuntimeNameHuggingfaceTransformers):
runtime = model.RuntimeNameHuggingfaceTransformers
case string(model.RuntimeNameVLLM):
runtime = model.RuntimeNameVLLM
}

return runtime
}
4 changes: 2 additions & 2 deletions api/v1alpha1/workspace_validation.go
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ func (r *TuningSpec) validateCreate(ctx context.Context, workspaceNamespace stri
// Currently require a preset to specified, in future we can consider defining a template
if r.Preset == nil {
errs = errs.Also(apis.ErrMissingField("Preset"))
} else if presetName := string(r.Preset.Name); !utils.IsValidPreset(presetName) {
} else if presetName := string(r.Preset.Name); !plugin.IsValidPreset(presetName) {
errs = errs.Also(apis.ErrInvalidValue(fmt.Sprintf("Unsupported tuning preset name %s", presetName), "presetName"))
}
return errs
Expand Down Expand Up @@ -407,7 +407,7 @@ func (i *InferenceSpec) validateCreate() (errs *apis.FieldError) {
if i.Preset != nil {
presetName := string(i.Preset.Name)
// Validate preset name
if !utils.IsValidPreset(presetName) {
if !plugin.IsValidPreset(presetName) {
errs = errs.Also(apis.ErrInvalidValue(fmt.Sprintf("Unsupported inference preset name %s", presetName), "presetName"))
}
// Validate private preset has private image specified
Expand Down
1 change: 1 addition & 0 deletions pkg/featuregates/featuregates.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ var (
// FeatureGates is a map that holds the feature gates and their default values for Kaito.
FeatureGates = map[string]bool{
consts.FeatureFlagKarpenter: false,
consts.FeatureFlagVLLM: false,
// Add more feature gates here
}
)
Expand Down
145 changes: 132 additions & 13 deletions pkg/model/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ package model

import (
"time"

"github.com/kaito-project/kaito/pkg/utils"
)

type Model interface {
Expand All @@ -13,23 +15,140 @@ type Model interface {
SupportTuning() bool
}

// RuntimeName is LLM runtime name.
type RuntimeName string

const (
RuntimeNameHuggingfaceTransformers RuntimeName = "transformers"
RuntimeNameVLLM RuntimeName = "vllm"

InferenceFileHuggingface = "/workspace/tfs/inference_api.py"
InferenceFileVLLM = "/workspace/vllm/inference_api.py"
)

// PresetParam defines the preset inference parameters for a model.
type PresetParam struct {
ModelFamilyName string // The name of the model family.
ImageAccessMode string // Defines where the Image is Public or Private.
DiskStorageRequirement string // Disk storage requirements for the model.
GPUCountRequirement string // Number of GPUs required for the Preset. Used for inference.
TotalGPUMemoryRequirement string // Total GPU memory required for the Preset. Used for inference.
PerGPUMemoryRequirement string // GPU memory required per GPU. Used for inference.
TuningPerGPUMemoryRequirement map[string]int // Min GPU memory per tuning method (batch size 1). Used for tuning.
TorchRunParams map[string]string // Parameters for configuring the torchrun command.
TorchRunRdzvParams map[string]string // Optional rendezvous parameters for distributed training/inference using torchrun (elastic).
BaseCommand string // The initial command (e.g., 'torchrun', 'accelerate launch') used in the command line.
ModelRunParams map[string]string // Parameters for running the model training/inference.
Tag string // The model image tag
ModelFamilyName string // The name of the model family.
ImageAccessMode string // Defines where the Image is Public or Private.

DiskStorageRequirement string // Disk storage requirements for the model.
GPUCountRequirement string // Number of GPUs required for the Preset. Used for inference.
TotalGPUMemoryRequirement string // Total GPU memory required for the Preset. Used for inference.
PerGPUMemoryRequirement string // GPU memory required per GPU. Used for inference.
TuningPerGPUMemoryRequirement map[string]int // Min GPU memory per tuning method (batch size 1). Used for tuning.
WorldSize int // Defines the number of processes required for distributed inference.

RuntimeParam

// ReadinessTimeout defines the maximum duration for creating the workload.
// This timeout accommodates the size of the image, ensuring pull completion
// even under slower network conditions or unforeseen delays.
ReadinessTimeout time.Duration
WorldSize int // Defines the number of processes required for distributed inference.
Tag string // The model image tag
}

// RuntimeParam defines the llm runtime parameters.
type RuntimeParam struct {
Transformers HuggingfaceTransformersParam
VLLM VLLMParam
}

type HuggingfaceTransformersParam struct {
BaseCommand string // The initial command (e.g., 'torchrun', 'accelerate launch') used in the command line.
TorchRunParams map[string]string // Parameters for configuring the torchrun command.
TorchRunRdzvParams map[string]string // Optional rendezvous parameters for distributed training/inference using torchrun (elastic).
ModelRunParams map[string]string // Parameters for running the model training/inference.
}

type VLLMParam struct {
BaseCommand string
// The model name used in the openai serving API.
// see https://platform.openai.com/docs/api-reference/chat/create#chat-create-model.
ModelName string
// Parameters for distributed inference.
DistributionParams map[string]string
// Parameters for running the model training/inference.
ModelRunParams map[string]string
}

func (p *PresetParam) DeepCopy() *PresetParam {
if p == nil {
return nil
}
out := new(PresetParam)
*out = *p
out.RuntimeParam = p.RuntimeParam.DeepCopy()
out.TuningPerGPUMemoryRequirement = make(map[string]int, len(p.TuningPerGPUMemoryRequirement))
for k, v := range p.TuningPerGPUMemoryRequirement {
out.TuningPerGPUMemoryRequirement[k] = v
}
return out
}

func (rp *RuntimeParam) DeepCopy() RuntimeParam {
if rp == nil {
return RuntimeParam{}
}
out := RuntimeParam{}
out.Transformers = rp.Transformers.DeepCopy()
out.VLLM = rp.VLLM.DeepCopy()
return out
}

func (h *HuggingfaceTransformersParam) DeepCopy() HuggingfaceTransformersParam {
if h == nil {
return HuggingfaceTransformersParam{}
}
out := HuggingfaceTransformersParam{}
out.BaseCommand = h.BaseCommand
out.TorchRunParams = make(map[string]string, len(h.TorchRunParams))
for k, v := range h.TorchRunParams {
out.TorchRunParams[k] = v
}
out.TorchRunRdzvParams = make(map[string]string, len(h.TorchRunRdzvParams))
for k, v := range h.TorchRunRdzvParams {
out.TorchRunRdzvParams[k] = v
}
out.ModelRunParams = make(map[string]string, len(h.ModelRunParams))
for k, v := range h.ModelRunParams {
out.ModelRunParams[k] = v
}
return out
}

func (v *VLLMParam) DeepCopy() VLLMParam {
if v == nil {
return VLLMParam{}
}
out := VLLMParam{}
out.BaseCommand = v.BaseCommand
out.DistributionParams = make(map[string]string, len(v.DistributionParams))
for k, v := range v.DistributionParams {
out.DistributionParams[k] = v
}
out.ModelRunParams = make(map[string]string, len(v.ModelRunParams))
for k, v := range v.ModelRunParams {
out.ModelRunParams[k] = v
}
return out
}

// builds the container command:
// eg. torchrun <TORCH_PARAMS> <OPTIONAL_RDZV_PARAMS> baseCommand <MODEL_PARAMS>
func (p *PresetParam) GetInferenceCommand(runtime RuntimeName, skuNumGPUs string) []string {
switch runtime {
case RuntimeNameHuggingfaceTransformers:
torchCommand := utils.BuildCmdStr(p.Transformers.BaseCommand, p.Transformers.TorchRunParams, p.Transformers.TorchRunRdzvParams)
modelCommand := utils.BuildCmdStr(InferenceFileHuggingface, p.Transformers.ModelRunParams)
return utils.ShellCmd(torchCommand + " " + modelCommand)
case RuntimeNameVLLM:
if p.VLLM.ModelName != "" {
p.VLLM.ModelRunParams["served-model-name"] = p.VLLM.ModelName
}
p.VLLM.ModelRunParams["tensor-parallel-size"] = skuNumGPUs
modelCommand := utils.BuildCmdStr(InferenceFileVLLM, p.VLLM.ModelRunParams)
return utils.ShellCmd(p.VLLM.BaseCommand + " " + modelCommand)
default:
return nil
}
}
5 changes: 0 additions & 5 deletions pkg/utils/common-preset.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
package utils

import (
"github.com/kaito-project/kaito/pkg/utils/plugin"
corev1 "k8s.io/api/core/v1"
)

Expand Down Expand Up @@ -150,7 +149,3 @@ func ConfigAdapterVolume() (corev1.Volume, corev1.VolumeMount) {
}
return volume, volumeMount
}

func IsValidPreset(preset string) bool {
return plugin.KaitoModelRegister.Has(preset)
}
14 changes: 8 additions & 6 deletions pkg/utils/common.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,13 +68,15 @@ func MergeConfigMaps(baseMap, overrideMap map[string]string) map[string]string {
return merged
}

func BuildCmdStr(baseCommand string, runParams map[string]string) string {
func BuildCmdStr(baseCommand string, runParams ...map[string]string) string {
updatedBaseCommand := baseCommand
for key, value := range runParams {
if value == "" {
updatedBaseCommand = fmt.Sprintf("%s --%s", updatedBaseCommand, key)
} else {
updatedBaseCommand = fmt.Sprintf("%s --%s=%s", updatedBaseCommand, key, value)
for _, runParam := range runParams {
for key, value := range runParam {
if value == "" {
updatedBaseCommand = fmt.Sprintf("%s --%s", updatedBaseCommand, key)
} else {
updatedBaseCommand = fmt.Sprintf("%s --%s=%s", updatedBaseCommand, key, value)
}
}
}

Expand Down
2 changes: 1 addition & 1 deletion pkg/utils/common_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package utils

import (
"context"
"sigs.k8s.io/controller-runtime/pkg/client"
"testing"

"github.com/stretchr/testify/assert"
Expand All @@ -12,6 +11,7 @@ import (
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/runtime"
"k8s.io/client-go/kubernetes/scheme"
"sigs.k8s.io/controller-runtime/pkg/client"
"sigs.k8s.io/controller-runtime/pkg/client/fake"
)

Expand Down
5 changes: 4 additions & 1 deletion pkg/utils/consts/consts.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ const (
// RAGEngineFinalizer is used to make sure that ragengine controller handles garbage collection.
RAGEngineFinalizer = "ragengine.finalizer.kaito.sh"
DefaultReleaseNamespaceEnvVar = "RELEASE_NAMESPACE"
FeatureFlagKarpenter = "Karpenter"
AzureCloudName = "azure"
AWSCloudName = "aws"
GPUString = "gpu"
Expand All @@ -20,6 +19,10 @@ const (
GiBToBytes = 1024 * 1024 * 1024 // Conversion factor from GiB to bytes
NvidiaGPU = "nvidia.com/gpu"

// Feature flags
FeatureFlagKarpenter = "Karpenter"
FeatureFlagVLLM = "vLLM"

// Nodeclaim related consts
KaitoNodePoolName = "kaito"
LabelNodePool = "karpenter.sh/nodepool"
Expand Down
4 changes: 4 additions & 0 deletions pkg/utils/plugin/plugin.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,3 +60,7 @@ func (reg *ModelRegister) Has(name string) bool {
_, ok := reg.models[name]
return ok
}

func IsValidPreset(preset string) bool {
return KaitoModelRegister.Has(preset)
}
22 changes: 18 additions & 4 deletions pkg/utils/test/testModel.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,15 @@ type testModel struct{}
func (*testModel) GetInferenceParameters() *model.PresetParam {
return &model.PresetParam{
GPUCountRequirement: "1",
ReadinessTimeout: time.Duration(30) * time.Minute,
BaseCommand: "python3",
RuntimeParam: model.RuntimeParam{
VLLM: model.VLLMParam{
BaseCommand: "python3",
},
Transformers: model.HuggingfaceTransformersParam{
BaseCommand: "accelerate launch",
},
},
ReadinessTimeout: time.Duration(30) * time.Minute,
}
}
func (*testModel) GetTuningParameters() *model.PresetParam {
Expand All @@ -37,8 +44,15 @@ type testDistributedModel struct{}
func (*testDistributedModel) GetInferenceParameters() *model.PresetParam {
return &model.PresetParam{
GPUCountRequirement: "1",
ReadinessTimeout: time.Duration(30) * time.Minute,
BaseCommand: "python3",
RuntimeParam: model.RuntimeParam{
VLLM: model.VLLMParam{
BaseCommand: "python3",
},
Transformers: model.HuggingfaceTransformersParam{
BaseCommand: "accelerate launch",
},
},
ReadinessTimeout: time.Duration(30) * time.Minute,
}
}
func (*testDistributedModel) GetTuningParameters() *model.PresetParam {
Expand Down
26 changes: 26 additions & 0 deletions pkg/utils/test/testUtils.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ package test
import (
"github.com/aws/karpenter-core/pkg/apis/v1alpha5"
"github.com/kaito-project/kaito/api/v1alpha1"
"github.com/kaito-project/kaito/pkg/model"
"github.com/samber/lo"
appsv1 "k8s.io/api/apps/v1"
corev1 "k8s.io/api/core/v1"
Expand Down Expand Up @@ -139,6 +140,31 @@ var (
},
},
}
MockWorkspaceWithPresetVLLM = &v1alpha1.Workspace{
ObjectMeta: metav1.ObjectMeta{
Name: "testWorkspace",
Namespace: "kaito",
Annotations: map[string]string{
v1alpha1.AnnotationWorkspaceRuntime: string(model.RuntimeNameVLLM),
},
},
Resource: v1alpha1.ResourceSpec{
Count: &gpuNodeCount,
InstanceType: "Standard_NC12s_v3",
LabelSelector: &metav1.LabelSelector{
MatchLabels: map[string]string{
"apps": "test",
},
},
},
Inference: &v1alpha1.InferenceSpec{
Preset: &v1alpha1.PresetSpec{
PresetMeta: v1alpha1.PresetMeta{
Name: "test-model",
},
},
},
}
)

var MockWorkspaceWithPresetHash = "89ae127050ec264a5ce84db48ef7226574cdf1299e6bd27fe90b927e34cc8adb"
Expand Down
2 changes: 1 addition & 1 deletion pkg/workspace/controllers/workspace_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -780,7 +780,7 @@ func (c *WorkspaceReconciler) applyInference(ctx context.Context, wObj *kaitov1a
} else if apierrors.IsNotFound(err) {
var workloadObj client.Object
// Need to create a new workload
workloadObj, err = inference.CreatePresetInference(ctx, wObj, revisionStr, inferenceParam, model.SupportDistributedInference(), c.Client)
workloadObj, err = inference.CreatePresetInference(ctx, wObj, revisionStr, model, c.Client)
if err != nil {
return
}
Expand Down
Loading

0 comments on commit 8244ebc

Please sign in to comment.