Skip to content

Commit

Permalink
feat: Add postgresql_security_label resource (#482)
Browse files Browse the repository at this point in the history
This PR is based on the great job @jbunting did in #365, which fixed the
quoted identifier for object name and provider in the `pg_seclabels`
table.

Additional changes:

1. Update the doc to explain the import process

---------

Co-authored-by: Jared Bunting <[email protected]>
Co-authored-by: Cyril Gaudin <[email protected]>
  • Loading branch information
3 people authored Oct 30, 2024
1 parent 31fee05 commit b202448
Show file tree
Hide file tree
Showing 12 changed files with 560 additions and 1 deletion.
2 changes: 2 additions & 0 deletions postgresql/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ const (
featureFunction
featureServer
featureCreateRoleSelfGrant
featureSecurityLabel
)

var (
Expand Down Expand Up @@ -120,6 +121,7 @@ var (
// New privileges rules in version 16
// https://www.postgresql.org/docs/16/release-16.html#RELEASE-16-PRIVILEGES
featureCreateRoleSelfGrant: semver.MustParseRange(">=16.0.0"),
featureSecurityLabel: semver.MustParseRange(">=11.0.0"),
}
)

Expand Down
1 change: 1 addition & 0 deletions postgresql/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,7 @@ func Provider() *schema.Provider {
"postgresql_function": resourcePostgreSQLFunction(),
"postgresql_server": resourcePostgreSQLServer(),
"postgresql_user_mapping": resourcePostgreSQLUserMapping(),
"postgresql_security_label": resourcePostgreSQLSecurityLabel(),
},

DataSourcesMap: map[string]*schema.Resource{
Expand Down
198 changes: 198 additions & 0 deletions postgresql/resource_postgresql_security_label.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,198 @@
package postgresql

import (
"bytes"
"database/sql"
"fmt"
"log"
"regexp"
"strings"

"github.com/hashicorp/terraform-plugin-sdk/v2/helper/schema"
"github.com/lib/pq"
)

const (
securityLabelObjectNameAttr = "object_name"
securityLabelObjectTypeAttr = "object_type"
securityLabelProviderAttr = "label_provider"
securityLabelLabelAttr = "label"
)

func resourcePostgreSQLSecurityLabel() *schema.Resource {
return &schema.Resource{
Create: PGResourceFunc(resourcePostgreSQLSecurityLabelCreate),
Read: PGResourceFunc(resourcePostgreSQLSecurityLabelRead),
Update: PGResourceFunc(resourcePostgreSQLSecurityLabelUpdate),
Delete: PGResourceFunc(resourcePostgreSQLSecurityLabelDelete),
Importer: &schema.ResourceImporter{
StateContext: schema.ImportStatePassthroughContext,
},

Schema: map[string]*schema.Schema{
securityLabelObjectNameAttr: {
Type: schema.TypeString,
Required: true,
ForceNew: true,
Description: "The name of the existing object to apply the security label to",
},
securityLabelObjectTypeAttr: {
Type: schema.TypeString,
Required: true,
ForceNew: true,
Description: "The type of the existing object to apply the security label to",
},
securityLabelProviderAttr: {
Type: schema.TypeString,
Required: true,
ForceNew: true,
Description: "The provider to apply the security label for",
},
securityLabelLabelAttr: {
Type: schema.TypeString,
Required: true,
ForceNew: false,
Description: "The label to be applied",
},
},
}
}

func resourcePostgreSQLSecurityLabelCreate(db *DBConnection, d *schema.ResourceData) error {
if !db.featureSupported(featureSecurityLabel) {
return fmt.Errorf(
"security Label is not supported for this Postgres version (%s)",
db.version,
)
}
log.Printf("[DEBUG] PostgreSQL security label Create")
label := d.Get(securityLabelLabelAttr).(string)
if err := resourcePostgreSQLSecurityLabelUpdateImpl(db, d, pq.QuoteLiteral(label)); err != nil {
return err
}

d.SetId(generateSecurityLabelID(d))

return resourcePostgreSQLSecurityLabelReadImpl(db, d)
}

func resourcePostgreSQLSecurityLabelUpdateImpl(db *DBConnection, d *schema.ResourceData, label string) error {
b := bytes.NewBufferString("SECURITY LABEL ")

objectType := d.Get(securityLabelObjectTypeAttr).(string)
objectName := d.Get(securityLabelObjectNameAttr).(string)
provider := d.Get(securityLabelProviderAttr).(string)
fmt.Fprint(b, " FOR ", pq.QuoteIdentifier(provider))
fmt.Fprint(b, " ON ", objectType, pq.QuoteIdentifier(objectName))
fmt.Fprint(b, " IS ", label)

if _, err := db.Exec(b.String()); err != nil {
log.Printf("[WARN] PostgreSQL security label Create failed %s", err)
return fmt.Errorf("could not create security label: %w", err)
}
return nil
}

func resourcePostgreSQLSecurityLabelRead(db *DBConnection, d *schema.ResourceData) error {
if !db.featureSupported(featureSecurityLabel) {
return fmt.Errorf(
"Security Label is not supported for this Postgres version (%s)",
db.version,
)
}
log.Printf("[DEBUG] PostgreSQL security label Read")

return resourcePostgreSQLSecurityLabelReadImpl(db, d)
}

func resourcePostgreSQLSecurityLabelReadImpl(db *DBConnection, d *schema.ResourceData) error {
objectType := d.Get(securityLabelObjectTypeAttr).(string)
objectName := d.Get(securityLabelObjectNameAttr).(string)
provider := d.Get(securityLabelProviderAttr).(string)

txn, err := startTransaction(db.client, "")
if err != nil {
return err
}
defer deferredRollback(txn)

query := "SELECT objtype, provider, objname, label FROM pg_seclabels WHERE objtype = $1 and objname = $2 and provider = $3"
row := db.QueryRow(query, objectType, quoteIdentifier(objectName), quoteIdentifier(provider))

var label, newObjectName, newProvider string
err = row.Scan(&objectType, &newProvider, &newObjectName, &label)
switch {
case err == sql.ErrNoRows:
log.Printf("[WARN] PostgreSQL security label for (%s '%s') with provider %s not found", objectType, objectName, provider)
d.SetId("")
return nil
case err != nil:
return fmt.Errorf("Error reading security label: %w", err)
}

if quoteIdentifier(objectName) != newObjectName || quoteIdentifier(provider) != newProvider {
// In reality, this should never happen, but if it does, we want to make sure that the state is in sync with the remote system
// This will trigger a TF error saying that the provider has a bug if it ever happens
objectName = newObjectName
provider = newProvider
}
d.Set(securityLabelObjectTypeAttr, objectType)
d.Set(securityLabelObjectNameAttr, objectName)
d.Set(securityLabelProviderAttr, provider)
d.Set(securityLabelLabelAttr, label)
d.SetId(generateSecurityLabelID(d))

return nil
}

func resourcePostgreSQLSecurityLabelDelete(db *DBConnection, d *schema.ResourceData) error {
if !db.featureSupported(featureSecurityLabel) {
return fmt.Errorf(
"Security Label is not supported for this Postgres version (%s)",
db.version,
)
}
log.Printf("[DEBUG] PostgreSQL security label Delete")

if err := resourcePostgreSQLSecurityLabelUpdateImpl(db, d, "NULL"); err != nil {
return err
}

d.SetId("")

return nil
}

func resourcePostgreSQLSecurityLabelUpdate(db *DBConnection, d *schema.ResourceData) error {
if !db.featureSupported(featureServer) {
return fmt.Errorf(
"Security Label is not supported for this Postgres version (%s)",
db.version,
)
}
log.Printf("[DEBUG] PostgreSQL security label Update")

label := d.Get(securityLabelLabelAttr).(string)
if err := resourcePostgreSQLSecurityLabelUpdateImpl(db, d, pq.QuoteLiteral(label)); err != nil {
return err
}

return resourcePostgreSQLSecurityLabelReadImpl(db, d)
}

func generateSecurityLabelID(d *schema.ResourceData) string {
return strings.Join([]string{
d.Get(securityLabelProviderAttr).(string),
d.Get(securityLabelObjectTypeAttr).(string),
d.Get(securityLabelObjectNameAttr).(string),
}, ".")
}

func quoteIdentifier(s string) string {
var result = s
re := regexp.MustCompile(`^[a-zA-Z_][a-zA-Z0-9_]*$`)
if !re.MatchString(s) || s != strings.ToLower(s) {
result = pq.QuoteIdentifier(s)
}
return result
}
Loading

0 comments on commit b202448

Please sign in to comment.