Skip to content

Commit

Permalink
feat: update repo
Browse files Browse the repository at this point in the history
  • Loading branch information
Zygimantass committed Jan 3, 2025
1 parent ec9a667 commit 5a0a773
Show file tree
Hide file tree
Showing 8 changed files with 364 additions and 132 deletions.
44 changes: 24 additions & 20 deletions core/provider/digitalocean/digitalocean_provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@ package digitalocean
import (
"context"
"fmt"
"strings"

"github.com/digitalocean/godo"
"github.com/puzpuzpuz/xsync/v3"
"go.uber.org/zap"

xsync "github.com/puzpuzpuz/xsync/v3"

"github.com/skip-mev/petri/core/v2/provider"
"github.com/skip-mev/petri/core/v2/util"
"golang.org/x/crypto/ssh"
Expand All @@ -27,30 +29,33 @@ type Provider struct {

userIPs []string

sshPubKey, sshPrivKey, sshFingerprint string
sshKeyPair *SSHKeyPair

droplets *xsync.MapOf[string, *godo.Droplet]
containers *xsync.MapOf[string, string]
sshClients *xsync.MapOf[string, *ssh.Client]

firewallID string
}

// NewDigitalOceanProvider creates a provider that implements the Provider interface for DigitalOcean.
// Token is the DigitalOcean API token
func NewDigitalOceanProvider(ctx context.Context, logger *zap.Logger, providerName string, token string) (*Provider, error) {
func NewDigitalOceanProvider(ctx context.Context, logger *zap.Logger, providerName string, token string, additionalUserIPS []string, sshKeyPair *SSHKeyPair) (*Provider, error) {
doClient := godo.NewFromToken(token)

sshPubKey, sshPrivKey, sshFingerprint, err := makeSSHKeyPair()
if err != nil {
return nil, err
if sshKeyPair == nil {
newSshKeyPair, err := MakeSSHKeyPair()
if err != nil {
return nil, err
}
sshKeyPair = newSshKeyPair
}

userIPs, err := getUserIPs(ctx)
if err != nil {
return nil, err
}

userIPs = append(userIPs, additionalUserIPS...)

digitalOceanProvider := &Provider{
logger: logger.Named("digitalocean_provider"),
name: providerName,
Expand All @@ -59,17 +64,10 @@ func NewDigitalOceanProvider(ctx context.Context, logger *zap.Logger, providerNa

userIPs: userIPs,

droplets: xsync.NewMapOf[string, *godo.Droplet](),
containers: xsync.NewMapOf[string, string](),
sshClients: xsync.NewMapOf[string, *ssh.Client](),

sshPubKey: sshPubKey,
sshPrivKey: sshPrivKey,
sshFingerprint: sshFingerprint,
sshKeyPair: sshKeyPair,
}

logger.Debug("petri tag", zap.String("tag", digitalOceanProvider.petriTag))

_, err = digitalOceanProvider.createTag(ctx, digitalOceanProvider.petriTag)
if err != nil {
return nil, err
Expand All @@ -81,9 +79,15 @@ func NewDigitalOceanProvider(ctx context.Context, logger *zap.Logger, providerNa
}

digitalOceanProvider.firewallID = firewall.ID
_, err = digitalOceanProvider.createSSHKey(ctx, sshPubKey)
if err != nil {
return nil, err

//TODO(Zygimantass): TOCTOU issue
if key, _, err := doClient.Keys.GetByFingerprint(ctx, sshKeyPair.Fingerprint); err != nil || key == nil {
_, err = digitalOceanProvider.createSSHKey(ctx, sshKeyPair.PublicKey)
if err != nil {
if !strings.Contains(err.Error(), "422") {
return nil, err
}
}
}

return digitalOceanProvider, nil
Expand Down Expand Up @@ -135,7 +139,7 @@ func (p *Provider) teardownFirewall(ctx context.Context) error {
}

func (p *Provider) teardownSSHKey(ctx context.Context) error {
res, err := p.doClient.Keys.DeleteByFingerprint(ctx, p.sshFingerprint)
res, err := p.doClient.Keys.DeleteByFingerprint(ctx, p.sshKeyPair.Fingerprint)
if err != nil {
return err
}
Expand Down
56 changes: 32 additions & 24 deletions core/provider/digitalocean/droplet.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ import (
"github.com/skip-mev/petri/core/v2/provider"
"github.com/skip-mev/petri/core/v2/util"

"strconv"

_ "embed"
)

Expand All @@ -27,21 +29,29 @@ func (p *Provider) CreateDroplet(ctx context.Context, definition provider.TaskDe
return nil, fmt.Errorf("failed to validate task definition: %w", err)
}

doConfig, ok := definition.ProviderSpecificConfig.(DigitalOceanTaskConfig)
var doConfig DigitalOceanTaskConfig
doConfig = definition.ProviderSpecificConfig

if !ok {
return nil, fmt.Errorf("could not cast provider specific config to DigitalOceanConfig")
if err := doConfig.ValidateBasic(); err != nil {
return nil, fmt.Errorf("could not cast digitalocean specific config: %w", err)
}

imageId, err := strconv.ParseInt(doConfig["image_id"], 10, 64)

if err != nil {
return nil, fmt.Errorf("failed to parse image ID: %w", err)
}

req := &godo.DropletCreateRequest{
Name: fmt.Sprintf("%s-%s", p.petriTag, definition.Name),
Region: doConfig.Region,
Size: doConfig.Size,
Region: doConfig["region"],
Size: doConfig["size"],
Image: godo.DropletCreateImage{
ID: doConfig.ImageID,
ID: int(imageId),
},
SSHKeys: []godo.DropletCreateSSHKey{
{
Fingerprint: p.sshFingerprint,
Fingerprint: p.sshKeyPair.Fingerprint,
},
},
Tags: []string{p.petriTag},
Expand Down Expand Up @@ -100,13 +110,13 @@ func (p *Provider) CreateDroplet(ctx context.Context, definition provider.TaskDe
}

func (p *Provider) deleteDroplet(ctx context.Context, name string) error {
cachedDroplet, ok := p.droplets.Load(name)
droplet, err := p.getDroplet(ctx, name)

if !ok {
return fmt.Errorf("could not find droplet %s", name)
if err != nil {
return err
}

res, err := p.doClient.Droplets.Delete(ctx, cachedDroplet.ID)
res, err := p.doClient.Droplets.Delete(ctx, droplet.ID)
if err != nil {
return err
}
Expand All @@ -118,17 +128,11 @@ func (p *Provider) deleteDroplet(ctx context.Context, name string) error {
return nil
}

func (p *Provider) getDroplet(ctx context.Context, name string, returnOnCacheHit bool) (*godo.Droplet, error) {
cachedDroplet, ok := p.droplets.Load(name)
if !ok {
return nil, fmt.Errorf("could not find droplet %s", name)
}

if ok && returnOnCacheHit {
return cachedDroplet, nil
}
func (p *Provider) getDroplet(ctx context.Context, name string) (*godo.Droplet, error) {
// TODO(Zygimantass): this change assumes that all Petri droplets are unique by name
// which should be technically true, but there might be edge cases where it's not.
droplets, res, err := p.doClient.Droplets.ListByName(ctx, name, nil)

droplet, res, err := p.doClient.Droplets.Get(ctx, cachedDroplet.ID)
if err != nil {
return nil, err
}
Expand All @@ -137,7 +141,11 @@ func (p *Provider) getDroplet(ctx context.Context, name string, returnOnCacheHit
return nil, fmt.Errorf("unexpected status code: %d", res.StatusCode)
}

return droplet, nil
if len(droplets) == 0 {
return nil, fmt.Errorf("could not find droplet")
}

return &droplets[0], nil
}

func (p *Provider) getDropletDockerClient(ctx context.Context, taskName string) (*dockerclient.Client, error) {
Expand All @@ -155,7 +163,7 @@ func (p *Provider) getDropletDockerClient(ctx context.Context, taskName string)
}

func (p *Provider) getDropletSSHClient(ctx context.Context, taskName string) (*ssh.Client, error) {
if _, ok := p.droplets.Load(taskName); !ok {
if _, err := p.getDroplet(ctx, taskName); err != nil {
return nil, fmt.Errorf("droplet %s does not exist", taskName)
}

Expand All @@ -172,7 +180,7 @@ func (p *Provider) getDropletSSHClient(ctx context.Context, taskName string) (*s
return nil, err
}

parsedSSHKey, err := ssh.ParsePrivateKey([]byte(p.sshPrivKey))
parsedSSHKey, err := ssh.ParsePrivateKey([]byte(p.sshKeyPair.PrivateKey))
if err != nil {
return nil, err
}
Expand Down
48 changes: 35 additions & 13 deletions core/provider/digitalocean/ssh.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package digitalocean

import (
"bytes"
"context"
"crypto/rand"
"crypto/rsa"
Expand All @@ -15,30 +16,51 @@ import (
"golang.org/x/crypto/ssh"
)

func makeSSHKeyPair() (string, string, string, error) {
privateKey, err := rsa.GenerateKey(rand.Reader, 1024)
if err != nil {
return "", "", "", err
}
type SSHKeyPair struct {
PublicKey string
PrivateKey string
Fingerprint string
}

// generate and write private key as PEM
var privKeyBuf strings.Builder
func ParseSSHKeyPair(privKey []byte) (*SSHKeyPair, error) {
block, _ := pem.Decode(privKey)

privateKeyPEM := &pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(privateKey)}
if err := pem.Encode(&privKeyBuf, privateKeyPEM); err != nil {
return "", "", "", err
privateKey, err := x509.ParsePKCS1PrivateKey(block.Bytes)
if err != nil {
return nil, err
}

// generate and write public key
pub, err := ssh.NewPublicKey(&privateKey.PublicKey)
if err != nil {
return "", "", "", err
return nil, err
}

var pubKeyBuf strings.Builder
pubKeyBuf.Write(ssh.MarshalAuthorizedKey(pub))

return pubKeyBuf.String(), privKeyBuf.String(), ssh.FingerprintLegacyMD5(pub), nil
return &SSHKeyPair{
PublicKey: pubKeyBuf.String(),
PrivateKey: string(privKey),
Fingerprint: ssh.FingerprintLegacyMD5(pub),
}, nil
}

func MakeSSHKeyPair() (*SSHKeyPair, error) {
privateKey, err := rsa.GenerateKey(rand.Reader, 1024)
if err != nil {
return nil, err
}

// generate and write private key as PEM
var privKeyBuf bytes.Buffer

privateKeyPEM := &pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(privateKey)}
if err := pem.Encode(&privKeyBuf, privateKeyPEM); err != nil {
return nil, err
}

return ParseSSHKeyPair(privKeyBuf.Bytes())
}

func getUserIPs(ctx context.Context) (ips []string, err error) {
Expand All @@ -56,7 +78,7 @@ func getUserIPs(ctx context.Context) (ips []string, err error) {

ips = append(ips, strings.Trim(string(ifconfigIoIp), "\n"))

res, err = http.Get("https://ifconfig.co")
res, err = http.Get("https://ipinfo.io/ip")
if err != nil {
return ips, err
}
Expand Down
Loading

0 comments on commit 5a0a773

Please sign in to comment.