From a7d0e2c127d6c8ed4811d38ca8bb3308f21ece3e Mon Sep 17 00:00:00 2001 From: rickbrouwer Date: Mon, 16 Sep 2024 15:20:12 +0200 Subject: [PATCH] use oauth2 Signed-off-by: rickbrouwer --- pkg/scalers/mssql_scaler.go | 72 ++++++++++++++++++------------------- 1 file changed, 36 insertions(+), 36 deletions(-) diff --git a/pkg/scalers/mssql_scaler.go b/pkg/scalers/mssql_scaler.go index f6274cf916f..0b80512b248 100644 --- a/pkg/scalers/mssql_scaler.go +++ b/pkg/scalers/mssql_scaler.go @@ -52,7 +52,7 @@ func (m *mssqlMetadata) Validate() error { return nil } -func NewMSSQLScaler(config *scalersconfig.ScalerConfig) (Scaler, error) { +func NewMSSQLScaler(ctx context.Context, config *scalersconfig.ScalerConfig) (Scaler, error) { metricType, err := GetMetricTargetType(config) if err != nil { return nil, fmt.Errorf("error getting scaler metric type: %w", err) @@ -65,17 +65,20 @@ func NewMSSQLScaler(config *scalersconfig.ScalerConfig) (Scaler, error) { return nil, fmt.Errorf("error parsing mssql metadata: %w", err) } - conn, err := newMSSQLConnection(meta, logger) + scaler := &mssqlScaler{ + metricType: metricType, + metadata: meta, + logger: logger, + } + + conn, err := newMSSQLConnection(ctx, scaler) if err != nil { return nil, fmt.Errorf("error establishing mssql connection: %w", err) } - return &mssqlScaler{ - metricType: metricType, - metadata: meta, - connection: conn, - logger: logger, - }, nil + scaler.connection = conn + + return scaler, nil } func parseMSSQLMetadata(config *scalersconfig.ScalerConfig) (mssqlMetadata, error) { @@ -99,25 +102,26 @@ func parseMSSQLMetadata(config *scalersconfig.ScalerConfig) (mssqlMetadata, erro return meta, nil } -func newMSSQLConnection(meta mssqlMetadata, logger logr.Logger) (*sql.DB, error) { - connStr := getMSSQLConnectionString(meta) +func newMSSQLConnection(ctx context.Context, s *mssqlScaler) (*sql.DB, error) { + connStr := getMSSQLConnectionString(ctx, s) db, err := sql.Open("sqlserver", connStr) if err != nil { - logger.Error(err, fmt.Sprintf("Found error opening mssql: %s", err)) + s.logger.Error(err, "Found error opening mssql") return nil, err } err = db.Ping() if err != nil { - logger.Error(err, fmt.Sprintf("Found error pinging mssql: %s", err)) + s.logger.Error(err, "Found error pinging mssql") return nil, err } return db, nil } -func getMSSQLConnectionString(meta mssqlMetadata) string { +func getMSSQLConnectionString(ctx context.Context, s *mssqlScaler) string { + meta := s.metadata if meta.ConnectionString != "" { return meta.ConnectionString } @@ -142,9 +146,28 @@ func getMSSQLConnectionString(meta mssqlMetadata) string { connectionURL.Host = meta.Host } + if meta.WorkloadIdentityResource != "" { + token := s.getOAuthToken(ctx) + connectionURL.RawQuery += fmt.Sprintf("&access_token=%s", token) + } + return connectionURL.String() } +func (s *mssqlScaler) getOAuthToken(ctx context.Context) string { + if s.azureOAuth == nil { + s.azureOAuth = azure.NewAzureADWorkloadIdentityTokenProvider(ctx, s.metadata.WorkloadIdentityClientID, s.metadata.WorkloadIdentityTenantID, s.metadata.WorkloadIdentityAuthorityHost, s.metadata.WorkloadIdentityResource) + } + + err := s.azureOAuth.Refresh() + if err != nil { + fmt.Println("Error fetching OAuth token:", err) + return "" + } + + return s.azureOAuth.OAuthToken() +} + func (s *mssqlScaler) GetMetricSpecForScaling(context.Context) []v2.MetricSpec { externalMetric := &v2.ExternalMetricSource{ Metric: v2.MetricIdentifier{ @@ -174,29 +197,6 @@ func (s *mssqlScaler) GetMetricsAndActivity(ctx context.Context, metricName stri func (s *mssqlScaler) getQueryResult(ctx context.Context) (float64, error) { var value float64 - // If using Azure Workload Identity, refresh the token - if s.metadata.WorkloadIdentityResource != "" { - if s.azureOAuth == nil { - s.azureOAuth = azure.NewAzureADWorkloadIdentityTokenProvider(ctx, s.metadata.WorkloadIdentityClientID, s.metadata.WorkloadIdentityTenantID, s.metadata.WorkloadIdentityAuthorityHost, s.metadata.WorkloadIdentityResource) - } - - err := s.azureOAuth.Refresh() - if err != nil { - return 0, fmt.Errorf("error refreshing Azure AD token: %w", err) - } - - // Set the access token for the database connection - err = s.connection.PingContext(ctx) - if err != nil { - return 0, fmt.Errorf("error pinging database: %w", err) - } - - _, err = s.connection.ExecContext(ctx, "SET NOCOUNT ON; DECLARE @AccessToken NVARCHAR(MAX) = ?; EXEC sp_set_session_context @key=N'access_token', @value=@AccessToken;", s.azureOAuth.OAuthToken()) - if err != nil { - return 0, fmt.Errorf("error setting access token: %w", err) - } - } - err := s.connection.QueryRowContext(ctx, s.metadata.Query).Scan(&value) switch { case err == sql.ErrNoRows: