diff --git a/network-discovery/cmd/main.go b/network-discovery/cmd/main.go index 129e1e2..6bcff18 100644 --- a/network-discovery/cmd/main.go +++ b/network-discovery/cmd/main.go @@ -2,12 +2,10 @@ package main import ( "context" - "flag" "fmt" "log/slog" "os" "os/signal" - "strings" "syscall" "github.com/netboxlabs/diode-sdk-go/diode" @@ -24,57 +22,15 @@ const DefaultAppName = "network-discovery" // set via ldflags -X option at build time var version = "unknown" -func newLogger(logLevel string, logFormat string) *slog.Logger { - var l slog.Level - switch strings.ToUpper(logLevel) { - case "DEBUG": - l = slog.LevelDebug - case "INFO": - l = slog.LevelInfo - case "WARN": - l = slog.LevelWarn - case "ERROR": - l = slog.LevelError - default: - l = slog.LevelDebug - } - - var h slog.Handler - switch strings.ToUpper(logFormat) { - case "TEXT": - h = slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: l, AddSource: false}) - case "JSON": - h = slog.NewJSONHandler(os.Stdout, &slog.HandlerOptions{Level: l, AddSource: false}) - default: - h = slog.NewJSONHandler(os.Stdout, &slog.HandlerOptions{Level: l, AddSource: false}) - } - - return slog.New(h) -} - func main() { - configPath := flag.String("config", "", "path to the configuration file (required)") - - flag.Parse() - - if *configPath == "" { - fmt.Fprintf(os.Stderr, "Usage of network-discovery:\n") - flag.PrintDefaults() - os.Exit(1) - } - if _, err := os.Stat(*configPath); os.IsNotExist(err) { - fmt.Printf("configuration file '%s' does not exist\n", *configPath) - os.Exit(1) - } - - yamlData, err := os.ReadFile(*configPath) + fileData, err := config.RequireConfig() if err != nil { - fmt.Printf("error reading configuration file: %v\n", err) + fmt.Printf("%v\n", err) os.Exit(1) } - config := config.Config{ + c := config.Config{ Network: config.Network{ Config: config.StartupConfig{ Host: "0.0.0.0", @@ -84,16 +40,16 @@ func main() { }}, } - if err = yaml.Unmarshal(yamlData, &config); err != nil { + if err = yaml.Unmarshal(fileData, &c); err != nil { fmt.Printf("error parsing configuration file: %v\n", err) os.Exit(1) } client, err := diode.NewClient( - config.Network.Config.Target, + c.Network.Config.Target, DefaultAppName, version, - diode.WithAPIKey(config.Network.Config.APIKey), + diode.WithAPIKey(c.Network.Config.APIKey), ) if err != nil { fmt.Printf("error creating diode client: %v\n", err) @@ -101,7 +57,7 @@ func main() { } ctx := context.Background() - logger := newLogger(config.Network.Config.LogLevel, config.Network.Config.LogFormat) + logger := config.NewLogger(c.Network.Config.LogLevel, c.Network.Config.LogFormat) policyManager := policy.Manager{} err = policyManager.Configure(ctx, logger, client) @@ -111,7 +67,7 @@ func main() { } server := server.Server{} - server.Configure(logger, &policyManager, version, config.Network.Config) + server.Configure(logger, &policyManager, version, c.Network.Config) // handle signals done := make(chan bool, 1) diff --git a/network-discovery/config/utils.go b/network-discovery/config/utils.go new file mode 100644 index 0000000..d214bae --- /dev/null +++ b/network-discovery/config/utils.go @@ -0,0 +1,61 @@ +package config + +import ( + "flag" + "fmt" + "log/slog" + "os" + "strings" +) + +func NewLogger(logLevel string, logFormat string) *slog.Logger { + var l slog.Level + switch strings.ToUpper(logLevel) { + case "DEBUG": + l = slog.LevelDebug + case "INFO": + l = slog.LevelInfo + case "WARN": + l = slog.LevelWarn + case "ERROR": + l = slog.LevelError + default: + l = slog.LevelDebug + } + + var h slog.Handler + switch strings.ToUpper(logFormat) { + case "TEXT": + h = slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: l, AddSource: false}) + case "JSON": + h = slog.NewJSONHandler(os.Stdout, &slog.HandlerOptions{Level: l, AddSource: false}) + default: + h = slog.NewJSONHandler(os.Stdout, &slog.HandlerOptions{Level: l, AddSource: false}) + } + + return slog.New(h) +} + +func RequireConfig() ([]byte, error) { + + configPath := flag.String("config", "", "path to the configuration file (required)") + + flag.Parse() + + if *configPath == "" { + fmt.Fprintf(os.Stderr, "Usage of network-discovery:\n") + flag.PrintDefaults() + return nil, fmt.Errorf("") + + } + if _, err := os.Stat(*configPath); os.IsNotExist(err) { + return nil, fmt.Errorf("configuration file '%s' does not exist\n", *configPath) + } + + fileData, err := os.ReadFile(*configPath) + if err != nil { + return nil, fmt.Errorf("error reading configuration file: %v\n", err) + } + + return fileData, nil +} diff --git a/network-discovery/config/utils_test.go b/network-discovery/config/utils_test.go new file mode 100644 index 0000000..8ddad9f --- /dev/null +++ b/network-discovery/config/utils_test.go @@ -0,0 +1,132 @@ +package config_test + +import ( + "flag" + "log/slog" + "os" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/netboxlabs/orb-discovery/network-discovery/config" +) + +func TestNewLogger(t *testing.T) { + tests := []struct { + desc string + loggingLevel string + loggingFormat string + }{ + { + desc: "with debug level and json format", + loggingLevel: "debug", + loggingFormat: "json", + }, + { + desc: "with debug level and text format", + loggingLevel: "debug", + loggingFormat: "text", + }, + { + desc: "with info level and json format", + loggingLevel: "info", + loggingFormat: "json", + }, + { + desc: "with info level and text format", + loggingLevel: "warn", + loggingFormat: "json", + }, + { + desc: "with error level and text format", + loggingLevel: "error", + loggingFormat: "text", + }, + { + desc: "with error level and empty format", + loggingLevel: "error", + loggingFormat: "", + }, + { + desc: "with empty level and text format", + loggingLevel: "", + loggingFormat: "text", + }, + } + for _, tt := range tests { + t.Run(tt.desc, func(t *testing.T) { + log := config.NewLogger(tt.loggingLevel, tt.loggingFormat) + require.NotNil(t, log) + + handlerOK := false + if tt.loggingFormat == "text" { + _, handlerOK = log.Handler().(*slog.TextHandler) + } else { + _, handlerOK = log.Handler().(*slog.JSONHandler) + } + assert.True(t, handlerOK) + }) + } +} + +func TestRequireConfig(t *testing.T) { + t.Run("No Config Path Provided", func(t *testing.T) { + // Simulate no flags passed + os.Args = []string{"network-discovery"} + flag.CommandLine = flag.NewFlagSet(os.Args[0], flag.ExitOnError) // Reset flags + + data, err := config.RequireConfig() + assert.Nil(t, data) + assert.EqualError(t, err, "") + }) + + t.Run("Config File Does Not Exist", func(t *testing.T) { + // Simulate a non-existent file + os.Args = []string{"network-discovery", "-config", "/non/existent/path"} + flag.CommandLine = flag.NewFlagSet(os.Args[0], flag.ExitOnError) // Reset flags + + data, err := config.RequireConfig() + assert.Nil(t, data) + assert.EqualError(t, err, "configuration file '/non/existent/path' does not exist\n") + }) + + t.Run("Error Reading Config File", func(t *testing.T) { + // Create a file and simulate an error by removing it before reading + tmpFile, err := os.CreateTemp("", "test-config") + assert.NoError(t, err) + tmpFilePath := tmpFile.Name() + tmpFile.Close() + + // Remove the file to simulate read error + os.Remove(tmpFilePath) + + os.Args = []string{"network-discovery", "-config", tmpFilePath} + flag.CommandLine = flag.NewFlagSet(os.Args[0], flag.ExitOnError) // Reset flags + + data, err := config.RequireConfig() + assert.Nil(t, data) + assert.Contains(t, err.Error(), "does not exist") + }) + + t.Run("Valid Config File", func(t *testing.T) { + // Create a temporary file with valid content + tmpFile, err := os.CreateTemp("", "test-config") + assert.NoError(t, err) + defer os.Remove(tmpFile.Name()) + + // Write YAML content to the file + content := "network:\n policies:\n discovery_1:\n config:\n schedule: '* * * * *'" + _, err = tmpFile.WriteString(content) + assert.NoError(t, err) + tmpFile.Close() + + // Pass the file path as a flag + os.Args = []string{"network-discovery", "-config", tmpFile.Name()} + flag.CommandLine = flag.NewFlagSet(os.Args[0], flag.ExitOnError) // Reset flags + + data, err := config.RequireConfig() + assert.NoError(t, err) + assert.Equal(t, content, string(data)) + }) +}