Skip to content

Commit

Permalink
Merge pull request #2702 from openziti/add-verify-ext-jwt-oidc
Browse files Browse the repository at this point in the history
Add verify ext jwt OIDC
  • Loading branch information
dovholuknf authored Jan 29, 2025
2 parents 86ef90e + f4fcebc commit ec06ce2
Show file tree
Hide file tree
Showing 15 changed files with 769 additions and 121 deletions.
2 changes: 1 addition & 1 deletion controller/oidc_auth/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ const (
AuthMethodSecondaryExtJwt = "ejs"
)

// NewNativeOnlyOP creates an OIDC Provider that allows native clients and only the AutCode PKCE flow.
// NewNativeOnlyOP creates an OIDC Provider that allows native clients and only the AuthCode PKCE flow.
func NewNativeOnlyOP(ctx context.Context, env model.Env, config Config) (http.Handler, error) {
cert, kid, method := env.GetServerCert()
config.Storage = NewStorage(kid, cert.Leaf.PublicKey, cert.PrivateKey, method, &config, env)
Expand Down
122 changes: 122 additions & 0 deletions internal/cobra/cobra-utils.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
/*
Copyright NetFoundry Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
https://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package cobra

import (
"fmt"
"github.com/openziti/ziti/ziti/cmd/consts"
"github.com/spf13/cobra"
"github.com/spf13/pflag"
)

func AddFlagAnnotation(cmd *cobra.Command, flagName string, key string, value string) error {
flag := cmd.Flags().Lookup(flagName)
if flag == nil {
return fmt.Errorf("flag %q not found", flagName)
}
if flag.Annotations == nil {
flag.Annotations = make(map[string][]string)
}
flag.Annotations[key] = append(flag.Annotations[key], value)
return nil
}

func GetFlagsForAnnotation(cmd *cobra.Command, annotation string) string {
var result string
maxLength := MaxFlagNameLength(cmd)
cmd.Flags().VisitAll(func(flag *pflag.Flag) {
if flag.Annotations[annotation] != nil {
if flag.Shorthand != "" {
result += fmt.Sprintf(" -%s, --%-*s %s\n", flag.Shorthand, maxLength, flag.Name, flag.Usage)
} else {
result += fmt.Sprintf(" --%-*s %s\n", maxLength, flag.Name, flag.Usage)
}
}
})

return result
}

func GetFlagsWithoutAnnotations(cmd *cobra.Command, annotations ...string) string {
var result string
maxLength := MaxFlagNameLength(cmd)

// Create a map of the provided annotations for quick lookup
annotationMap := make(map[string]bool)
for _, annotation := range annotations {
annotationMap[annotation] = true
}

cmd.Flags().VisitAll(func(flag *pflag.Flag) {
hasAnnotation := false
for ann := range flag.Annotations {
if annotationMap[ann] {
hasAnnotation = true
break
}
}
if !hasAnnotation {
if flag.Shorthand != "" {
result += fmt.Sprintf(" -%s, --%-*s %s\n", flag.Shorthand, maxLength, flag.Name, flag.Usage)
} else {
result += fmt.Sprintf(" --%-*s %s\n", maxLength, flag.Name, flag.Usage)
}
}
})

return result
}

func MaxFlagNameLength(cmd *cobra.Command) int {
// Calculate the maximum flag length across ALL flags
maxLength := 0
cmd.Flags().VisitAll(func(flag *pflag.Flag) {
length := len(flag.Name)
if flag.Shorthand != "" {
length += 4 // Account for "-x, "
}
if length > maxLength {
maxLength = length
}
})
return maxLength
}

func SetHelpTemplate(cmd *cobra.Command) {
l := GetFlagsForAnnotation(cmd, cmdconsts.LoginFlagKey)
c := GetFlagsForAnnotation(cmd, cmdconsts.CommonFlagKey)
u := GetFlagsWithoutAnnotations(cmd, cmdconsts.LoginFlagKey, cmdconsts.CommonFlagKey)

cmd.SetHelpTemplate(`{{.Long}}
Usage:
{{.UseLine}}
Available Commands:
{{range .Commands}}{{if (or .IsAvailableCommand (eq .Name "help"))}}
{{rpad .Name .NamePadding }} {{.Short}}{{end}}{{end}}
Flags:
` + u + `
Flags related to logging in:
` + l + `
Common flags for all commands:
` + c + `
Use "{{.CommandPath}} [command] --help" for more information about a command.
`)
}
43 changes: 43 additions & 0 deletions internal/rest/client/helper.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
/*
Copyright NetFoundry Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
https://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package client

import (
"context"
"github.com/michaelquigley/pfxlog"
"github.com/openziti/edge-api/rest_client_api_client"
"github.com/openziti/edge-api/rest_client_api_client/external_jwt_signer"
"github.com/openziti/edge-api/rest_model"
internalconsts "github.com/openziti/ziti/internal/rest/consts"
)

func ExternalJWTSignerFromFilter(client *rest_client_api_client.ZitiEdgeClient, filter string) *rest_model.ClientExternalJWTSignerDetail {
params := &external_jwt_signer.ListExternalJWTSignersParams{
Filter: &filter,
Context: context.Background(),
}
params.SetTimeout(internalconsts.DefaultTimeout)
resp, err := client.ExternalJWTSigner.ListExternalJWTSigners(params)
if err != nil {
pfxlog.Logger().Errorf("Could not obtain an ID for the external jwt signer with filter %s: %v", filter, err)
return nil
}
if resp == nil || resp.Payload == nil || resp.Payload.Data == nil || len(resp.Payload.Data) == 0 {
return nil
}
return resp.Payload.Data[0]
}
7 changes: 7 additions & 0 deletions internal/rest/consts/consts.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
package internal_consts

import "time"

const (
DefaultTimeout = 5 * time.Second
)
87 changes: 13 additions & 74 deletions internal/rest/mgmt/helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,11 @@ package mgmt

import (
"context"
"crypto/tls"
"crypto/x509"
"errors"
"fmt"
"github.com/openziti/edge-api/rest_management_api_client"
"github.com/openziti/edge-api/rest_management_api_client/auth_policy"
"github.com/openziti/edge-api/rest_management_api_client/certificate_authority"
"github.com/openziti/edge-api/rest_management_api_client/config"
rest_mgmt "github.com/openziti/edge-api/rest_management_api_client/current_api_session"
"github.com/openziti/edge-api/rest_management_api_client/edge_router"
"github.com/openziti/edge-api/rest_management_api_client/edge_router_policy"
"github.com/openziti/edge-api/rest_management_api_client/external_jwt_signer"
Expand All @@ -36,24 +32,16 @@ import (
"github.com/openziti/edge-api/rest_management_api_client/service_edge_router_policy"
"github.com/openziti/edge-api/rest_management_api_client/service_policy"
"github.com/openziti/edge-api/rest_model"
"github.com/openziti/edge-api/rest_util"
"github.com/openziti/ziti/ziti/util"
"github.com/openziti/ziti/internal/rest/consts"
log "github.com/sirupsen/logrus"
"net/http"
"os"
"time"
)

const (
DefaultTimeout = 5 * time.Second
)

func IdentityFromFilter(client *rest_management_api_client.ZitiEdgeManagement, filter string) *rest_model.IdentityDetail {
params := &identity.ListIdentitiesParams{
Filter: &filter,
Context: context.Background(),
}
params.SetTimeout(DefaultTimeout)
params.SetTimeout(internal_consts.DefaultTimeout)
resp, err := client.Identity.ListIdentities(params, nil)
if err != nil {
log.Debugf("Could not obtain an ID for the identity with filter %s: %v", filter, err)
Expand All @@ -71,7 +59,7 @@ func ServiceFromFilter(client *rest_management_api_client.ZitiEdgeManagement, fi
Filter: &filter,
Context: context.Background(),
}
params.SetTimeout(DefaultTimeout)
params.SetTimeout(internal_consts.DefaultTimeout)
resp, err := client.Service.ListServices(params, nil)
if err != nil {
log.Debugf("Could not obtain an ID for the service with filter %s: %v", filter, err)
Expand All @@ -88,7 +76,7 @@ func ServicePolicyFromFilter(client *rest_management_api_client.ZitiEdgeManageme
Filter: &filter,
Context: context.Background(),
}
params.SetTimeout(DefaultTimeout)
params.SetTimeout(internal_consts.DefaultTimeout)
resp, err := client.ServicePolicy.ListServicePolicies(params, nil)
if err != nil {
log.Errorf("Could not obtain an ID for the service policy with filter %s: %v", filter, err)
Expand All @@ -105,7 +93,7 @@ func AuthPolicyFromFilter(client *rest_management_api_client.ZitiEdgeManagement,
Filter: &filter,
Context: context.Background(),
}
params.SetTimeout(DefaultTimeout)
params.SetTimeout(internal_consts.DefaultTimeout)
resp, err := client.AuthPolicy.ListAuthPolicies(params, nil)
if err != nil {
log.Errorf("Could not obtain an ID for the auth policy with filter %s: %v", filter, err)
Expand All @@ -122,7 +110,7 @@ func CertificateAuthorityFromFilter(client *rest_management_api_client.ZitiEdgeM
Filter: &filter,
Context: context.Background(),
}
params.SetTimeout(DefaultTimeout)
params.SetTimeout(internal_consts.DefaultTimeout)
resp, err := client.CertificateAuthority.ListCas(params, nil)
if err != nil {
log.Errorf("Could not obtain an ID for the certificate authority with filter %s: %v", filter, err)
Expand All @@ -139,7 +127,7 @@ func ConfigTypeFromFilter(client *rest_management_api_client.ZitiEdgeManagement,
Filter: &filter,
Context: context.Background(),
}
params.SetTimeout(DefaultTimeout)
params.SetTimeout(internal_consts.DefaultTimeout)
resp, err := client.Config.ListConfigTypes(params, nil)
if err != nil {
log.Errorf("Could not obtain an ID for the config type with filter %s: %v", filter, err)
Expand All @@ -156,7 +144,7 @@ func ConfigFromFilter(client *rest_management_api_client.ZitiEdgeManagement, fil
Filter: &filter,
Context: context.Background(),
}
params.SetTimeout(DefaultTimeout)
params.SetTimeout(internal_consts.DefaultTimeout)
resp, err := client.Config.ListConfigs(params, nil)
if err != nil {
log.Errorf("Could not obtain an ID for the config with filter %s: %v", filter, err)
Expand All @@ -173,7 +161,7 @@ func ExternalJWTSignerFromFilter(client *rest_management_api_client.ZitiEdgeMana
Filter: &filter,
Context: context.Background(),
}
params.SetTimeout(DefaultTimeout)
params.SetTimeout(internal_consts.DefaultTimeout)
resp, err := client.ExternalJWTSigner.ListExternalJWTSigners(params, nil)
if err != nil {
log.Errorf("Could not obtain an ID for the external jwt signer with filter %s: %v", filter, err)
Expand All @@ -190,7 +178,7 @@ func PostureCheckFromFilter(client *rest_management_api_client.ZitiEdgeManagemen
Filter: &filter,
Context: context.Background(),
}
params.SetTimeout(DefaultTimeout)
params.SetTimeout(internal_consts.DefaultTimeout)
resp, err := client.PostureChecks.ListPostureChecks(params, nil)
if err != nil {
log.Errorf("Could not obtain an ID for the posture check with filter %s: %v", filter, err)
Expand All @@ -206,7 +194,7 @@ func EdgeRouterPolicyFromFilter(client *rest_management_api_client.ZitiEdgeManag
params := &edge_router_policy.ListEdgeRouterPoliciesParams{
Filter: &filter,
}
params.SetTimeout(DefaultTimeout)
params.SetTimeout(internal_consts.DefaultTimeout)
resp, err := client.EdgeRouterPolicy.ListEdgeRouterPolicies(params, nil)
if err != nil {
log.Errorf("Could not obtain an ID for the edge router policies with filter %s: %v", filter, err)
Expand All @@ -222,7 +210,7 @@ func EdgeRouterFromFilter(client *rest_management_api_client.ZitiEdgeManagement,
params := &edge_router.ListEdgeRoutersParams{
Filter: &filter,
}
params.SetTimeout(DefaultTimeout)
params.SetTimeout(internal_consts.DefaultTimeout)
resp, err := client.EdgeRouter.ListEdgeRouters(params, nil)
if err != nil {
log.Errorf("Could not obtain an ID for the edge routers with filter %s: %v", filter, err)
Expand All @@ -238,7 +226,7 @@ func ServiceEdgeRouterPolicyFromFilter(client *rest_management_api_client.ZitiEd
params := &service_edge_router_policy.ListServiceEdgeRouterPoliciesParams{
Filter: &filter,
}
params.SetTimeout(DefaultTimeout)
params.SetTimeout(internal_consts.DefaultTimeout)
resp, err := client.ServiceEdgeRouterPolicy.ListServiceEdgeRouterPolicies(params, nil)
if err != nil {
log.Errorf("Could not obtain an ID for the ServiceEdgeRouterPolicy routers with filter %s: %v", filter, err)
Expand All @@ -253,52 +241,3 @@ func ServiceEdgeRouterPolicyFromFilter(client *rest_management_api_client.ZitiEd
func NameFilter(name string) string {
return fmt.Sprintf("name = \"%s\"", name)
}

func NewClient() (*rest_management_api_client.ZitiEdgeManagement, error) {
cachedCreds, _, loadErr := util.LoadRestClientConfig()
if loadErr != nil {
return nil, loadErr
}

cachedId := cachedCreds.EdgeIdentities[cachedCreds.Default] //only support default for now
if cachedId == nil {
return nil, errors.New("no identity found")
}

caPool := x509.NewCertPool()
if _, cacertErr := os.Stat(cachedId.CaCert); cacertErr == nil {
rootPemData, err := os.ReadFile(cachedId.CaCert)
if err != nil {
return nil, err
}
caPool.AppendCertsFromPEM(rootPemData)
} else {
return nil, errors.New("CA cert file not found in config file")
}

tlsConfig := &tls.Config{
RootCAs: caPool,
}

transport := &http.Transport{
TLSClientConfig: tlsConfig,
}

// Assign the transport to the default HTTP client
http.DefaultClient = &http.Client{
Transport: transport,
}
c, e := rest_util.NewEdgeManagementClientWithToken(http.DefaultClient, cachedId.Url, cachedId.Token)
if e != nil {
return nil, e
}

apiSessionParams := &rest_mgmt.GetCurrentAPISessionParams{
Context: context.Background(),
}
_, authErr := c.CurrentAPISession.GetCurrentAPISession(apiSessionParams, nil)
if authErr != nil {
return nil, errors.New("client not authenticated. login with 'ziti edge login' before executing")
}
return c, nil
}
Loading

0 comments on commit ec06ce2

Please sign in to comment.