From 070ae2f8e904015c1128822fcdb5e0589dfa288a Mon Sep 17 00:00:00 2001 From: Shreyas Goenka Date: Fri, 3 Jan 2025 16:28:35 +0530 Subject: [PATCH 01/22] Refactor `bundle init` --- cmd/bundle/init.go | 219 +++--------------- integration/bundle/helpers_test.go | 38 +-- integration/bundle/init_test.go | 176 ++++++++++++++ libs/template/builtin.go | 1 + libs/template/config.go | 21 ++ libs/template/config_test.go | 39 ++++ libs/template/materialize.go | 94 -------- libs/template/materialize_test.go | 24 -- libs/template/reader.go | 199 ++++++++++++++++ libs/template/template.go | 145 ++++++++++++ .../template/template_test.go | 55 ++--- libs/template/writer.go | 169 ++++++++++++++ 12 files changed, 823 insertions(+), 357 deletions(-) delete mode 100644 libs/template/materialize.go delete mode 100644 libs/template/materialize_test.go create mode 100644 libs/template/reader.go create mode 100644 libs/template/template.go rename cmd/bundle/init_test.go => libs/template/template_test.go (64%) create mode 100644 libs/template/writer.go diff --git a/cmd/bundle/init.go b/cmd/bundle/init.go index 687c141eca..4da5a69be7 100644 --- a/cmd/bundle/init.go +++ b/cmd/bundle/init.go @@ -4,152 +4,17 @@ import ( "context" "errors" "fmt" - "io/fs" - "os" "path/filepath" - "slices" "strings" "github.com/databricks/cli/cmd/root" "github.com/databricks/cli/libs/cmdio" "github.com/databricks/cli/libs/dbr" "github.com/databricks/cli/libs/filer" - "github.com/databricks/cli/libs/git" "github.com/databricks/cli/libs/template" "github.com/spf13/cobra" ) -var gitUrlPrefixes = []string{ - "https://", - "git@", -} - -type nativeTemplate struct { - name string - gitUrl string - description string - aliases []string - hidden bool -} - -const customTemplate = "custom..." - -var nativeTemplates = []nativeTemplate{ - { - name: "default-python", - description: "The default Python template for Notebooks / Delta Live Tables / Workflows", - }, - { - name: "default-sql", - description: "The default SQL template for .sql files that run with Databricks SQL", - }, - { - name: "dbt-sql", - description: "The dbt SQL template (databricks.com/blog/delivering-cost-effective-data-real-time-dbt-and-databricks)", - }, - { - name: "mlops-stacks", - gitUrl: "https://github.com/databricks/mlops-stacks", - description: "The Databricks MLOps Stacks template (github.com/databricks/mlops-stacks)", - aliases: []string{"mlops-stack"}, - }, - { - name: "default-pydabs", - gitUrl: "https://databricks.github.io/workflows-authoring-toolkit/pydabs-template.git", - hidden: true, - description: "The default PyDABs template", - }, - { - name: customTemplate, - description: "Bring your own template", - }, -} - -// Return template descriptions for command-line help -func nativeTemplateHelpDescriptions() string { - var lines []string - for _, template := range nativeTemplates { - if template.name != customTemplate && !template.hidden { - lines = append(lines, fmt.Sprintf("- %s: %s", template.name, template.description)) - } - } - return strings.Join(lines, "\n") -} - -// Return template options for an interactive prompt -func nativeTemplateOptions() []cmdio.Tuple { - names := make([]cmdio.Tuple, 0, len(nativeTemplates)) - for _, template := range nativeTemplates { - if template.hidden { - continue - } - tuple := cmdio.Tuple{ - Name: template.name, - Id: template.description, - } - names = append(names, tuple) - } - return names -} - -func getNativeTemplateByDescription(description string) string { - for _, template := range nativeTemplates { - if template.description == description { - return template.name - } - } - return "" -} - -func getUrlForNativeTemplate(name string) string { - for _, template := range nativeTemplates { - if template.name == name { - return template.gitUrl - } - if slices.Contains(template.aliases, name) { - return template.gitUrl - } - } - return "" -} - -func getFsForNativeTemplate(name string) (fs.FS, error) { - builtin, err := template.Builtin() - if err != nil { - return nil, err - } - - // If this is a built-in template, the return value will be non-nil. - var templateFS fs.FS - for _, entry := range builtin { - if entry.Name == name { - templateFS = entry.FS - break - } - } - - return templateFS, nil -} - -func isRepoUrl(url string) bool { - result := false - for _, prefix := range gitUrlPrefixes { - if strings.HasPrefix(url, prefix) { - result = true - break - } - } - return result -} - -// Computes the repo name from the repo URL. Treats the last non empty word -// when splitting at '/' as the repo name. For example: for url git@github.com:databricks/cli.git -// the name would be "cli.git" -func repoName(url string) string { - parts := strings.Split(strings.TrimRight(url, "/"), "/") - return parts[len(parts)-1] -} - func constructOutputFiler(ctx context.Context, outputDir string) (filer.Filer, error) { outputDir, err := filepath.Abs(outputDir) if err != nil { @@ -182,7 +47,7 @@ TEMPLATE_PATH optionally specifies which template to use. It can be one of the f - a local file system path with a template directory - a Git repository URL, e.g. https://github.com/my/repository -See https://docs.databricks.com/en/dev-tools/bundles/templates.html for more information on templates.`, nativeTemplateHelpDescriptions()), +See https://docs.databricks.com/en/dev-tools/bundles/templates.html for more information on templates.`, template.HelpDescriptions()), } var configFile string @@ -196,7 +61,6 @@ See https://docs.databricks.com/en/dev-tools/bundles/templates.html for more inf cmd.Flags().StringVar(&branch, "tag", "", "Git tag to use for template initialization") cmd.Flags().StringVar(&tag, "branch", "", "Git branch to use for template initialization") - cmd.PreRunE = root.MustWorkspaceClient cmd.RunE = func(cmd *cobra.Command, args []string) error { if tag != "" && branch != "" { return errors.New("only one of --tag or --branch can be specified") @@ -208,82 +72,51 @@ See https://docs.databricks.com/en/dev-tools/bundles/templates.html for more inf ref = tag } + var tmpl *template.Template + var err error ctx := cmd.Context() - var templatePath string + if len(args) > 0 { - templatePath = args[0] + // User already specified a template local path or a Git URL. Use that + // information to configure a reader for the template + tmpl = template.Get(template.Custom) + // TODO: Get rid of the name arg. + if template.IsGitRepoUrl(args[0]) { + tmpl.SetReader(template.NewGitReader("", args[0], ref, templateDir)) + } else { + tmpl.SetReader(template.NewLocalReader("", args[0])) + } } else { - var err error - if !cmdio.IsPromptSupported(ctx) { - return errors.New("please specify a template") + tmplId, err := template.PromptForTemplateId(cmd.Context(), ref, templateDir) + if tmplId == template.Custom { + // If a user selects custom during the prompt, ask them to provide a path or Git URL + // as a positional argument. + cmdio.LogString(ctx, "Please specify a path or Git repository to use a custom template.") + cmdio.LogString(ctx, "See https://docs.databricks.com/en/dev-tools/bundles/templates.html to learn more about custom templates.") + return nil } - description, err := cmdio.SelectOrdered(ctx, nativeTemplateOptions(), "Template to use") if err != nil { return err } - templatePath = getNativeTemplateByDescription(description) - } - - outputFiler, err := constructOutputFiler(ctx, outputDir) - if err != nil { - return err - } - - if templatePath == customTemplate { - cmdio.LogString(ctx, "Please specify a path or Git repository to use a custom template.") - cmdio.LogString(ctx, "See https://docs.databricks.com/en/dev-tools/bundles/templates.html to learn more about custom templates.") - return nil - } - // Expand templatePath to a git URL if it's an alias for a known native template - // and we know it's git URL. - if gitUrl := getUrlForNativeTemplate(templatePath); gitUrl != "" { - templatePath = gitUrl + tmpl = template.Get(tmplId) } - if !isRepoUrl(templatePath) { - if templateDir != "" { - return errors.New("--template-dir can only be used with a Git repository URL") - } + defer tmpl.Reader.Close() - templateFS, err := getFsForNativeTemplate(templatePath) - if err != nil { - return err - } - - // If this is not a built-in template, then it must be a local file system path. - if templateFS == nil { - templateFS = os.DirFS(templatePath) - } - - // skip downloading the repo because input arg is not a URL. We assume - // it's a path on the local file system in that case - return template.Materialize(ctx, configFile, templateFS, outputFiler) - } - - // Create a temporary directory with the name of the repository. The '*' - // character is replaced by a random string in the generated temporary directory. - repoDir, err := os.MkdirTemp("", repoName(templatePath)+"-*") + outputFiler, err := constructOutputFiler(ctx, outputDir) if err != nil { return err } - // start the spinner - promptSpinner := cmdio.Spinner(ctx) - promptSpinner <- "Downloading the template\n" + tmpl.Writer.Initialize(tmpl.Reader, configFile, outputFiler) - // TODO: Add automated test that the downloaded git repo is cleaned up. - // Clone the repository in the temporary directory - err = git.Clone(ctx, templatePath, ref, repoDir) - close(promptSpinner) + err = tmpl.Writer.Materialize(ctx) if err != nil { return err } - // Clean up downloaded repository once the template is materialized. - defer os.RemoveAll(repoDir) - templateFS := os.DirFS(filepath.Join(repoDir, templateDir)) - return template.Materialize(ctx, configFile, templateFS, outputFiler) + return tmpl.Writer.LogTelemetry(ctx) } return cmd } diff --git a/integration/bundle/helpers_test.go b/integration/bundle/helpers_test.go index e884cd8c68..60177297e6 100644 --- a/integration/bundle/helpers_test.go +++ b/integration/bundle/helpers_test.go @@ -8,18 +8,13 @@ import ( "os" "os/exec" "path/filepath" - "strings" "github.com/databricks/cli/bundle" - "github.com/databricks/cli/cmd/root" "github.com/databricks/cli/internal/testcli" "github.com/databricks/cli/internal/testutil" "github.com/databricks/cli/libs/cmdio" "github.com/databricks/cli/libs/env" - "github.com/databricks/cli/libs/filer" - "github.com/databricks/cli/libs/flags" "github.com/databricks/cli/libs/folders" - "github.com/databricks/cli/libs/template" "github.com/databricks/databricks-sdk-go" "github.com/stretchr/testify/require" ) @@ -32,19 +27,32 @@ func initTestTemplate(t testutil.TestingT, ctx context.Context, templateName str } func initTestTemplateWithBundleRoot(t testutil.TestingT, ctx context.Context, templateName string, config map[string]any, bundleRoot string) string { - templateRoot := filepath.Join("bundles", templateName) + return "" - configFilePath := writeConfigFile(t, config) + // TODO: Make this function work but do not log telemetry. - ctx = root.SetWorkspaceClient(ctx, nil) - cmd := cmdio.NewIO(ctx, flags.OutputJSON, strings.NewReader(""), os.Stdout, os.Stderr, "", "bundles") - ctx = cmdio.InContext(ctx, cmd) + // templateRoot := filepath.Join("bundles", templateName) - out, err := filer.NewLocalClient(bundleRoot) - require.NoError(t, err) - err = template.Materialize(ctx, configFilePath, os.DirFS(templateRoot), out) - require.NoError(t, err) - return bundleRoot + // configFilePath := writeConfigFile(t, config) + + // ctx = root.SetWorkspaceClient(ctx, nil) + // cmd := cmdio.NewIO(ctx, flags.OutputJSON, strings.NewReader(""), os.Stdout, os.Stderr, "", "bundles") + // ctx = cmdio.InContext(ctx, cmd) + // ctx = telemetry.WithMockLogger(ctx) + + // out, err := filer.NewLocalClient(bundleRoot) + // require.NoError(t, err) + // tmpl := template.TemplateX{ + // TemplateOpts: template.TemplateOpts{ + // ConfigFilePath: configFilePath, + // TemplateFS: os.DirFS(templateRoot), + // OutputFiler: out, + // }, + // } + + // err = tmpl.Materialize(ctx) + // require.NoError(t, err) + // return bundleRoot } func writeConfigFile(t testutil.TestingT, config map[string]any) string { diff --git a/integration/bundle/init_test.go b/integration/bundle/init_test.go index f5c263ca3d..3826f55433 100644 --- a/integration/bundle/init_test.go +++ b/integration/bundle/init_test.go @@ -15,6 +15,7 @@ import ( "github.com/databricks/cli/internal/testcli" "github.com/databricks/cli/internal/testutil" "github.com/databricks/cli/libs/iamutil" + "github.com/databricks/cli/libs/telemetry" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -42,6 +43,9 @@ func TestBundleInitOnMlopsStacks(t *testing.T) { ctx, wt := acc.WorkspaceTest(t) w := wt.W + // Use mock logger to introspect the telemetry payload. + ctx = telemetry.WithMockLogger(ctx) + tmpDir1 := t.TempDir() tmpDir2 := t.TempDir() @@ -64,6 +68,28 @@ func TestBundleInitOnMlopsStacks(t *testing.T) { assert.NoFileExists(t, filepath.Join(tmpDir2, "repo_name", projectName, "README.md")) testcli.RequireSuccessfulRun(t, ctx, "bundle", "init", "mlops-stacks", "--output-dir", tmpDir2, "--config-file", filepath.Join(tmpDir1, "config.json")) + // Assert the telemetry payload is correctly logged. + tlmyEvents := telemetry.Introspect(ctx) + require.Len(t, telemetry.Introspect(ctx), 1) + event := tlmyEvents[0].BundleInitEvent + assert.Equal(t, "mlops-stacks", event.TemplateName) + + get := func(key string) string { + for _, v := range event.TemplateEnumArgs { + if v.Key == key { + return v.Value + } + } + return "" + } + + // Enum values should be present in the telemetry payload. + assert.Equal(t, "no", get("input_include_models_in_unity_catalog")) + assert.Equal(t, strings.ToLower(env), get("input_cloud")) + // Freeform strings should not be present in the telemetry payload. + assert.Equal(t, "", get("input_project_name")) + assert.Equal(t, "", get("input_root_dir")) + // Assert that the README.md file was created contents := testutil.ReadFile(t, filepath.Join(tmpDir2, "repo_name", projectName, "README.md")) assert.Contains(t, contents, fmt.Sprintf("# %s", projectName)) @@ -99,6 +125,156 @@ func TestBundleInitOnMlopsStacks(t *testing.T) { assert.Contains(t, job.Settings.Name, fmt.Sprintf("dev-%s-batch-inference-job", projectName)) } +func TestBundleInitTelemetryForDefaultTemplates(t *testing.T) { + projectName := testutil.RandomName("name_") + + tcases := []struct { + name string + args map[string]string + expectedArgs map[string]string + }{ + { + name: "dbt-sql", + args: map[string]string{ + "project_name": fmt.Sprintf("dbt-sql-%s", projectName), + "http_path": "/sql/1.0/warehouses/id", + "default_catalog": "abcd", + "personal_schemas": "yes, use a schema based on the current user name during development", + }, + expectedArgs: map[string]string{ + "personal_schemas": "yes, use a schema based on the current user name during development", + }, + }, + { + name: "default-python", + args: map[string]string{ + "project_name": fmt.Sprintf("default_python_%s", projectName), + "include_notebook": "yes", + "include_dlt": "yes", + "include_python": "no", + }, + expectedArgs: map[string]string{ + "include_notebook": "yes", + "include_dlt": "yes", + "include_python": "no", + }, + }, + { + name: "default-sql", + args: map[string]string{ + "project_name": fmt.Sprintf("sql_project_%s", projectName), + "http_path": "/sql/1.0/warehouses/id", + "default_catalog": "abcd", + "personal_schemas": "yes, automatically use a schema based on the current user name during development", + }, + expectedArgs: map[string]string{ + "personal_schemas": "yes, automatically use a schema based on the current user name during development", + }, + }, + } + + for _, tc := range tcases { + ctx, _ := acc.WorkspaceTest(t) + + // Use mock logger to introspect the telemetry payload. + ctx = telemetry.WithMockLogger(ctx) + + tmpDir1 := t.TempDir() + tmpDir2 := t.TempDir() + + // Create a config file with the project name and root dir + initConfig := tc.args + b, err := json.Marshal(initConfig) + require.NoError(t, err) + err = os.WriteFile(filepath.Join(tmpDir1, "config.json"), b, 0o644) + require.NoError(t, err) + + // Run bundle init + assert.NoDirExists(t, filepath.Join(tmpDir2, tc.args["project_name"])) + testcli.RequireSuccessfulRun(t, ctx, "bundle", "init", tc.name, "--output-dir", tmpDir2, "--config-file", filepath.Join(tmpDir1, "config.json")) + assert.DirExists(t, filepath.Join(tmpDir2, tc.args["project_name"])) + + // Assert the telemetry payload is correctly logged. + logs := telemetry.Introspect(ctx) + require.Len(t, logs, 1) + event := logs[0].BundleInitEvent + assert.Equal(t, event.TemplateName, tc.name) + + get := func(key string) string { + for _, v := range event.TemplateEnumArgs { + if v.Key == key { + return v.Value + } + } + return "" + } + + // Assert the template enum args are correctly logged. + assert.Len(t, event.TemplateEnumArgs, len(tc.expectedArgs)) + for k, v := range tc.expectedArgs { + assert.Equal(t, get(k), v) + } + } +} + +func TestBundleInitTelemetryForCustomTemplates(t *testing.T) { + ctx, _ := acc.WorkspaceTest(t) + + tmpDir1 := t.TempDir() + tmpDir2 := t.TempDir() + tmpDir3 := t.TempDir() + + err := os.Mkdir(filepath.Join(tmpDir1, "template"), 0o755) + require.NoError(t, err) + err = os.WriteFile(filepath.Join(tmpDir1, "template", "foo.txt.tmpl"), []byte("{{bundle_uuid}}"), 0o644) + require.NoError(t, err) + err = os.WriteFile(filepath.Join(tmpDir1, "databricks_template_schema.json"), []byte(` +{ + "properties": { + "a": { + "description": "whatever", + "type": "string" + }, + "b": { + "description": "whatever", + "type": "string", + "enum": ["yes", "no"] + } + } +} +`), 0o644) + require.NoError(t, err) + + // Create a config file with the project name and root dir + initConfig := map[string]string{ + "a": "v1", + "b": "yes", + } + b, err := json.Marshal(initConfig) + require.NoError(t, err) + err = os.WriteFile(filepath.Join(tmpDir3, "config.json"), b, 0o644) + require.NoError(t, err) + + // Use mock logger to introspect the telemetry payload. + ctx = telemetry.WithMockLogger(ctx) + + // Run bundle init. + testcli.RequireSuccessfulRun(t, ctx, "bundle", "init", tmpDir1, "--output-dir", tmpDir2, "--config-file", filepath.Join(tmpDir3, "config.json")) + + // Assert the telemetry payload is correctly logged. For custom templates we should + // never set template_enum_args. + tlmyEvents := telemetry.Introspect(ctx) + require.Len(t, len(tlmyEvents), 1) + event := tlmyEvents[0].BundleInitEvent + assert.Equal(t, "custom", event.TemplateName) + assert.Empty(t, event.TemplateEnumArgs) + + // Ensure that the UUID returned by the `bundle_uuid` helper is the same UUID + // that's logged in the telemetry event. + fileC := testutil.ReadFile(t, filepath.Join(tmpDir2, "foo.txt")) + assert.Equal(t, event.Uuid, fileC) +} + func TestBundleInitHelpers(t *testing.T) { ctx, wt := acc.WorkspaceTest(t) w := wt.W diff --git a/libs/template/builtin.go b/libs/template/builtin.go index dcb3a88582..96cdcbb961 100644 --- a/libs/template/builtin.go +++ b/libs/template/builtin.go @@ -15,6 +15,7 @@ type BuiltinTemplate struct { } // Builtin returns the list of all built-in templates. +// TODO: Make private? func Builtin() ([]BuiltinTemplate, error) { templates, err := fs.Sub(builtinTemplates, "templates") if err != nil { diff --git a/libs/template/config.go b/libs/template/config.go index 8e7695b915..34eee065c6 100644 --- a/libs/template/config.go +++ b/libs/template/config.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "io/fs" + "slices" "github.com/databricks/cli/libs/cmdio" "github.com/databricks/cli/libs/jsonschema" @@ -273,3 +274,23 @@ func (c *config) validate() error { } return nil } + +// Return enum values selected by the user during template initialization. These +// values are safe to send over in telemetry events due to their limited cardinality. +func (c *config) enumValues() map[string]string { + res := map[string]string{} + for k, p := range c.schema.Properties { + if p.Type != jsonschema.StringType { + continue + } + if p.Enum == nil { + continue + } + v := c.values[k] + + if slices.Contains(p.Enum, v) { + res[k] = v.(string) + } + } + return res +} diff --git a/libs/template/config_test.go b/libs/template/config_test.go index 515a0b9f5b..3f971a8622 100644 --- a/libs/template/config_test.go +++ b/libs/template/config_test.go @@ -564,3 +564,42 @@ func TestPromptIsSkippedAnyOf(t *testing.T) { assert.True(t, skip) assert.Equal(t, "hello-world", c.values["xyz"]) } + +func TestConfigEnumValues(t *testing.T) { + c := &config{ + schema: &jsonschema.Schema{ + Properties: map[string]*jsonschema.Schema{ + "a": { + Type: jsonschema.StringType, + }, + "b": { + Type: jsonschema.BooleanType, + }, + "c": { + Type: jsonschema.StringType, + Enum: []any{"v1", "v2"}, + }, + "d": { + Type: jsonschema.StringType, + Enum: []any{"v3", "v4"}, + }, + "e": { + Type: jsonschema.StringType, + Enum: []any{"v5", "v6"}, + }, + }, + }, + values: map[string]any{ + "a": "w1", + "b": false, + "c": "v1", + "d": "v3", + "e": "v7", + }, + } + + assert.Equal(t, map[string]string{ + "c": "v1", + "d": "v3", + }, c.enumValues()) +} diff --git a/libs/template/materialize.go b/libs/template/materialize.go deleted file mode 100644 index 86a6a8c37a..0000000000 --- a/libs/template/materialize.go +++ /dev/null @@ -1,94 +0,0 @@ -package template - -import ( - "context" - "errors" - "fmt" - "io/fs" - - "github.com/databricks/cli/libs/cmdio" - "github.com/databricks/cli/libs/filer" -) - -const ( - libraryDirName = "library" - templateDirName = "template" - schemaFileName = "databricks_template_schema.json" -) - -// This function materializes the input templates as a project, using user defined -// configurations. -// Parameters: -// -// ctx: context containing a cmdio object. This is used to prompt the user -// configFilePath: file path containing user defined config values -// templateFS: root of the template definition -// outputFiler: filer to use for writing the initialized template -func Materialize(ctx context.Context, configFilePath string, templateFS fs.FS, outputFiler filer.Filer) error { - if _, err := fs.Stat(templateFS, schemaFileName); errors.Is(err, fs.ErrNotExist) { - return fmt.Errorf("not a bundle template: expected to find a template schema file at %s", schemaFileName) - } - - config, err := newConfig(ctx, templateFS, schemaFileName) - if err != nil { - return err - } - - // Read and assign config values from file - if configFilePath != "" { - err = config.assignValuesFromFile(configFilePath) - if err != nil { - return err - } - } - - helpers := loadHelpers(ctx) - r, err := newRenderer(ctx, config.values, helpers, templateFS, templateDirName, libraryDirName) - if err != nil { - return err - } - - // Print welcome message - welcome := config.schema.WelcomeMessage - if welcome != "" { - welcome, err = r.executeTemplate(welcome) - if err != nil { - return err - } - cmdio.LogString(ctx, welcome) - } - - // Prompt user for any missing config values. Assign default values if - // terminal is not TTY - err = config.promptOrAssignDefaultValues(r) - if err != nil { - return err - } - err = config.validate() - if err != nil { - return err - } - - // Walk and render the template, since input configuration is complete - err = r.walk() - if err != nil { - return err - } - - err = r.persistToDisk(ctx, outputFiler) - if err != nil { - return err - } - - success := config.schema.SuccessMessage - if success == "" { - cmdio.LogString(ctx, "✨ Successfully initialized template") - } else { - success, err = r.executeTemplate(success) - if err != nil { - return err - } - cmdio.LogString(ctx, success) - } - return nil -} diff --git a/libs/template/materialize_test.go b/libs/template/materialize_test.go deleted file mode 100644 index f7cd916e33..0000000000 --- a/libs/template/materialize_test.go +++ /dev/null @@ -1,24 +0,0 @@ -package template - -import ( - "context" - "fmt" - "os" - "testing" - - "github.com/databricks/cli/cmd/root" - "github.com/databricks/databricks-sdk-go" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func TestMaterializeForNonTemplateDirectory(t *testing.T) { - tmpDir := t.TempDir() - w, err := databricks.NewWorkspaceClient(&databricks.Config{}) - require.NoError(t, err) - ctx := root.SetWorkspaceClient(context.Background(), w) - - // Try to materialize a non-template directory. - err = Materialize(ctx, "", os.DirFS(tmpDir), nil) - assert.EqualError(t, err, fmt.Sprintf("not a bundle template: expected to find a template schema file at %s", schemaFileName)) -} diff --git a/libs/template/reader.go b/libs/template/reader.go new file mode 100644 index 0000000000..6cfaf9cb6c --- /dev/null +++ b/libs/template/reader.go @@ -0,0 +1,199 @@ +package template + +import ( + "context" + "fmt" + "io/fs" + "os" + "path/filepath" + "strings" + + "github.com/databricks/cli/libs/cmdio" + "github.com/databricks/cli/libs/git" +) + +// TODO: Add tests for all these readers. +type Reader interface { + // FS returns a file system that contains the template + // definition files. This function is NOT thread safe. + FS(ctx context.Context) (fs.FS, error) + + // Close releases any resources associated with the reader + // like cleaning up temporary directories. + Close() error + + Name() string +} + +type builtinReader struct { + name string + fsCached fs.FS +} + +func (r *builtinReader) FS(ctx context.Context) (fs.FS, error) { + // If the FS has already been loaded, return it. + if r.fsCached != nil { + return r.fsCached, nil + } + + builtin, err := Builtin() + if err != nil { + return nil, err + } + + var templateFS fs.FS + for _, entry := range builtin { + if entry.Name == r.name { + templateFS = entry.FS + break + } + } + + r.fsCached = templateFS + return r.fsCached, nil +} + +func (r *builtinReader) Close() error { + return nil +} + +func (r *builtinReader) Name() string { + return r.name +} + +type gitReader struct { + name string + // URL of the git repository that contains the template + gitUrl string + // tag or branch to checkout + ref string + // subdirectory within the repository that contains the template + templateDir string + // temporary directory where the repository is cloned + tmpRepoDir string + + fsCached fs.FS +} + +// Computes the repo name from the repo URL. Treats the last non empty word +// when splitting at '/' as the repo name. For example: for url git@github.com:databricks/cli.git +// the name would be "cli.git" +func repoName(url string) string { + parts := strings.Split(strings.TrimRight(url, "/"), "/") + return parts[len(parts)-1] +} + +var gitUrlPrefixes = []string{ + "https://", + "git@", +} + +// TODO: Copy over tests for this function. +func IsGitRepoUrl(url string) bool { + result := false + for _, prefix := range gitUrlPrefixes { + if strings.HasPrefix(url, prefix) { + result = true + break + } + } + return result +} + +// TODO: Can I remove the name from here and other readers? +func NewGitReader(name, gitUrl, ref, templateDir string) Reader { + return &gitReader{ + name: name, + gitUrl: gitUrl, + ref: ref, + templateDir: templateDir, + } +} + +// TODO: Test the idempotency of this function as well. +func (r *gitReader) FS(ctx context.Context) (fs.FS, error) { + // If the FS has already been loaded, return it. + if r.fsCached != nil { + return r.fsCached, nil + } + + // Create a temporary directory with the name of the repository. The '*' + // character is replaced by a random string in the generated temporary directory. + repoDir, err := os.MkdirTemp("", repoName(r.gitUrl)+"-*") + if err != nil { + return nil, err + } + r.tmpRepoDir = repoDir + + // start the spinner + promptSpinner := cmdio.Spinner(ctx) + promptSpinner <- "Downloading the template\n" + + err = git.Clone(ctx, r.gitUrl, r.ref, repoDir) + close(promptSpinner) + if err != nil { + return nil, err + } + + r.fsCached = os.DirFS(filepath.Join(repoDir, r.templateDir)) + return r.fsCached, nil +} + +func (r *gitReader) Close() error { + if r.tmpRepoDir == "" { + return nil + } + + return os.RemoveAll(r.tmpRepoDir) +} + +func (r *gitReader) Name() string { + return r.name +} + +type localReader struct { + name string + // Path on the local filesystem that contains the template + path string + + fsCached fs.FS +} + +func NewLocalReader(name, path string) Reader { + return &localReader{ + name: name, + path: path, + } +} + +func (r *localReader) FS(ctx context.Context) (fs.FS, error) { + // If the FS has already been loaded, return it. + if r.fsCached != nil { + return r.fsCached, nil + } + + r.fsCached = os.DirFS(r.path) + return r.fsCached, nil +} + +func (r *localReader) Close() error { + return nil +} + +func (r *localReader) Name() string { + return r.name +} + +type failReader struct{} + +func (r *failReader) FS(ctx context.Context) (fs.FS, error) { + return nil, fmt.Errorf("this is a placeholder reader that always fails. Please configure a real reader.") +} + +func (r *failReader) Close() error { + return fmt.Errorf("this is a placeholder reader that always fails. Please configure a real reader.") +} + +func (r *failReader) Name() string { + return "failReader" +} diff --git a/libs/template/template.go b/libs/template/template.go new file mode 100644 index 0000000000..1467ff2e5d --- /dev/null +++ b/libs/template/template.go @@ -0,0 +1,145 @@ +package template + +import ( + "context" + "fmt" + "strings" + + "github.com/databricks/cli/libs/cmdio" + "github.com/databricks/cli/libs/filer" +) + +type Template struct { + Reader Reader + Writer Writer + + Id TemplateId + Description string + Aliases []string + Hidden bool +} + +// TODO: Make details private? +// TODO: Combine this with the generic template struct? +type NativeTemplate struct { + Name string + Description string + Aliases []string + GitUrl string + Hidden bool + IsOwnedByDatabricks bool +} + +type TemplateId string + +const ( + DefaultPython TemplateId = "default-python" + DefaultSql TemplateId = "default-sql" + DbtSql TemplateId = "dbt-sql" + MlopsStacks TemplateId = "mlops-stacks" + DefaultPydabs TemplateId = "default-pydabs" + Custom TemplateId = "custom" +) + +var allTemplates = []Template{ + { + Id: DefaultPython, + Description: "The default Python template for Notebooks / Delta Live Tables / Workflows", + Reader: &builtinReader{name: "default-python"}, + Writer: &writerWithTelemetry{}, + }, + { + Id: DefaultSql, + Description: "The default SQL template for .sql files that run with Databricks SQL", + Reader: &builtinReader{name: "default-sql"}, + Writer: &writerWithTelemetry{}, + }, + { + Id: DbtSql, + Description: "The dbt SQL template (databricks.com/blog/delivering-cost-effective-data-real-time-dbt-and-databricks)", + Reader: &builtinReader{name: "dbt-sql"}, + Writer: &writerWithTelemetry{}, + }, + { + Id: MlopsStacks, + Description: "The Databricks MLOps Stacks template (github.com/databricks/mlops-stacks)", + Aliases: []string{"mlops-stack"}, + Reader: &gitReader{gitUrl: "https://github.com/databricks/mlops-stacks"}, + Writer: &writerWithTelemetry{}, + }, + { + Id: DefaultPydabs, + Hidden: true, + Description: "The default PyDABs template", + Reader: &gitReader{gitUrl: "https://databricks.github.io/workflows-authoring-toolkit/pydabs-template.git"}, + Writer: &writerWithTelemetry{}, + }, + { + Id: Custom, + Description: "Bring your own template", + Reader: &failReader{}, + Writer: &defaultWriter{}, + }, +} + +func HelpDescriptions() string { + var lines []string + for _, template := range allTemplates { + if template.Id != Custom && !template.Hidden { + lines = append(lines, fmt.Sprintf("- %s: %s", template.Id, template.Description)) + } + } + return strings.Join(lines, "\n") +} + +func options() []cmdio.Tuple { + names := make([]cmdio.Tuple, 0, len(allTemplates)) + for _, template := range allTemplates { + if template.Hidden { + continue + } + tuple := cmdio.Tuple{ + Name: string(template.Id), + Id: template.Description, + } + names = append(names, tuple) + } + return names +} + +// TODO CONTINUE defining the methods that the init command will finally rely on. +func PromptForTemplateId(ctx context.Context, ref, templateDir string) (TemplateId, error) { + if !cmdio.IsPromptSupported(ctx) { + return "", fmt.Errorf("please specify a template") + } + description, err := cmdio.SelectOrdered(ctx, options(), "Template to use") + if err != nil { + return "", err + } + + for _, template := range allTemplates { + if template.Description == description { + return template.Id, nil + } + } + + panic("this should never happen - template not found") +} + +func (tmpl *Template) InitializeWriter(configPath string, outputFiler filer.Filer) { + tmpl.Writer.Initialize(tmpl.Reader, configPath, outputFiler) +} + +func (tmpl *Template) SetReader(r Reader) { + tmpl.Reader = r +} + +func Get(id TemplateId) *Template { + for _, template := range allTemplates { + if template.Id == id { + return &template + } + } + + return nil +} diff --git a/cmd/bundle/init_test.go b/libs/template/template_test.go similarity index 64% rename from cmd/bundle/init_test.go rename to libs/template/template_test.go index 475b2e1499..6b6ca0d0e9 100644 --- a/cmd/bundle/init_test.go +++ b/libs/template/template_test.go @@ -1,4 +1,4 @@ -package bundle +package template import ( "testing" @@ -7,12 +7,31 @@ import ( "github.com/stretchr/testify/assert" ) +func TestTemplateHelpDescriptions(t *testing.T) { + expected := `- default-python: The default Python template for Notebooks / Delta Live Tables / Workflows +- default-sql: The default SQL template for .sql files that run with Databricks SQL +- dbt-sql: The dbt SQL template (databricks.com/blog/delivering-cost-effective-data-real-time-dbt-and-databricks) +- mlops-stacks: The Databricks MLOps Stacks template (github.com/databricks/mlops-stacks)` + assert.Equal(t, expected, HelpDescriptions()) +} + +func TestTemplateOptions(t *testing.T) { + expected := []cmdio.Tuple{ + {Name: "default-python", Id: "The default Python template for Notebooks / Delta Live Tables / Workflows"}, + {Name: "default-sql", Id: "The default SQL template for .sql files that run with Databricks SQL"}, + {Name: "dbt-sql", Id: "The dbt SQL template (databricks.com/blog/delivering-cost-effective-data-real-time-dbt-and-databricks)"}, + {Name: "mlops-stacks", Id: "The Databricks MLOps Stacks template (github.com/databricks/mlops-stacks)"}, + {Name: "custom", Id: "Bring your own template"}, + } + assert.Equal(t, expected, options()) +} + func TestBundleInitIsRepoUrl(t *testing.T) { - assert.True(t, isRepoUrl("git@github.com:databricks/cli.git")) - assert.True(t, isRepoUrl("https://github.com/databricks/cli.git")) + assert.True(t, IsGitRepoUrl("git@github.com:databricks/cli.git")) + assert.True(t, IsGitRepoUrl("https://github.com/databricks/cli.git")) - assert.False(t, isRepoUrl("./local")) - assert.False(t, isRepoUrl("foo")) + assert.False(t, IsGitRepoUrl("./local")) + assert.False(t, IsGitRepoUrl("foo")) } func TestBundleInitRepoName(t *testing.T) { @@ -26,29 +45,3 @@ func TestBundleInitRepoName(t *testing.T) { assert.Equal(t, "invalid-url", repoName("invalid-url")) assert.Equal(t, "www.github.com", repoName("https://www.github.com")) } - -func TestNativeTemplateOptions(t *testing.T) { - expected := []cmdio.Tuple{ - {Name: "default-python", Id: "The default Python template for Notebooks / Delta Live Tables / Workflows"}, - {Name: "default-sql", Id: "The default SQL template for .sql files that run with Databricks SQL"}, - {Name: "dbt-sql", Id: "The dbt SQL template (databricks.com/blog/delivering-cost-effective-data-real-time-dbt-and-databricks)"}, - {Name: "mlops-stacks", Id: "The Databricks MLOps Stacks template (github.com/databricks/mlops-stacks)"}, - {Name: "custom...", Id: "Bring your own template"}, - } - assert.Equal(t, expected, nativeTemplateOptions()) -} - -func TestNativeTemplateHelpDescriptions(t *testing.T) { - expected := `- default-python: The default Python template for Notebooks / Delta Live Tables / Workflows -- default-sql: The default SQL template for .sql files that run with Databricks SQL -- dbt-sql: The dbt SQL template (databricks.com/blog/delivering-cost-effective-data-real-time-dbt-and-databricks) -- mlops-stacks: The Databricks MLOps Stacks template (github.com/databricks/mlops-stacks)` - assert.Equal(t, expected, nativeTemplateHelpDescriptions()) -} - -func TestGetUrlForNativeTemplate(t *testing.T) { - assert.Equal(t, "https://github.com/databricks/mlops-stacks", getUrlForNativeTemplate("mlops-stacks")) - assert.Equal(t, "https://github.com/databricks/mlops-stacks", getUrlForNativeTemplate("mlops-stack")) - assert.Equal(t, "", getUrlForNativeTemplate("default-python")) - assert.Equal(t, "", getUrlForNativeTemplate("invalid")) -} diff --git a/libs/template/writer.go b/libs/template/writer.go new file mode 100644 index 0000000000..29a5deb698 --- /dev/null +++ b/libs/template/writer.go @@ -0,0 +1,169 @@ +package template + +import ( + "context" + "errors" + "fmt" + "io/fs" + + "github.com/databricks/cli/libs/cmdio" + "github.com/databricks/cli/libs/filer" +) + +// TODO: Retain coverage for the missing schema test case +// func TestMaterializeForNonTemplateDirectory(t *testing.T) { +// tmpDir := t.TempDir() +// w, err := databricks.NewWorkspaceClient(&databricks.Config{}) +// require.NoError(t, err) +// ctx := root.SetWorkspaceClient(context.Background(), w) + +// tmpl := TemplateX{ +// TemplateOpts: TemplateOpts{ +// ConfigFilePath: "", +// TemplateFS: os.DirFS(tmpDir), +// OutputFiler: nil, +// }, +// } + +// // Try to materialize a non-template directory. +// err = tmpl.Materialize(ctx) +// assert.EqualError(t, err, fmt.Sprintf("not a bundle template: expected to find a template schema file at %s", schemaFileName)) +// } + + +// TODO: Add tests for these writers, mocking the cmdio library +// at the same time. +const ( + libraryDirName = "library" + templateDirName = "template" + schemaFileName = "databricks_template_schema.json" +) + +type Writer interface { + Initialize(reader Reader, configPath string, outputFiler filer.Filer) + Materialize(ctx context.Context) error + LogTelemetry(ctx context.Context) error +} + +type defaultWriter struct { + reader Reader + configPath string + outputFiler filer.Filer + + // Internal state + config *config + renderer *renderer +} + +func (tmpl *defaultWriter) Initialize(reader Reader, configPath string, outputFiler filer.Filer) { + tmpl.configPath = configPath + tmpl.outputFiler = outputFiler +} + +func (tmpl *defaultWriter) promptForInput(ctx context.Context) error { + readerFs, err := tmpl.reader.FS(ctx) + if err != nil { + return err + } + if _, err := fs.Stat(readerFs, schemaFileName); errors.Is(err, fs.ErrNotExist) { + return fmt.Errorf("not a bundle template: expected to find a template schema file at %s", schemaFileName) + } + + tmpl.config, err = newConfig(ctx, readerFs, schemaFileName) + if err != nil { + return err + } + + // Read and assign config values from file + if tmpl.configPath != "" { + err = tmpl.config.assignValuesFromFile(tmpl.configPath) + if err != nil { + return err + } + } + + helpers := loadHelpers(ctx) + tmpl.renderer, err = newRenderer(ctx, tmpl.config.values, helpers, readerFs, templateDirName, libraryDirName) + if err != nil { + return err + } + + // Print welcome message + welcome := tmpl.config.schema.WelcomeMessage + if welcome != "" { + welcome, err = tmpl.renderer.executeTemplate(welcome) + if err != nil { + return err + } + cmdio.LogString(ctx, welcome) + } + + // Prompt user for any missing config values. Assign default values if + // terminal is not TTY + err = tmpl.config.promptOrAssignDefaultValues(tmpl.renderer) + if err != nil { + return err + } + return tmpl.config.validate() +} + +func (tmpl *defaultWriter) printSuccessMessage(ctx context.Context) error { + success := tmpl.config.schema.SuccessMessage + if success == "" { + cmdio.LogString(ctx, "✨ Successfully initialized template") + return nil + } + + success, err := tmpl.renderer.executeTemplate(success) + if err != nil { + return err + } + cmdio.LogString(ctx, success) + return nil +} + +func (tmpl *defaultWriter) Materialize(ctx context.Context) error { + err := tmpl.promptForInput(ctx) + if err != nil { + return err + } + + // Walk the template file tree and compute in-memory representations of the + // output files. + err = tmpl.renderer.walk() + if err != nil { + return err + } + + // Flush the output files to disk. + err = tmpl.renderer.persistToDisk(ctx, tmpl.outputFiler) + if err != nil { + return err + } + + return tmpl.printSuccessMessage(ctx) +} + +func (tmpl *defaultWriter) LogTelemetry(ctx context.Context) error { + // no-op + return nil +} + +type writerWithTelemetry struct { + defaultWriter +} + +func (tmpl *writerWithTelemetry) LogTelemetry(ctx context.Context) error { + // Log telemetry. TODO. + return nil +} + +func NewWriterWithTelemetry(reader Reader, configPath string, outputFiler filer.Filer) Writer { + return &writerWithTelemetry{ + defaultWriter: defaultWriter{ + reader: reader, + configPath: configPath, + outputFiler: outputFiler, + }, + } +} From daa2b919aa872dbc47cfdcdd17f64511ac682a2f Mon Sep 17 00:00:00 2001 From: Shreyas Goenka Date: Fri, 3 Jan 2025 16:30:22 +0530 Subject: [PATCH 02/22] undo test changes --- integration/bundle/init_test.go | 176 -------------------------------- 1 file changed, 176 deletions(-) diff --git a/integration/bundle/init_test.go b/integration/bundle/init_test.go index 3826f55433..f5c263ca3d 100644 --- a/integration/bundle/init_test.go +++ b/integration/bundle/init_test.go @@ -15,7 +15,6 @@ import ( "github.com/databricks/cli/internal/testcli" "github.com/databricks/cli/internal/testutil" "github.com/databricks/cli/libs/iamutil" - "github.com/databricks/cli/libs/telemetry" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -43,9 +42,6 @@ func TestBundleInitOnMlopsStacks(t *testing.T) { ctx, wt := acc.WorkspaceTest(t) w := wt.W - // Use mock logger to introspect the telemetry payload. - ctx = telemetry.WithMockLogger(ctx) - tmpDir1 := t.TempDir() tmpDir2 := t.TempDir() @@ -68,28 +64,6 @@ func TestBundleInitOnMlopsStacks(t *testing.T) { assert.NoFileExists(t, filepath.Join(tmpDir2, "repo_name", projectName, "README.md")) testcli.RequireSuccessfulRun(t, ctx, "bundle", "init", "mlops-stacks", "--output-dir", tmpDir2, "--config-file", filepath.Join(tmpDir1, "config.json")) - // Assert the telemetry payload is correctly logged. - tlmyEvents := telemetry.Introspect(ctx) - require.Len(t, telemetry.Introspect(ctx), 1) - event := tlmyEvents[0].BundleInitEvent - assert.Equal(t, "mlops-stacks", event.TemplateName) - - get := func(key string) string { - for _, v := range event.TemplateEnumArgs { - if v.Key == key { - return v.Value - } - } - return "" - } - - // Enum values should be present in the telemetry payload. - assert.Equal(t, "no", get("input_include_models_in_unity_catalog")) - assert.Equal(t, strings.ToLower(env), get("input_cloud")) - // Freeform strings should not be present in the telemetry payload. - assert.Equal(t, "", get("input_project_name")) - assert.Equal(t, "", get("input_root_dir")) - // Assert that the README.md file was created contents := testutil.ReadFile(t, filepath.Join(tmpDir2, "repo_name", projectName, "README.md")) assert.Contains(t, contents, fmt.Sprintf("# %s", projectName)) @@ -125,156 +99,6 @@ func TestBundleInitOnMlopsStacks(t *testing.T) { assert.Contains(t, job.Settings.Name, fmt.Sprintf("dev-%s-batch-inference-job", projectName)) } -func TestBundleInitTelemetryForDefaultTemplates(t *testing.T) { - projectName := testutil.RandomName("name_") - - tcases := []struct { - name string - args map[string]string - expectedArgs map[string]string - }{ - { - name: "dbt-sql", - args: map[string]string{ - "project_name": fmt.Sprintf("dbt-sql-%s", projectName), - "http_path": "/sql/1.0/warehouses/id", - "default_catalog": "abcd", - "personal_schemas": "yes, use a schema based on the current user name during development", - }, - expectedArgs: map[string]string{ - "personal_schemas": "yes, use a schema based on the current user name during development", - }, - }, - { - name: "default-python", - args: map[string]string{ - "project_name": fmt.Sprintf("default_python_%s", projectName), - "include_notebook": "yes", - "include_dlt": "yes", - "include_python": "no", - }, - expectedArgs: map[string]string{ - "include_notebook": "yes", - "include_dlt": "yes", - "include_python": "no", - }, - }, - { - name: "default-sql", - args: map[string]string{ - "project_name": fmt.Sprintf("sql_project_%s", projectName), - "http_path": "/sql/1.0/warehouses/id", - "default_catalog": "abcd", - "personal_schemas": "yes, automatically use a schema based on the current user name during development", - }, - expectedArgs: map[string]string{ - "personal_schemas": "yes, automatically use a schema based on the current user name during development", - }, - }, - } - - for _, tc := range tcases { - ctx, _ := acc.WorkspaceTest(t) - - // Use mock logger to introspect the telemetry payload. - ctx = telemetry.WithMockLogger(ctx) - - tmpDir1 := t.TempDir() - tmpDir2 := t.TempDir() - - // Create a config file with the project name and root dir - initConfig := tc.args - b, err := json.Marshal(initConfig) - require.NoError(t, err) - err = os.WriteFile(filepath.Join(tmpDir1, "config.json"), b, 0o644) - require.NoError(t, err) - - // Run bundle init - assert.NoDirExists(t, filepath.Join(tmpDir2, tc.args["project_name"])) - testcli.RequireSuccessfulRun(t, ctx, "bundle", "init", tc.name, "--output-dir", tmpDir2, "--config-file", filepath.Join(tmpDir1, "config.json")) - assert.DirExists(t, filepath.Join(tmpDir2, tc.args["project_name"])) - - // Assert the telemetry payload is correctly logged. - logs := telemetry.Introspect(ctx) - require.Len(t, logs, 1) - event := logs[0].BundleInitEvent - assert.Equal(t, event.TemplateName, tc.name) - - get := func(key string) string { - for _, v := range event.TemplateEnumArgs { - if v.Key == key { - return v.Value - } - } - return "" - } - - // Assert the template enum args are correctly logged. - assert.Len(t, event.TemplateEnumArgs, len(tc.expectedArgs)) - for k, v := range tc.expectedArgs { - assert.Equal(t, get(k), v) - } - } -} - -func TestBundleInitTelemetryForCustomTemplates(t *testing.T) { - ctx, _ := acc.WorkspaceTest(t) - - tmpDir1 := t.TempDir() - tmpDir2 := t.TempDir() - tmpDir3 := t.TempDir() - - err := os.Mkdir(filepath.Join(tmpDir1, "template"), 0o755) - require.NoError(t, err) - err = os.WriteFile(filepath.Join(tmpDir1, "template", "foo.txt.tmpl"), []byte("{{bundle_uuid}}"), 0o644) - require.NoError(t, err) - err = os.WriteFile(filepath.Join(tmpDir1, "databricks_template_schema.json"), []byte(` -{ - "properties": { - "a": { - "description": "whatever", - "type": "string" - }, - "b": { - "description": "whatever", - "type": "string", - "enum": ["yes", "no"] - } - } -} -`), 0o644) - require.NoError(t, err) - - // Create a config file with the project name and root dir - initConfig := map[string]string{ - "a": "v1", - "b": "yes", - } - b, err := json.Marshal(initConfig) - require.NoError(t, err) - err = os.WriteFile(filepath.Join(tmpDir3, "config.json"), b, 0o644) - require.NoError(t, err) - - // Use mock logger to introspect the telemetry payload. - ctx = telemetry.WithMockLogger(ctx) - - // Run bundle init. - testcli.RequireSuccessfulRun(t, ctx, "bundle", "init", tmpDir1, "--output-dir", tmpDir2, "--config-file", filepath.Join(tmpDir3, "config.json")) - - // Assert the telemetry payload is correctly logged. For custom templates we should - // never set template_enum_args. - tlmyEvents := telemetry.Introspect(ctx) - require.Len(t, len(tlmyEvents), 1) - event := tlmyEvents[0].BundleInitEvent - assert.Equal(t, "custom", event.TemplateName) - assert.Empty(t, event.TemplateEnumArgs) - - // Ensure that the UUID returned by the `bundle_uuid` helper is the same UUID - // that's logged in the telemetry event. - fileC := testutil.ReadFile(t, filepath.Join(tmpDir2, "foo.txt")) - assert.Equal(t, event.Uuid, fileC) -} - func TestBundleInitHelpers(t *testing.T) { ctx, wt := acc.WorkspaceTest(t) w := wt.W From a743139f8372ae968c5082da9c501461ff6630ec Mon Sep 17 00:00:00 2001 From: Shreyas Goenka Date: Fri, 3 Jan 2025 16:31:49 +0530 Subject: [PATCH 03/22] remove final telemetry bits --- libs/template/config.go | 21 ------------------- libs/template/config_test.go | 39 ------------------------------------ 2 files changed, 60 deletions(-) diff --git a/libs/template/config.go b/libs/template/config.go index 34eee065c6..8e7695b915 100644 --- a/libs/template/config.go +++ b/libs/template/config.go @@ -5,7 +5,6 @@ import ( "errors" "fmt" "io/fs" - "slices" "github.com/databricks/cli/libs/cmdio" "github.com/databricks/cli/libs/jsonschema" @@ -274,23 +273,3 @@ func (c *config) validate() error { } return nil } - -// Return enum values selected by the user during template initialization. These -// values are safe to send over in telemetry events due to their limited cardinality. -func (c *config) enumValues() map[string]string { - res := map[string]string{} - for k, p := range c.schema.Properties { - if p.Type != jsonschema.StringType { - continue - } - if p.Enum == nil { - continue - } - v := c.values[k] - - if slices.Contains(p.Enum, v) { - res[k] = v.(string) - } - } - return res -} diff --git a/libs/template/config_test.go b/libs/template/config_test.go index 3f971a8622..515a0b9f5b 100644 --- a/libs/template/config_test.go +++ b/libs/template/config_test.go @@ -564,42 +564,3 @@ func TestPromptIsSkippedAnyOf(t *testing.T) { assert.True(t, skip) assert.Equal(t, "hello-world", c.values["xyz"]) } - -func TestConfigEnumValues(t *testing.T) { - c := &config{ - schema: &jsonschema.Schema{ - Properties: map[string]*jsonschema.Schema{ - "a": { - Type: jsonschema.StringType, - }, - "b": { - Type: jsonschema.BooleanType, - }, - "c": { - Type: jsonschema.StringType, - Enum: []any{"v1", "v2"}, - }, - "d": { - Type: jsonschema.StringType, - Enum: []any{"v3", "v4"}, - }, - "e": { - Type: jsonschema.StringType, - Enum: []any{"v5", "v6"}, - }, - }, - }, - values: map[string]any{ - "a": "w1", - "b": false, - "c": "v1", - "d": "v3", - "e": "v7", - }, - } - - assert.Equal(t, map[string]string{ - "c": "v1", - "d": "v3", - }, c.enumValues()) -} From d238dd833cf0e6019a309324d92d43926b4fff05 Mon Sep 17 00:00:00 2001 From: Shreyas Goenka Date: Fri, 3 Jan 2025 17:28:22 +0530 Subject: [PATCH 04/22] add resolver --- cmd/bundle/init.go | 86 +++++++-------------------------------- libs/template/reader.go | 43 ++------------------ libs/template/resolve.go | 70 +++++++++++++++++++++++++++++++ libs/template/template.go | 52 ++++++++++------------- libs/template/writer.go | 57 +++++++++++++++++--------- 5 files changed, 147 insertions(+), 161 deletions(-) create mode 100644 libs/template/resolve.go diff --git a/cmd/bundle/init.go b/cmd/bundle/init.go index 4da5a69be7..307e367d60 100644 --- a/cmd/bundle/init.go +++ b/cmd/bundle/init.go @@ -1,40 +1,15 @@ package bundle import ( - "context" "errors" "fmt" - "path/filepath" - "strings" "github.com/databricks/cli/cmd/root" "github.com/databricks/cli/libs/cmdio" - "github.com/databricks/cli/libs/dbr" - "github.com/databricks/cli/libs/filer" "github.com/databricks/cli/libs/template" "github.com/spf13/cobra" ) -func constructOutputFiler(ctx context.Context, outputDir string) (filer.Filer, error) { - outputDir, err := filepath.Abs(outputDir) - if err != nil { - return nil, err - } - - // If the CLI is running on DBR and we're writing to the workspace file system, - // use the extension-aware workspace filesystem filer to instantiate the template. - // - // It is not possible to write notebooks through the workspace filesystem's FUSE mount. - // Therefore this is the only way we can initialize templates that contain notebooks - // when running the CLI on DBR and initializing a template to the workspace. - // - if strings.HasPrefix(outputDir, "/Workspace/") && dbr.RunsOnRuntime(ctx) { - return filer.NewWorkspaceFilesExtensionsClient(root.WorkspaceClient(ctx), outputDir) - } - - return filer.NewLocalClient(outputDir) -} - func newInitCommand() *cobra.Command { cmd := &cobra.Command{ Use: "init [TEMPLATE_PATH]", @@ -62,61 +37,28 @@ See https://docs.databricks.com/en/dev-tools/bundles/templates.html for more inf cmd.Flags().StringVar(&tag, "branch", "", "Git branch to use for template initialization") cmd.RunE = func(cmd *cobra.Command, args []string) error { - if tag != "" && branch != "" { - return errors.New("only one of --tag or --branch can be specified") - } - - // Git ref to use for template initialization - ref := branch - if tag != "" { - ref = tag + r := template.Resolver{ + TemplatePathOrUrl: args[0], + ConfigFile: configFile, + OutputDir: outputDir, + TemplateDir: templateDir, + Tag: tag, + Branch: branch, } - var tmpl *template.Template - var err error ctx := cmd.Context() - - if len(args) > 0 { - // User already specified a template local path or a Git URL. Use that - // information to configure a reader for the template - tmpl = template.Get(template.Custom) - // TODO: Get rid of the name arg. - if template.IsGitRepoUrl(args[0]) { - tmpl.SetReader(template.NewGitReader("", args[0], ref, templateDir)) - } else { - tmpl.SetReader(template.NewLocalReader("", args[0])) - } - } else { - tmplId, err := template.PromptForTemplateId(cmd.Context(), ref, templateDir) - if tmplId == template.Custom { - // If a user selects custom during the prompt, ask them to provide a path or Git URL - // as a positional argument. - cmdio.LogString(ctx, "Please specify a path or Git repository to use a custom template.") - cmdio.LogString(ctx, "See https://docs.databricks.com/en/dev-tools/bundles/templates.html to learn more about custom templates.") - return nil - } - if err != nil { - return err - } - - tmpl = template.Get(tmplId) + tmpl, err := r.Resolve(ctx) + if errors.Is(err, template.ErrCustomSelected) { + cmdio.LogString(ctx, "Please specify a path or Git repository to use a custom template.") + cmdio.LogString(ctx, "See https://docs.databricks.com/en/dev-tools/bundles/templates.html to learn more about custom templates.") + return nil } - - defer tmpl.Reader.Close() - - outputFiler, err := constructOutputFiler(ctx, outputDir) - if err != nil { - return err - } - - tmpl.Writer.Initialize(tmpl.Reader, configFile, outputFiler) - - err = tmpl.Writer.Materialize(ctx) if err != nil { return err } + defer tmpl.Reader.Close() - return tmpl.Writer.LogTelemetry(ctx) + return tmpl.Writer.Materialize(ctx, tmpl.Reader) } return cmd } diff --git a/libs/template/reader.go b/libs/template/reader.go index 6cfaf9cb6c..19d4ec243e 100644 --- a/libs/template/reader.go +++ b/libs/template/reader.go @@ -21,12 +21,10 @@ type Reader interface { // Close releases any resources associated with the reader // like cleaning up temporary directories. Close() error - - Name() string } type builtinReader struct { - name string + name TemplateName fsCached fs.FS } @@ -43,7 +41,7 @@ func (r *builtinReader) FS(ctx context.Context) (fs.FS, error) { var templateFS fs.FS for _, entry := range builtin { - if entry.Name == r.name { + if entry.Name == string(r.name) { templateFS = entry.FS break } @@ -57,13 +55,7 @@ func (r *builtinReader) Close() error { return nil } -func (r *builtinReader) Name() string { - return r.name -} - type gitReader struct { - name string - // URL of the git repository that contains the template gitUrl string // tag or branch to checkout ref string @@ -88,7 +80,7 @@ var gitUrlPrefixes = []string{ "git@", } -// TODO: Copy over tests for this function. +// TODO: Make private? func IsGitRepoUrl(url string) bool { result := false for _, prefix := range gitUrlPrefixes { @@ -100,16 +92,6 @@ func IsGitRepoUrl(url string) bool { return result } -// TODO: Can I remove the name from here and other readers? -func NewGitReader(name, gitUrl, ref, templateDir string) Reader { - return &gitReader{ - name: name, - gitUrl: gitUrl, - ref: ref, - templateDir: templateDir, - } -} - // TODO: Test the idempotency of this function as well. func (r *gitReader) FS(ctx context.Context) (fs.FS, error) { // If the FS has already been loaded, return it. @@ -147,10 +129,6 @@ func (r *gitReader) Close() error { return os.RemoveAll(r.tmpRepoDir) } -func (r *gitReader) Name() string { - return r.name -} - type localReader struct { name string // Path on the local filesystem that contains the template @@ -159,13 +137,6 @@ type localReader struct { fsCached fs.FS } -func NewLocalReader(name, path string) Reader { - return &localReader{ - name: name, - path: path, - } -} - func (r *localReader) FS(ctx context.Context) (fs.FS, error) { // If the FS has already been loaded, return it. if r.fsCached != nil { @@ -180,10 +151,6 @@ func (r *localReader) Close() error { return nil } -func (r *localReader) Name() string { - return r.name -} - type failReader struct{} func (r *failReader) FS(ctx context.Context) (fs.FS, error) { @@ -193,7 +160,3 @@ func (r *failReader) FS(ctx context.Context) (fs.FS, error) { func (r *failReader) Close() error { return fmt.Errorf("this is a placeholder reader that always fails. Please configure a real reader.") } - -func (r *failReader) Name() string { - return "failReader" -} diff --git a/libs/template/resolve.go b/libs/template/resolve.go new file mode 100644 index 0000000000..a4099e2758 --- /dev/null +++ b/libs/template/resolve.go @@ -0,0 +1,70 @@ +package template + +import ( + "context" + "errors" +) + +type Resolver struct { + TemplatePathOrUrl string + ConfigFile string + OutputDir string + TemplateDir string + Tag string + Branch string +} + +var ErrCustomSelected = errors.New("custom template selected") + +// Configures the reader and the writer for template and returns +// a handle to the template. +// Prompts the user if needed. +func (r Resolver) Resolve(ctx context.Context) (*Template, error) { + if r.Tag != "" && r.Branch != "" { + return nil, errors.New("only one of --tag or --branch can be specified") + } + + // Git ref to use for template initialization + ref := r.Branch + if r.Tag != "" { + ref = r.Tag + } + + var tmpl *Template + if r.TemplatePathOrUrl == "" { + // Prompt the user to select a template + // if a template path or URL is not provided. + tmplId, err := SelectTemplate(ctx) + if err != nil { + return nil, err + } + + if tmplId == Custom { + return nil, ErrCustomSelected + } + + tmpl = Get(tmplId) + } else { + // Based on the provided template path or URL, + // configure a reader for the template. + tmpl = Get(Custom) + if IsGitRepoUrl(r.TemplatePathOrUrl) { + tmpl.Reader = &gitReader{ + gitUrl: r.TemplatePathOrUrl, + ref: ref, + templateDir: r.TemplateDir, + } + } else { + tmpl.Reader = &localReader{ + path: r.TemplatePathOrUrl, + } + } + } + + err := tmpl.Writer.Configure(ctx, r.ConfigFile, r.OutputDir) + if err != nil { + return nil, err + } + + return tmpl, nil +} diff --git a/libs/template/template.go b/libs/template/template.go index 1467ff2e5d..ec8e1ac152 100644 --- a/libs/template/template.go +++ b/libs/template/template.go @@ -6,14 +6,14 @@ import ( "strings" "github.com/databricks/cli/libs/cmdio" - "github.com/databricks/cli/libs/filer" ) type Template struct { + // TODO: Make private as much as possible. Reader Reader Writer Writer - Id TemplateId + Name TemplateName Description string Aliases []string Hidden bool @@ -30,52 +30,52 @@ type NativeTemplate struct { IsOwnedByDatabricks bool } -type TemplateId string +type TemplateName string const ( - DefaultPython TemplateId = "default-python" - DefaultSql TemplateId = "default-sql" - DbtSql TemplateId = "dbt-sql" - MlopsStacks TemplateId = "mlops-stacks" - DefaultPydabs TemplateId = "default-pydabs" - Custom TemplateId = "custom" + DefaultPython TemplateName = "default-python" + DefaultSql TemplateName = "default-sql" + DbtSql TemplateName = "dbt-sql" + MlopsStacks TemplateName = "mlops-stacks" + DefaultPydabs TemplateName = "default-pydabs" + Custom TemplateName = "custom" ) var allTemplates = []Template{ { - Id: DefaultPython, + Name: DefaultPython, Description: "The default Python template for Notebooks / Delta Live Tables / Workflows", Reader: &builtinReader{name: "default-python"}, Writer: &writerWithTelemetry{}, }, { - Id: DefaultSql, + Name: DefaultSql, Description: "The default SQL template for .sql files that run with Databricks SQL", Reader: &builtinReader{name: "default-sql"}, Writer: &writerWithTelemetry{}, }, { - Id: DbtSql, + Name: DbtSql, Description: "The dbt SQL template (databricks.com/blog/delivering-cost-effective-data-real-time-dbt-and-databricks)", Reader: &builtinReader{name: "dbt-sql"}, Writer: &writerWithTelemetry{}, }, { - Id: MlopsStacks, + Name: MlopsStacks, Description: "The Databricks MLOps Stacks template (github.com/databricks/mlops-stacks)", Aliases: []string{"mlops-stack"}, Reader: &gitReader{gitUrl: "https://github.com/databricks/mlops-stacks"}, Writer: &writerWithTelemetry{}, }, { - Id: DefaultPydabs, + Name: DefaultPydabs, Hidden: true, Description: "The default PyDABs template", Reader: &gitReader{gitUrl: "https://databricks.github.io/workflows-authoring-toolkit/pydabs-template.git"}, Writer: &writerWithTelemetry{}, }, { - Id: Custom, + Name: Custom, Description: "Bring your own template", Reader: &failReader{}, Writer: &defaultWriter{}, @@ -85,8 +85,8 @@ var allTemplates = []Template{ func HelpDescriptions() string { var lines []string for _, template := range allTemplates { - if template.Id != Custom && !template.Hidden { - lines = append(lines, fmt.Sprintf("- %s: %s", template.Id, template.Description)) + if template.Name != Custom && !template.Hidden { + lines = append(lines, fmt.Sprintf("- %s: %s", template.Name, template.Description)) } } return strings.Join(lines, "\n") @@ -99,7 +99,7 @@ func options() []cmdio.Tuple { continue } tuple := cmdio.Tuple{ - Name: string(template.Id), + Name: string(template.Name), Id: template.Description, } names = append(names, tuple) @@ -108,7 +108,7 @@ func options() []cmdio.Tuple { } // TODO CONTINUE defining the methods that the init command will finally rely on. -func PromptForTemplateId(ctx context.Context, ref, templateDir string) (TemplateId, error) { +func SelectTemplate(ctx context.Context) (TemplateName, error) { if !cmdio.IsPromptSupported(ctx) { return "", fmt.Errorf("please specify a template") } @@ -119,24 +119,16 @@ func PromptForTemplateId(ctx context.Context, ref, templateDir string) (Template for _, template := range allTemplates { if template.Description == description { - return template.Id, nil + return template.Name, nil } } panic("this should never happen - template not found") } -func (tmpl *Template) InitializeWriter(configPath string, outputFiler filer.Filer) { - tmpl.Writer.Initialize(tmpl.Reader, configPath, outputFiler) -} - -func (tmpl *Template) SetReader(r Reader) { - tmpl.Reader = r -} - -func Get(id TemplateId) *Template { +func Get(name TemplateName) *Template { for _, template := range allTemplates { - if template.Id == id { + if template.Name == name { return &template } } diff --git a/libs/template/writer.go b/libs/template/writer.go index 29a5deb698..b0ec1ad46f 100644 --- a/libs/template/writer.go +++ b/libs/template/writer.go @@ -5,8 +5,12 @@ import ( "errors" "fmt" "io/fs" + "path/filepath" + "strings" + "github.com/databricks/cli/cmd/root" "github.com/databricks/cli/libs/cmdio" + "github.com/databricks/cli/libs/dbr" "github.com/databricks/cli/libs/filer" ) @@ -30,7 +34,6 @@ import ( // assert.EqualError(t, err, fmt.Sprintf("not a bundle template: expected to find a template schema file at %s", schemaFileName)) // } - // TODO: Add tests for these writers, mocking the cmdio library // at the same time. const ( @@ -40,13 +43,12 @@ const ( ) type Writer interface { - Initialize(reader Reader, configPath string, outputFiler filer.Filer) - Materialize(ctx context.Context) error + Configure(ctx context.Context, configPath, outputDir string) error + Materialize(ctx context.Context, r Reader) error LogTelemetry(ctx context.Context) error } type defaultWriter struct { - reader Reader configPath string outputFiler filer.Filer @@ -55,13 +57,40 @@ type defaultWriter struct { renderer *renderer } -func (tmpl *defaultWriter) Initialize(reader Reader, configPath string, outputFiler filer.Filer) { +func constructOutputFiler(ctx context.Context, outputDir string) (filer.Filer, error) { + outputDir, err := filepath.Abs(outputDir) + if err != nil { + return nil, err + } + + // If the CLI is running on DBR and we're writing to the workspace file system, + // use the extension-aware workspace filesystem filer to instantiate the template. + // + // It is not possible to write notebooks through the workspace filesystem's FUSE mount. + // Therefore this is the only way we can initialize templates that contain notebooks + // when running the CLI on DBR and initializing a template to the workspace. + // + if strings.HasPrefix(outputDir, "/Workspace/") && dbr.RunsOnRuntime(ctx) { + return filer.NewWorkspaceFilesExtensionsClient(root.WorkspaceClient(ctx), outputDir) + } + + return filer.NewLocalClient(outputDir) +} + +func (tmpl *defaultWriter) Configure(ctx context.Context, configPath string, outputDir string) error { tmpl.configPath = configPath + + outputFiler, err := constructOutputFiler(ctx, outputDir) + if err != nil { + return err + } + tmpl.outputFiler = outputFiler + return nil } -func (tmpl *defaultWriter) promptForInput(ctx context.Context) error { - readerFs, err := tmpl.reader.FS(ctx) +func (tmpl *defaultWriter) promptForInput(ctx context.Context, reader Reader) error { + readerFs, err := reader.FS(ctx) if err != nil { return err } @@ -122,8 +151,8 @@ func (tmpl *defaultWriter) printSuccessMessage(ctx context.Context) error { return nil } -func (tmpl *defaultWriter) Materialize(ctx context.Context) error { - err := tmpl.promptForInput(ctx) +func (tmpl *defaultWriter) Materialize(ctx context.Context, reader Reader) error { + err := tmpl.promptForInput(ctx, reader) if err != nil { return err } @@ -157,13 +186,3 @@ func (tmpl *writerWithTelemetry) LogTelemetry(ctx context.Context) error { // Log telemetry. TODO. return nil } - -func NewWriterWithTelemetry(reader Reader, configPath string, outputFiler filer.Filer) Writer { - return &writerWithTelemetry{ - defaultWriter: defaultWriter{ - reader: reader, - configPath: configPath, - outputFiler: outputFiler, - }, - } -} From 2965c302683e64587b39b94b647cdac2ecf39282 Mon Sep 17 00:00:00 2001 From: Shreyas Goenka Date: Fri, 3 Jan 2025 18:29:50 +0530 Subject: [PATCH 05/22] add unit tests for reader --- integration/bundle/helpers_test.go | 40 +++++------ libs/template/builtin.go | 13 ++-- libs/template/builtin_test.go | 4 +- libs/template/reader.go | 23 +++--- libs/template/reader_test.go | 112 +++++++++++++++++++++++++++++ libs/template/resolve.go | 5 +- libs/template/template.go | 69 ++++++++---------- libs/template/template_test.go | 8 +-- 8 files changed, 190 insertions(+), 84 deletions(-) create mode 100644 libs/template/reader_test.go diff --git a/integration/bundle/helpers_test.go b/integration/bundle/helpers_test.go index 60177297e6..032b7f95ec 100644 --- a/integration/bundle/helpers_test.go +++ b/integration/bundle/helpers_test.go @@ -8,13 +8,17 @@ import ( "os" "os/exec" "path/filepath" + "strings" "github.com/databricks/cli/bundle" + "github.com/databricks/cli/cmd/root" "github.com/databricks/cli/internal/testcli" "github.com/databricks/cli/internal/testutil" "github.com/databricks/cli/libs/cmdio" "github.com/databricks/cli/libs/env" + "github.com/databricks/cli/libs/flags" "github.com/databricks/cli/libs/folders" + "github.com/databricks/cli/libs/template" "github.com/databricks/databricks-sdk-go" "github.com/stretchr/testify/require" ) @@ -27,32 +31,28 @@ func initTestTemplate(t testutil.TestingT, ctx context.Context, templateName str } func initTestTemplateWithBundleRoot(t testutil.TestingT, ctx context.Context, templateName string, config map[string]any, bundleRoot string) string { - return "" + templateRoot := filepath.Join("bundles", templateName) - // TODO: Make this function work but do not log telemetry. + configFilePath := writeConfigFile(t, config) - // templateRoot := filepath.Join("bundles", templateName) + ctx = root.SetWorkspaceClient(ctx, nil) + cmd := cmdio.NewIO(ctx, flags.OutputJSON, strings.NewReader(""), os.Stdout, os.Stderr, "", "bundles") + ctx = cmdio.InContext(ctx, cmd) - // configFilePath := writeConfigFile(t, config) + r := template.Resolver{ + TemplatePathOrUrl: templateRoot, + ConfigFile: configFilePath, + OutputDir: bundleRoot, + } - // ctx = root.SetWorkspaceClient(ctx, nil) - // cmd := cmdio.NewIO(ctx, flags.OutputJSON, strings.NewReader(""), os.Stdout, os.Stderr, "", "bundles") - // ctx = cmdio.InContext(ctx, cmd) - // ctx = telemetry.WithMockLogger(ctx) + tmpl, err := r.Resolve(ctx) + require.NoError(t, err) + defer tmpl.Reader.Close() - // out, err := filer.NewLocalClient(bundleRoot) - // require.NoError(t, err) - // tmpl := template.TemplateX{ - // TemplateOpts: template.TemplateOpts{ - // ConfigFilePath: configFilePath, - // TemplateFS: os.DirFS(templateRoot), - // OutputFiler: out, - // }, - // } + err = tmpl.Writer.Materialize(ctx, tmpl.Reader) + require.NoError(t, err) - // err = tmpl.Materialize(ctx) - // require.NoError(t, err) - // return bundleRoot + return bundleRoot } func writeConfigFile(t testutil.TestingT, config map[string]any) string { diff --git a/libs/template/builtin.go b/libs/template/builtin.go index 96cdcbb961..5b10534ef5 100644 --- a/libs/template/builtin.go +++ b/libs/template/builtin.go @@ -8,15 +8,14 @@ import ( //go:embed all:templates var builtinTemplates embed.FS -// BuiltinTemplate represents a template that is built into the CLI. -type BuiltinTemplate struct { +// builtinTemplate represents a template that is built into the CLI. +type builtinTemplate struct { Name string FS fs.FS } -// Builtin returns the list of all built-in templates. -// TODO: Make private? -func Builtin() ([]BuiltinTemplate, error) { +// builtin returns the list of all built-in templates. +func builtin() ([]builtinTemplate, error) { templates, err := fs.Sub(builtinTemplates, "templates") if err != nil { return nil, err @@ -27,7 +26,7 @@ func Builtin() ([]BuiltinTemplate, error) { return nil, err } - var out []BuiltinTemplate + var out []builtinTemplate for _, entry := range entries { if !entry.IsDir() { continue @@ -38,7 +37,7 @@ func Builtin() ([]BuiltinTemplate, error) { return nil, err } - out = append(out, BuiltinTemplate{ + out = append(out, builtinTemplate{ Name: entry.Name(), FS: templateFS, }) diff --git a/libs/template/builtin_test.go b/libs/template/builtin_test.go index 79e04cb841..162a227ea9 100644 --- a/libs/template/builtin_test.go +++ b/libs/template/builtin_test.go @@ -9,12 +9,12 @@ import ( ) func TestBuiltin(t *testing.T) { - out, err := Builtin() + out, err := builtin() require.NoError(t, err) assert.GreaterOrEqual(t, len(out), 3) // Create a map of templates by name for easier lookup - templates := make(map[string]*BuiltinTemplate) + templates := make(map[string]*builtinTemplate) for _, tmpl := range out { templates[tmpl.Name] = &tmpl } diff --git a/libs/template/reader.go b/libs/template/reader.go index 19d4ec243e..56d264eddc 100644 --- a/libs/template/reader.go +++ b/libs/template/reader.go @@ -9,10 +9,8 @@ import ( "strings" "github.com/databricks/cli/libs/cmdio" - "github.com/databricks/cli/libs/git" ) -// TODO: Add tests for all these readers. type Reader interface { // FS returns a file system that contains the template // definition files. This function is NOT thread safe. @@ -24,7 +22,7 @@ type Reader interface { } type builtinReader struct { - name TemplateName + name string fsCached fs.FS } @@ -34,19 +32,23 @@ func (r *builtinReader) FS(ctx context.Context) (fs.FS, error) { return r.fsCached, nil } - builtin, err := Builtin() + builtin, err := builtin() if err != nil { return nil, err } var templateFS fs.FS for _, entry := range builtin { - if entry.Name == string(r.name) { + if entry.Name == r.name { templateFS = entry.FS break } } + if templateFS == nil { + return nil, fmt.Errorf("builtin template %s not found", r.name) + } + r.fsCached = templateFS return r.fsCached, nil } @@ -64,6 +66,10 @@ type gitReader struct { // temporary directory where the repository is cloned tmpRepoDir string + // Function to clone the repository. This is a function pointer to allow + // mocking in tests. + cloneFunc func(ctx context.Context, url, reference, targetPath string) error + fsCached fs.FS } @@ -80,8 +86,7 @@ var gitUrlPrefixes = []string{ "git@", } -// TODO: Make private? -func IsGitRepoUrl(url string) bool { +func isRepoUrl(url string) bool { result := false for _, prefix := range gitUrlPrefixes { if strings.HasPrefix(url, prefix) { @@ -92,7 +97,6 @@ func IsGitRepoUrl(url string) bool { return result } -// TODO: Test the idempotency of this function as well. func (r *gitReader) FS(ctx context.Context) (fs.FS, error) { // If the FS has already been loaded, return it. if r.fsCached != nil { @@ -111,7 +115,7 @@ func (r *gitReader) FS(ctx context.Context) (fs.FS, error) { promptSpinner := cmdio.Spinner(ctx) promptSpinner <- "Downloading the template\n" - err = git.Clone(ctx, r.gitUrl, r.ref, repoDir) + err = r.cloneFunc(ctx, r.gitUrl, r.ref, repoDir) close(promptSpinner) if err != nil { return nil, err @@ -130,7 +134,6 @@ func (r *gitReader) Close() error { } type localReader struct { - name string // Path on the local filesystem that contains the template path string diff --git a/libs/template/reader_test.go b/libs/template/reader_test.go new file mode 100644 index 0000000000..1cce2321ba --- /dev/null +++ b/libs/template/reader_test.go @@ -0,0 +1,112 @@ +package template + +import ( + "context" + "io" + "os" + "path/filepath" + "strings" + "testing" + + "github.com/databricks/cli/internal/testutil" + "github.com/databricks/cli/libs/cmdio" + "github.com/databricks/cli/libs/flags" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestBuiltInReader(t *testing.T) { + exists := []string{ + "default-python", + "default-sql", + "dbt-sql", + } + + for _, name := range exists { + r := &builtinReader{name: name} + fs, err := r.FS(context.Background()) + assert.NoError(t, err) + assert.NotNil(t, fs) + } + + // TODO: Read one of the files to confirm further test this reader. + r := &builtinReader{name: "doesnotexist"} + _, err := r.FS(context.Background()) + assert.EqualError(t, err, "builtin template doesnotexist not found") +} + +func TestGitUrlReader(t *testing.T) { + ctx := context.Background() + cmd := cmdio.NewIO(ctx, flags.OutputJSON, strings.NewReader(""), os.Stdout, os.Stderr, "", "bundles") + ctx = cmdio.InContext(ctx, cmd) + + var args []string + numCalls := 0 + cloneFunc := func(ctx context.Context, url, reference, targetPath string) error { + numCalls++ + args = []string{url, reference, targetPath} + err := os.MkdirAll(filepath.Join(targetPath, "a/b/c"), 0o755) + require.NoError(t, err) + testutil.WriteFile(t, filepath.Join(targetPath, "a", "b", "c", "somefile"), "somecontent") + return nil + } + r := &gitReader{ + gitUrl: "someurl", + cloneFunc: cloneFunc, + ref: "sometag", + templateDir: "a/b/c", + } + + // Assert cloneFunc is called with the correct args. + fs, err := r.FS(ctx) + require.NoError(t, err) + require.NotEmpty(t, r.tmpRepoDir) + assert.DirExists(t, r.tmpRepoDir) + assert.Equal(t, []string{"someurl", "sometag", r.tmpRepoDir}, args) + + // Assert the fs returned is rooted at the templateDir. + fd, err := fs.Open("somefile") + require.NoError(t, err) + defer fd.Close() + b, err := io.ReadAll(fd) + require.NoError(t, err) + assert.Equal(t, "somecontent", string(b)) + + // Assert the FS is cached. cloneFunc should not be called again. + _, err = r.FS(ctx) + require.NoError(t, err) + assert.Equal(t, 1, numCalls) + + // Assert Close cleans up the tmpRepoDir. + err = r.Close() + require.NoError(t, err) + assert.NoDirExists(t, r.tmpRepoDir) +} + +func TestLocalReader(t *testing.T) { + tmpDir := t.TempDir() + testutil.WriteFile(t, filepath.Join(tmpDir, "somefile"), "somecontent") + ctx := context.Background() + + r := &localReader{path: tmpDir} + fs, err := r.FS(ctx) + require.NoError(t, err) + + // Assert the fs returned is rooted at correct location. + fd, err := fs.Open("somefile") + require.NoError(t, err) + defer fd.Close() + b, err := io.ReadAll(fd) + require.NoError(t, err) + assert.Equal(t, "somecontent", string(b)) + + // Assert close does not error + assert.NoError(t, r.Close()) +} + +func TestFailReader(t *testing.T) { + r := &failReader{} + assert.Error(t, r.Close()) + _, err := r.FS(context.Background()) + assert.Error(t, err) +} diff --git a/libs/template/resolve.go b/libs/template/resolve.go index a4099e2758..7d3b68bff4 100644 --- a/libs/template/resolve.go +++ b/libs/template/resolve.go @@ -3,6 +3,8 @@ package template import ( "context" "errors" + + "github.com/databricks/cli/libs/git" ) type Resolver struct { @@ -48,11 +50,12 @@ func (r Resolver) Resolve(ctx context.Context) (*Template, error) { // Based on the provided template path or URL, // configure a reader for the template. tmpl = Get(Custom) - if IsGitRepoUrl(r.TemplatePathOrUrl) { + if isRepoUrl(r.TemplatePathOrUrl) { tmpl.Reader = &gitReader{ gitUrl: r.TemplatePathOrUrl, ref: ref, templateDir: r.TemplateDir, + cloneFunc: git.Clone, } } else { tmpl.Reader = &localReader{ diff --git a/libs/template/template.go b/libs/template/template.go index ec8e1ac152..75b913d871 100644 --- a/libs/template/template.go +++ b/libs/template/template.go @@ -6,28 +6,17 @@ import ( "strings" "github.com/databricks/cli/libs/cmdio" + "github.com/databricks/cli/libs/git" ) type Template struct { - // TODO: Make private as much as possible. Reader Reader Writer Writer - Name TemplateName - Description string - Aliases []string - Hidden bool -} - -// TODO: Make details private? -// TODO: Combine this with the generic template struct? -type NativeTemplate struct { - Name string - Description string - Aliases []string - GitUrl string - Hidden bool - IsOwnedByDatabricks bool + name TemplateName + description string + aliases []string + hidden bool } type TemplateName string @@ -43,40 +32,40 @@ const ( var allTemplates = []Template{ { - Name: DefaultPython, - Description: "The default Python template for Notebooks / Delta Live Tables / Workflows", + name: DefaultPython, + description: "The default Python template for Notebooks / Delta Live Tables / Workflows", Reader: &builtinReader{name: "default-python"}, Writer: &writerWithTelemetry{}, }, { - Name: DefaultSql, - Description: "The default SQL template for .sql files that run with Databricks SQL", + name: DefaultSql, + description: "The default SQL template for .sql files that run with Databricks SQL", Reader: &builtinReader{name: "default-sql"}, Writer: &writerWithTelemetry{}, }, { - Name: DbtSql, - Description: "The dbt SQL template (databricks.com/blog/delivering-cost-effective-data-real-time-dbt-and-databricks)", + name: DbtSql, + description: "The dbt SQL template (databricks.com/blog/delivering-cost-effective-data-real-time-dbt-and-databricks)", Reader: &builtinReader{name: "dbt-sql"}, Writer: &writerWithTelemetry{}, }, { - Name: MlopsStacks, - Description: "The Databricks MLOps Stacks template (github.com/databricks/mlops-stacks)", - Aliases: []string{"mlops-stack"}, - Reader: &gitReader{gitUrl: "https://github.com/databricks/mlops-stacks"}, + name: MlopsStacks, + description: "The Databricks MLOps Stacks template (github.com/databricks/mlops-stacks)", + aliases: []string{"mlops-stack"}, + Reader: &gitReader{gitUrl: "https://github.com/databricks/mlops-stacks", cloneFunc: git.Clone}, Writer: &writerWithTelemetry{}, }, { - Name: DefaultPydabs, - Hidden: true, - Description: "The default PyDABs template", - Reader: &gitReader{gitUrl: "https://databricks.github.io/workflows-authoring-toolkit/pydabs-template.git"}, + name: DefaultPydabs, + hidden: true, + description: "The default PyDABs template", + Reader: &gitReader{gitUrl: "https://databricks.github.io/workflows-authoring-toolkit/pydabs-template.git", cloneFunc: git.Clone}, Writer: &writerWithTelemetry{}, }, { - Name: Custom, - Description: "Bring your own template", + name: Custom, + description: "Bring your own template", Reader: &failReader{}, Writer: &defaultWriter{}, }, @@ -85,8 +74,8 @@ var allTemplates = []Template{ func HelpDescriptions() string { var lines []string for _, template := range allTemplates { - if template.Name != Custom && !template.Hidden { - lines = append(lines, fmt.Sprintf("- %s: %s", template.Name, template.Description)) + if template.name != Custom && !template.hidden { + lines = append(lines, fmt.Sprintf("- %s: %s", template.name, template.description)) } } return strings.Join(lines, "\n") @@ -95,12 +84,12 @@ func HelpDescriptions() string { func options() []cmdio.Tuple { names := make([]cmdio.Tuple, 0, len(allTemplates)) for _, template := range allTemplates { - if template.Hidden { + if template.hidden { continue } tuple := cmdio.Tuple{ - Name: string(template.Name), - Id: template.Description, + Name: string(template.name), + Id: template.description, } names = append(names, tuple) } @@ -118,8 +107,8 @@ func SelectTemplate(ctx context.Context) (TemplateName, error) { } for _, template := range allTemplates { - if template.Description == description { - return template.Name, nil + if template.description == description { + return template.name, nil } } @@ -128,7 +117,7 @@ func SelectTemplate(ctx context.Context) (TemplateName, error) { func Get(name TemplateName) *Template { for _, template := range allTemplates { - if template.Name == name { + if template.name == name { return &template } } diff --git a/libs/template/template_test.go b/libs/template/template_test.go index 6b6ca0d0e9..73d818dfe4 100644 --- a/libs/template/template_test.go +++ b/libs/template/template_test.go @@ -27,11 +27,11 @@ func TestTemplateOptions(t *testing.T) { } func TestBundleInitIsRepoUrl(t *testing.T) { - assert.True(t, IsGitRepoUrl("git@github.com:databricks/cli.git")) - assert.True(t, IsGitRepoUrl("https://github.com/databricks/cli.git")) + assert.True(t, isRepoUrl("git@github.com:databricks/cli.git")) + assert.True(t, isRepoUrl("https://github.com/databricks/cli.git")) - assert.False(t, IsGitRepoUrl("./local")) - assert.False(t, IsGitRepoUrl("foo")) + assert.False(t, isRepoUrl("./local")) + assert.False(t, isRepoUrl("foo")) } func TestBundleInitRepoName(t *testing.T) { From ed24f2880b101b8a671f681d9e87ea332aa33ebd Mon Sep 17 00:00:00 2001 From: Shreyas Goenka Date: Fri, 3 Jan 2025 18:32:43 +0530 Subject: [PATCH 06/22] more test --- libs/template/reader_test.go | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/libs/template/reader_test.go b/libs/template/reader_test.go index 1cce2321ba..74e70a5e13 100644 --- a/libs/template/reader_test.go +++ b/libs/template/reader_test.go @@ -27,12 +27,23 @@ func TestBuiltInReader(t *testing.T) { fs, err := r.FS(context.Background()) assert.NoError(t, err) assert.NotNil(t, fs) + + // Assert file content returned is accurate and every template has a welcome + // message defined. + fd, err := fs.Open("databricks_template_schema.json") + require.NoError(t, err) + defer fd.Close() + b, err := io.ReadAll(fd) + require.NoError(t, err) + assert.Contains(t, string(b), "welcome_message") } - // TODO: Read one of the files to confirm further test this reader. r := &builtinReader{name: "doesnotexist"} _, err := r.FS(context.Background()) assert.EqualError(t, err, "builtin template doesnotexist not found") + + // Close should not error. + assert.NoError(t, r.Close()) } func TestGitUrlReader(t *testing.T) { From af89580f8e6fb0ab600fc21a8db832f62ec64d29 Mon Sep 17 00:00:00 2001 From: Shreyas Goenka Date: Mon, 6 Jan 2025 13:41:28 +0530 Subject: [PATCH 07/22] fix tests and lint --- cmd/bundle/init.go | 1 + libs/template/reader_test.go | 6 ++--- libs/template/resolve.go | 52 +++++++++++++++++++++++++----------- libs/template/template.go | 1 - libs/template/writer.go | 25 ++--------------- libs/template/writer_test.go | 50 ++++++++++++++++++++++++++++++++++ 6 files changed, 93 insertions(+), 42 deletions(-) create mode 100644 libs/template/writer_test.go diff --git a/cmd/bundle/init.go b/cmd/bundle/init.go index 307e367d60..755a09b9fb 100644 --- a/cmd/bundle/init.go +++ b/cmd/bundle/init.go @@ -36,6 +36,7 @@ See https://docs.databricks.com/en/dev-tools/bundles/templates.html for more inf cmd.Flags().StringVar(&branch, "tag", "", "Git tag to use for template initialization") cmd.Flags().StringVar(&tag, "branch", "", "Git branch to use for template initialization") + cmd.PreRunE = root.MustWorkspaceClient cmd.RunE = func(cmd *cobra.Command, args []string) error { r := template.Resolver{ TemplatePathOrUrl: args[0], diff --git a/libs/template/reader_test.go b/libs/template/reader_test.go index 74e70a5e13..f1e037fca4 100644 --- a/libs/template/reader_test.go +++ b/libs/template/reader_test.go @@ -32,10 +32,10 @@ func TestBuiltInReader(t *testing.T) { // message defined. fd, err := fs.Open("databricks_template_schema.json") require.NoError(t, err) - defer fd.Close() b, err := io.ReadAll(fd) require.NoError(t, err) assert.Contains(t, string(b), "welcome_message") + assert.NoError(t, fd.Close()) } r := &builtinReader{name: "doesnotexist"} @@ -78,10 +78,10 @@ func TestGitUrlReader(t *testing.T) { // Assert the fs returned is rooted at the templateDir. fd, err := fs.Open("somefile") require.NoError(t, err) - defer fd.Close() b, err := io.ReadAll(fd) require.NoError(t, err) assert.Equal(t, "somecontent", string(b)) + assert.NoError(t, fd.Close()) // Assert the FS is cached. cloneFunc should not be called again. _, err = r.FS(ctx) @@ -106,10 +106,10 @@ func TestLocalReader(t *testing.T) { // Assert the fs returned is rooted at correct location. fd, err := fs.Open("somefile") require.NoError(t, err) - defer fd.Close() b, err := io.ReadAll(fd) require.NoError(t, err) assert.Equal(t, "somecontent", string(b)) + assert.NoError(t, fd.Close()) // Assert close does not error assert.NoError(t, r.Close()) diff --git a/libs/template/resolve.go b/libs/template/resolve.go index 7d3b68bff4..88299b808a 100644 --- a/libs/template/resolve.go +++ b/libs/template/resolve.go @@ -8,12 +8,24 @@ import ( ) type Resolver struct { + // One of the following three: + // 1. Path to a local template directory. + // 2. URL to a Git repository containing a template. + // 3. Name of a built-in template. TemplatePathOrUrl string - ConfigFile string - OutputDir string - TemplateDir string - Tag string - Branch string + + // Path to a JSON file containing the configuration values to be used for + // template initialization. + ConfigFile string + + // Directory to write the initialized template to. + OutputDir string + + // Directory path within a Git repository containing the template. + TemplateDir string + + Tag string + Branch string } var ErrCustomSelected = errors.New("custom template selected") @@ -32,23 +44,32 @@ func (r Resolver) Resolve(ctx context.Context) (*Template, error) { ref = r.Tag } - var tmpl *Template + var err error + var templateName TemplateName + if r.TemplatePathOrUrl == "" { // Prompt the user to select a template // if a template path or URL is not provided. - tmplId, err := SelectTemplate(ctx) + templateName, err = SelectTemplate(ctx) if err != nil { return nil, err } + } - if tmplId == Custom { - return nil, ErrCustomSelected - } + templateName = TemplateName(r.TemplatePathOrUrl) + + // User should not directly select "custom" and instead should provide the + // file path or the Git URL for the template directly. + // Custom is just for internal representation purposes. + if templateName == Custom { + return nil, ErrCustomSelected + } - tmpl = Get(tmplId) - } else { - // Based on the provided template path or URL, - // configure a reader for the template. + tmpl := Get(templateName) + + // If the user directory provided a template path or URL that is not a built-in template, + // then configure a reader for the template. + if tmpl == nil { tmpl = Get(Custom) if isRepoUrl(r.TemplatePathOrUrl) { tmpl.Reader = &gitReader{ @@ -62,9 +83,10 @@ func (r Resolver) Resolve(ctx context.Context) (*Template, error) { path: r.TemplatePathOrUrl, } } + } - err := tmpl.Writer.Configure(ctx, r.ConfigFile, r.OutputDir) + err = tmpl.Writer.Configure(ctx, r.ConfigFile, r.OutputDir) if err != nil { return nil, err } diff --git a/libs/template/template.go b/libs/template/template.go index 75b913d871..46bdef57a9 100644 --- a/libs/template/template.go +++ b/libs/template/template.go @@ -96,7 +96,6 @@ func options() []cmdio.Tuple { return names } -// TODO CONTINUE defining the methods that the init command will finally rely on. func SelectTemplate(ctx context.Context) (TemplateName, error) { if !cmdio.IsPromptSupported(ctx) { return "", fmt.Errorf("please specify a template") diff --git a/libs/template/writer.go b/libs/template/writer.go index b0ec1ad46f..f0b7ae6de6 100644 --- a/libs/template/writer.go +++ b/libs/template/writer.go @@ -14,28 +14,7 @@ import ( "github.com/databricks/cli/libs/filer" ) -// TODO: Retain coverage for the missing schema test case -// func TestMaterializeForNonTemplateDirectory(t *testing.T) { -// tmpDir := t.TempDir() -// w, err := databricks.NewWorkspaceClient(&databricks.Config{}) -// require.NoError(t, err) -// ctx := root.SetWorkspaceClient(context.Background(), w) - -// tmpl := TemplateX{ -// TemplateOpts: TemplateOpts{ -// ConfigFilePath: "", -// TemplateFS: os.DirFS(tmpDir), -// OutputFiler: nil, -// }, -// } - -// // Try to materialize a non-template directory. -// err = tmpl.Materialize(ctx) -// assert.EqualError(t, err, fmt.Sprintf("not a bundle template: expected to find a template schema file at %s", schemaFileName)) -// } - -// TODO: Add tests for these writers, mocking the cmdio library -// at the same time. +// TODO: Add some golden tests for these. const ( libraryDirName = "library" templateDirName = "template" @@ -77,7 +56,7 @@ func constructOutputFiler(ctx context.Context, outputDir string) (filer.Filer, e return filer.NewLocalClient(outputDir) } -func (tmpl *defaultWriter) Configure(ctx context.Context, configPath string, outputDir string) error { +func (tmpl *defaultWriter) Configure(ctx context.Context, configPath, outputDir string) error { tmpl.configPath = configPath outputFiler, err := constructOutputFiler(ctx, outputDir) diff --git a/libs/template/writer_test.go b/libs/template/writer_test.go new file mode 100644 index 0000000000..7c19a316d7 --- /dev/null +++ b/libs/template/writer_test.go @@ -0,0 +1,50 @@ +package template + +import ( + "context" + "testing" + + "github.com/databricks/cli/cmd/root" + "github.com/databricks/cli/libs/dbr" + "github.com/databricks/cli/libs/filer" + "github.com/databricks/databricks-sdk-go" + workspaceConfig "github.com/databricks/databricks-sdk-go/config" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestDefaultWriterConfigure(t *testing.T) { + // Test on local file system. + w1 := &defaultWriter{} + err := w1.Configure(context.Background(), "/foo/bar", "/out/abc") + assert.NoError(t, err) + + assert.Equal(t, "/foo/bar", w1.configPath) + assert.IsType(t, &filer.LocalClient{}, w1.outputFiler) + + // Test on DBR + ctx := dbr.MockRuntime(context.Background(), true) + ctx = root.SetWorkspaceClient(ctx, &databricks.WorkspaceClient{ + Config: &workspaceConfig.Config{Host: "https://myhost.com"}, + }) + w2 := &defaultWriter{} + err = w2.Configure(ctx, "/foo/bar", "/Workspace/out/abc") + assert.NoError(t, err) + + assert.Equal(t, "/foo/bar", w2.configPath) + assert.IsType(t, &filer.WorkspaceFilesClient{}, w2.outputFiler) +} + +func TestMaterializeForNonTemplateDirectory(t *testing.T) { + tmpDir1 := t.TempDir() + tmpDir2 := t.TempDir() + ctx := context.Background() + + w := &defaultWriter{} + err := w.Configure(ctx, "/foo/bar", tmpDir1) + require.NoError(t, err) + + // Try to materialize a non-template directory. + err = w.Materialize(ctx, &localReader{path: tmpDir2}) + assert.EqualError(t, err, "not a bundle template: expected to find a template schema file at databricks_template_schema.json") +} From 9f83fe03622bd08e9af92b522acc5f12b4250a09 Mon Sep 17 00:00:00 2001 From: Shreyas Goenka Date: Mon, 6 Jan 2025 15:40:11 +0530 Subject: [PATCH 08/22] fix test --- .../workspace_files_extensions_client.go | 22 +++++++++---------- .../workspace_files_extensions_client_test.go | 2 +- libs/template/writer_test.go | 2 +- 3 files changed, 13 insertions(+), 13 deletions(-) diff --git a/libs/filer/workspace_files_extensions_client.go b/libs/filer/workspace_files_extensions_client.go index 9ee2722e17..0127d180c6 100644 --- a/libs/filer/workspace_files_extensions_client.go +++ b/libs/filer/workspace_files_extensions_client.go @@ -16,7 +16,7 @@ import ( "github.com/databricks/databricks-sdk-go/service/workspace" ) -type workspaceFilesExtensionsClient struct { +type WorkspaceFilesExtensionsClient struct { workspaceClient *databricks.WorkspaceClient wsfs Filer @@ -32,7 +32,7 @@ type workspaceFileStatus struct { nameForWorkspaceAPI string } -func (w *workspaceFilesExtensionsClient) stat(ctx context.Context, name string) (wsfsFileInfo, error) { +func (w *WorkspaceFilesExtensionsClient) stat(ctx context.Context, name string) (wsfsFileInfo, error) { info, err := w.wsfs.Stat(ctx, name) if err != nil { return wsfsFileInfo{}, err @@ -42,7 +42,7 @@ func (w *workspaceFilesExtensionsClient) stat(ctx context.Context, name string) // This function returns the stat for the provided notebook. The stat object itself contains the path // with the extension since it is meant to be used in the context of a fs.FileInfo. -func (w *workspaceFilesExtensionsClient) getNotebookStatByNameWithExt(ctx context.Context, name string) (*workspaceFileStatus, error) { +func (w *WorkspaceFilesExtensionsClient) getNotebookStatByNameWithExt(ctx context.Context, name string) (*workspaceFileStatus, error) { ext := path.Ext(name) nameWithoutExt := strings.TrimSuffix(name, ext) @@ -104,7 +104,7 @@ func (w *workspaceFilesExtensionsClient) getNotebookStatByNameWithExt(ctx contex }, nil } -func (w *workspaceFilesExtensionsClient) getNotebookStatByNameWithoutExt(ctx context.Context, name string) (*workspaceFileStatus, error) { +func (w *WorkspaceFilesExtensionsClient) getNotebookStatByNameWithoutExt(ctx context.Context, name string) (*workspaceFileStatus, error) { stat, err := w.stat(ctx, name) if err != nil { return nil, err @@ -184,7 +184,7 @@ func newWorkspaceFilesExtensionsClient(w *databricks.WorkspaceClient, root strin filer = newWorkspaceFilesReadaheadCache(filer) } - return &workspaceFilesExtensionsClient{ + return &WorkspaceFilesExtensionsClient{ workspaceClient: w, wsfs: filer, @@ -193,7 +193,7 @@ func newWorkspaceFilesExtensionsClient(w *databricks.WorkspaceClient, root strin }, nil } -func (w *workspaceFilesExtensionsClient) ReadDir(ctx context.Context, name string) ([]fs.DirEntry, error) { +func (w *WorkspaceFilesExtensionsClient) ReadDir(ctx context.Context, name string) ([]fs.DirEntry, error) { entries, err := w.wsfs.ReadDir(ctx, name) if err != nil { return nil, err @@ -235,7 +235,7 @@ func (w *workspaceFilesExtensionsClient) ReadDir(ctx context.Context, name strin // Note: The import API returns opaque internal errors for namespace clashes // (e.g. a file and a notebook or a directory and a notebook). Thus users of this // method should be careful to avoid such clashes. -func (w *workspaceFilesExtensionsClient) Write(ctx context.Context, name string, reader io.Reader, mode ...WriteMode) error { +func (w *WorkspaceFilesExtensionsClient) Write(ctx context.Context, name string, reader io.Reader, mode ...WriteMode) error { if w.readonly { return ReadOnlyError{"write"} } @@ -244,7 +244,7 @@ func (w *workspaceFilesExtensionsClient) Write(ctx context.Context, name string, } // Try to read the file as a regular file. If the file is not found, try to read it as a notebook. -func (w *workspaceFilesExtensionsClient) Read(ctx context.Context, name string) (io.ReadCloser, error) { +func (w *WorkspaceFilesExtensionsClient) Read(ctx context.Context, name string) (io.ReadCloser, error) { // Ensure that the file / notebook exists. We do this check here to avoid reading // the content of a notebook called `foo` when the user actually wanted // to read the content of a file called `foo`. @@ -283,7 +283,7 @@ func (w *workspaceFilesExtensionsClient) Read(ctx context.Context, name string) } // Try to delete the file as a regular file. If the file is not found, try to delete it as a notebook. -func (w *workspaceFilesExtensionsClient) Delete(ctx context.Context, name string, mode ...DeleteMode) error { +func (w *WorkspaceFilesExtensionsClient) Delete(ctx context.Context, name string, mode ...DeleteMode) error { if w.readonly { return ReadOnlyError{"delete"} } @@ -320,7 +320,7 @@ func (w *workspaceFilesExtensionsClient) Delete(ctx context.Context, name string } // Try to stat the file as a regular file. If the file is not found, try to stat it as a notebook. -func (w *workspaceFilesExtensionsClient) Stat(ctx context.Context, name string) (fs.FileInfo, error) { +func (w *WorkspaceFilesExtensionsClient) Stat(ctx context.Context, name string) (fs.FileInfo, error) { info, err := w.wsfs.Stat(ctx, name) // If the file is not found, it might be a notebook. @@ -361,7 +361,7 @@ func (w *workspaceFilesExtensionsClient) Stat(ctx context.Context, name string) // Note: The import API returns opaque internal errors for namespace clashes // (e.g. a file and a notebook or a directory and a notebook). Thus users of this // method should be careful to avoid such clashes. -func (w *workspaceFilesExtensionsClient) Mkdir(ctx context.Context, name string) error { +func (w *WorkspaceFilesExtensionsClient) Mkdir(ctx context.Context, name string) error { if w.readonly { return ReadOnlyError{"mkdir"} } diff --git a/libs/filer/workspace_files_extensions_client_test.go b/libs/filer/workspace_files_extensions_client_test.go index 10a2bebf0a..9ea837fa99 100644 --- a/libs/filer/workspace_files_extensions_client_test.go +++ b/libs/filer/workspace_files_extensions_client_test.go @@ -181,7 +181,7 @@ func TestFilerWorkspaceFilesExtensionsErrorsOnDupName(t *testing.T) { root: NewWorkspaceRootPath("/dir"), } - workspaceFilesExtensionsClient := workspaceFilesExtensionsClient{ + workspaceFilesExtensionsClient := WorkspaceFilesExtensionsClient{ workspaceClient: mockedWorkspaceClient.WorkspaceClient, wsfs: &workspaceFilesClient, } diff --git a/libs/template/writer_test.go b/libs/template/writer_test.go index 7c19a316d7..8fefb9d4a9 100644 --- a/libs/template/writer_test.go +++ b/libs/template/writer_test.go @@ -32,7 +32,7 @@ func TestDefaultWriterConfigure(t *testing.T) { assert.NoError(t, err) assert.Equal(t, "/foo/bar", w2.configPath) - assert.IsType(t, &filer.WorkspaceFilesClient{}, w2.outputFiler) + assert.IsType(t, &filer.WorkspaceFilesExtensionsClient{}, w2.outputFiler) } func TestMaterializeForNonTemplateDirectory(t *testing.T) { From 00266cb8b372fdbdec45cd2a7df3e56d2deb978c Mon Sep 17 00:00:00 2001 From: Shreyas Goenka Date: Mon, 6 Jan 2025 15:54:13 +0530 Subject: [PATCH 09/22] fix another bug --- cmd/bundle/init.go | 6 +++++- libs/template/resolve.go | 4 ++-- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/cmd/bundle/init.go b/cmd/bundle/init.go index 755a09b9fb..c78ca50f2a 100644 --- a/cmd/bundle/init.go +++ b/cmd/bundle/init.go @@ -38,8 +38,12 @@ See https://docs.databricks.com/en/dev-tools/bundles/templates.html for more inf cmd.PreRunE = root.MustWorkspaceClient cmd.RunE = func(cmd *cobra.Command, args []string) error { + var templatePathOrUrl string + if len(args) > 0 { + templatePathOrUrl = args[0] + } r := template.Resolver{ - TemplatePathOrUrl: args[0], + TemplatePathOrUrl: templatePathOrUrl, ConfigFile: configFile, OutputDir: outputDir, TemplateDir: templateDir, diff --git a/libs/template/resolve.go b/libs/template/resolve.go index 88299b808a..6ec5963daf 100644 --- a/libs/template/resolve.go +++ b/libs/template/resolve.go @@ -54,10 +54,10 @@ func (r Resolver) Resolve(ctx context.Context) (*Template, error) { if err != nil { return nil, err } + } else { + templateName = TemplateName(r.TemplatePathOrUrl) } - templateName = TemplateName(r.TemplatePathOrUrl) - // User should not directly select "custom" and instead should provide the // file path or the Git URL for the template directly. // Custom is just for internal representation purposes. From 2b1c5cbe639822abb4b679baf55068b8b1d86ab2 Mon Sep 17 00:00:00 2001 From: Shreyas Goenka Date: Mon, 6 Jan 2025 16:50:38 +0530 Subject: [PATCH 10/22] skip test on windows --- libs/template/resolve.go | 2 ++ libs/template/writer_test.go | 26 +++++++++++++++++--------- 2 files changed, 19 insertions(+), 9 deletions(-) diff --git a/libs/template/resolve.go b/libs/template/resolve.go index 6ec5963daf..2c3f021e2f 100644 --- a/libs/template/resolve.go +++ b/libs/template/resolve.go @@ -24,6 +24,8 @@ type Resolver struct { // Directory path within a Git repository containing the template. TemplateDir string + // Git tag or branch to download the template from. Only one of these can be + // specified. Tag string Branch string } diff --git a/libs/template/writer_test.go b/libs/template/writer_test.go index 8fefb9d4a9..9d57966ee2 100644 --- a/libs/template/writer_test.go +++ b/libs/template/writer_test.go @@ -2,6 +2,7 @@ package template import ( "context" + "runtime" "testing" "github.com/databricks/cli/cmd/root" @@ -15,24 +16,31 @@ import ( func TestDefaultWriterConfigure(t *testing.T) { // Test on local file system. - w1 := &defaultWriter{} - err := w1.Configure(context.Background(), "/foo/bar", "/out/abc") + w := &defaultWriter{} + err := w.Configure(context.Background(), "/foo/bar", "/out/abc") assert.NoError(t, err) - assert.Equal(t, "/foo/bar", w1.configPath) - assert.IsType(t, &filer.LocalClient{}, w1.outputFiler) + assert.Equal(t, "/foo/bar", w.configPath) + assert.IsType(t, &filer.LocalClient{}, w.outputFiler) +} + +func TestDefaultWriterConfigureOnDBR(t *testing.T) { + // This test is not valid on windows because a DBR image is always based on + // Linux. + if runtime.GOOS == "windows" { + t.Skip("Skipping test on Windows") + } - // Test on DBR ctx := dbr.MockRuntime(context.Background(), true) ctx = root.SetWorkspaceClient(ctx, &databricks.WorkspaceClient{ Config: &workspaceConfig.Config{Host: "https://myhost.com"}, }) - w2 := &defaultWriter{} - err = w2.Configure(ctx, "/foo/bar", "/Workspace/out/abc") + w := &defaultWriter{} + err := w.Configure(ctx, "/foo/bar", "/Workspace/out/abc") assert.NoError(t, err) - assert.Equal(t, "/foo/bar", w2.configPath) - assert.IsType(t, &filer.WorkspaceFilesExtensionsClient{}, w2.outputFiler) + assert.Equal(t, "/foo/bar", w.configPath) + assert.IsType(t, &filer.WorkspaceFilesExtensionsClient{}, w.outputFiler) } func TestMaterializeForNonTemplateDirectory(t *testing.T) { From 8607e0808c5197637ba0ef6bbf07266771805ee3 Mon Sep 17 00:00:00 2001 From: Shreyas Goenka Date: Mon, 6 Jan 2025 17:53:25 +0530 Subject: [PATCH 11/22] add unit tests for resolver --- libs/cmdio/io.go | 10 +++ libs/template/{resolve.go => resolver.go} | 1 - libs/template/resolver_test.go | 87 +++++++++++++++++++++++ libs/template/template.go | 8 ++- libs/template/template_test.go | 15 ++++ 5 files changed, 118 insertions(+), 3 deletions(-) rename libs/template/{resolve.go => resolver.go} (97%) create mode 100644 libs/template/resolver_test.go diff --git a/libs/cmdio/io.go b/libs/cmdio/io.go index c0e9e868a8..0d7ba5da8e 100644 --- a/libs/cmdio/io.go +++ b/libs/cmdio/io.go @@ -285,3 +285,13 @@ func fromContext(ctx context.Context) *cmdIO { } return io } + +func MockContext(ctx context.Context) context.Context { + return InContext(ctx, &cmdIO{ + interactive: false, + outputFormat: flags.OutputText, + in: io.NopCloser(strings.NewReader("")), + out: io.Discard, + err: io.Discard, + }) +} diff --git a/libs/template/resolve.go b/libs/template/resolver.go similarity index 97% rename from libs/template/resolve.go rename to libs/template/resolver.go index 2c3f021e2f..8d1b26d8d7 100644 --- a/libs/template/resolve.go +++ b/libs/template/resolver.go @@ -62,7 +62,6 @@ func (r Resolver) Resolve(ctx context.Context) (*Template, error) { // User should not directly select "custom" and instead should provide the // file path or the Git URL for the template directly. - // Custom is just for internal representation purposes. if templateName == Custom { return nil, ErrCustomSelected } diff --git a/libs/template/resolver_test.go b/libs/template/resolver_test.go new file mode 100644 index 0000000000..9fde4e9e96 --- /dev/null +++ b/libs/template/resolver_test.go @@ -0,0 +1,87 @@ +package template + +import ( + "context" + "testing" + + "github.com/databricks/cli/libs/cmdio" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestTemplateResolverBothTagAndBranch(t *testing.T) { + r := Resolver{ + Tag: "tag", + Branch: "branch", + } + + _, err := r.Resolve(context.Background()) + assert.EqualError(t, err, "only one of --tag or --branch can be specified") +} + +func TestTemplateResolverErrorsWhenPromptingIsNotSupported(t *testing.T) { + r := Resolver{} + ctx := cmdio.MockContext(context.Background()) + + _, err := r.Resolve(ctx) + assert.EqualError(t, err, "prompting is not supported. Please specify the path, name or URL of the template to use") +} + +func TestTemplateResolverErrorWhenUserSelectsCustom(t *testing.T) { + r := Resolver{ + TemplatePathOrUrl: "custom", + } + + _, err := r.Resolve(context.Background()) + assert.EqualError(t, err, "custom template selected") +} + +func TestTemplateResolverForDefaultTemplates(t *testing.T) { + for _, name := range []string{ + "default-python", + "default-sql", + "dbt-sql", + } { + r := Resolver{ + TemplatePathOrUrl: name, + } + + tmpl, err := r.Resolve(context.Background()) + require.NoError(t, err) + + assert.Equal(t, &builtinReader{name: name}, tmpl.Reader) + assert.IsType(t, &writerWithTelemetry{}, tmpl.Writer) + } + + r := Resolver{ + TemplatePathOrUrl: "mlops-stacks", + ConfigFile: "/config/file", + } + + tmpl, err := r.Resolve(context.Background()) + require.NoError(t, err) + + // Assert reader and writer configuration + assert.Equal(t, "https://github.com/databricks/mlops-stacks", tmpl.Reader.(*gitReader).gitUrl) + assert.Equal(t, "/config/file", tmpl.Writer.(*writerWithTelemetry).configPath) +} + +func TestTemplateResolverForCustomTemplate(t *testing.T) { + r := Resolver{ + TemplatePathOrUrl: "https://www.example.com/abc", + Tag: "tag", + TemplateDir: "/template/dir", + ConfigFile: "/config/file", + } + + tmpl, err := r.Resolve(context.Background()) + require.NoError(t, err) + + // Assert reader configuration + assert.Equal(t, "https://www.example.com/abc", tmpl.Reader.(*gitReader).gitUrl) + assert.Equal(t, "tag", tmpl.Reader.(*gitReader).ref) + assert.Equal(t, "/template/dir", tmpl.Reader.(*gitReader).templateDir) + + // Assert writer configuration + assert.Equal(t, "/config/file", tmpl.Writer.(*defaultWriter).configPath) +} diff --git a/libs/template/template.go b/libs/template/template.go index 46bdef57a9..31c6e28144 100644 --- a/libs/template/template.go +++ b/libs/template/template.go @@ -27,7 +27,11 @@ const ( DbtSql TemplateName = "dbt-sql" MlopsStacks TemplateName = "mlops-stacks" DefaultPydabs TemplateName = "default-pydabs" - Custom TemplateName = "custom" + + // Custom represents any template that is not one of the above default + // templates. It's a catch for any custom templates that customers provide + // as a path or URL. + Custom TemplateName = "custom" ) var allTemplates = []Template{ @@ -98,7 +102,7 @@ func options() []cmdio.Tuple { func SelectTemplate(ctx context.Context) (TemplateName, error) { if !cmdio.IsPromptSupported(ctx) { - return "", fmt.Errorf("please specify a template") + return "", fmt.Errorf("prompting is not supported. Please specify the path, name or URL of the template to use") } description, err := cmdio.SelectOrdered(ctx, options(), "Template to use") if err != nil { diff --git a/libs/template/template_test.go b/libs/template/template_test.go index 73d818dfe4..cfcdd6251a 100644 --- a/libs/template/template_test.go +++ b/libs/template/template_test.go @@ -45,3 +45,18 @@ func TestBundleInitRepoName(t *testing.T) { assert.Equal(t, "invalid-url", repoName("invalid-url")) assert.Equal(t, "www.github.com", repoName("https://www.github.com")) } + +func TestTemplateTelemetryIsCapturedForAllDefaultTemplates(t *testing.T) { + for _, tmpl := range allTemplates { + w := tmpl.Writer + + if tmpl.name == Custom { + // Assert telemetry is not captured for user templates. + assert.IsType(t, &defaultWriter{}, w) + } else { + // Assert telemetry is captured for all other templates, i.e. templates + // owned by databricks. + assert.IsType(t, &writerWithTelemetry{}, w) + } + } +} From 39dfe05a7345b82b095de46ea5937c121ea2ba71 Mon Sep 17 00:00:00 2001 From: Shreyas Goenka Date: Mon, 6 Jan 2025 18:03:11 +0530 Subject: [PATCH 12/22] add test for the get method --- libs/template/template_test.go | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/libs/template/template_test.go b/libs/template/template_test.go index cfcdd6251a..a51da49739 100644 --- a/libs/template/template_test.go +++ b/libs/template/template_test.go @@ -60,3 +60,30 @@ func TestTemplateTelemetryIsCapturedForAllDefaultTemplates(t *testing.T) { } } } + +func TestTemplateGet(t *testing.T) { + names := []TemplateName{ + DefaultPython, + DefaultSql, + DbtSql, + MlopsStacks, + DefaultPydabs, + Custom, + } + + for _, name := range names { + tmpl := Get(name) + assert.Equal(t, tmpl.name, name) + } + + notExist := []string{ + "/some/path", + "doesnotexist", + "https://www.someurl.com", + } + + for _, name := range notExist { + tmpl := Get(TemplateName(name)) + assert.Nil(t, tmpl) + } +} From 71057774f0cb4378473b61b0663cd9e024a30fad Mon Sep 17 00:00:00 2001 From: Shreyas Goenka Date: Mon, 6 Jan 2025 18:12:12 +0530 Subject: [PATCH 13/22] - --- libs/template/writer.go | 1 - 1 file changed, 1 deletion(-) diff --git a/libs/template/writer.go b/libs/template/writer.go index f0b7ae6de6..7bdf1ba5a6 100644 --- a/libs/template/writer.go +++ b/libs/template/writer.go @@ -14,7 +14,6 @@ import ( "github.com/databricks/cli/libs/filer" ) -// TODO: Add some golden tests for these. const ( libraryDirName = "library" templateDirName = "template" From 4755db4d700dee07701e3887efcb224d1b894903 Mon Sep 17 00:00:00 2001 From: Shreyas Goenka Date: Mon, 6 Jan 2025 18:27:15 +0530 Subject: [PATCH 14/22] custom rename to avoid break --- libs/template/resolver.go | 2 +- libs/template/resolver_test.go | 2 +- libs/template/template.go | 2 +- libs/template/template_test.go | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/libs/template/resolver.go b/libs/template/resolver.go index 8d1b26d8d7..16b2e692b1 100644 --- a/libs/template/resolver.go +++ b/libs/template/resolver.go @@ -60,7 +60,7 @@ func (r Resolver) Resolve(ctx context.Context) (*Template, error) { templateName = TemplateName(r.TemplatePathOrUrl) } - // User should not directly select "custom" and instead should provide the + // User should not directly select "custom..." and instead should provide the // file path or the Git URL for the template directly. if templateName == Custom { return nil, ErrCustomSelected diff --git a/libs/template/resolver_test.go b/libs/template/resolver_test.go index 9fde4e9e96..c94548d59a 100644 --- a/libs/template/resolver_test.go +++ b/libs/template/resolver_test.go @@ -29,7 +29,7 @@ func TestTemplateResolverErrorsWhenPromptingIsNotSupported(t *testing.T) { func TestTemplateResolverErrorWhenUserSelectsCustom(t *testing.T) { r := Resolver{ - TemplatePathOrUrl: "custom", + TemplatePathOrUrl: "custom...", } _, err := r.Resolve(context.Background()) diff --git a/libs/template/template.go b/libs/template/template.go index 31c6e28144..2cb60af475 100644 --- a/libs/template/template.go +++ b/libs/template/template.go @@ -31,7 +31,7 @@ const ( // Custom represents any template that is not one of the above default // templates. It's a catch for any custom templates that customers provide // as a path or URL. - Custom TemplateName = "custom" + Custom TemplateName = "custom..." ) var allTemplates = []Template{ diff --git a/libs/template/template_test.go b/libs/template/template_test.go index a51da49739..a2221f7660 100644 --- a/libs/template/template_test.go +++ b/libs/template/template_test.go @@ -21,7 +21,7 @@ func TestTemplateOptions(t *testing.T) { {Name: "default-sql", Id: "The default SQL template for .sql files that run with Databricks SQL"}, {Name: "dbt-sql", Id: "The dbt SQL template (databricks.com/blog/delivering-cost-effective-data-real-time-dbt-and-databricks)"}, {Name: "mlops-stacks", Id: "The Databricks MLOps Stacks template (github.com/databricks/mlops-stacks)"}, - {Name: "custom", Id: "Bring your own template"}, + {Name: "custom...", Id: "Bring your own template"}, } assert.Equal(t, expected, options()) } From 0a68fb34f9cdf6f8ac338706c79741f158fb4303 Mon Sep 17 00:00:00 2001 From: Shreyas Goenka Date: Mon, 6 Jan 2025 18:30:18 +0530 Subject: [PATCH 15/22] - --- libs/template/template.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/libs/template/template.go b/libs/template/template.go index 2cb60af475..5e7d3dcd4e 100644 --- a/libs/template/template.go +++ b/libs/template/template.go @@ -29,8 +29,8 @@ const ( DefaultPydabs TemplateName = "default-pydabs" // Custom represents any template that is not one of the above default - // templates. It's a catch for any custom templates that customers provide - // as a path or URL. + // templates. It's a catch all for any custom templates that customers provide + // as a path or URL argument. Custom TemplateName = "custom..." ) From 8202aafc782600a152c2710422f5c76fb9506ee7 Mon Sep 17 00:00:00 2001 From: Shreyas Goenka Date: Mon, 6 Jan 2025 18:38:22 +0530 Subject: [PATCH 16/22] more test --- libs/template/resolver.go | 3 ++- libs/template/resolver_test.go | 22 +++++++++++++++++++++- 2 files changed, 23 insertions(+), 2 deletions(-) diff --git a/libs/template/resolver.go b/libs/template/resolver.go index 16b2e692b1..5afc6a2370 100644 --- a/libs/template/resolver.go +++ b/libs/template/resolver.go @@ -3,6 +3,7 @@ package template import ( "context" "errors" + "fmt" "github.com/databricks/cli/libs/git" ) @@ -37,7 +38,7 @@ var ErrCustomSelected = errors.New("custom template selected") // Prompts the user if needed. func (r Resolver) Resolve(ctx context.Context) (*Template, error) { if r.Tag != "" && r.Branch != "" { - return nil, errors.New("only one of --tag or --branch can be specified") + return nil, fmt.Errorf("only one of --tag or --branch can be specified") } // Git ref to use for template initialization diff --git a/libs/template/resolver_test.go b/libs/template/resolver_test.go index c94548d59a..d057e78808 100644 --- a/libs/template/resolver_test.go +++ b/libs/template/resolver_test.go @@ -66,7 +66,7 @@ func TestTemplateResolverForDefaultTemplates(t *testing.T) { assert.Equal(t, "/config/file", tmpl.Writer.(*writerWithTelemetry).configPath) } -func TestTemplateResolverForCustomTemplate(t *testing.T) { +func TestTemplateResolverForCustomUrl(t *testing.T) { r := Resolver{ TemplatePathOrUrl: "https://www.example.com/abc", Tag: "tag", @@ -77,6 +77,8 @@ func TestTemplateResolverForCustomTemplate(t *testing.T) { tmpl, err := r.Resolve(context.Background()) require.NoError(t, err) + assert.Equal(t, Custom, tmpl.name) + // Assert reader configuration assert.Equal(t, "https://www.example.com/abc", tmpl.Reader.(*gitReader).gitUrl) assert.Equal(t, "tag", tmpl.Reader.(*gitReader).ref) @@ -85,3 +87,21 @@ func TestTemplateResolverForCustomTemplate(t *testing.T) { // Assert writer configuration assert.Equal(t, "/config/file", tmpl.Writer.(*defaultWriter).configPath) } + +func TestTemplateResolverForCustomPath(t *testing.T) { + r := Resolver{ + TemplatePathOrUrl: "/custom/path", + ConfigFile: "/config/file", + } + + tmpl, err := r.Resolve(context.Background()) + require.NoError(t, err) + + assert.Equal(t, Custom, tmpl.name,) + + // Assert reader configuration + assert.Equal(t, "/custom/path", tmpl.Reader.(*localReader).path) + + // Assert writer configuration + assert.Equal(t, "/config/file", tmpl.Writer.(*defaultWriter).configPath) +} From 4743095adc6a095a80de902e44e97413025d14ba Mon Sep 17 00:00:00 2001 From: Shreyas Goenka Date: Mon, 6 Jan 2025 18:42:34 +0530 Subject: [PATCH 17/22] lint --- libs/template/resolver_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/libs/template/resolver_test.go b/libs/template/resolver_test.go index d057e78808..a792908e7f 100644 --- a/libs/template/resolver_test.go +++ b/libs/template/resolver_test.go @@ -97,7 +97,7 @@ func TestTemplateResolverForCustomPath(t *testing.T) { tmpl, err := r.Resolve(context.Background()) require.NoError(t, err) - assert.Equal(t, Custom, tmpl.name,) + assert.Equal(t, Custom, tmpl.name) // Assert reader configuration assert.Equal(t, "/custom/path", tmpl.Reader.(*localReader).path) From 57a75190cc0dc5910b1aaac2a1cd69aed9846d80 Mon Sep 17 00:00:00 2001 From: Shreyas Goenka Date: Mon, 6 Jan 2025 18:49:28 +0530 Subject: [PATCH 18/22] cleanup a bit --- libs/template/resolver_test.go | 4 ++-- libs/template/template.go | 10 +++++----- libs/template/template_test.go | 2 +- libs/template/writer.go | 8 ++++---- 4 files changed, 12 insertions(+), 12 deletions(-) diff --git a/libs/template/resolver_test.go b/libs/template/resolver_test.go index a792908e7f..460912aa30 100644 --- a/libs/template/resolver_test.go +++ b/libs/template/resolver_test.go @@ -50,7 +50,7 @@ func TestTemplateResolverForDefaultTemplates(t *testing.T) { require.NoError(t, err) assert.Equal(t, &builtinReader{name: name}, tmpl.Reader) - assert.IsType(t, &writerWithTelemetry{}, tmpl.Writer) + assert.IsType(t, &writerWithFullTelemetry{}, tmpl.Writer) } r := Resolver{ @@ -63,7 +63,7 @@ func TestTemplateResolverForDefaultTemplates(t *testing.T) { // Assert reader and writer configuration assert.Equal(t, "https://github.com/databricks/mlops-stacks", tmpl.Reader.(*gitReader).gitUrl) - assert.Equal(t, "/config/file", tmpl.Writer.(*writerWithTelemetry).configPath) + assert.Equal(t, "/config/file", tmpl.Writer.(*writerWithFullTelemetry).configPath) } func TestTemplateResolverForCustomUrl(t *testing.T) { diff --git a/libs/template/template.go b/libs/template/template.go index 5e7d3dcd4e..120ca63d78 100644 --- a/libs/template/template.go +++ b/libs/template/template.go @@ -39,33 +39,33 @@ var allTemplates = []Template{ name: DefaultPython, description: "The default Python template for Notebooks / Delta Live Tables / Workflows", Reader: &builtinReader{name: "default-python"}, - Writer: &writerWithTelemetry{}, + Writer: &writerWithFullTelemetry{}, }, { name: DefaultSql, description: "The default SQL template for .sql files that run with Databricks SQL", Reader: &builtinReader{name: "default-sql"}, - Writer: &writerWithTelemetry{}, + Writer: &writerWithFullTelemetry{}, }, { name: DbtSql, description: "The dbt SQL template (databricks.com/blog/delivering-cost-effective-data-real-time-dbt-and-databricks)", Reader: &builtinReader{name: "dbt-sql"}, - Writer: &writerWithTelemetry{}, + Writer: &writerWithFullTelemetry{}, }, { name: MlopsStacks, description: "The Databricks MLOps Stacks template (github.com/databricks/mlops-stacks)", aliases: []string{"mlops-stack"}, Reader: &gitReader{gitUrl: "https://github.com/databricks/mlops-stacks", cloneFunc: git.Clone}, - Writer: &writerWithTelemetry{}, + Writer: &writerWithFullTelemetry{}, }, { name: DefaultPydabs, hidden: true, description: "The default PyDABs template", Reader: &gitReader{gitUrl: "https://databricks.github.io/workflows-authoring-toolkit/pydabs-template.git", cloneFunc: git.Clone}, - Writer: &writerWithTelemetry{}, + Writer: &writerWithFullTelemetry{}, }, { name: Custom, diff --git a/libs/template/template_test.go b/libs/template/template_test.go index a2221f7660..88e753816b 100644 --- a/libs/template/template_test.go +++ b/libs/template/template_test.go @@ -56,7 +56,7 @@ func TestTemplateTelemetryIsCapturedForAllDefaultTemplates(t *testing.T) { } else { // Assert telemetry is captured for all other templates, i.e. templates // owned by databricks. - assert.IsType(t, &writerWithTelemetry{}, w) + assert.IsType(t, &writerWithFullTelemetry{}, w) } } } diff --git a/libs/template/writer.go b/libs/template/writer.go index 7bdf1ba5a6..bfdd906fb2 100644 --- a/libs/template/writer.go +++ b/libs/template/writer.go @@ -152,15 +152,15 @@ func (tmpl *defaultWriter) Materialize(ctx context.Context, reader Reader) error } func (tmpl *defaultWriter) LogTelemetry(ctx context.Context) error { - // no-op + // TODO, only log the template name and uuid. return nil } -type writerWithTelemetry struct { +type writerWithFullTelemetry struct { defaultWriter } -func (tmpl *writerWithTelemetry) LogTelemetry(ctx context.Context) error { - // Log telemetry. TODO. +func (tmpl *writerWithFullTelemetry) LogTelemetry(ctx context.Context) error { + // TODO, log template name, uuid and enum args as well.`` return nil } From 38d47e62f20fb7aef0583f2e27ffb30b3a2e5e1b Mon Sep 17 00:00:00 2001 From: Shreyas Goenka Date: Mon, 6 Jan 2025 18:54:50 +0530 Subject: [PATCH 19/22] fix aliasing --- libs/template/template.go | 3 ++- libs/template/template_test.go | 3 +++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/libs/template/template.go b/libs/template/template.go index 120ca63d78..e40208445a 100644 --- a/libs/template/template.go +++ b/libs/template/template.go @@ -3,6 +3,7 @@ package template import ( "context" "fmt" + "slices" "strings" "github.com/databricks/cli/libs/cmdio" @@ -120,7 +121,7 @@ func SelectTemplate(ctx context.Context) (TemplateName, error) { func Get(name TemplateName) *Template { for _, template := range allTemplates { - if template.name == name { + if template.name == name || slices.Contains(template.aliases, string(name)) { return &template } } diff --git a/libs/template/template_test.go b/libs/template/template_test.go index 88e753816b..8540b7c226 100644 --- a/libs/template/template_test.go +++ b/libs/template/template_test.go @@ -86,4 +86,7 @@ func TestTemplateGet(t *testing.T) { tmpl := Get(TemplateName(name)) assert.Nil(t, tmpl) } + + // Assert the alias works. + assert.Equal(t, Get(TemplateName("mlops-stack")).name, MlopsStacks) } From ad28ce1d37b6e1945a92cb54ee7184290449289e Mon Sep 17 00:00:00 2001 From: Shreyas Goenka Date: Mon, 6 Jan 2025 18:58:45 +0530 Subject: [PATCH 20/22] fix aliasing --- libs/template/template_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/libs/template/template_test.go b/libs/template/template_test.go index 8540b7c226..cea55b424b 100644 --- a/libs/template/template_test.go +++ b/libs/template/template_test.go @@ -88,5 +88,5 @@ func TestTemplateGet(t *testing.T) { } // Assert the alias works. - assert.Equal(t, Get(TemplateName("mlops-stack")).name, MlopsStacks) + assert.Equal(t, MlopsStacks, Get(TemplateName("mlops-stack")).name) } From 2a0dbdec3e79635feb009067349c96c3727398f0 Mon Sep 17 00:00:00 2001 From: Shreyas Goenka Date: Thu, 9 Jan 2025 16:30:14 +0530 Subject: [PATCH 21/22] remove fail reader --- libs/template/reader.go | 10 ---------- libs/template/reader_test.go | 7 ------- libs/template/template.go | 5 +++-- 3 files changed, 3 insertions(+), 19 deletions(-) diff --git a/libs/template/reader.go b/libs/template/reader.go index 56d264eddc..44369ecbb1 100644 --- a/libs/template/reader.go +++ b/libs/template/reader.go @@ -153,13 +153,3 @@ func (r *localReader) FS(ctx context.Context) (fs.FS, error) { func (r *localReader) Close() error { return nil } - -type failReader struct{} - -func (r *failReader) FS(ctx context.Context) (fs.FS, error) { - return nil, fmt.Errorf("this is a placeholder reader that always fails. Please configure a real reader.") -} - -func (r *failReader) Close() error { - return fmt.Errorf("this is a placeholder reader that always fails. Please configure a real reader.") -} diff --git a/libs/template/reader_test.go b/libs/template/reader_test.go index f1e037fca4..3dd96647b1 100644 --- a/libs/template/reader_test.go +++ b/libs/template/reader_test.go @@ -114,10 +114,3 @@ func TestLocalReader(t *testing.T) { // Assert close does not error assert.NoError(t, r.Close()) } - -func TestFailReader(t *testing.T) { - r := &failReader{} - assert.Error(t, r.Close()) - _, err := r.FS(context.Background()) - assert.Error(t, err) -} diff --git a/libs/template/template.go b/libs/template/template.go index e40208445a..30f11e54a9 100644 --- a/libs/template/template.go +++ b/libs/template/template.go @@ -71,8 +71,9 @@ var allTemplates = []Template{ { name: Custom, description: "Bring your own template", - Reader: &failReader{}, - Writer: &defaultWriter{}, + // Reader is determined at runtime based on the user input. + Reader: nil, + Writer: &defaultWriter{}, }, } From a2a3ae7154697dcfa4b620ade95e6dfe440bdaa1 Mon Sep 17 00:00:00 2001 From: Shreyas Goenka Date: Thu, 9 Jan 2025 16:34:06 +0530 Subject: [PATCH 22/22] address comments --- libs/template/reader_test.go | 42 +++++++++++++++++++--------------- libs/template/resolver_test.go | 36 ++++++++++++++++------------- 2 files changed, 43 insertions(+), 35 deletions(-) diff --git a/libs/template/reader_test.go b/libs/template/reader_test.go index 3dd96647b1..58a65f0da9 100644 --- a/libs/template/reader_test.go +++ b/libs/template/reader_test.go @@ -23,27 +23,31 @@ func TestBuiltInReader(t *testing.T) { } for _, name := range exists { - r := &builtinReader{name: name} - fs, err := r.FS(context.Background()) - assert.NoError(t, err) - assert.NotNil(t, fs) - - // Assert file content returned is accurate and every template has a welcome - // message defined. - fd, err := fs.Open("databricks_template_schema.json") - require.NoError(t, err) - b, err := io.ReadAll(fd) - require.NoError(t, err) - assert.Contains(t, string(b), "welcome_message") - assert.NoError(t, fd.Close()) + t.Run(name, func(t *testing.T) { + r := &builtinReader{name: name} + fs, err := r.FS(context.Background()) + assert.NoError(t, err) + assert.NotNil(t, fs) + + // Assert file content returned is accurate and every template has a welcome + // message defined. + fd, err := fs.Open("databricks_template_schema.json") + require.NoError(t, err) + b, err := io.ReadAll(fd) + require.NoError(t, err) + assert.Contains(t, string(b), "welcome_message") + assert.NoError(t, fd.Close()) + }) } - r := &builtinReader{name: "doesnotexist"} - _, err := r.FS(context.Background()) - assert.EqualError(t, err, "builtin template doesnotexist not found") + t.Run("doesnotexist", func(t *testing.T) { + r := &builtinReader{name: "doesnotexist"} + _, err := r.FS(context.Background()) + assert.EqualError(t, err, "builtin template doesnotexist not found") - // Close should not error. - assert.NoError(t, r.Close()) + // Close should not error. + assert.NoError(t, r.Close()) + }) } func TestGitUrlReader(t *testing.T) { @@ -56,7 +60,7 @@ func TestGitUrlReader(t *testing.T) { cloneFunc := func(ctx context.Context, url, reference, targetPath string) error { numCalls++ args = []string{url, reference, targetPath} - err := os.MkdirAll(filepath.Join(targetPath, "a/b/c"), 0o755) + err := os.MkdirAll(filepath.Join(targetPath, "a", "b", "c"), 0o755) require.NoError(t, err) testutil.WriteFile(t, filepath.Join(targetPath, "a", "b", "c", "somefile"), "somecontent") return nil diff --git a/libs/template/resolver_test.go b/libs/template/resolver_test.go index 460912aa30..96a232a603 100644 --- a/libs/template/resolver_test.go +++ b/libs/template/resolver_test.go @@ -42,28 +42,32 @@ func TestTemplateResolverForDefaultTemplates(t *testing.T) { "default-sql", "dbt-sql", } { - r := Resolver{ - TemplatePathOrUrl: name, - } + t.Run(name, func(t *testing.T) { + r := Resolver{ + TemplatePathOrUrl: name, + } - tmpl, err := r.Resolve(context.Background()) - require.NoError(t, err) + tmpl, err := r.Resolve(context.Background()) + require.NoError(t, err) - assert.Equal(t, &builtinReader{name: name}, tmpl.Reader) - assert.IsType(t, &writerWithFullTelemetry{}, tmpl.Writer) + assert.Equal(t, &builtinReader{name: name}, tmpl.Reader) + assert.IsType(t, &writerWithFullTelemetry{}, tmpl.Writer) + }) } - r := Resolver{ - TemplatePathOrUrl: "mlops-stacks", - ConfigFile: "/config/file", - } + t.Run("mlops-stacks", func(t *testing.T) { + r := Resolver{ + TemplatePathOrUrl: "mlops-stacks", + ConfigFile: "/config/file", + } - tmpl, err := r.Resolve(context.Background()) - require.NoError(t, err) + tmpl, err := r.Resolve(context.Background()) + require.NoError(t, err) - // Assert reader and writer configuration - assert.Equal(t, "https://github.com/databricks/mlops-stacks", tmpl.Reader.(*gitReader).gitUrl) - assert.Equal(t, "/config/file", tmpl.Writer.(*writerWithFullTelemetry).configPath) + // Assert reader and writer configuration + assert.Equal(t, "https://github.com/databricks/mlops-stacks", tmpl.Reader.(*gitReader).gitUrl) + assert.Equal(t, "/config/file", tmpl.Writer.(*writerWithFullTelemetry).configPath) + }) } func TestTemplateResolverForCustomUrl(t *testing.T) {