Skip to content

Commit

Permalink
use oauth2
Browse files Browse the repository at this point in the history
Signed-off-by: rickbrouwer <[email protected]>
  • Loading branch information
rickbrouwer committed Sep 16, 2024
1 parent 9eee167 commit a7d0e2c
Showing 1 changed file with 36 additions and 36 deletions.
72 changes: 36 additions & 36 deletions pkg/scalers/mssql_scaler.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ func (m *mssqlMetadata) Validate() error {
return nil
}

func NewMSSQLScaler(config *scalersconfig.ScalerConfig) (Scaler, error) {
func NewMSSQLScaler(ctx context.Context, config *scalersconfig.ScalerConfig) (Scaler, error) {
metricType, err := GetMetricTargetType(config)
if err != nil {
return nil, fmt.Errorf("error getting scaler metric type: %w", err)
Expand All @@ -65,17 +65,20 @@ func NewMSSQLScaler(config *scalersconfig.ScalerConfig) (Scaler, error) {
return nil, fmt.Errorf("error parsing mssql metadata: %w", err)
}

conn, err := newMSSQLConnection(meta, logger)
scaler := &mssqlScaler{
metricType: metricType,
metadata: meta,
logger: logger,
}

conn, err := newMSSQLConnection(ctx, scaler)
if err != nil {
return nil, fmt.Errorf("error establishing mssql connection: %w", err)
}

return &mssqlScaler{
metricType: metricType,
metadata: meta,
connection: conn,
logger: logger,
}, nil
scaler.connection = conn

return scaler, nil
}

func parseMSSQLMetadata(config *scalersconfig.ScalerConfig) (mssqlMetadata, error) {
Expand All @@ -99,25 +102,26 @@ func parseMSSQLMetadata(config *scalersconfig.ScalerConfig) (mssqlMetadata, erro
return meta, nil
}

func newMSSQLConnection(meta mssqlMetadata, logger logr.Logger) (*sql.DB, error) {
connStr := getMSSQLConnectionString(meta)
func newMSSQLConnection(ctx context.Context, s *mssqlScaler) (*sql.DB, error) {
connStr := getMSSQLConnectionString(ctx, s)

db, err := sql.Open("sqlserver", connStr)
if err != nil {
logger.Error(err, fmt.Sprintf("Found error opening mssql: %s", err))
s.logger.Error(err, "Found error opening mssql")
return nil, err
}

err = db.Ping()
if err != nil {
logger.Error(err, fmt.Sprintf("Found error pinging mssql: %s", err))
s.logger.Error(err, "Found error pinging mssql")
return nil, err
}

return db, nil
}

func getMSSQLConnectionString(meta mssqlMetadata) string {
func getMSSQLConnectionString(ctx context.Context, s *mssqlScaler) string {
meta := s.metadata
if meta.ConnectionString != "" {
return meta.ConnectionString
}
Expand All @@ -142,9 +146,28 @@ func getMSSQLConnectionString(meta mssqlMetadata) string {
connectionURL.Host = meta.Host
}

if meta.WorkloadIdentityResource != "" {
token := s.getOAuthToken(ctx)
connectionURL.RawQuery += fmt.Sprintf("&access_token=%s", token)
}

return connectionURL.String()
}

func (s *mssqlScaler) getOAuthToken(ctx context.Context) string {
if s.azureOAuth == nil {
s.azureOAuth = azure.NewAzureADWorkloadIdentityTokenProvider(ctx, s.metadata.WorkloadIdentityClientID, s.metadata.WorkloadIdentityTenantID, s.metadata.WorkloadIdentityAuthorityHost, s.metadata.WorkloadIdentityResource)
}

err := s.azureOAuth.Refresh()
if err != nil {
fmt.Println("Error fetching OAuth token:", err)
return ""
}

return s.azureOAuth.OAuthToken()
}

func (s *mssqlScaler) GetMetricSpecForScaling(context.Context) []v2.MetricSpec {
externalMetric := &v2.ExternalMetricSource{
Metric: v2.MetricIdentifier{
Expand Down Expand Up @@ -174,29 +197,6 @@ func (s *mssqlScaler) GetMetricsAndActivity(ctx context.Context, metricName stri
func (s *mssqlScaler) getQueryResult(ctx context.Context) (float64, error) {
var value float64

// If using Azure Workload Identity, refresh the token
if s.metadata.WorkloadIdentityResource != "" {
if s.azureOAuth == nil {
s.azureOAuth = azure.NewAzureADWorkloadIdentityTokenProvider(ctx, s.metadata.WorkloadIdentityClientID, s.metadata.WorkloadIdentityTenantID, s.metadata.WorkloadIdentityAuthorityHost, s.metadata.WorkloadIdentityResource)
}

err := s.azureOAuth.Refresh()
if err != nil {
return 0, fmt.Errorf("error refreshing Azure AD token: %w", err)
}

// Set the access token for the database connection
err = s.connection.PingContext(ctx)
if err != nil {
return 0, fmt.Errorf("error pinging database: %w", err)
}

_, err = s.connection.ExecContext(ctx, "SET NOCOUNT ON; DECLARE @AccessToken NVARCHAR(MAX) = ?; EXEC sp_set_session_context @key=N'access_token', @value=@AccessToken;", s.azureOAuth.OAuthToken())
if err != nil {
return 0, fmt.Errorf("error setting access token: %w", err)
}
}

err := s.connection.QueryRowContext(ctx, s.metadata.Query).Scan(&value)
switch {
case err == sql.ErrNoRows:
Expand Down

0 comments on commit a7d0e2c

Please sign in to comment.