Skip to content

Commit

Permalink
Use provider and project from context (#2096)
Browse files Browse the repository at this point in the history
- Remove or deprecate provider and project proto fields
- Add additional HTTP mapping for endpoints containing provider

Prerequisite for #2078
  • Loading branch information
eleftherias authored Jan 10, 2024
1 parent 56f795b commit f6ed42c
Show file tree
Hide file tree
Showing 16 changed files with 2,814 additions and 1,538 deletions.
5 changes: 1 addition & 4 deletions cmd/cli/app/artifact/artifact_list.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,10 +65,7 @@ func listCommand(ctx context.Context, cmd *cobra.Command, conn *grpc.ClientConn)

artifactList, err := client.ListArtifacts(ctx, &minderv1.ListArtifactsRequest{
Context: &minderv1.Context{Provider: &provider, Project: &project},
// keep those until we decide to delete them from the payload and leave the context only
Provider: provider,
ProjectId: project,
From: fromFilter,
From: fromFilter,
},
)

Expand Down
19 changes: 5 additions & 14 deletions cmd/cli/app/provider/provider_enroll.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,10 +94,7 @@ func EnrollProviderCommand(ctx context.Context, cmd *cobra.Command, conn *grpc.C
if token != "" {
// use pat for enrollment
_, err := client.StoreProviderToken(context.Background(), &minderv1.StoreProviderTokenRequest{
Context: &minderv1.Context{Provider: &provider, Project: &project},
// keep those until we decide to delete them from the payload and leave the context only
Provider: provider,
ProjectId: project,
Context: &minderv1.Context{Provider: &provider, Project: &project},
AccessToken: token,
Owner: &owner,
})
Expand All @@ -117,12 +114,9 @@ func EnrollProviderCommand(ctx context.Context, cmd *cobra.Command, conn *grpc.C

resp, err := client.GetAuthorizationURL(ctx, &minderv1.GetAuthorizationURLRequest{
Context: &minderv1.Context{Provider: &provider, Project: &project},
// keep those until we decide to delete them from the payload and leave the context only
Provider: provider,
ProjectId: project,
Cli: true,
Port: int32(port),
Owner: &owner,
Cli: true,
Port: int32(port),
Owner: &owner,
})
if err != nil {
return cli.MessageAndError("error getting authorization URL", err)
Expand Down Expand Up @@ -196,10 +190,7 @@ func callBackServer(ctx context.Context, cmd *cobra.Command, provider string, pr

// todo: check if token has been created. We need an endpoint to pass an state and check if token is created
res, err := client.VerifyProviderTokenFrom(clientCtx, &minderv1.VerifyProviderTokenFromRequest{
Context: &minderv1.Context{Provider: &provider, Project: &project},
// keep those until we decide to delete them from the payload and leave the context only
Provider: provider,
ProjectId: project,
Context: &minderv1.Context{Provider: &provider, Project: &project},
Timestamp: timestamppb.New(t),
})
if err == nil && res.Status == "OK" {
Expand Down
3 changes: 0 additions & 3 deletions cmd/cli/app/quickstart/quickstart.go
Original file line number Diff line number Diff line change
Expand Up @@ -210,9 +210,6 @@ func quickstartCommand(_ context.Context, cmd *cobra.Command, conn *grpc.ClientC
// Get the list of all registered repositories
listResp, err := repoClient.ListRepositories(ctx, &minderv1.ListRepositoriesRequest{
Context: &minderv1.Context{Provider: &provider, Project: &project},
// keep this until we decide to delete them from the payload and rely only on the context
Provider: provider,
ProjectId: project,
})
if err != nil {
return cli.MessageAndError("Error getting list of repos", err)
Expand Down
4 changes: 1 addition & 3 deletions cmd/cli/app/repo/repo_delete.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,7 @@ func deleteCommand(ctx context.Context, cmd *cobra.Command, conn *grpc.ClientCon
// delete repo by name
resp, err := client.DeleteRepositoryByName(ctx, &minderv1.DeleteRepositoryByNameRequest{
Context: &minderv1.Context{Provider: &provider, Project: &project},
// keep this until we decide to remove it from the proto and rely on the context instead
Provider: provider,
Name: name,
Name: name,
})
if err != nil {
return cli.MessageAndError("Error deleting repo by name", err)
Expand Down
4 changes: 1 addition & 3 deletions cmd/cli/app/repo/repo_get.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,7 @@ func getCommand(ctx context.Context, cmd *cobra.Command, conn *grpc.ClientConn)
// check repo by name
resp, err := client.GetRepositoryByName(ctx, &minderv1.GetRepositoryByNameRequest{
Context: &minderv1.Context{Provider: &provider, Project: &project},
// keep this until we decide to remove it from the proto and rely on the context instead
Provider: provider,
Name: name,
Name: name,
})
if err != nil {
return cli.MessageAndError("Error getting repo by name", err)
Expand Down
3 changes: 0 additions & 3 deletions cmd/cli/app/repo/repo_list.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,6 @@ func listCommand(ctx context.Context, cmd *cobra.Command, conn *grpc.ClientConn)

resp, err := client.ListRepositories(ctx, &minderv1.ListRepositoriesRequest{
Context: &minderv1.Context{Provider: &provider, Project: &project},
// keep this until we decide to delete them from the payload and rely only on the context
Provider: provider,
ProjectId: project,
})
if err != nil {
return cli.MessageAndError("Error listing repositories", err)
Expand Down
11 changes: 1 addition & 10 deletions cmd/cli/app/repo/repo_register.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,9 +98,6 @@ func fetchAlreadyRegisteredRepos(ctx context.Context, provider, project string,
sets.Set[string], error) {
alreadyRegisteredRepos, err := client.ListRepositories(ctx, &minderv1.ListRepositoriesRequest{
Context: &minderv1.Context{Provider: &provider, Project: &project},
// keep this until we decide to delete them from the payload and rely only on the context
Provider: provider,
ProjectId: project,
})
if err != nil {
return nil, err
Expand Down Expand Up @@ -135,9 +132,6 @@ func fetchRemoteRepositoriesFromProvider(ctx context.Context, provider, project
[]*minderv1.UpstreamRepositoryRef, error) {
remoteListResp, err := client.ListRemoteRepositoriesFromProvider(ctx, &minderv1.ListRemoteRepositoriesFromProviderRequest{
Context: &minderv1.Context{Provider: &provider, Project: &project},
// keep this until we decide to delete them from the payload and rely only on the context
Provider: provider,
ProjectId: project,
})
if err != nil {
return nil, err
Expand Down Expand Up @@ -232,10 +226,7 @@ func registerSelectedRepos(
repo := selectedRepos[idx]

result, err := client.RegisterRepository(context.Background(), &minderv1.RegisterRepositoryRequest{
Context: &minderv1.Context{Provider: &provider, Project: &project},
// keep this until we decide to delete them from the payload and rely only on the context
Provider: provider,
ProjectId: project,
Context: &minderv1.Context{Provider: &provider, Project: &project},
Repository: repo,
})
if err != nil {
Expand Down
27 changes: 8 additions & 19 deletions docs/docs/ref/proto.md

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

33 changes: 8 additions & 25 deletions internal/controlplane/common.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,18 +32,6 @@ import (
pb "github.com/stacklok/minder/pkg/api/protobuf/go/minder/v1"
)

// ProjectIDGetter is an interface that can be implemented by a request
type ProjectIDGetter interface {
// GetProjectId returns the project ID
GetProjectId() string
}

// ProviderNameGetter is an interface that can be implemented by a request
type ProviderNameGetter interface {
// GetProvider returns the provider name
GetProvider() string
}

// HasProtoContext is an interface that can be implemented by a request
type HasProtoContext interface {
GetContext() *pb.Context
Expand All @@ -56,18 +44,13 @@ func providerError(err error) error {
return fmt.Errorf("provider error: %w", err)
}

func getProjectFromRequestOrDefault(ctx context.Context, in ProjectIDGetter) (uuid.UUID, error) {
func getProjectFromRequestOrDefault(ctx context.Context, in HasProtoContext) (uuid.UUID, error) {
var requestedProject string

// Prefer the context message from the protobuf
pbContext, ok := in.(HasProtoContext)
if ok && pbContext.GetContext().GetProject() != "" {
requestedProject = pbContext.GetContext().GetProject()
} else if in.GetProjectId() != "" {
requestedProject = in.GetProjectId()
}

if requestedProject == "" {
if in.GetContext().GetProject() != "" {
requestedProject = in.GetContext().GetProject()
} else {
proj, err := auth.GetDefaultProject(ctx)
if err != nil {
return uuid.UUID{}, status.Errorf(codes.InvalidArgument, "cannot infer project id: %s", err)
Expand All @@ -86,15 +69,15 @@ func getProjectFromRequestOrDefault(ctx context.Context, in ProjectIDGetter) (uu
func getProviderFromRequestOrDefault(
ctx context.Context,
store db.Store,
in ProviderNameGetter,
in HasProtoContext,
projectId uuid.UUID,
) (db.Provider, error) {
providers, err := store.ListProvidersByProjectID(ctx, projectId)
if err != nil {
return db.Provider{}, status.Errorf(codes.InvalidArgument, "cannot retrieve providers: %s", err)
}
// if we do not have a provider name, check if we can infer it
if in.GetProvider() == "" {
if in.GetContext().GetProvider() == "" {
if len(providers) == 1 {
return providers[0], nil
}
Expand All @@ -103,12 +86,12 @@ func getProviderFromRequestOrDefault(
}

matchesName := func(provider db.Provider) bool {
return provider.Name == in.GetProvider()
return provider.Name == in.GetContext().GetProvider()
}

i := slices.IndexFunc(providers, matchesName)
if i == -1 {
return db.Provider{}, util.UserVisibleError(codes.InvalidArgument, "invalid provider name: %s", in.GetProvider())
return db.Provider{}, util.UserVisibleError(codes.InvalidArgument, "invalid provider name: %s", in.GetContext().GetProvider())
}
return providers[i], nil
}
18 changes: 9 additions & 9 deletions internal/controlplane/handlers_oauth.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,19 +51,19 @@ func (s *Server) GetAuthorizationURL(ctx context.Context,
return nil, err
}

// Configure tracing
// trace call to AuthCodeURL
span := trace.SpanFromContext(ctx)
span.SetName("server.GetAuthorizationURL")
span.SetAttributes(attribute.Key("provider").String(req.Provider))
defer span.End()

// get provider info
provider, err := getProviderFromRequestOrDefault(ctx, s.store, req, projectID)
if err != nil {
return nil, providerError(fmt.Errorf("provider error: %w", err))
}

// Configure tracing
// trace call to AuthCodeURL
span := trace.SpanFromContext(ctx)
span.SetName("server.GetAuthorizationURL")
span.SetAttributes(attribute.Key("provider").String(provider.Name))
defer span.End()

// Create a new OAuth2 config for the given provider
oauthConfig, err := auth.NewOAuthConfig(provider.Name, req.Cli)
if err != nil {
Expand Down Expand Up @@ -156,7 +156,7 @@ func (s *Server) ExchangeCodeForTokenCLI(ctx context.Context,
}

// generate a new OAuth2 config for the given provider
oauthConfig, err := auth.NewOAuthConfig(in.Provider, true)
oauthConfig, err := auth.NewOAuthConfig(provider.Name, true)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -264,7 +264,7 @@ func (s *Server) StoreProviderToken(ctx context.Context,
}

// validate token
err = auth.ValidateProviderToken(ctx, in.Provider, in.AccessToken)
err = auth.ValidateProviderToken(ctx, provider.Name, in.AccessToken)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid token provided")
}
Expand Down
11 changes: 8 additions & 3 deletions internal/controlplane/handlers_oauth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,8 @@ func TestGetAuthorizationURL(t *testing.T) {
projectID := uuid.New()
port := sql.NullInt32{Int32: 8080, Valid: true}
providerID := uuid.New()
providerName := "github"
projectIdStr := projectID.String()

testCases := []struct {
name string
Expand All @@ -99,9 +101,12 @@ func TestGetAuthorizationURL(t *testing.T) {
{
name: "Success",
req: &pb.GetAuthorizationURLRequest{
Provider: "github",
Port: 8080,
Cli: true,
Context: &pb.Context{
Provider: &providerName,
Project: &projectIdStr,
},
Port: 8080,
Cli: true,
},
buildStubs: func(store *mockdb.MockStore) {
store.EXPECT().
Expand Down
12 changes: 6 additions & 6 deletions internal/controlplane/handlers_repositories.go
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ func (s *Server) RegisterRepository(ctx context.Context,
// publish a reconcile event for the registered repositories
log.Printf("publishing register event for repository: %s/%s", r.Owner, r.Name)

msg, err := reconcilers.NewRepoReconcilerMessage(in.Provider, r.RepoId, projectID)
msg, err := reconcilers.NewRepoReconcilerMessage(provider.Name, r.RepoId, projectID)
if err != nil {
log.Printf("error creating reconciler event: %v", err)
return response, nil
Expand Down Expand Up @@ -364,16 +364,16 @@ func (s *Server) ListRemoteRepositoriesFromProvider(
return nil, err
}

zerolog.Ctx(ctx).Debug().
Str("provider", in.Provider).
Str("projectID", projectID.String()).
Msgf("listing repositories for provider: %s", in.Provider)

provider, err := getProviderFromRequestOrDefault(ctx, s.store, in, projectID)
if err != nil {
return nil, providerError(fmt.Errorf("provider error: %w", err))
}

zerolog.Ctx(ctx).Debug().
Str("provider", provider.Name).
Str("projectID", projectID.String()).
Msg("listing repositories")

// FIXME: this is a hack to get the owner filter from the request
_, owner_filter, err := s.getProviderAccessToken(ctx, provider.Name, projectID, true)

Expand Down
Loading

0 comments on commit f6ed42c

Please sign in to comment.