Skip to content

Commit

Permalink
feat: add PostgreSQL as storage backend (#89)
Browse files Browse the repository at this point in the history
  • Loading branch information
radiohertz authored Jan 20, 2025
1 parent 55bb787 commit ba55bcb
Show file tree
Hide file tree
Showing 14 changed files with 830 additions and 0 deletions.
35 changes: 35 additions & 0 deletions .github/workflows/on-push-pr.yml
Original file line number Diff line number Diff line change
Expand Up @@ -126,3 +126,38 @@ jobs:
run: echo "NANODEP_MYSQL_STORAGE_TEST_DSN=nanodep:nanodep@tcp(localhost:$PORT)/nanodep" >> $GITHUB_ENV

- run: go test -v ./storage/mysql

pgsql-test:
runs-on: 'ubuntu-latest'
needs: format-build-test
services:
postgres:
image: postgres:13.16
env:
POSTGRES_DB: nanodep
POSTGRES_USER: nanodep
POSTGRES_PASSWORD: nanodep
options: >-
--health-cmd pg_isready
--health-interval 10s
--health-timeout 5s
--health-retries 5
ports:
- 5432:5432
env:
PGPASSWORD: nanodep
PORT: 5432
steps:
- uses: actions/checkout@d632683dd7b4114ad314bca15554477dd762a938 # v4.2.0

- uses: actions/setup-go@0a12ed9d6a96ab950c8f026ed9f722fe0da7ef32 # v5.0.2
with:
go-version: '1.21.x'

- name: pgsql schema
run: psql -h localhost -U nanodep -d nanodep -f ./storage/pgsql/schema.sql

- name: setup test dsn
run: echo "NANODEP_PSQL_STORAGE_TEST_DSN=postgres://nanodep:@localhost/nanodep?sslmode=disable" >> $GITHUB_ENV

- run: go test -v ./storage/pgsql
3 changes: 3 additions & 0 deletions cli/storage.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"github.com/micromdm/nanodep/storage/file"
"github.com/micromdm/nanodep/storage/inmem"
"github.com/micromdm/nanodep/storage/mysql"
"github.com/micromdm/nanodep/storage/pgsql"
)

// Storage parses a storage name and dsn to determine which and return a storage backend.
Expand All @@ -35,6 +36,8 @@ func Storage(storageName, dsn, options string) (storage.AllStorage, error) {
store = inmem.New()
case "mysql":
store, err = mysql.New(mysql.WithDSN(dsn))
case "pgsql":
store, err = pgsql.New(pgsql.WithDSN(dsn))
default:
return nil, fmt.Errorf("unknown storage: %q", storageName)
}
Expand Down
11 changes: 11 additions & 0 deletions docs/operations-guide.md
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,17 @@ Configures the MySQL storage backend. The `-dsn` flag should be in the [format t
*Example:* `-storage mysql -dsn nanodep:nanodep/mydepdb`

##### pgsql storage backend

* `-storage pgsql`

Configures the PostgreSQL storage backend. The `-storage-dsn` flag should be in the [format the SQL driver expects](https://pkg.go.dev/github.com/lib/pq#pkg-overview). PostgreSQL 9.5 or later is required.

> [!TIP]
> Be sure to create the storage tables with the [schema.sql](../storage/pgsql/schema.sql) file.
*Example:* `-storage pgsql -storage-dsn postgres://postgres:toor@localhost:5432/nanodep`

#### -version

* print version
Expand Down
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ go 1.17
require (
github.com/go-sql-driver/mysql v1.8.1
github.com/gomodule/oauth1 v0.2.0
github.com/lib/pq v1.10.9
github.com/micromdm/nanolib v0.2.0
github.com/peterbourgon/diskv/v3 v3.0.1
github.com/smallstep/pkcs7 v0.1.1
Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ github.com/gomodule/oauth1 v0.2.0/go.mod h1:4r/a8/3RkhMBxJQWL5qzbOEcaQmNPIkNoI7P
github.com/google/btree v1.0.0 h1:0udJVsspx3VBr5FwtLhQQtuAsVc79tTq0ocGIPAU6qo=
github.com/google/btree v1.0.0/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ=
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw=
github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o=
github.com/micromdm/nanolib v0.2.0 h1:g5GHQuUpS82WIAB15LyenjF/0/WSUNJMe5XZfCJSXq4=
github.com/micromdm/nanolib v0.2.0/go.mod h1:FwBKCvvphgYvbdUZ+qw5kay7NHJcg6zPi8W7kXNajmE=
github.com/peterbourgon/diskv/v3 v3.0.1 h1:x06SQA46+PKIUftmEujdwSEpIx8kR+M9eLYsUxeYveU=
Expand Down
3 changes: 3 additions & 0 deletions storage/pgsql/generate.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
package pgsql

//go:generate sqlc generate
243 changes: 243 additions & 0 deletions storage/pgsql/pgsql.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,243 @@
package pgsql

import (
"context"
"database/sql"
_ "embed"
"errors"
"fmt"
"time"

_ "github.com/lib/pq"
"github.com/micromdm/nanodep/client"
"github.com/micromdm/nanodep/storage"
"github.com/micromdm/nanodep/storage/pgsql/sqlc"
)

// PSQL implements storage.AllStorage using PSQL.
type PSQLStorage struct {
db *sql.DB
q *sqlc.Queries
}

type config struct {
driver string
dsn string
db *sql.DB
}

// Function callback to configure PSQLStorage
type Option func(*config)

// WithDSN sets the data source name
func WithDSN(dsn string) Option {
return func(c *config) {
c.dsn = dsn
}
}

// WithDriver sets the driver
func WithDriver(driver string) Option {
return func(c *config) {
c.driver = driver
}
}

// WithDB sets the db
func WithDB(db *sql.DB) Option {
return func(c *config) {
c.db = db

}
}

// Create a new PSQLStorage instance
func New(opts ...Option) (*PSQLStorage, error) {
cfg := &config{driver: "postgres"}
for _, opt := range opts {
opt(cfg)
}

var err error
if cfg.db == nil {
cfg.db, err = sql.Open(cfg.driver, cfg.dsn)
if err != nil {
return nil, err
}
}
if err = cfg.db.Ping(); err != nil {
return nil, err
}
return &PSQLStorage{db: cfg.db, q: sqlc.New(cfg.db)}, nil

}

// RetrieveAuthTokens reads the DEP OAuth tokens for name (DEP name).
func (s *PSQLStorage) RetrieveAuthTokens(ctx context.Context, name string) (*client.OAuth1Tokens, error) {
tokenRow, err := s.q.GetAuthTokens(ctx, name)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, fmt.Errorf("%v: %w", err, storage.ErrNotFound)
}
return nil, err
}
if !tokenRow.ConsumerKey.Valid { // all auth token fields are set together
return nil, fmt.Errorf("consumer key not valid: %w", storage.ErrNotFound)
}

return &client.OAuth1Tokens{
ConsumerKey: tokenRow.ConsumerKey.String,
ConsumerSecret: tokenRow.ConsumerSecret.String,
AccessToken: tokenRow.AccessToken.String,
AccessSecret: tokenRow.AccessSecret.String,
AccessTokenExpiry: tokenRow.AccessTokenExpiry.Time,
}, nil
}

// StoreAuthTokens saves the DEP OAuth tokens for the DEP name.
func (s *PSQLStorage) StoreAuthTokens(ctx context.Context, name string, tokens *client.OAuth1Tokens) error {
return s.q.StoreAuthTokens(ctx, sqlc.StoreAuthTokensParams{
Name: name,
ConsumerKey: sql.NullString{String: tokens.ConsumerKey, Valid: true},
ConsumerSecret: sql.NullString{String: tokens.ConsumerSecret, Valid: true},
AccessToken: sql.NullString{String: tokens.AccessToken, Valid: true},
AccessSecret: sql.NullString{String: tokens.AccessSecret, Valid: true},
AccessTokenExpiry: sql.NullTime{Time: tokens.AccessTokenExpiry, Valid: true},
})
}

// RetrieveConfig reads the JSON DEP config of a DEP name.
//
// Returns (nil, nil) if the DEP name does not exist, or if the config
// for the DEP name does not exist.
func (s *PSQLStorage) RetrieveConfig(ctx context.Context, name string) (*client.Config, error) {
baseURL, err := s.q.GetConfigBaseURL(ctx, name)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
// If the DEP name does not exist, then the config does not exist.
return nil, nil
}
return nil, err
}
if !baseURL.Valid {
// If the config_base_url is NULL, then config does not exist.
return nil, nil
}
return &client.Config{
BaseURL: baseURL.String,
}, nil
}

// StoreConfig saves the DEP config for name (DEP name).
func (s *PSQLStorage) StoreConfig(ctx context.Context, name string, config *client.Config) error {
return s.q.StoreConfig(ctx, sqlc.StoreConfigParams{
Name: name,
ConfigBaseUrl: sql.NullString{String: config.BaseURL, Valid: true},
})
}

