Skip to content

Commit

Permalink
Add azure-workload auth to MSSQL scaler
Browse files Browse the repository at this point in the history
Signed-off-by: Rick Brouwer <[email protected]>
  • Loading branch information
rickbrouwer committed Sep 16, 2024
1 parent 85d4dca commit 5a08a0d
Show file tree
Hide file tree
Showing 3 changed files with 219 additions and 235 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ Here is an overview of all new **experimental** features:
- **GCP Scalers**: Added custom time horizon in GCP scalers ([#5778](https://github.com/kedacore/keda/issues/5778))
- **GitHub Scaler**: Fixed pagination, fetching repository list ([#5738](https://github.com/kedacore/keda/issues/5738))
- **Kafka**: Fix logic to scale to zero on invalid offset even with earliest offsetResetPolicy ([#5689](https://github.com/kedacore/keda/issues/5689))
- **MSSQL Scaler**: Add azure-workload auth ([#6104](https://github.com/kedacore/keda/issues/6104))
- **RabbitMQ Scaler**: Add connection name for AMQP ([#5958](https://github.com/kedacore/keda/issues/5958))
- TODO ([#XXX](https://github.com/kedacore/keda/issues/XXX))

Expand Down
238 changes: 89 additions & 149 deletions pkg/scalers/mssql_scaler.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,73 +3,55 @@ package scalers
import (
"context"
"database/sql"
"errors"
"fmt"
"net"
"net/url"
"strconv"

// mssql driver required for this scaler
_ "github.com/denisenkom/go-mssqldb"
"github.com/go-logr/logr"
v2 "k8s.io/api/autoscaling/v2"
"k8s.io/metrics/pkg/apis/external_metrics"

"github.com/kedacore/keda/v2/apis/keda/v1alpha1"
"github.com/kedacore/keda/v2/pkg/scalers/azure"
"github.com/kedacore/keda/v2/pkg/scalers/scalersconfig"
)

var (
// ErrMsSQLNoQuery is returned when "query" is missing from the config.
ErrMsSQLNoQuery = errors.New("no query given")

// ErrMsSQLNoTargetValue is returned when "targetValue" is missing from the config.
ErrMsSQLNoTargetValue = errors.New("no targetValue given")
)

// mssqlScaler exposes a data pointer to mssqlMetadata and sql.DB connection
type mssqlScaler struct {
metricType v2.MetricTargetType
metadata *mssqlMetadata
metadata mssqlMetadata
connection *sql.DB
logger logr.Logger
azureOAuth *azure.ADWorkloadIdentityTokenProvider
}

// mssqlMetadata defines metadata used by KEDA to query a Microsoft SQL database
type mssqlMetadata struct {
// The connection string used to connect to the MSSQL database.
// Both URL syntax (sqlserver://host?database=dbName) and OLEDB syntax is supported.
// +optional
connectionString string
// The username credential for connecting to the MSSQL instance, if not specified in the connection string.
// +optional
username string
// The password credential for connecting to the MSSQL instance, if not specified in the connection string.
// +optional
password string
// The hostname of the MSSQL instance endpoint, if not specified in the connection string.
// +optional
host string
// The port number of the MSSQL instance endpoint, if not specified in the connection string.
// +optional
port int
// The name of the database to query, if not specified in the connection string.
// +optional
database string
// The T-SQL query to run against the target database - e.g. SELECT COUNT(*) FROM table.
// +required
query string
// The threshold that is used as targetAverageValue in the Horizontal Pod Autoscaler.
// +required
targetValue float64
// The threshold that is used in activation phase
// +optional
activationTargetValue float64
// The index of the scaler inside the ScaledObject
// +internal
triggerIndex int
ConnectionString string `keda:"name=connectionString,order=authParams;resolvedEnv;triggerMetadata,optional"`
Username string `keda:"name=username,order=authParams;triggerMetadata,optional"`
Password string `keda:"name=password,order=authParams;resolvedEnv;triggerMetadata,optional"`
Host string `keda:"name=host,order=authParams;triggerMetadata,optional"`
Port int `keda:"name=port,order=authParams;triggerMetadata,optional"`
Database string `keda:"name=database,order=authParams;triggerMetadata,optional"`
Query string `keda:"name=query,order=triggerMetadata"`
TargetValue float64 `keda:"name=targetValue,order=triggerMetadata"`
ActivationTargetValue float64 `keda:"name=activationTargetValue,order=triggerMetadata,optional,default=0"`

TriggerIndex int

WorkloadIdentityResource string `keda:"name=WorkloadIdentityResource,order=authParams;triggerMetadata,optional"`
WorkloadIdentityClientID string
WorkloadIdentityTenantID string
WorkloadIdentityAuthorityHost string
}

func (m *mssqlMetadata) Validate() error {
if m.ConnectionString == "" && m.Host == "" {
return fmt.Errorf("must provide either connectionstring or host")
}
return nil
}

// NewMSSQLScaler creates a new mssql scaler
func NewMSSQLScaler(config *scalersconfig.ScalerConfig) (Scaler, error) {
metricType, err := GetMetricTargetType(config)
if err != nil {
Expand All @@ -96,85 +78,28 @@ func NewMSSQLScaler(config *scalersconfig.ScalerConfig) (Scaler, error) {
}, nil
}

// parseMSSQLMetadata takes a ScalerConfig and returns a mssqlMetadata or an error if the config is invalid
func parseMSSQLMetadata(config *scalersconfig.ScalerConfig) (*mssqlMetadata, error) {
func parseMSSQLMetadata(config *scalersconfig.ScalerConfig) (mssqlMetadata, error) {
meta := mssqlMetadata{}

// Query
if val, ok := config.TriggerMetadata["query"]; ok {
meta.query = val
} else {
return nil, ErrMsSQLNoQuery
err := config.TypedConfig(&meta)
if err != nil {
return meta, err
}

// Target query value
if val, ok := config.TriggerMetadata["targetValue"]; ok {
targetValue, err := strconv.ParseFloat(val, 64)
if err != nil {
return nil, fmt.Errorf("targetValue parsing error %w", err)
}
meta.targetValue = targetValue
} else {
if config.AsMetricSource {
meta.targetValue = 0
} else {
return nil, ErrMsSQLNoTargetValue
}
}
meta.TriggerIndex = config.TriggerIndex

// Activation target value
meta.activationTargetValue = 0
if val, ok := config.TriggerMetadata["activationTargetValue"]; ok {
activationTargetValue, err := strconv.ParseFloat(val, 64)
if err != nil {
return nil, fmt.Errorf("activationTargetValue parsing error %w", err)
if config.PodIdentity.Provider == v1alpha1.PodIdentityProviderAzureWorkload {
if config.AuthParams["workloadIdentityResource"] != "" {
meta.WorkloadIdentityClientID = config.PodIdentity.GetIdentityID()
meta.WorkloadIdentityTenantID = config.PodIdentity.GetIdentityTenantID()
meta.WorkloadIdentityAuthorityHost = config.PodIdentity.GetIdentityAuthorityHost()
meta.WorkloadIdentityResource = config.AuthParams["workloadIdentityResource"]
}
meta.activationTargetValue = activationTargetValue
}

// Connection string, which can either be provided explicitly or via the helper fields
switch {
case config.AuthParams["connectionString"] != "":
meta.connectionString = config.AuthParams["connectionString"]
case config.TriggerMetadata["connectionStringFromEnv"] != "":
meta.connectionString = config.ResolvedEnv[config.TriggerMetadata["connectionStringFromEnv"]]
default:
meta.connectionString = ""
var err error

host, err := GetFromAuthOrMeta(config, "host")
if err != nil {
return nil, err
}
meta.host = host

var paramPort string
paramPort, _ = GetFromAuthOrMeta(config, "port")
if paramPort != "" {
port, err := strconv.Atoi(paramPort)
if err != nil {
return nil, fmt.Errorf("port parsing error %w", err)
}
meta.port = port
}

meta.username, _ = GetFromAuthOrMeta(config, "username")

// database is optional in SQL s
meta.database, _ = GetFromAuthOrMeta(config, "database")

if config.AuthParams["password"] != "" {
meta.password = config.AuthParams["password"]
} else if config.TriggerMetadata["passwordFromEnv"] != "" {
meta.password = config.ResolvedEnv[config.TriggerMetadata["passwordFromEnv"]]
}
}
meta.triggerIndex = config.TriggerIndex
return &meta, nil
return meta, nil
}

// newMSSQLConnection returns a new, opened SQL connection for the provided mssqlMetadata
func newMSSQLConnection(meta *mssqlMetadata, logger logr.Logger) (*sql.DB, error) {
func newMSSQLConnection(meta mssqlMetadata, logger logr.Logger) (*sql.DB, error) {
connStr := getMSSQLConnectionString(meta)

db, err := sql.Open("sqlserver", connStr)
Expand All @@ -192,46 +117,40 @@ func newMSSQLConnection(meta *mssqlMetadata, logger logr.Logger) (*sql.DB, error
return db, nil
}

// getMSSQLConnectionString returns a connection string from a mssqlMetadata
func getMSSQLConnectionString(meta *mssqlMetadata) string {
var connStr string

if meta.connectionString != "" {
connStr = meta.connectionString
} else {
query := url.Values{}
if meta.database != "" {
query.Add("database", meta.database)
}
func getMSSQLConnectionString(meta mssqlMetadata) string {
if meta.ConnectionString != "" {
return meta.ConnectionString
}

connectionURL := &url.URL{Scheme: "sqlserver", RawQuery: query.Encode()}
if meta.username != "" {
if meta.password != "" {
connectionURL.User = url.UserPassword(meta.username, meta.password)
} else {
connectionURL.User = url.User(meta.username)
}
}
query := url.Values{}
if meta.Database != "" {
query.Add("database", meta.Database)
}

if meta.port > 0 {
connectionURL.Host = net.JoinHostPort(meta.host, fmt.Sprintf("%d", meta.port))
connectionURL := &url.URL{Scheme: "sqlserver", RawQuery: query.Encode()}
if meta.Username != "" {
if meta.Password != "" {
connectionURL.User = url.UserPassword(meta.Username, meta.Password)
} else {
connectionURL.Host = meta.host
connectionURL.User = url.User(meta.Username)
}
}

connStr = connectionURL.String()
if meta.Port > 0 {
connectionURL.Host = net.JoinHostPort(meta.Host, fmt.Sprintf("%d", meta.Port))
} else {
connectionURL.Host = meta.Host
}

return connStr
return connectionURL.String()
}

// GetMetricSpecForScaling returns the MetricSpec for the Horizontal Pod Autoscaler
func (s *mssqlScaler) GetMetricSpecForScaling(context.Context) []v2.MetricSpec {
externalMetric := &v2.ExternalMetricSource{
Metric: v2.MetricIdentifier{
Name: GenerateMetricNameWithIndex(s.metadata.triggerIndex, "mssql"),
Name: GenerateMetricNameWithIndex(s.metadata.TriggerIndex, "mssql"),
},
Target: GetMetricTargetMili(s.metricType, s.metadata.targetValue),
Target: GetMetricTargetMili(s.metricType, s.metadata.TargetValue),
}

metricSpec := v2.MetricSpec{
Expand All @@ -241,7 +160,6 @@ func (s *mssqlScaler) GetMetricSpecForScaling(context.Context) []v2.MetricSpec {
return []v2.MetricSpec{metricSpec}
}

// GetMetricsAndActivity returns a value for a supported metric or an error if there is a problem getting the metric
func (s *mssqlScaler) GetMetricsAndActivity(ctx context.Context, metricName string) ([]external_metrics.ExternalMetricValue, bool, error) {
num, err := s.getQueryResult(ctx)
if err != nil {
Expand All @@ -250,25 +168,47 @@ func (s *mssqlScaler) GetMetricsAndActivity(ctx context.Context, metricName stri

metric := GenerateMetricInMili(metricName, num)

return []external_metrics.ExternalMetricValue{metric}, num > s.metadata.activationTargetValue, nil
return []external_metrics.ExternalMetricValue{metric}, num > s.metadata.ActivationTargetValue, nil
}

// getQueryResult returns the result of the scaler query
func (s *mssqlScaler) getQueryResult(ctx context.Context) (float64, error) {
var value float64
err := s.connection.QueryRowContext(ctx, s.metadata.query).Scan(&value)

// If using Azure Workload Identity, refresh the token
if s.metadata.WorkloadIdentityResource != "" {
if s.azureOAuth == nil {
s.azureOAuth = azure.NewAzureADWorkloadIdentityTokenProvider(ctx, s.metadata.WorkloadIdentityClientID, s.metadata.WorkloadIdentityTenantID, s.metadata.WorkloadIdentityAuthorityHost, s.metadata.WorkloadIdentityResource)
}

err := s.azureOAuth.Refresh()
if err != nil {
return 0, fmt.Errorf("error refreshing Azure AD token: %w", err)
}

// Set the access token for the database connection
err = s.connection.PingContext(ctx)
if err != nil {
return 0, fmt.Errorf("error pinging database: %w", err)
}

_, err = s.connection.ExecContext(ctx, "SET NOCOUNT ON; DECLARE @AccessToken NVARCHAR(MAX) = ?; EXEC sp_set_session_context @key=N'access_token', @value=@AccessToken;", s.azureOAuth.OAuthToken())
if err != nil {
return 0, fmt.Errorf("error setting access token: %w", err)
}
}

err := s.connection.QueryRowContext(ctx, s.metadata.Query).Scan(&value)
switch {
case err == sql.ErrNoRows:
case err == sql.ErrNoRows:
value = 0
case err != nil:
case err != nil:
s.logger.Error(err, fmt.Sprintf("Could not query mssql database: %s", err))
return 0, err
}

return value, nil
}

// Close closes the mssql database connections
func (s *mssqlScaler) Close(context.Context) error {
err := s.connection.Close()
if err != nil {
Expand Down
Loading

0 comments on commit 5a08a0d

Please sign in to comment.