diff --git a/cli/init.go b/cli/init.go index cab3dc2e8..b5a936933 100644 --- a/cli/init.go +++ b/cli/init.go @@ -16,18 +16,20 @@ limitations under the License. package cli import ( + "errors" "fmt" "os" "path/filepath" "strings" "github.com/spf13/cobra" - "github.com/yomorun/yomo/cli/serverless/golang" + "github.com/yomorun/yomo/cli/template" "github.com/yomorun/yomo/pkg/file" "github.com/yomorun/yomo/pkg/log" ) var name string +var sfnType string // initCmd represents the init command var initCmd = &cobra.Command{ @@ -48,17 +50,28 @@ var initCmd = &cobra.Command{ name = strings.ReplaceAll(name, " ", "_") // create app.go fname := filepath.Join(name, defaultSFNSourceFile) - contentTmpl := golang.InitTmpl + contentTmpl, err := template.GetContent("init", sfnType, "", false) + if err != nil { + log.FailureStatusEvent(os.Stdout, err.Error()) + return + } if err := file.PutContents(fname, contentTmpl); err != nil { - log.FailureStatusEvent(os.Stdout, "Write stream function into app.go file failure with the error: %v", err) + log.FailureStatusEvent(os.Stdout, "Write stream function into %s file failure with the error: %v", fname, err) return } - // create app_test.go testName := filepath.Join(name, defaultSFNTestSourceFile) - if err := file.PutContents(testName, golang.InitTestTmpl); err != nil { - log.FailureStatusEvent(os.Stdout, "Write unittest tmpl into app_test.go file failure with the error: %v", err) - return + testTmpl, err := template.GetContent("init", sfnType, "", true) + if err != nil { + if !errors.Is(err, template.ErrUnsupportedTest) { + log.FailureStatusEvent(os.Stdout, err.Error()) + return + } + } else { + if err := file.PutContents(testName, testTmpl); err != nil { + log.FailureStatusEvent(os.Stdout, "Write unittest tmpl into %s file failure with the error: %v", testName, err) + return + } } // create .env @@ -79,4 +92,5 @@ func init() { rootCmd.AddCommand(initCmd) initCmd.Flags().StringVarP(&name, "name", "n", "", "The name of Stream Function") + initCmd.Flags().StringVarP(&sfnType, "type", "t", "llm", "The type of Stream Function, support normal and llm") } diff --git a/cli/serverless/golang/template.go b/cli/serverless/golang/template.go index d1767dd0b..4c8f17c80 100644 --- a/cli/serverless/golang/template.go +++ b/cli/serverless/golang/template.go @@ -9,12 +9,6 @@ import ( //go:embed templates/main.tmpl var MainFuncTmpl []byte -//go:embed templates/init.tmpl -var InitTmpl []byte - -//go:embed templates/init_test.tmpl -var InitTestTmpl []byte - //go:embed templates/wasi_main.tmpl var WasiMainFuncTmpl []byte diff --git a/cli/serverless/golang/templates/init.tmpl b/cli/template/go/init_llm.tmpl similarity index 100% rename from cli/serverless/golang/templates/init.tmpl rename to cli/template/go/init_llm.tmpl diff --git a/cli/serverless/golang/templates/init_test.tmpl b/cli/template/go/init_llm_test.tmpl similarity index 100% rename from cli/serverless/golang/templates/init_test.tmpl rename to cli/template/go/init_llm_test.tmpl diff --git a/cli/template/go/init_normal.tmpl b/cli/template/go/init_normal.tmpl new file mode 100644 index 000000000..aa24478b8 --- /dev/null +++ b/cli/template/go/init_normal.tmpl @@ -0,0 +1,40 @@ +package main + +import ( + "fmt" + "strings" + + "github.com/yomorun/yomo/serverless" +) + +// Init is an optional function invoked during the initialization phase of the +// sfn instance. It's designed for setup tasks like global variable +// initialization, establishing database connections, or loading models into +// GPU memory. If initialization fails, the sfn instance will halt and terminate. +// This function can be omitted if no initialization tasks are needed. +func Init() error { + return nil +} + +// DataTags specifies the data tags to which this serverless function +// subscribes, essential for data reception. Upon receiving data with these +// tags, the Handler function is triggered. +func DataTags() []uint32 { + return []uint32{0x33} +} + +// Handler orchestrates the core processing logic of this function. +// - ctx.Tag() identifies the tag of the incoming data. +// - ctx.Data() accesses the raw data. +// - ctx.Write() forwards processed data downstream. +func Handler(ctx serverless.Context) { + data := ctx.Data() + fmt.Printf("<< sfn received[%d Bytes]: %s\n", len(data), data) + output := strings.ToUpper(string(data)) + err := ctx.Write(0x34, []byte(output)) + if err != nil { + fmt.Printf(">> sfn write error: %v\n", err) + return + } + fmt.Printf(">> sfn written[%d Bytes]: %s\n", len(output), output) +} diff --git a/cli/template/go/init_normal_test.tmpl b/cli/template/go/init_normal_test.tmpl new file mode 100644 index 000000000..5cec7d053 --- /dev/null +++ b/cli/template/go/init_normal_test.tmpl @@ -0,0 +1,39 @@ +package main + +import ( + "fmt" + "reflect" + "testing" + + "github.com/yomorun/yomo/serverless/mock" +) + +func TestHandler(t *testing.T) { + tests := []struct { + name string + ctx *mock.MockContext + // want is the expected data and tag that be written by ctx.Write + want []mock.WriteRecord + }{ + { + name: "upper", + ctx: mock.NewMockContext([]byte("hello"), 0x33), + want: []mock.WriteRecord{ + {Data: []byte("HELLO"), Tag: 0x34}, + }, + }, + // TODO: add more test cases. + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + Handler(tt.ctx) + got := tt.ctx.RecordsWritten() + + fmt.Println(string(got[0].Data)) + + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("TestHandler got: %v, want: %v", got, tt.want) + } + }) + } +} diff --git a/cli/template/template.go b/cli/template/template.go new file mode 100644 index 000000000..27745a283 --- /dev/null +++ b/cli/template/template.go @@ -0,0 +1,93 @@ +package template + +import ( + "embed" + "errors" + "os" + "strings" +) + +//go:embed go +var fs embed.FS + +var ( + ErrUnsupportedSfnType = errors.New("unsupported sfn type") + ErrorUnsupportedLang = errors.New("unsupported lang") + ErrUnsupportedTest = errors.New("unsupported test") +) + +var ( + SupportedSfnTypes = []string{"llm", "normal"} + SupportedLangs = []string{"go", "node"} +) + +// get template content +func GetContent(command string, sfnType string, lang string, isTest bool) ([]byte, error) { + if command == "" { + command = "init" + } + sfnType, err := validateSfnType(sfnType) + if err != nil { + return nil, err + } + lang, err = validateLang(lang) + if err != nil { + return nil, err + } + sb := new(strings.Builder) + sb.WriteString(lang) + sb.WriteString("/") + sb.WriteString(command) + sb.WriteString("_") + sb.WriteString(sfnType) + if isTest { + sb.WriteString("_test") + } + sb.WriteString(".tmpl") + + // valdiate the path exists + name := sb.String() + f, err := fs.Open(name) + if err != nil { + if os.IsNotExist(err) { + if isTest { + return nil, ErrUnsupportedTest + } + return nil, err + } + return nil, err + } + defer f.Close() + _, err = f.Stat() + if err != nil { + return nil, err + } + + return fs.ReadFile(name) +} + +func validateSfnType(sfnType string) (string, error) { + if sfnType == "" { + // default sfn type + return "llm", nil + } + for _, t := range SupportedSfnTypes { + if t == sfnType { + return sfnType, nil + } + } + return sfnType, ErrUnsupportedSfnType +} + +func validateLang(lang string) (string, error) { + if lang == "" { + // default lang + return "go", nil + } + for _, l := range SupportedLangs { + if l == lang { + return lang, nil + } + } + return lang, ErrorUnsupportedLang +}