Skip to content

Commit

Permalink
Authenticate with CIR using OIDC (#72)
Browse files Browse the repository at this point in the history
  • Loading branch information
liamtoozer authored May 8, 2024
1 parent ac0f892 commit e249f5c
Show file tree
Hide file tree
Showing 5 changed files with 54 additions and 35 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -110,3 +110,4 @@ https://golangci-lint.run/welcome/install/#local-installation to see additional
| OIDC_TOKEN_VALIDITY_IN_SECONDS | The time in seconds an OIDC token is valid | 3600 |
| OIDC_TOKEN_LEEWAY_IN_SECONDS | The leeway to use when validating OIDC tokens | 300 |
| SDS_OAUTH2_CLIENT_ID | The OAuth2 Client ID used when setting up IAP on the SDS | |
| CIR_OAUTH2_CLIENT_ID | The OAuth2 Client ID used when setting up IAP on the CIR | |
13 changes: 11 additions & 2 deletions authentication/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"crypto/x509"
"encoding/pem"
"fmt"
"github.com/ONSdigital/eq-questionnaire-launcher/oidc"
"io"
"math/rand"
"net/http"
Expand Down Expand Up @@ -573,6 +574,9 @@ func getRequiredSchemaMetadata(launcherSchema surveys.LauncherSchema) ([]Metadat

func getSchema(launcherSchema surveys.LauncherSchema) (QuestionnaireSchema, string) {
var url string
var schema QuestionnaireSchema

client := clients.GetHTTPClient()

if launcherSchema.URL != "" {
url = launcherSchema.URL
Expand All @@ -581,6 +585,12 @@ func getSchema(launcherSchema surveys.LauncherSchema) (QuestionnaireSchema, stri

log.Println("Collection Instrument ID: ", launcherSchema.CIRInstrumentID)
url = fmt.Sprintf("%s/v2/retrieve_collection_instrument?guid=%s", hostURL, launcherSchema.CIRInstrumentID)

_, err := oidc.ConfigureClientAuthentication(client, "CIR_OAUTH2_CLIENT_ID")
if err != nil {
log.Print(err)
return schema, fmt.Sprintf("Unable to generate CIR authentication credentials %s", url)
}
} else {
hostURL := settings.Get("SURVEY_RUNNER_SCHEMA_URL")

Expand All @@ -590,8 +600,7 @@ func getSchema(launcherSchema surveys.LauncherSchema) (QuestionnaireSchema, stri

log.Println("Loading metadata from schema:", url)

var schema QuestionnaireSchema
resp, err := clients.GetHTTPClient().Get(url)
resp, err := client.Get(url)
if err != nil {
log.Println("Failed to load schema from:", url)
return schema, fmt.Sprintf("Failed to load Schema from %s", url)
Expand Down
41 changes: 28 additions & 13 deletions oidc/oidc.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"fmt"
"log"
"net/http"
"strconv"
"time"

Expand All @@ -15,33 +16,33 @@ import (
"google.golang.org/api/option"
)

func GenerateIdToken() (oauth2.TokenSource, error) {
func generateIdToken(clientIdName string) (oauth2.TokenSource, error) {
oidcBackend := settings.Get("OIDC_TOKEN_BACKEND")
if oidcBackend == "gcp" {
audience := settings.Get("SDS_OAUTH2_CLIENT_ID")
audience := settings.Get(clientIdName)
if audience == "" {
return nil, fmt.Errorf("SDS_OAUTH2_CLIENT_ID not set")
return nil, fmt.Errorf("%s not set", clientIdName)
}
return getGCPIdToken(audience)
return getGCPIdToken(audience, clientIdName)
}
return nil, nil
}

func cachedWithTTL(fn func(audience string) (oauth2.TokenSource, error)) func(audience string) (oauth2.TokenSource, error) {
func cachedWithTTL(fn func(audience string, clientIdName string) (oauth2.TokenSource, error)) func(audience string, clientIdName string) (oauth2.TokenSource, error) {
validitySeconds, _ := strconv.Atoi(settings.Get("OIDC_TOKEN_VALIDITY_IN_SECONDS"))
leewaySeconds, _ := strconv.Atoi(settings.Get("OIDC_TOKEN_LEEWAY_IN_SECONDS"))

ttl := validitySeconds - leewaySeconds
// Create cache with default expiration of TTL seconds and cleanup interval of 1 minute
ttlCache := cache.New(time.Duration(ttl)*time.Second, time.Minute)

cachedFunc := func(audience string) (oauth2.TokenSource, error) {
cachedFunc := func(audience string, clientIdName string) (oauth2.TokenSource, error) {
cachedSource, found := ttlCache.Get(audience)
if found {
log.Printf("Found cached GCP ID token source for audience: %s", audience)
log.Printf("Found cached GCP ID token source for %s audience: %s", clientIdName, audience)
return cachedSource.(oauth2.TokenSource), nil
}
tokenSource, err := fn(audience)
tokenSource, err := fn(audience, clientIdName)
if err != nil {
return nil, err
}
Expand All @@ -53,27 +54,41 @@ func cachedWithTTL(fn func(audience string) (oauth2.TokenSource, error)) func(au

// uses the Google Cloud metadata server environment to create an identity token that can be added to a HTTP request
// based off https://cloud.google.com/docs/authentication/get-id-token#go
func getIdTokenFromMetadataServer(audience string) (oauth2.TokenSource, error) {
func getIdTokenFromMetadataServer(audience string, clientIdName string) (oauth2.TokenSource, error) {
ctx := context.Background()
// Construct the GoogleCredentials object which obtains the default configuration from your working environment.
credentials, err := google.FindDefaultCredentials(ctx)
if err != nil {
return nil, fmt.Errorf("failed to generate default credentials: %w", err)
return nil, fmt.Errorf("failed to generate default credentials for %s: %w", clientIdName, err)
}

ts, err := idtoken.NewTokenSource(ctx, audience, option.WithCredentials(credentials))
if err != nil {
return nil, fmt.Errorf("failed to create NewTokenSource: %w", err)
return nil, fmt.Errorf("failed to create NewTokenSource for %s: %w", clientIdName, err)
}

// Generate the ID token.
_, err = ts.Token()
if err != nil {
return nil, fmt.Errorf("failed to receive token: %w", err)
return nil, fmt.Errorf("failed to receive token for %s: %w", clientIdName, err)
}
log.Printf("Succesfully generated GCP ID token for audience: %s", audience)
log.Printf("Successfully generated GCP ID token for %s audience: %s", clientIdName, audience)

return ts, nil
}

func ConfigureClientAuthentication(client *http.Client, clientIdName string) (*http.Client, error) {
tokenSource, err := generateIdToken(clientIdName)
if err != nil {
return client, err
}

if tokenSource != nil {
client.Transport = &oauth2.Transport{
Source: tokenSource,
}
}
return client, nil
}

var getGCPIdToken = cachedWithTTL(getIdTokenFromMetadataServer)
1 change: 1 addition & 0 deletions settings/settings.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ func init() {
setSetting("OIDC_TOKEN_LEEWAY_IN_SECONDS", "300")
setSetting("OIDC_TOKEN_BACKEND", "local")
setSetting("SDS_OAUTH2_CLIENT_ID", "")
setSetting("CIR_OAUTH2_CLIENT_ID", "")
}

// Get returns the value for the specified named setting
Expand Down
33 changes: 13 additions & 20 deletions surveys/surveys.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,15 @@ package surveys
import (
"encoding/json"
"errors"
"golang.org/x/text/cases"
"golang.org/x/text/language"
"io"
"log"

"fmt"
"github.com/AreaHQ/jsonhal"
"github.com/ONSdigital/eq-questionnaire-launcher/clients"
"github.com/ONSdigital/eq-questionnaire-launcher/oidc"
"github.com/ONSdigital/eq-questionnaire-launcher/settings"
"golang.org/x/oauth2"
"golang.org/x/text/cases"
"golang.org/x/text/language"
"io"
"log"
"sort"
)

Expand Down Expand Up @@ -143,16 +141,19 @@ func getAvailableSchemasFromRegister() []LauncherSchema {
}

func GetAvailableSchemasFromCIR() []CIMetadata {

ciMetadataList := []CIMetadata{}

hostURL := settings.Get("CIR_API_BASE_URL")

log.Printf("CIR API Base URL: %s", hostURL)
client, err := oidc.ConfigureClientAuthentication(clients.GetHTTPClient(), "CIR_OAUTH2_CLIENT_ID")
if err != nil {
log.Print(err)
return ciMetadataList
}

log.Printf("CIR API Base URL: %s", hostURL)
url := fmt.Sprintf("%s/v2/ci_metadata", hostURL)

resp, err := clients.GetHTTPClient().Get(url)
resp, err := client.Get(url)
if err != nil || resp.StatusCode != 200 {
log.Print(err)
return ciMetadataList
Expand Down Expand Up @@ -236,18 +237,10 @@ func GetSupplementaryDataSets(surveyId string, periodId string) ([]DatasetMetada
datasetList := []DatasetMetadata{}
hostURL := settings.Get("SDS_API_BASE_URL")

client := clients.GetHTTPClient()
tokenSource, err := oidc.GenerateIdToken()

client, err := oidc.ConfigureClientAuthentication(clients.GetHTTPClient(), "SDS_OAUTH2_CLIENT_ID")
if err != nil {
log.Print(err)
return datasetList, errors.New("unable to generate authentication credentials")
}

if tokenSource != nil {
client.Transport = &oauth2.Transport{
Source: tokenSource,
}
return datasetList, errors.New("unable to generate SDS authentication credentials")
}

log.Printf("SDS API Base URL: %s", hostURL)
Expand Down

0 comments on commit e249f5c

Please sign in to comment.