diff --git a/README.md b/README.md index 71642fe2..c14fd472 100644 --- a/README.md +++ b/README.md @@ -110,3 +110,4 @@ https://golangci-lint.run/welcome/install/#local-installation to see additional | OIDC_TOKEN_VALIDITY_IN_SECONDS | The time in seconds an OIDC token is valid | 3600 | | OIDC_TOKEN_LEEWAY_IN_SECONDS | The leeway to use when validating OIDC tokens | 300 | | SDS_OAUTH2_CLIENT_ID | The OAuth2 Client ID used when setting up IAP on the SDS | | +| CIR_OAUTH2_CLIENT_ID | The OAuth2 Client ID used when setting up IAP on the CIR | | diff --git a/authentication/auth.go b/authentication/auth.go index 8d47b2f6..162695bd 100644 --- a/authentication/auth.go +++ b/authentication/auth.go @@ -6,6 +6,7 @@ import ( "crypto/x509" "encoding/pem" "fmt" + "github.com/ONSdigital/eq-questionnaire-launcher/oidc" "io" "math/rand" "net/http" @@ -573,6 +574,9 @@ func getRequiredSchemaMetadata(launcherSchema surveys.LauncherSchema) ([]Metadat func getSchema(launcherSchema surveys.LauncherSchema) (QuestionnaireSchema, string) { var url string + var schema QuestionnaireSchema + + client := clients.GetHTTPClient() if launcherSchema.URL != "" { url = launcherSchema.URL @@ -581,6 +585,12 @@ func getSchema(launcherSchema surveys.LauncherSchema) (QuestionnaireSchema, stri log.Println("Collection Instrument ID: ", launcherSchema.CIRInstrumentID) url = fmt.Sprintf("%s/v2/retrieve_collection_instrument?guid=%s", hostURL, launcherSchema.CIRInstrumentID) + + _, err := oidc.ConfigureClientAuthentication(client, "CIR_OAUTH2_CLIENT_ID") + if err != nil { + log.Print(err) + return schema, fmt.Sprintf("Unable to generate CIR authentication credentials %s", url) + } } else { hostURL := settings.Get("SURVEY_RUNNER_SCHEMA_URL") @@ -590,8 +600,7 @@ func getSchema(launcherSchema surveys.LauncherSchema) (QuestionnaireSchema, stri log.Println("Loading metadata from schema:", url) - var schema QuestionnaireSchema - resp, err := clients.GetHTTPClient().Get(url) + resp, err := client.Get(url) if err != nil { log.Println("Failed to load schema from:", url) return schema, fmt.Sprintf("Failed to load Schema from %s", url) diff --git a/oidc/oidc.go b/oidc/oidc.go index 456e8434..313ccdf6 100644 --- a/oidc/oidc.go +++ b/oidc/oidc.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "log" + "net/http" "strconv" "time" @@ -15,19 +16,19 @@ import ( "google.golang.org/api/option" ) -func GenerateIdToken() (oauth2.TokenSource, error) { +func generateIdToken(clientIdName string) (oauth2.TokenSource, error) { oidcBackend := settings.Get("OIDC_TOKEN_BACKEND") if oidcBackend == "gcp" { - audience := settings.Get("SDS_OAUTH2_CLIENT_ID") + audience := settings.Get(clientIdName) if audience == "" { - return nil, fmt.Errorf("SDS_OAUTH2_CLIENT_ID not set") + return nil, fmt.Errorf("%s not set", clientIdName) } - return getGCPIdToken(audience) + return getGCPIdToken(audience, clientIdName) } return nil, nil } -func cachedWithTTL(fn func(audience string) (oauth2.TokenSource, error)) func(audience string) (oauth2.TokenSource, error) { +func cachedWithTTL(fn func(audience string, clientIdName string) (oauth2.TokenSource, error)) func(audience string, clientIdName string) (oauth2.TokenSource, error) { validitySeconds, _ := strconv.Atoi(settings.Get("OIDC_TOKEN_VALIDITY_IN_SECONDS")) leewaySeconds, _ := strconv.Atoi(settings.Get("OIDC_TOKEN_LEEWAY_IN_SECONDS")) @@ -35,13 +36,13 @@ func cachedWithTTL(fn func(audience string) (oauth2.TokenSource, error)) func(au // Create cache with default expiration of TTL seconds and cleanup interval of 1 minute ttlCache := cache.New(time.Duration(ttl)*time.Second, time.Minute) - cachedFunc := func(audience string) (oauth2.TokenSource, error) { + cachedFunc := func(audience string, clientIdName string) (oauth2.TokenSource, error) { cachedSource, found := ttlCache.Get(audience) if found { - log.Printf("Found cached GCP ID token source for audience: %s", audience) + log.Printf("Found cached GCP ID token source for %s audience: %s", clientIdName, audience) return cachedSource.(oauth2.TokenSource), nil } - tokenSource, err := fn(audience) + tokenSource, err := fn(audience, clientIdName) if err != nil { return nil, err } @@ -53,27 +54,41 @@ func cachedWithTTL(fn func(audience string) (oauth2.TokenSource, error)) func(au // uses the Google Cloud metadata server environment to create an identity token that can be added to a HTTP request // based off https://cloud.google.com/docs/authentication/get-id-token#go -func getIdTokenFromMetadataServer(audience string) (oauth2.TokenSource, error) { +func getIdTokenFromMetadataServer(audience string, clientIdName string) (oauth2.TokenSource, error) { ctx := context.Background() // Construct the GoogleCredentials object which obtains the default configuration from your working environment. credentials, err := google.FindDefaultCredentials(ctx) if err != nil { - return nil, fmt.Errorf("failed to generate default credentials: %w", err) + return nil, fmt.Errorf("failed to generate default credentials for %s: %w", clientIdName, err) } ts, err := idtoken.NewTokenSource(ctx, audience, option.WithCredentials(credentials)) if err != nil { - return nil, fmt.Errorf("failed to create NewTokenSource: %w", err) + return nil, fmt.Errorf("failed to create NewTokenSource for %s: %w", clientIdName, err) } // Generate the ID token. _, err = ts.Token() if err != nil { - return nil, fmt.Errorf("failed to receive token: %w", err) + return nil, fmt.Errorf("failed to receive token for %s: %w", clientIdName, err) } - log.Printf("Succesfully generated GCP ID token for audience: %s", audience) + log.Printf("Successfully generated GCP ID token for %s audience: %s", clientIdName, audience) return ts, nil } +func ConfigureClientAuthentication(client *http.Client, clientIdName string) (*http.Client, error) { + tokenSource, err := generateIdToken(clientIdName) + if err != nil { + return client, err + } + + if tokenSource != nil { + client.Transport = &oauth2.Transport{ + Source: tokenSource, + } + } + return client, nil +} + var getGCPIdToken = cachedWithTTL(getIdTokenFromMetadataServer) diff --git a/settings/settings.go b/settings/settings.go index 3ce354ea..c9e086f9 100644 --- a/settings/settings.go +++ b/settings/settings.go @@ -28,6 +28,7 @@ func init() { setSetting("OIDC_TOKEN_LEEWAY_IN_SECONDS", "300") setSetting("OIDC_TOKEN_BACKEND", "local") setSetting("SDS_OAUTH2_CLIENT_ID", "") + setSetting("CIR_OAUTH2_CLIENT_ID", "") } // Get returns the value for the specified named setting diff --git a/surveys/surveys.go b/surveys/surveys.go index 8c6f776d..c3025e7b 100644 --- a/surveys/surveys.go +++ b/surveys/surveys.go @@ -3,17 +3,15 @@ package surveys import ( "encoding/json" "errors" - "golang.org/x/text/cases" - "golang.org/x/text/language" - "io" - "log" - "fmt" "github.com/AreaHQ/jsonhal" "github.com/ONSdigital/eq-questionnaire-launcher/clients" "github.com/ONSdigital/eq-questionnaire-launcher/oidc" "github.com/ONSdigital/eq-questionnaire-launcher/settings" - "golang.org/x/oauth2" + "golang.org/x/text/cases" + "golang.org/x/text/language" + "io" + "log" "sort" ) @@ -143,16 +141,19 @@ func getAvailableSchemasFromRegister() []LauncherSchema { } func GetAvailableSchemasFromCIR() []CIMetadata { - ciMetadataList := []CIMetadata{} - hostURL := settings.Get("CIR_API_BASE_URL") - log.Printf("CIR API Base URL: %s", hostURL) + client, err := oidc.ConfigureClientAuthentication(clients.GetHTTPClient(), "CIR_OAUTH2_CLIENT_ID") + if err != nil { + log.Print(err) + return ciMetadataList + } + log.Printf("CIR API Base URL: %s", hostURL) url := fmt.Sprintf("%s/v2/ci_metadata", hostURL) - resp, err := clients.GetHTTPClient().Get(url) + resp, err := client.Get(url) if err != nil || resp.StatusCode != 200 { log.Print(err) return ciMetadataList @@ -236,18 +237,10 @@ func GetSupplementaryDataSets(surveyId string, periodId string) ([]DatasetMetada datasetList := []DatasetMetadata{} hostURL := settings.Get("SDS_API_BASE_URL") - client := clients.GetHTTPClient() - tokenSource, err := oidc.GenerateIdToken() - + client, err := oidc.ConfigureClientAuthentication(clients.GetHTTPClient(), "SDS_OAUTH2_CLIENT_ID") if err != nil { log.Print(err) - return datasetList, errors.New("unable to generate authentication credentials") - } - - if tokenSource != nil { - client.Transport = &oauth2.Transport{ - Source: tokenSource, - } + return datasetList, errors.New("unable to generate SDS authentication credentials") } log.Printf("SDS API Base URL: %s", hostURL)