forked from kaito-project/kaito
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmodel.go
144 lines (130 loc) · 4.69 KB
/
model.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
package llama2chat
import (
"time"
kaitov1alpha1 "github.com/kaito-project/kaito/api/v1alpha1"
"github.com/kaito-project/kaito/pkg/model"
"github.com/kaito-project/kaito/pkg/utils/plugin"
"github.com/kaito-project/kaito/pkg/workspace/inference"
)
func init() {
plugin.KaitoModelRegister.Register(&plugin.Registration{
Name: "llama-2-7b-chat",
Instance: &llama2chatA,
})
plugin.KaitoModelRegister.Register(&plugin.Registration{
Name: "llama-2-13b-chat",
Instance: &llama2chatB,
})
plugin.KaitoModelRegister.Register(&plugin.Registration{
Name: "llama-2-70b-chat",
Instance: &llama2chatC,
})
}
var (
baseCommandPresetLlama = "cd /workspace/llama/llama-2 && torchrun"
llamaRunParams = map[string]string{
"max_seq_len": "512",
"max_batch_size": "8",
}
)
var llama2chatA llama2Chat7b
type llama2Chat7b struct{}
func (*llama2Chat7b) GetInferenceParameters() *model.PresetParam {
return &model.PresetParam{
ModelFamilyName: "LLaMa2",
ImageAccessMode: string(kaitov1alpha1.ModelImageAccessModePrivate),
DiskStorageRequirement: "34Gi",
GPUCountRequirement: "1",
TotalGPUMemoryRequirement: "16Gi",
PerGPUMemoryRequirement: "14Gi", // We run llama2 using tensor parallelism, the memory of each GPU needs to be bigger than the tensor shard size.
RuntimeParam: model.RuntimeParam{
Transformers: model.HuggingfaceTransformersParam{
BaseCommand: baseCommandPresetLlama,
TorchRunParams: inference.DefaultTorchRunParams,
TorchRunRdzvParams: inference.DefaultTorchRunRdzvParams,
InferenceMainFile: "inference_api.py",
ModelRunParams: llamaRunParams,
},
},
ReadinessTimeout: time.Duration(10) * time.Minute,
WorldSize: 1,
// Tag: llama has private image access mode. The image tag is determined by the user.
}
}
func (*llama2Chat7b) GetTuningParameters() *model.PresetParam {
return nil // Currently doesn't support fine-tuning
}
func (*llama2Chat7b) SupportDistributedInference() bool {
return false
}
func (*llama2Chat7b) SupportTuning() bool {
return false
}
var llama2chatB llama2Chat13b
type llama2Chat13b struct{}
func (*llama2Chat13b) GetInferenceParameters() *model.PresetParam {
return &model.PresetParam{
ModelFamilyName: "LLaMa2",
ImageAccessMode: string(kaitov1alpha1.ModelImageAccessModePrivate),
DiskStorageRequirement: "46Gi",
GPUCountRequirement: "2",
TotalGPUMemoryRequirement: "30Gi",
PerGPUMemoryRequirement: "15Gi", // We run llama2 using tensor parallelism, the memory of each GPU needs to be bigger than the tensor shard size.
RuntimeParam: model.RuntimeParam{
Transformers: model.HuggingfaceTransformersParam{
BaseCommand: baseCommandPresetLlama,
TorchRunParams: inference.DefaultTorchRunParams,
TorchRunRdzvParams: inference.DefaultTorchRunRdzvParams,
InferenceMainFile: "inference_api.py",
ModelRunParams: llamaRunParams,
},
},
ReadinessTimeout: time.Duration(20) * time.Minute,
WorldSize: 2,
// Tag: llama has private image access mode. The image tag is determined by the user.
}
}
func (*llama2Chat13b) GetTuningParameters() *model.PresetParam {
return nil // Currently doesn't support fine-tuning
}
func (*llama2Chat13b) SupportDistributedInference() bool {
return true
}
func (*llama2Chat13b) SupportTuning() bool {
return false
}
var llama2chatC llama2Chat70b
type llama2Chat70b struct{}
func (*llama2Chat70b) GetInferenceParameters() *model.PresetParam {
return &model.PresetParam{
ModelFamilyName: "LLaMa2",
ImageAccessMode: string(kaitov1alpha1.ModelImageAccessModePrivate),
DiskStorageRequirement: "158Gi",
GPUCountRequirement: "8",
TotalGPUMemoryRequirement: "192Gi",
PerGPUMemoryRequirement: "19Gi", // We run llama2 using tensor parallelism, the memory of each GPU needs to be bigger than the tensor shard size.
RuntimeParam: model.RuntimeParam{
Transformers: model.HuggingfaceTransformersParam{
BaseCommand: baseCommandPresetLlama,
TorchRunParams: inference.DefaultTorchRunParams,
TorchRunRdzvParams: inference.DefaultTorchRunRdzvParams,
InferenceMainFile: "inference_api.py",
ModelRunParams: llamaRunParams,
},
},
ReadinessTimeout: time.Duration(30) * time.Minute,
WorldSize: 8,
// Tag: llama has private image access mode. The image tag is determined by the user.
}
}
func (*llama2Chat70b) GetTuningParameters() *model.PresetParam {
return nil // Currently doesn't support fine-tuning
}
func (*llama2Chat70b) SupportDistributedInference() bool {
return true
}
func (*llama2Chat70b) SupportTuning() bool {
return false
}