diff --git a/core/provider/definitions.go b/core/provider/definitions.go index 2d9680a..ed9a9dd 100644 --- a/core/provider/definitions.go +++ b/core/provider/definitions.go @@ -1,5 +1,10 @@ package provider +import ( + "fmt" + "strconv" +) + // VolumeDefinition defines the configuration for a volume. Some providers might not support creating a volume, // but this detail is abstracted away when CreateTask is used type VolumeDefinition struct { @@ -8,6 +13,22 @@ type VolumeDefinition struct { Size string } +func (v *VolumeDefinition) ValidateBasic() error { + if v.Name == "" { + return fmt.Errorf("volume name cannot be empty") + } + + if v.MountPath == "" { + return fmt.Errorf("mount path cannot be empty") + } + + if v.Size == "" { + return fmt.Errorf("size cannot be empty") + } + + return nil +} + // TaskDefinition defines the configuration for a task type TaskDefinition struct { Name string // Name is used when generating volumes, etc. - additional resources for the container @@ -24,9 +45,71 @@ type TaskDefinition struct { ProviderSpecificConfig interface{} } +func (t *TaskDefinition) ValidateBasic() error { + if t.Name == "" { + return fmt.Errorf("name cannot be empty") + } + + if t.ContainerName == "" { + return fmt.Errorf("container name cannot be empty") + } + + if err := t.Image.ValidateBasic(); err != nil { + return fmt.Errorf("image definition is invalid: %w", err) + } + + for _, port := range t.Ports { + if port == "" { + return fmt.Errorf("port cannot be empty") + } + + portInt, err := strconv.ParseUint(port, 10, 64) + + if err != nil { + return fmt.Errorf("port must be a valid unsigned integer") + } + + if portInt > 65535 { + return fmt.Errorf("port must be less than 65535") + } + } + + for _, v := range t.Sidecars { + if err := v.ValidateBasic(); err != nil { + return fmt.Errorf("sidecar is invalid: %w", err) + } + } + + return nil +} + // ImageDefinition defines the details of a specific Docker image type ImageDefinition struct { Image string UID string GID string } + +func (i *ImageDefinition) ValidateBasic() error { + if i.Image == "" { + return fmt.Errorf("image cannot be empty") + } + + if i.UID == "" { + return fmt.Errorf("uid cannot be empty") + } + + if _, err := strconv.ParseUint(i.UID, 10, 64); err != nil { + return fmt.Errorf("uid must be a valid unsigned integer") + } + + if i.GID == "" { + return fmt.Errorf("gid cannot be empty") + } + + if _, err := strconv.ParseUint(i.GID, 10, 64); err != nil { + return fmt.Errorf("gid must be a valid unsigned integer") + } + + return nil +} diff --git a/core/provider/definitions_test.go b/core/provider/definitions_test.go new file mode 100644 index 0000000..a0d6e7d --- /dev/null +++ b/core/provider/definitions_test.go @@ -0,0 +1,203 @@ +package provider_test + +import ( + "github.com/skip-mev/petri/core/v2/provider" + "github.com/stretchr/testify/assert" + "testing" +) + +var validImageDefinition = provider.ImageDefinition{ + Image: "test", + GID: "1000", + UID: "1000", +} + +func TestImageDefinitionValidation(t *testing.T) { + tcs := []struct { + name string + def provider.ImageDefinition + expectPass bool + }{ + { + name: "valid", + def: provider.ImageDefinition{ + Image: "test", + GID: "1000", + UID: "1000", + }, + expectPass: true, + }, + { + name: "empty image", + def: provider.ImageDefinition{ + Image: "", + GID: "1000", + UID: "1000", + }, + }, + { + name: "empty uid", + def: provider.ImageDefinition{ + Image: "test", + GID: "1000", + UID: "", + }, + expectPass: false, + }, + } + + for _, tc := range tcs { + t.Run(tc.name, func(t *testing.T) { + err := tc.def.ValidateBasic() + if tc.expectPass { + assert.Nil(t, err) + } else { + assert.NotNil(t, err) + } + }) + } +} + +func TestVolumeDefinitionValidation(t *testing.T) { + tcs := []struct { + name string + def provider.VolumeDefinition + expectPass bool + }{ + { + name: "valid", + def: provider.VolumeDefinition{ + MountPath: "/tmp", + Name: "test", + Size: "100", + }, + expectPass: true, + }, + { + name: "empty mountpath", + def: provider.VolumeDefinition{ + MountPath: "", + Name: "test", + Size: "100", + }, + expectPass: false, + }, + { + name: "empty name", + def: provider.VolumeDefinition{ + MountPath: "/tmp", + Name: "", + Size: "100", + }, + expectPass: false, + }, + { + name: "empty size", + def: provider.VolumeDefinition{ + MountPath: "/tmp", + Name: "test", + Size: "", + }, + expectPass: false, + }, + } + + for _, tc := range tcs { + t.Run(tc.name, func(t *testing.T) { + err := tc.def.ValidateBasic() + if tc.expectPass { + assert.Nil(t, err) + } else { + assert.NotNil(t, err) + } + }) + } +} + +func TestTaskDefinitionValidation(t *testing.T) { + tcs := []struct { + name string + def provider.TaskDefinition + expectPass bool + }{ + { + name: "valid", + def: provider.TaskDefinition{ + Name: "test", + ContainerName: "test", + Image: validImageDefinition, + }, + expectPass: true, + }, + { + name: "no name", + def: provider.TaskDefinition{ + ContainerName: "test", + Image: validImageDefinition, + }, + expectPass: false, + }, + { + name: "no container name", + def: provider.TaskDefinition{ + Name: "test", + Image: validImageDefinition, + }, + expectPass: false, + }, + { + name: "no image", + def: provider.TaskDefinition{ + Name: "test", + ContainerName: "test", + }, + expectPass: false, + }, + { + name: "invalid image", + def: provider.TaskDefinition{ + Name: "test", + ContainerName: "test", + Image: provider.ImageDefinition{ + Image: "", + }, + }, + expectPass: false, + }, + { + name: "invalid port", + def: provider.TaskDefinition{ + Name: "test", + ContainerName: "test", + Image: validImageDefinition, + Ports: []string{"", "100000"}, + }, + expectPass: false, + }, + { + name: "invalid sidecar", + def: provider.TaskDefinition{ + Name: "test", + ContainerName: "test", + Image: validImageDefinition, + Sidecars: []provider.TaskDefinition{ + { + Name: "test", + }, + }, + }, + expectPass: false, + }, + } + + for _, tc := range tcs { + t.Run(tc.name, func(t *testing.T) { + err := tc.def.ValidateBasic() + if tc.expectPass { + assert.Nil(t, err) + } else { + assert.NotNil(t, err) + } + }) + } +} diff --git a/core/provider/digitalocean/droplet.go b/core/provider/digitalocean/droplet.go index 4c7b6cd..a92f8e9 100644 --- a/core/provider/digitalocean/droplet.go +++ b/core/provider/digitalocean/droplet.go @@ -17,6 +17,10 @@ import _ "embed" var dockerCloudInit string func (p *Provider) CreateDroplet(ctx context.Context, definition provider.TaskDefinition) (*godo.Droplet, error) { + if err := definition.ValidateBasic(); err != nil { + return nil, fmt.Errorf("failed to validate task definition: %w", err) + } + doConfig, ok := definition.ProviderSpecificConfig.(DigitalOceanTaskConfig) if !ok { diff --git a/core/provider/digitalocean/task.go b/core/provider/digitalocean/task.go index fcb01bf..a6c1f95 100644 --- a/core/provider/digitalocean/task.go +++ b/core/provider/digitalocean/task.go @@ -22,6 +22,10 @@ import ( ) func (p *Provider) CreateTask(ctx context.Context, logger *zap.Logger, definition provider.TaskDefinition) (string, error) { + if err := definition.ValidateBasic(); err != nil { + return "", fmt.Errorf("failed to validate task definition: %w", err) + } + if definition.ProviderSpecificConfig == nil { return "", fmt.Errorf("digitalocean specific config is nil") } @@ -355,6 +359,10 @@ func (p *Provider) RunCommand(ctx context.Context, taskName string, command []st } func (p *Provider) RunCommandWhileStopped(ctx context.Context, taskName string, definition provider.TaskDefinition, command []string) (string, string, int, error) { + if err := definition.ValidateBasic(); err != nil { + return "", "", 0, fmt.Errorf("failed to validate task definition: %w", err) + } + dockerClient, err := p.getDropletDockerClient(ctx, taskName) if err != nil { diff --git a/core/provider/docker/task.go b/core/provider/docker/task.go index 6697ade..a407593 100644 --- a/core/provider/docker/task.go +++ b/core/provider/docker/task.go @@ -21,6 +21,10 @@ import ( ) func (p *Provider) CreateTask(ctx context.Context, logger *zap.Logger, definition provider.TaskDefinition) (string, error) { + if err := definition.ValidateBasic(); err != nil { + return "", fmt.Errorf("failed to validate task definition: %w", err) + } + logger = logger.Named("docker_provider") if err := p.pullImage(ctx, definition.Image.Image); err != nil { @@ -222,6 +226,10 @@ func (p *Provider) RunCommand(ctx context.Context, id string, command []string) } func (p *Provider) RunCommandWhileStopped(ctx context.Context, id string, definition provider.TaskDefinition, command []string) (string, string, int, error) { + if err := definition.ValidateBasic(); err != nil { + return "", "", 0, fmt.Errorf("failed to validate task definition: %w", err) + } + p.logger.Debug("running command while stopped", zap.String("id", id), zap.Strings("command", command)) status, err := p.GetTaskStatus(ctx, id) diff --git a/core/provider/docker/volume.go b/core/provider/docker/volume.go index 5409d42..af989c5 100644 --- a/core/provider/docker/volume.go +++ b/core/provider/docker/volume.go @@ -22,6 +22,10 @@ import ( // CreateVolume is an idempotent operation func (p *Provider) CreateVolume(ctx context.Context, definition provider.VolumeDefinition) (string, error) { + if err := definition.ValidateBasic(); err != nil { + return "", fmt.Errorf("failed to validate volume definition: %w", err) + } + p.logger.Debug("creating volume", zap.String("name", definition.Name), zap.String("size", definition.Size)) existingVolume, err := p.dockerClient.VolumeInspect(ctx, definition.Name) diff --git a/core/provider/task.go b/core/provider/task.go index 7234ceb..61103b7 100644 --- a/core/provider/task.go +++ b/core/provider/task.go @@ -3,12 +3,17 @@ package provider import ( "context" "errors" + "fmt" "go.uber.org/zap" "sync" ) // CreateTask creates a task structure and sets up its underlying workload on a provider, including sidecars if there are any in the definition func CreateTask(ctx context.Context, logger *zap.Logger, provider Provider, definition TaskDefinition) (*Task, error) { + if err := definition.ValidateBasic(); err != nil { + return nil, fmt.Errorf("failed to validate task definition: %w", err) + } + task := &Task{ Provider: provider, Definition: definition, diff --git a/core/types/chain.go b/core/types/chain.go index 3b742db..3f8f266 100644 --- a/core/types/chain.go +++ b/core/types/chain.go @@ -2,6 +2,7 @@ package types import ( "context" + "fmt" rpchttp "github.com/cometbft/cometbft/rpc/client/http" "github.com/cosmos/cosmos-sdk/client" codectypes "github.com/cosmos/cosmos-sdk/codec/types" @@ -10,6 +11,9 @@ import ( "google.golang.org/grpc" ) +// GenesisModifier is a function that takes in genesis bytes and returns modified genesis bytes +type GenesisModifier func([]byte) ([]byte, error) + // ChainI is an interface for a logical chain type ChainI interface { Init(context.Context) error @@ -71,5 +75,60 @@ type ChainConfig struct { NodeDefinitionModifier NodeDefinitionModifier // NodeDefinitionModifier is a function that modifies a node's definition } -// GenesisModifier is a function that takes in genesis bytes and returns modified genesis bytes -type GenesisModifier func([]byte) ([]byte, error) +func (c *ChainConfig) ValidateBasic() error { + if c.Denom == "" { + return fmt.Errorf("denom cannot be empty") + } + + if c.Decimals == 0 { + return fmt.Errorf("decimals cannot be 0") + } + + if c.NumValidators == 0 { + return fmt.Errorf("num validators cannot be 0") + } + + if c.BinaryName == "" { + return fmt.Errorf("binary name cannot be empty") + } + + if c.GasPrices == "" { + return fmt.Errorf("gas prices cannot be empty") + } + + if c.GasAdjustment == 0 { + return fmt.Errorf("gas adjustment cannot be 0") + } + + if err := c.Image.ValidateBasic(); err != nil { + return fmt.Errorf("image definition is invalid: %w", err) + } + + if c.SidecarImage.Image != "" { + if err := c.SidecarImage.ValidateBasic(); err != nil { + return fmt.Errorf("sidecar image definition is invalid: %w", err) + } + } + + if c.Bech32Prefix == "" { + return fmt.Errorf("bech32 prefix cannot be empty") + } + + if c.CoinType == "" { + return fmt.Errorf("coin type cannot be empty") + } + + if c.HDPath == "" { + return fmt.Errorf("HD path cannot be empty") + } + + if c.ChainId == "" { + return fmt.Errorf("chain ID cannot be empty") + } + + if c.NodeCreator == nil { + return fmt.Errorf("node creator cannot be nil") + } + + return nil +} diff --git a/core/types/node.go b/core/types/node.go index be332b3..dc49368 100644 --- a/core/types/node.go +++ b/core/types/node.go @@ -2,6 +2,7 @@ package types import ( "context" + "fmt" rpchttp "github.com/cometbft/cometbft/rpc/client/http" sdk "github.com/cosmos/cosmos-sdk/types" "github.com/skip-mev/petri/core/provider" @@ -20,6 +21,22 @@ type NodeConfig struct { Provider provider.Provider // Provider is the provider this node is running on } +func (c NodeConfig) ValidateBasic() error { + if c.Name == "" { + return fmt.Errorf("name cannot be empty") + } + + if c.Chain == nil { + return fmt.Errorf("chain cannot be nil") + } + + if c.Provider == nil { + return fmt.Errorf("provider cannot be nil") + } + + return nil +} + // NodeDefinitionModifier is a type of function that given a NodeConfig modifies the task definition. It usually // adds additional sidecars or modifies the entrypoint. This function is typically called in NodeCreator // before the task is created diff --git a/cosmos/chain/chain.go b/cosmos/chain/chain.go index 3e11151..6cc6153 100644 --- a/cosmos/chain/chain.go +++ b/cosmos/chain/chain.go @@ -39,6 +39,10 @@ var _ petritypes.ChainI = &Chain{} // CreateChain creates the Chain object and initializes the node tasks, their backing compute and the validator wallets func CreateChain(ctx context.Context, logger *zap.Logger, infraProvider provider.Provider, config petritypes.ChainConfig) (*Chain, error) { + if err := config.ValidateBasic(); err != nil { + return nil, fmt.Errorf("failed to validate chain config: %w", err) + } + var chain Chain chain.mu = sync.RWMutex{} diff --git a/cosmos/node/node.go b/cosmos/node/node.go index 5f7bde0..1e731c7 100644 --- a/cosmos/node/node.go +++ b/cosmos/node/node.go @@ -28,6 +28,10 @@ var _ petritypes.NodeCreator = CreateNode // CreateNode creates a new logical node and creates the underlying workload for it func CreateNode(ctx context.Context, logger *zap.Logger, nodeConfig petritypes.NodeConfig) (petritypes.NodeI, error) { + if err := nodeConfig.ValidateBasic(); err != nil { + return nil, fmt.Errorf("failed to validate node config: %w", err) + } + var node Node node.logger = logger.Named("node")