// RetrieveAssignerProfile reads the assigner profile UUID and its timestamp for name (DEP name).
//
// Returns an empty profile UUID if it does not exist.
func (s *PSQLStorage) RetrieveAssignerProfile(ctx context.Context, name string) (profileUUID string, modTime time.Time, err error) {
assignerRow, err := s.q.GetAssignerProfile(ctx, name)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
// an 'empty' profile UUID is valid, return nil error
return "", time.Time{}, nil
}
return "", time.Time{}, err
}
if assignerRow.AssignerProfileUuid.Valid {
profileUUID = assignerRow.AssignerProfileUuid.String
}
if assignerRow.AssignerProfileUuidAt.Valid {
modTime = assignerRow.AssignerProfileUuidAt.Time
}
return
}

// StoreAssignerProfile saves the assigner profile UUID for name (DEP name).
func (s *PSQLStorage) StoreAssignerProfile(ctx context.Context, name string, profileUUID string) error {
return s.q.StoreAssignerProfile(ctx, sqlc.StoreAssignerProfileParams{
Name: name,
AssignerProfileUuid: sql.NullString{String: profileUUID, Valid: true},
})
}

// RetrieveCursor reads the reads the DEP fetch and sync cursor for name (DEP name).
//
// Returns an empty cursor if the cursor does not exist.
func (s *PSQLStorage) RetrieveCursor(ctx context.Context, name string) (string, error) {
cursor, err := s.q.GetSyncerCursor(ctx, name)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return "", nil
}
return "", err
}
if !cursor.Valid {
return "", nil
}
return cursor.String, nil
}

// StoreCursor saves the DEP fetch and sync cursor for name (DEP name).
func (s *PSQLStorage) StoreCursor(ctx context.Context, name, cursor string) error {
return s.q.StoreCursor(ctx, sqlc.StoreCursorParams{
Name: name,
SyncerCursor: sql.NullString{String: cursor, Valid: true},
})

}

// StoreTokenPKI stores the staging PEM bytes in pemCert and pemKey for name (DEP name).
func (s *PSQLStorage) StoreTokenPKI(ctx context.Context, name string, pemCert []byte, pemKey []byte) error {
return s.q.StoreTokenPKI(ctx, sqlc.StoreTokenPKIParams{
Name: name,
TokenpkiStagingCertPem: pemCert,
TokenpkiStagingKeyPem: pemKey,
})
}

// UpstageTokenPKI copies the staging PKI certificate and private key to the
// current PKI certificate and private key.
func (s *PSQLStorage) UpstageTokenPKI(ctx context.Context, name string) error {
err := s.q.UpstageKeypair(ctx, name)
if errors.Is(err, sql.ErrNoRows) {
return fmt.Errorf("%v: %w", err, storage.ErrNotFound)
}
return err
}

// RetrieveStagingTokenPKI returns the PEM bytes for the staged DEP
// token exchange certificate and private key using name (DEP name).
func (s *PSQLStorage) RetrieveStagingTokenPKI(ctx context.Context, name string) ([]byte, []byte, error) {
keypair, err := s.q.GetStagingKeypair(ctx, name)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, nil, fmt.Errorf("%v: %w", err, storage.ErrNotFound)
}
return nil, nil, err
}
if keypair.TokenpkiStagingCertPem == nil { // tokenpki_staging_cert_pem and tokenpki_staging_key_pem are set together
return nil, nil, fmt.Errorf("empty certificate: %w", storage.ErrNotFound)
}
return keypair.TokenpkiStagingCertPem, keypair.TokenpkiStagingKeyPem, nil
}

// RetrieveCurrentTokenPKI returns the PEM bytes for the previously-upstaged DEP
// token exchange certificate and private key using name (DEP name).
func (s *PSQLStorage) RetrieveCurrentTokenPKI(ctx context.Context, name string) (pemCert []byte, pemKey []byte, err error) {
keypair, err := s.q.GetCurrentKeypair(ctx, name)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, nil, fmt.Errorf("%v: %w", err, storage.ErrNotFound)
}
return nil, nil, err
}
if keypair.TokenpkiCertPem == nil { // tokenpki_cert_pem and tokenpki_key_pem are set together
return nil, nil, fmt.Errorf("empty certificate: %w", storage.ErrNotFound)
}
return keypair.TokenpkiCertPem, keypair.TokenpkiKeyPem, nil
}
25 changes: 25 additions & 0 deletions storage/pgsql/pgsql_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
package pgsql

import (
"context"
"os"
"testing"

_ "github.com/lib/pq"

"github.com/micromdm/nanodep/storage/test"
)

func TestPSQLStorage(t *testing.T) {
testDSN := os.Getenv("NANODEP_PSQL_STORAGE_TEST_DSN")
if testDSN == "" {
t.Skip("NANODEP_PSQL_STORAGE_TEST_DSN not set")
}

s, err := New(WithDSN(testDSN))
if err != nil {
t.Fatal(err)
}

test.TestWithStorages(t, context.Background(), s)
}
Loading

0 comments on commit ba55bcb

Please sign in to comment.