Skip to content

Commit

Permalink
feat: add the ast for parsing modelfile
Browse files Browse the repository at this point in the history
Signed-off-by: Gaius <[email protected]>
  • Loading branch information
gaius-qi committed Oct 24, 2024
1 parent a0d46bc commit df7650d
Show file tree
Hide file tree
Showing 11 changed files with 1,160 additions and 0 deletions.
4 changes: 4 additions & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,22 @@ module github.com/CloudNativeAI/modctl
go 1.22.4

require (
github.com/emirpasic/gods v1.18.1
github.com/sirupsen/logrus v1.9.3
github.com/spf13/cobra v1.8.1
github.com/spf13/viper v1.19.0
github.com/stretchr/testify v1.9.0
)

require (
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect
github.com/fsnotify/fsnotify v1.7.0 // indirect
github.com/hashicorp/hcl v1.0.0 // indirect
github.com/inconshreveable/mousetrap v1.1.0 // indirect
github.com/magiconair/properties v1.8.7 // indirect
github.com/mitchellh/mapstructure v1.5.0 // indirect
github.com/pelletier/go-toml/v2 v2.2.2 // indirect
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect
github.com/sagikazarmark/locafero v0.4.0 // indirect
github.com/sagikazarmark/slog-shim v0.1.0 // indirect
github.com/sourcegraph/conc v0.3.0 // indirect
Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSs
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM=
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/emirpasic/gods v1.18.1 h1:FXtiHYKDGKCW2KzwZKx0iC0PQmdlorYgdFG9jPXJ1Bc=
github.com/emirpasic/gods v1.18.1/go.mod h1:8tpGGwCnJ5H4r6BWwaV6OrWmMoPhUl5jm/FMNAnJvWQ=
github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8=
github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0=
github.com/fsnotify/fsnotify v1.7.0 h1:8JEhPFa5W2WU7YfeZzPNqzMP6Lwt7L2715Ggo0nosvA=
Expand Down
69 changes: 69 additions & 0 deletions pkg/modelfile/command/command.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
/*
* Copyright 2024 The CNAI Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package command

// Define the command strings for modelfile.
const (
// CONFIG is the command to set the configuration of the model, which is used for
// the model to be served, such as the config.json, generation_config.json, etc.
// The CONFIG command can be used multiple times in a modelfile, it
// will be copied the config file to the artifact package as a layer.
CONFIG = "config"

// MODEL is the command to set the model file path. The value of this command
// is the regex of the model file path to match the model file name.
// The MODEL command can be used multiple times in a modelfile, it will scan
// the model file path by the regex and copy each model file to the artifact
// package, and each model file will be a layer.
MODEL = "model"

// NAME is the command to set the model name, such as llama3-8b-instruct, gpt2-xl,
// qwen2-vl-72b-instruct, etc.
NAME = "name"

// ARCH is the command to set the architecture of the model, such as transformer,
// cnn, rnn, etc.
ARCH = "arch"

// FAMILY is the command to set the family of the model, such as llama3, gpt2, qwen2, etc.
FAMILY = "family"

// FORMAT is the command to set the format of the model, such as onnx, tensorflow, pytorch, etc.
FORMAT = "format"

// PARAMSIZE is the command to set the parameter size of the model.
PARAMSIZE = "paramsize"

// PRECISION is the command to set the precision of the model, such as bf16, fp16, int8, etc.
PRECISION = "precision"

// QUANTIZATION is the command to set the quantization of the model, such as awq, gptq, etc.
QUANTIZATION = "quantization"
)

// Commands is a list of all the commands that can be used in a modelfile.
var Commands = []string{
CONFIG,
MODEL,
NAME,
ARCH,
FAMILY,
FORMAT,
PARAMSIZE,
PRECISION,
QUANTIZATION,
}
206 changes: 206 additions & 0 deletions pkg/modelfile/modelfile.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,209 @@
*/

package modelfile

import (
"fmt"
"os"

log "github.com/sirupsen/logrus"

modefilecommand "github.com/CloudNativeAI/modctl/pkg/modelfile/command"
"github.com/CloudNativeAI/modctl/pkg/modelfile/parser"
"github.com/emirpasic/gods/sets/hashset"
)

// Modelfile is the interface for the modelfile. It is used to parse
// the modelfile by the path and get the information of the modelfile.
type Modelfile interface {
// GetConfigs returns the args of the config command in the modelfile,
// and deduplicates the args. The order of the args is the same as the
// order in the modelfile.
GetConfigs() []string

// GetModels returns the args of the model command in the modelfile,
// and deduplicates the args. The order of the args is the same as The
// order in the modelfile.
GetModels() []string

// GetName returns the value of the name command in the modelfile.
GetName() string

// GetArch returns the value of the arch command in the modelfile.
GetArch() string

// GetFamily returns the value of the family command in the modelfile.
GetFamily() string

// GetFormat returns the value of the format command in the modelfile.
GetFormat() string

// GetParamsize returns the value of the paramsize command in the modelfile.
GetParamsize() string

// GetPrecision returns the value of the precision command in the modelfile.
GetPrecision() string

// GetQuantization returns the value of the quantization command in the modelfile.
GetQuantization() string
}

// modelfile is the implementation of the Modelfile interface.
type modelfile struct {
config *hashset.Set
model *hashset.Set
name string
arch string
family string
format string
paramsize string
precision string
quantization string
}

// NewModelfile creates a new modelfile by the path of the modelfile.
// It parses the modelfile and returns the modelfile interface.
func NewModelfile(path string) (Modelfile, error) {
mf := &modelfile{
config: hashset.New(),
model: hashset.New(),
}
if err := mf.parseFile(path); err != nil {
return nil, err
}

return mf, nil
}

// parseFile parses the modelfile by the path, and validates the args of the commands.
func (mf *modelfile) parseFile(path string) error {
f, err := os.Open(path)
if err != nil {
return err
}
defer f.Close()

ast, err := parser.Parse(f)
if err != nil {
return err
}

for _, child := range ast.GetChildren() {
switch child.GetValue() {
case modefilecommand.CONFIG:
mf.config.Add(child.GetNext().GetValue())
case modefilecommand.MODEL:
mf.model.Add(child.GetNext().GetValue())
case modefilecommand.NAME:
if mf.name != "" {
return fmt.Errorf("duplicate name command on line %d", child.GetStartLine())
}
mf.name = child.GetNext().GetValue()
case modefilecommand.ARCH:
if mf.arch != "" {
return fmt.Errorf("duplicate arc command on line %d", child.GetStartLine())
}
mf.arch = child.GetNext().GetValue()
case modefilecommand.FAMILY:
if mf.family != "" {
return fmt.Errorf("duplicate family command on line %d", child.GetStartLine())
}
mf.family = child.GetNext().GetValue()
case modefilecommand.FORMAT:
if mf.format != "" {
return fmt.Errorf("duplicate format command on line %d", child.GetStartLine())
}
mf.format = child.GetNext().GetValue()
case modefilecommand.PARAMSIZE:
if mf.paramsize != "" {
return fmt.Errorf("duplicate paramsize command on line %d", child.GetStartLine())
}
mf.paramsize = child.GetNext().GetValue()
case modefilecommand.PRECISION:
if mf.precision != "" {
return fmt.Errorf("duplicate precision command on line %d", child.GetStartLine())
}
mf.precision = child.GetNext().GetValue()
case modefilecommand.QUANTIZATION:
if mf.quantization != "" {
return fmt.Errorf("duplicate quantization command on line %d", child.GetStartLine())
}
mf.quantization = child.GetNext().GetValue()
default:
return fmt.Errorf("unknown command %s on line %d", child.GetValue(), child.GetStartLine())
}
}

return nil
}

// GetConfigs returns the args of the config command in the modelfile,
// and deduplicates the args. The order of the args is the same as the
// order in the modelfile.
func (mf *modelfile) GetConfigs() []string {
var configs []string
for _, rawConfig := range mf.config.Values() {
config, ok := rawConfig.(string)
if !ok {
log.Warnf("failed to convert config to string: %v", rawConfig)
continue
}

configs = append(configs, config)
}

return configs
}

// GetModels returns the args of the model command in the modelfile,
// and deduplicates the args. The order of the args is the same as the
// order in the modelfile.
func (mf *modelfile) GetModels() []string {
var models []string
for _, rawModel := range mf.model.Values() {
model, ok := rawModel.(string)
if !ok {
log.Warnf("failed to convert model to string: %v", rawModel)
continue
}

models = append(models, model)
}

return models
}

// GetName returns the value of the name command in the modelfile.
func (mf *modelfile) GetName() string {
return mf.name
}

// GetArch returns the value of the arch command in the modelfile.
func (mf *modelfile) GetArch() string {
return mf.arch
}

// GetFamily returns the value of the family command in the modelfile.
func (mf *modelfile) GetFamily() string {
return mf.family
}

// GetFormat returns the value of the format command in the modelfile.
func (mf *modelfile) GetFormat() string {
return mf.format
}

// GetParamsize returns the value of the paramsize command in the modelfile.
func (mf *modelfile) GetParamsize() string {
return mf.paramsize
}

// GetPrecision returns the value of the precision command in the modelfile.
func (mf *modelfile) GetPrecision() string {
return mf.precision
}

// GetQuantization returns the value of the quantization command in the modelfile.
func (mf *modelfile) GetQuantization() string {
return mf.quantization
}
Loading

0 comments on commit df7650d

Please sign in to comment.