Skip to content

Commit

Permalink
Merge pull request #107 from skip-mev/mergify/bp/release/v1.x.x/pr-104
Browse files Browse the repository at this point in the history
[BLO-820] feat: add ValidateBasic checks for definitions and configs (backport #104)
  • Loading branch information
Zygimantass authored Nov 5, 2024
2 parents eb61461 + 3654781 commit 6fe29e2
Show file tree
Hide file tree
Showing 11 changed files with 401 additions and 2 deletions.
83 changes: 83 additions & 0 deletions core/provider/definitions.go
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
package provider

import (
"fmt"
"strconv"
)

// VolumeDefinition defines the configuration for a volume. Some providers might not support creating a volume,
// but this detail is abstracted away when CreateTask is used
type VolumeDefinition struct {
Expand All @@ -8,6 +13,22 @@ type VolumeDefinition struct {
Size string
}

func (v *VolumeDefinition) ValidateBasic() error {
if v.Name == "" {
return fmt.Errorf("volume name cannot be empty")
}

if v.MountPath == "" {
return fmt.Errorf("mount path cannot be empty")
}

if v.Size == "" {
return fmt.Errorf("size cannot be empty")
}

return nil
}

// TaskDefinition defines the configuration for a task
type TaskDefinition struct {
Name string // Name is used when generating volumes, etc. - additional resources for the container
Expand All @@ -24,9 +45,71 @@ type TaskDefinition struct {
ProviderSpecificConfig interface{}
}

func (t *TaskDefinition) ValidateBasic() error {
if t.Name == "" {
return fmt.Errorf("name cannot be empty")
}

if t.ContainerName == "" {
return fmt.Errorf("container name cannot be empty")
}

if err := t.Image.ValidateBasic(); err != nil {
return fmt.Errorf("image definition is invalid: %w", err)
}

for _, port := range t.Ports {
if port == "" {
return fmt.Errorf("port cannot be empty")
}

portInt, err := strconv.ParseUint(port, 10, 64)

if err != nil {
return fmt.Errorf("port must be a valid unsigned integer")
}

if portInt > 65535 {
return fmt.Errorf("port must be less than 65535")
}
}

for _, v := range t.Sidecars {
if err := v.ValidateBasic(); err != nil {
return fmt.Errorf("sidecar is invalid: %w", err)
}
}

return nil
}

// ImageDefinition defines the details of a specific Docker image
type ImageDefinition struct {
Image string
UID string
GID string
}

func (i *ImageDefinition) ValidateBasic() error {
if i.Image == "" {
return fmt.Errorf("image cannot be empty")
}

if i.UID == "" {
return fmt.Errorf("uid cannot be empty")
}

if _, err := strconv.ParseUint(i.UID, 10, 64); err != nil {
return fmt.Errorf("uid must be a valid unsigned integer")
}

if i.GID == "" {
return fmt.Errorf("gid cannot be empty")
}

if _, err := strconv.ParseUint(i.GID, 10, 64); err != nil {
return fmt.Errorf("gid must be a valid unsigned integer")
}

return nil
}
203 changes: 203 additions & 0 deletions core/provider/definitions_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,203 @@
package provider_test

import (
"github.com/skip-mev/petri/core/v2/provider"
"github.com/stretchr/testify/assert"
"testing"
)

var validImageDefinition = provider.ImageDefinition{
Image: "test",
GID: "1000",
UID: "1000",
}

func TestImageDefinitionValidation(t *testing.T) {
tcs := []struct {
name string
def provider.ImageDefinition
expectPass bool
}{
{
name: "valid",
def: provider.ImageDefinition{
Image: "test",
GID: "1000",
UID: "1000",
},
expectPass: true,
},
{
name: "empty image",
def: provider.ImageDefinition{
Image: "",
GID: "1000",
UID: "1000",
},
},
{
name: "empty uid",
def: provider.ImageDefinition{
Image: "test",
GID: "1000",
UID: "",
},
expectPass: false,
},
}

for _, tc := range tcs {
t.Run(tc.name, func(t *testing.T) {
err := tc.def.ValidateBasic()
if tc.expectPass {
assert.Nil(t, err)
} else {
assert.NotNil(t, err)
}
})
}
}

func TestVolumeDefinitionValidation(t *testing.T) {
tcs := []struct {
name string
def provider.VolumeDefinition
expectPass bool
}{
{
name: "valid",
def: provider.VolumeDefinition{
MountPath: "/tmp",
Name: "test",
Size: "100",
},
expectPass: true,
},
{
name: "empty mountpath",
def: provider.VolumeDefinition{
MountPath: "",
Name: "test",
Size: "100",
},
expectPass: false,
},
{
name: "empty name",
def: provider.VolumeDefinition{
MountPath: "/tmp",
Name: "",
Size: "100",
},
expectPass: false,
},
{
name: "empty size",
def: provider.VolumeDefinition{
MountPath: "/tmp",
Name: "test",
Size: "",
},
expectPass: false,
},
}

for _, tc := range tcs {
t.Run(tc.name, func(t *testing.T) {
err := tc.def.ValidateBasic()
if tc.expectPass {
assert.Nil(t, err)
} else {
assert.NotNil(t, err)
}
})
}
}

func TestTaskDefinitionValidation(t *testing.T) {
tcs := []struct {
name string
def provider.TaskDefinition
expectPass bool
}{
{
name: "valid",
def: provider.TaskDefinition{
Name: "test",
ContainerName: "test",
Image: validImageDefinition,
},
expectPass: true,
},
{
name: "no name",
def: provider.TaskDefinition{
ContainerName: "test",
Image: validImageDefinition,
},
expectPass: false,
},
{
name: "no container name",
def: provider.TaskDefinition{
Name: "test",
Image: validImageDefinition,
},
expectPass: false,
},
{
name: "no image",
def: provider.TaskDefinition{
Name: "test",
ContainerName: "test",
},
expectPass: false,
},
{
name: "invalid image",
def: provider.TaskDefinition{
Name: "test",
ContainerName: "test",
Image: provider.ImageDefinition{
Image: "",
},
},
expectPass: false,
},
{
name: "invalid port",
def: provider.TaskDefinition{
Name: "test",
ContainerName: "test",
Image: validImageDefinition,
Ports: []string{"", "100000"},
},
expectPass: false,
},
{
name: "invalid sidecar",
def: provider.TaskDefinition{
Name: "test",
ContainerName: "test",
Image: validImageDefinition,
Sidecars: []provider.TaskDefinition{
{
Name: "test",
},
},
},
expectPass: false,
},
}

for _, tc := range tcs {
t.Run(tc.name, func(t *testing.T) {
err := tc.def.ValidateBasic()
if tc.expectPass {
assert.Nil(t, err)
} else {
assert.NotNil(t, err)
}
})
}
}
4 changes: 4 additions & 0 deletions core/provider/digitalocean/droplet.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@ import _ "embed"
var dockerCloudInit string

func (p *Provider) CreateDroplet(ctx context.Context, definition provider.TaskDefinition) (*godo.Droplet, error) {
if err := definition.ValidateBasic(); err != nil {
return nil, fmt.Errorf("failed to validate task definition: %w", err)
}

doConfig, ok := definition.ProviderSpecificConfig.(DigitalOceanTaskConfig)

if !ok {
Expand Down
8 changes: 8 additions & 0 deletions core/provider/digitalocean/task.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@ import (
)

func (p *Provider) CreateTask(ctx context.Context, logger *zap.Logger, definition provider.TaskDefinition) (string, error) {
if err := definition.ValidateBasic(); err != nil {
return "", fmt.Errorf("failed to validate task definition: %w", err)
}

if definition.ProviderSpecificConfig == nil {
return "", fmt.Errorf("digitalocean specific config is nil")
}
Expand Down Expand Up @@ -355,6 +359,10 @@ func (p *Provider) RunCommand(ctx context.Context, taskName string, command []st
}

func (p *Provider) RunCommandWhileStopped(ctx context.Context, taskName string, definition provider.TaskDefinition, command []string) (string, string, int, error) {
if err := definition.ValidateBasic(); err != nil {
return "", "", 0, fmt.Errorf("failed to validate task definition: %w", err)
}

dockerClient, err := p.getDropletDockerClient(ctx, taskName)

if err != nil {
Expand Down
8 changes: 8 additions & 0 deletions core/provider/docker/task.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@ import (
)

func (p *Provider) CreateTask(ctx context.Context, logger *zap.Logger, definition provider.TaskDefinition) (string, error) {
if err := definition.ValidateBasic(); err != nil {
return "", fmt.Errorf("failed to validate task definition: %w", err)
}

logger = logger.Named("docker_provider")

if err := p.pullImage(ctx, definition.Image.Image); err != nil {
Expand Down Expand Up @@ -222,6 +226,10 @@ func (p *Provider) RunCommand(ctx context.Context, id string, command []string)
}

func (p *Provider) RunCommandWhileStopped(ctx context.Context, id string, definition provider.TaskDefinition, command []string) (string, string, int, error) {
if err := definition.ValidateBasic(); err != nil {
return "", "", 0, fmt.Errorf("failed to validate task definition: %w", err)
}

p.logger.Debug("running command while stopped", zap.String("id", id), zap.Strings("command", command))

status, err := p.GetTaskStatus(ctx, id)
Expand Down
4 changes: 4 additions & 0 deletions core/provider/docker/volume.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@ import (

// CreateVolume is an idempotent operation
func (p *Provider) CreateVolume(ctx context.Context, definition provider.VolumeDefinition) (string, error) {
if err := definition.ValidateBasic(); err != nil {
return "", fmt.Errorf("failed to validate volume definition: %w", err)
}

p.logger.Debug("creating volume", zap.String("name", definition.Name), zap.String("size", definition.Size))

existingVolume, err := p.dockerClient.VolumeInspect(ctx, definition.Name)
Expand Down
5 changes: 5 additions & 0 deletions core/provider/task.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,17 @@ package provider
import (
"context"
"errors"
"fmt"
"go.uber.org/zap"
"sync"
)

// CreateTask creates a task structure and sets up its underlying workload on a provider, including sidecars if there are any in the definition
func CreateTask(ctx context.Context, logger *zap.Logger, provider Provider, definition TaskDefinition) (*Task, error) {
if err := definition.ValidateBasic(); err != nil {
return nil, fmt.Errorf("failed to validate task definition: %w", err)
}

task := &Task{
Provider: provider,
Definition: definition,
Expand Down
Loading

0 comments on commit 6fe29e2

Please sign in to comment.