diff --git a/pkg/aws/aws.go b/pkg/aws/aws.go index 8bc5b2cb..881e4d44 100644 --- a/pkg/aws/aws.go +++ b/pkg/aws/aws.go @@ -27,12 +27,46 @@ import ( "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/config" "github.com/aws/aws-sdk-go-v2/service/ec2" + "github.com/aws/aws-sdk-go-v2/service/ec2/types" "github.com/pkg/errors" "github.com/submariner-io/admiral/pkg/reporter" "github.com/submariner-io/cloud-prepare/pkg/api" awsClient "github.com/submariner-io/cloud-prepare/pkg/aws/client" ) +type CloudOption func(*awsCloud) + +const ( + ControlPlaneSecurityGroupIDKey = "controlPlaneSecurityGroupID" + WorkerSecurityGroupIDKey = "workerSecurityGroupID" + PublicSubnetListKey = "PublicSubnetList" + VPCIDKey = "VPCID" +) + +func WithControlPlaneSecurityGroup(id string) CloudOption { + return func(cloud *awsCloud) { + cloud.cloudConfig[ControlPlaneSecurityGroupIDKey] = id + } +} + +func WithWorkerSecurityGroup(id string) CloudOption { + return func(cloud *awsCloud) { + cloud.cloudConfig[WorkerSecurityGroupIDKey] = id + } +} + +func WithPublicSubnetList(id []string) CloudOption { + return func(cloud *awsCloud) { + cloud.cloudConfig[PublicSubnetListKey] = id + } +} + +func WithVPCName(name string) CloudOption { + return func(cloud *awsCloud) { + cloud.cloudConfig[VPCIDKey] = name + } +} + const ( messageRetrieveVPCID = "Retrieving VPC ID" messageRetrievedVPCID = "Retrieved VPC ID %s" @@ -46,30 +80,45 @@ type awsCloud struct { region string nodeSGSuffix string controlPlaneSGSuffix string + cloudConfig map[string]interface{} } // NewCloud creates a new api.Cloud instance which can prepare AWS for Submariner to be deployed on it. -func NewCloud(client awsClient.Interface, infraID, region string) api.Cloud { - return &awsCloud{ - client: client, - infraID: infraID, - region: region, +func NewCloud(client awsClient.Interface, infraID, region string, opts ...CloudOption) api.Cloud { + cloud := &awsCloud{ + client: client, + infraID: infraID, + region: region, + cloudConfig: make(map[string]interface{}), + } + + for _, opt := range opts { + opt(cloud) } + + return cloud } // NewCloudFromConfig creates a new api.Cloud instance based on an AWS configuration // which can prepare AWS for Submariner to be deployed on it. -func NewCloudFromConfig(cfg *aws.Config, infraID, region string) api.Cloud { - return &awsCloud{ - client: ec2.NewFromConfig(*cfg), - infraID: infraID, - region: region, +func NewCloudFromConfig(cfg *aws.Config, infraID, region string, opts ...CloudOption) api.Cloud { + cloud := &awsCloud{ + client: ec2.NewFromConfig(*cfg), + infraID: infraID, + region: region, + cloudConfig: make(map[string]interface{}), } + + for _, opt := range opts { + opt(cloud) + } + + return cloud } // NewCloudFromSettings creates a new api.Cloud instance using the given credentials file and profile // which can prepare AWS for Submariner to be deployed on it. -func NewCloudFromSettings(credentialsFile, profile, infraID, region string) (api.Cloud, error) { +func NewCloudFromSettings(credentialsFile, profile, infraID, region string, opts ...CloudOption) (api.Cloud, error) { options := []func(*config.LoadOptions) error{config.WithRegion(region), config.WithSharedConfigProfile(profile)} if credentialsFile != DefaultCredentialsFile() { options = append(options, config.WithSharedCredentialsFiles([]string{credentialsFile})) @@ -80,7 +129,7 @@ func NewCloudFromSettings(credentialsFile, profile, infraID, region string) (api return nil, errors.Wrap(err, "error loading default config") } - return NewCloudFromConfig(&cfg, infraID, region), nil + return NewCloudFromConfig(&cfg, infraID, region, opts...), nil } // DefaultCredentialsFile returns the default credentials file name. @@ -98,13 +147,30 @@ func (ac *awsCloud) setSuffixes(vpcID string) error { return nil } - publicSubnets, err := ac.findPublicSubnets(vpcID, ac.filterByName("{infraID}*-public-{region}*")) - if err != nil { - return errors.Wrapf(err, "unable to find the public subnet") - } + var publicSubnets []types.Subnet + + if subnets, exists := ac.cloudConfig[PublicSubnetListKey]; exists { + if subnetIDs, ok := subnets.([]string); ok && len(subnetIDs) > 0 { + for _, id := range subnetIDs { + subnet, err := ac.getSubnetByID(id) + if err != nil { + return errors.Wrapf(err, "unable to find subnet with ID %s", id) + } + + publicSubnets = append(publicSubnets, *subnet) + } + } else { + return errors.New("Subnet IDs must be a valid non-empty slice of strings") + } + } else { + publicSubnets, err := ac.findPublicSubnets(vpcID, ac.filterByName("{infraID}*-public-{region}*")) + if err != nil { + return errors.Wrapf(err, "unable to find the public subnet") + } - if len(publicSubnets) == 0 { - return errors.New("no public subnet found") + if len(publicSubnets) == 0 { + return errors.New("no public subnet found") + } } pattern := fmt.Sprintf(`%s.*-subnet-public-%s.*`, regexp.QuoteMeta(ac.infraID), regexp.QuoteMeta(ac.region)) @@ -137,9 +203,11 @@ func (ac *awsCloud) OpenPorts(ports []api.PortSpec, status reporter.Interface) e return status.Error(err, "unable to retrieve the VPC ID") } - err = ac.setSuffixes(vpcID) - if err != nil { - return status.Error(err, "unable to retrieve the security group names") + if _, found := ac.cloudConfig[VPCIDKey]; !found { + err = ac.setSuffixes(vpcID) + if err != nil { + return status.Error(err, "unable to retrieve the security group names") + } } status.Success(messageRetrievedVPCID, vpcID) @@ -180,9 +248,11 @@ func (ac *awsCloud) ClosePorts(status reporter.Interface) error { return status.Error(err, "unable to retrieve the VPC ID") } - err = ac.setSuffixes(vpcID) - if err != nil { - return status.Error(err, "unable to retrieve the security group names") + if _, found := ac.cloudConfig[VPCIDKey]; !found { + err = ac.setSuffixes(vpcID) + if err != nil { + return status.Error(err, "unable to retrieve the security group names") + } } status.Success(messageRetrievedVPCID, vpcID) diff --git a/pkg/aws/aws_cloud_test.go b/pkg/aws/aws_cloud_test.go index e5c82a97..4ffac4eb 100644 --- a/pkg/aws/aws_cloud_test.go +++ b/pkg/aws/aws_cloud_test.go @@ -40,7 +40,6 @@ func testOpenPorts() { JustBeforeEach(func() { t.expectDescribeVpcs(t.vpcID) - t.expectDescribeVpcsSigs(t.vpcID) t.expectDescribePublicSubnets(t.subnets...) retError = t.cloud.OpenPorts([]api.PortSpec{ @@ -118,7 +117,6 @@ func testClosePorts() { JustBeforeEach(func() { t.expectDescribeVpcs(t.vpcID) t.expectDescribePublicSubnets(t.subnets...) - t.expectDescribeVpcsSigs(t.vpcID) t.expectDescribePublicSubnetsSigs(t.subnets...) retError = t.cloud.ClosePorts(reporter.Stdout()) diff --git a/pkg/aws/aws_suite_test.go b/pkg/aws/aws_suite_test.go index 061b7e1c..d30d2c64 100644 --- a/pkg/aws/aws_suite_test.go +++ b/pkg/aws/aws_suite_test.go @@ -32,6 +32,7 @@ import ( "github.com/submariner-io/admiral/pkg/mock" "github.com/submariner-io/cloud-prepare/pkg/aws/client/fake" "go.uber.org/mock/gomock" + "k8s.io/utils/ptr" ) const ( @@ -50,8 +51,9 @@ const ( masterSGName = infraID + "-master-sg" workerSGName = infraID + "-worker-sg" gatewaySGName = infraID + "-submariner-gw-sg" + providerAWSTagPrefix = "tag:sigs.k8s.io/cluster-api-provider-aws/cluster/" clusterFilterTagName = "tag:kubernetes.io/cluster/" + infraID - clusterFilterTagNameSigs = "tag:sigs.k8s.io/cluster-api-provider-aws/cluster/" + infraID + clusterFilterTagNameSigs = providerAWSTagPrefix + infraID ) var internalTrafficDesc = fmt.Sprintf("Should contain %q", internalTraffic) @@ -113,24 +115,8 @@ func (f *fakeAWSClientBase) expectDescribeVpcs(vpcID string) { }, types.Filter{ Name: awssdk.String(clusterFilterTagName), Values: []string{"owned"}, - })).Return(&ec2.DescribeVpcsOutput{Vpcs: vpcs}, nil).AnyTimes() -} - -func (f *fakeAWSClientBase) expectDescribeVpcsSigs(vpcID string) { - var vpcs []types.Vpc - if vpcID != "" { - vpcs = []types.Vpc{ - { - VpcId: awssdk.String(vpcID), - }, - } - } - - f.awsClient.EXPECT().DescribeVpcs(gomock.Any(), eqFilters(types.Filter{ - Name: awssdk.String("tag:Name"), - Values: []string{infraID + "-vpc"}, }, types.Filter{ - Name: awssdk.String(clusterFilterTagNameSigs), + Name: ptr.To(providerAWSTagPrefix + infraID), Values: []string{"owned"}, })).Return(&ec2.DescribeVpcsOutput{Vpcs: vpcs}, nil).AnyTimes() } diff --git a/pkg/aws/ocpgwdeployer.go b/pkg/aws/ocpgwdeployer.go index 5ec4ee84..5faf6218 100644 --- a/pkg/aws/ocpgwdeployer.go +++ b/pkg/aws/ocpgwdeployer.go @@ -69,16 +69,35 @@ func (d *ocpGatewayDeployer) Deploy(input api.GatewayDeployInput, status reporte status.Success(messageRetrievedVPCID, vpcID) - err = d.aws.setSuffixes(vpcID) - if err != nil { - return status.Error(err, "unable to retrieve the security group names") + if _, found := d.aws.cloudConfig[VPCIDKey]; !found { + err = d.aws.setSuffixes(vpcID) + if err != nil { + return status.Error(err, "unable to retrieve the security group names") + } } status.Start(messageValidatePrerequisites) - publicSubnets, err := d.aws.findPublicSubnets(vpcID, d.aws.filterByName("{infraID}*-public-{region}*")) - if err != nil { - return status.Error(err, "unable to find public subnets") + var publicSubnets []types.Subnet + + if subnets, exists := d.aws.cloudConfig[PublicSubnetListKey]; exists { + if subnetIDs, ok := subnets.([]string); ok && len(subnetIDs) > 0 { + for _, id := range subnetIDs { + subnet, err := d.aws.getSubnetByID(id) + if err != nil { + return errors.Wrapf(err, "unable to find subnet with ID %s", id) + } + + publicSubnets = append(publicSubnets, *subnet) + } + } else { + return errors.New("Subnet IDs must be a valid non-empty slice of strings") + } + } else { + publicSubnets, err = d.aws.findPublicSubnets(vpcID, d.aws.filterByName("{infraID}*-public-{region}*")) + if err != nil { + return status.Error(err, "unable to find public subnets") + } } err = d.validateDeployPrerequisites(vpcID, input, publicSubnets) @@ -97,9 +116,15 @@ func (d *ocpGatewayDeployer) Deploy(input api.GatewayDeployInput, status reporte status.Success("Created Submariner gateway security group %s", gatewaySG) + return d.processSubnets(vpcID, gatewaySG, publicSubnets, input, status) +} + +func (d *ocpGatewayDeployer) processSubnets(vpcID, gatewaySG string, publicSubnets []types.Subnet, + input api.GatewayDeployInput, status reporter.Interface, +) error { subnets, err := d.aws.getSubnetsSupportingInstanceType(publicSubnets, d.instanceType) if err != nil { - return status.Error(err, "unable to create security group") + return status.Error(err, "unable to get subnets supporting instance type") } taggedSubnets, _ := filterSubnets(subnets, func(subnet *types.Subnet) (bool, error) { @@ -315,9 +340,11 @@ func (d *ocpGatewayDeployer) Cleanup(status reporter.Interface) error { status.Success(messageRetrievedVPCID, vpcID) - err = d.aws.setSuffixes(vpcID) - if err != nil { - return status.Error(err, "unable to retrieve the security group names") + if _, found := d.aws.cloudConfig[VPCIDKey]; !found { + err = d.aws.setSuffixes(vpcID) + if err != nil { + return status.Error(err, "unable to retrieve the security group names") + } } status.Start(messageValidatePrerequisites) @@ -329,13 +356,30 @@ func (d *ocpGatewayDeployer) Cleanup(status reporter.Interface) error { status.Success(messageValidatedPrerequisites) - subnets, err := d.aws.getTaggedPublicSubnets(vpcID) - if err != nil { - return err + var publicSubnets []types.Subnet + + if subnets, exists := d.aws.cloudConfig[PublicSubnetListKey]; exists { + if subnetIDs, ok := subnets.([]string); ok && len(subnetIDs) > 0 { + for _, id := range subnetIDs { + subnet, err := d.aws.getSubnetByID(id) + if err != nil { + return errors.Wrapf(err, "unable to find subnet with ID %s", id) + } + + publicSubnets = append(publicSubnets, *subnet) + } + } else { + return errors.New("Subnet IDs must be a valid non-empty slice of strings") + } + } else { + publicSubnets, err = d.aws.getTaggedPublicSubnets(vpcID) + if err != nil { + return err + } } - for i := range subnets { - subnet := &subnets[i] + for i := range publicSubnets { + subnet := &publicSubnets[i] subnetName := extractName(subnet.Tags) status.Start("Removing gateway node for public subnet %s", subnetName) diff --git a/pkg/aws/ocpgwdeployer_test.go b/pkg/aws/ocpgwdeployer_test.go index 6a9d0585..923d2954 100644 --- a/pkg/aws/ocpgwdeployer_test.go +++ b/pkg/aws/ocpgwdeployer_test.go @@ -284,7 +284,6 @@ func newGatewayDeployerTestDriver() *gatewayDeployerTestDriver { t.expectDescribeInstances(instanceImageID) t.expectDescribeSecurityGroups(workerSGName, workerGroupID) t.expectDescribePublicSubnets(t.subnets...) - t.expectDescribeVpcsSigs(t.vpcID) t.expectDescribePublicSubnetsSigs(t.subnets...) var err error diff --git a/pkg/aws/securitygroups.go b/pkg/aws/securitygroups.go index ce361aac..7e95fb11 100644 --- a/pkg/aws/securitygroups.go +++ b/pkg/aws/securitygroups.go @@ -35,7 +35,7 @@ import ( const internalTraffic = "Internal Submariner traffic" -func (ac *awsCloud) getSecurityGroupID(vpcID, name string) (*string, error) { +func (ac *awsCloud) getSecurityGroupName(vpcID, name string) (*string, error) { group, err := ac.getSecurityGroup(vpcID, name) if err != nil { return nil, err @@ -44,6 +44,21 @@ func (ac *awsCloud) getSecurityGroupID(vpcID, name string) (*string, error) { return group.GroupId, nil } +func (ac *awsCloud) getSecurityGroupByID(groupID string) (types.SecurityGroup, error) { + output, err := ac.client.DescribeSecurityGroups(context.TODO(), &ec2.DescribeSecurityGroupsInput{ + GroupIds: []string{groupID}, + }) + if err != nil { + return types.SecurityGroup{}, errors.Wrapf(err, "unable to describe security group %s", groupID) + } + + if len(output.SecurityGroups) == 0 { + return types.SecurityGroup{}, errors.New("security group not found") + } + + return output.SecurityGroups[0], nil +} + func (ac *awsCloud) getSecurityGroup(vpcID, name string) (types.SecurityGroup, error) { filters := []types.Filter{ ec2Filter("vpc-id", vpcID), @@ -97,14 +112,34 @@ func (ac *awsCloud) createClusterSGRule(srcGroup, destGroup *string, port uint16 } func (ac *awsCloud) allowPortInCluster(vpcID string, port uint16, protocol string) error { - workerGroupID, err := ac.getSecurityGroupID(vpcID, "{infraID}"+ac.nodeSGSuffix) - if err != nil { - return err + var workerGroupID, controlPlaneGroupID *string + var err error + + if id, exists := ac.cloudConfig[WorkerSecurityGroupIDKey]; exists { + if workerGroupIDStr, ok := id.(string); ok && workerGroupIDStr != "" { + workerGroupID = &workerGroupIDStr + } else { + return errors.New("Worker Security Group ID must be a valid non-empty string") + } + } else { + workerGroupID, err = ac.getSecurityGroupName(vpcID, "{infraID}"+ac.nodeSGSuffix) + + if err != nil { + return err + } } - masterGroupID, err := ac.getSecurityGroupID(vpcID, "{infraID}"+ac.controlPlaneSGSuffix) - if err != nil { - return err + if id, exists := ac.cloudConfig[ControlPlaneSecurityGroupIDKey]; exists { + if controlPlaneGroupIDStr, ok := id.(string); ok && controlPlaneGroupIDStr != "" { + controlPlaneGroupID = &controlPlaneGroupIDStr + } else { + return errors.New("Control Plane Security Group ID must be a valid non-empty string") + } + } else { + controlPlaneGroupID, err = ac.getSecurityGroupName(vpcID, "{infraID}"+ac.controlPlaneSGSuffix) + if err != nil { + return err + } } err = ac.createClusterSGRule(workerGroupID, workerGroupID, port, protocol, fmt.Sprintf("%s between the workers", internalTraffic)) @@ -112,12 +147,14 @@ func (ac *awsCloud) allowPortInCluster(vpcID string, port uint16, protocol strin return err } - err = ac.createClusterSGRule(workerGroupID, masterGroupID, port, protocol, fmt.Sprintf("%s from worker to master nodes", internalTraffic)) + err = ac.createClusterSGRule(workerGroupID, controlPlaneGroupID, port, protocol, + fmt.Sprintf("%s from worker to master nodes", internalTraffic)) if err != nil { return err } - return ac.createClusterSGRule(masterGroupID, workerGroupID, port, protocol, fmt.Sprintf("%s from master to worker nodes", internalTraffic)) + return ac.createClusterSGRule(controlPlaneGroupID, workerGroupID, port, protocol, + fmt.Sprintf("%s from master to worker nodes", internalTraffic)) } func (ac *awsCloud) createPublicSGRule(groupID *string, port uint16, protocol, description string) error { @@ -141,7 +178,7 @@ func (ac *awsCloud) createPublicSGRule(groupID *string, port uint16, protocol, d func (ac *awsCloud) createGatewaySG(vpcID string, ports []api.PortSpec) (string, error) { groupName := ac.withAWSInfo("{infraID}-submariner-gw-sg") - gatewayGroupID, err := ac.getSecurityGroupID(vpcID, groupName) + gatewayGroupID, err := ac.getSecurityGroupName(vpcID, groupName) if err != nil { if !isNotFoundError(err) { return "", err @@ -187,7 +224,7 @@ func gatewayDeletionRetriable(err error) bool { func (ac *awsCloud) deleteGatewaySG(vpcID string) error { groupName := ac.withAWSInfo("{infraID}-submariner-gw-sg") - gatewayGroupID, err := ac.getSecurityGroupID(vpcID, groupName) + gatewayGroupID, err := ac.getSecurityGroupName(vpcID, groupName) if err != nil { if isNotFoundError(err) { return nil @@ -219,14 +256,39 @@ func (ac *awsCloud) deleteGatewaySG(vpcID string) error { } func (ac *awsCloud) revokePortsInCluster(vpcID string) error { - workerGroup, err := ac.getSecurityGroup(vpcID, "{infraID}"+ac.nodeSGSuffix) - if err != nil { - return err + var workerGroup, controlPlaneGroup types.SecurityGroup + var err error + + if id, exists := ac.cloudConfig[WorkerSecurityGroupIDKey]; exists { + if workerGroupIDStr, ok := id.(string); ok && workerGroupIDStr != "" { + workerGroup, err = ac.getSecurityGroupByID(workerGroupIDStr) + if err != nil { + return errors.Wrap(err, "unable to get Worker Security Group by ID") + } + } else { + return errors.New("Worker Security Group ID must be a valid non-empty string") + } + } else { + workerGroup, err = ac.getSecurityGroup(vpcID, "{infraID}"+ac.nodeSGSuffix) + if err != nil { + return err + } } - masterGroup, err := ac.getSecurityGroup(vpcID, "{infraID}"+ac.controlPlaneSGSuffix) - if err != nil { - return err + if id, exists := ac.cloudConfig[ControlPlaneSecurityGroupIDKey]; exists { + if controlPlaneGroupIDStr, ok := id.(string); ok && controlPlaneGroupIDStr != "" { + controlPlaneGroup, err = ac.getSecurityGroupByID(controlPlaneGroupIDStr) + if err != nil { + return errors.Wrap(err, "unable to get Control Plane Security Group by ID") + } + } else { + return errors.New("Control Plane Security Group ID must be a valid non-empty string") + } + } else { + controlPlaneGroup, err = ac.getSecurityGroup(vpcID, "{infraID}"+ac.controlPlaneSGSuffix) + if err != nil { + return err + } } err = ac.revokePortsFromGroup(&workerGroup) @@ -234,7 +296,7 @@ func (ac *awsCloud) revokePortsInCluster(vpcID string) error { return err } - return ac.revokePortsFromGroup(&masterGroup) + return ac.revokePortsFromGroup(&controlPlaneGroup) } func (ac *awsCloud) revokePortsFromGroup(group *types.SecurityGroup) error { diff --git a/pkg/aws/subnets.go b/pkg/aws/subnets.go index 7748d575..3b6a293c 100644 --- a/pkg/aws/subnets.go +++ b/pkg/aws/subnets.go @@ -123,3 +123,18 @@ func (ac *awsCloud) untagPublicSubnet(subnetID *string) error { return errors.Wrap(err, "error deleting AWS tag") } + +func (ac *awsCloud) getSubnetByID(subnetID string) (*types.Subnet, error) { + output, err := ac.client.DescribeSubnets(context.TODO(), &ec2.DescribeSubnetsInput{ + SubnetIds: []string{subnetID}, + }) + if err != nil { + return nil, errors.Wrapf(err, "unable to describe subnet %s", subnetID) + } + + if len(output.Subnets) == 0 { + return nil, errors.New("subnet not found") + } + + return &output.Subnets[0], nil +} diff --git a/pkg/aws/validations.go b/pkg/aws/validations.go index 65ab1c62..d7926ed6 100644 --- a/pkg/aws/validations.go +++ b/pkg/aws/validations.go @@ -54,9 +54,21 @@ func (ac *awsCloud) validateCreateSecGroup(vpcID string) error { } func (ac *awsCloud) validateCreateSecGroupRule(vpcID string) error { - workerGroupID, err := ac.getSecurityGroupID(vpcID, "{infraID}"+ac.nodeSGSuffix) - if err != nil { - return err + var workerGroupID *string + + if id, exists := ac.cloudConfig[WorkerSecurityGroupIDKey]; exists { + if workerGroupIDStr, ok := id.(string); ok && workerGroupIDStr != "" { + workerGroupID = &workerGroupIDStr + } else { + return errors.New("Worker Security Group ID must be a valid non-empty string") + } + } else { + var err error + + workerGroupID, err = ac.getSecurityGroupName(vpcID, "{infraID}"+ac.nodeSGSuffix) + if err != nil { + return err + } } input := &ec2.AuthorizeSecurityGroupIngressInput{ @@ -64,7 +76,7 @@ func (ac *awsCloud) validateCreateSecGroupRule(vpcID string) error { GroupId: workerGroupID, } - _, err = ac.client.AuthorizeSecurityGroupIngress(context.TODO(), input) + _, err := ac.client.AuthorizeSecurityGroupIngress(context.TODO(), input) return determinePermissionError(err, "authorize security group ingress") } @@ -90,9 +102,21 @@ func (ac *awsCloud) validateDescribeInstanceTypeOfferings() error { } func (ac *awsCloud) validateDeleteSecGroup(vpcID string) error { - workerGroupID, err := ac.getSecurityGroupID(vpcID, "{infraID}"+ac.nodeSGSuffix) - if err != nil { - return err + var workerGroupID *string + + if id, exists := ac.cloudConfig[WorkerSecurityGroupIDKey]; exists { + if workerGroupIDStr, ok := id.(string); ok && workerGroupIDStr != "" { + workerGroupID = &workerGroupIDStr + } else { + return errors.New("Worker Security Group ID must be a valid non-empty string") + } + } else { + var err error + + workerGroupID, err = ac.getSecurityGroupName(vpcID, "{infraID}"+ac.nodeSGSuffix) + if err != nil { + return err + } } input := &ec2.DeleteSecurityGroupInput{ @@ -100,15 +124,27 @@ func (ac *awsCloud) validateDeleteSecGroup(vpcID string) error { GroupId: workerGroupID, } - _, err = ac.client.DeleteSecurityGroup(context.TODO(), input) + _, err := ac.client.DeleteSecurityGroup(context.TODO(), input) return determinePermissionError(err, "delete security group") } func (ac *awsCloud) validateDeleteSecGroupRule(vpcID string) error { - workerGroupID, err := ac.getSecurityGroupID(vpcID, "{infraID}"+ac.nodeSGSuffix) - if err != nil { - return err + var workerGroupID *string + + if id, exists := ac.cloudConfig[WorkerSecurityGroupIDKey]; exists { + if workerGroupIDStr, ok := id.(string); ok && workerGroupIDStr != "" { + workerGroupID = &workerGroupIDStr + } else { + return errors.New("Worker Security Group ID must be a valid non-empty string") + } + } else { + var err error + + workerGroupID, err = ac.getSecurityGroupName(vpcID, "{infraID}"+ac.nodeSGSuffix) + if err != nil { + return err + } } input := &ec2.RevokeSecurityGroupIngressInput{ @@ -116,7 +152,7 @@ func (ac *awsCloud) validateDeleteSecGroupRule(vpcID string) error { GroupId: workerGroupID, } - _, err = ac.client.RevokeSecurityGroupIngress(context.TODO(), input) + _, err := ac.client.RevokeSecurityGroupIngress(context.TODO(), input) return determinePermissionError(err, "revoke security group ingress") } diff --git a/pkg/aws/vpcs.go b/pkg/aws/vpcs.go index 8012cb32..9e038a58 100644 --- a/pkg/aws/vpcs.go +++ b/pkg/aws/vpcs.go @@ -27,25 +27,29 @@ import ( ) func (ac *awsCloud) getVpcID() (string, error) { - ownedFilters := ac.filterByCurrentCluster() var err error var result *ec2.DescribeVpcsOutput - vpcName := ac.withAWSInfo("{infraID}-vpc") - for i := range ownedFilters { - filters := []types.Filter{ - ac.filterByName(vpcName), - ownedFilters[i], + if vpcID, exists := ac.cloudConfig[VPCIDKey]; exists { + vpcIDStr, ok := vpcID.(string) + if !ok || vpcIDStr == "" { + return "", errors.New("VPC ID needs to be a valid non-empty string") } - result, err = ac.client.DescribeVpcs(context.TODO(), &ec2.DescribeVpcsInput{Filters: filters}) - if err != nil { - return "", errors.Wrap(err, "error describing AWS VPCs") - } + return vpcIDStr, nil + } - if len(result.Vpcs) != 0 { - break - } + ownedFilters := ac.filterByCurrentCluster() + vpcName := ac.withAWSInfo("{infraID}-vpc") + + filters := []types.Filter{ + ac.filterByName(vpcName), + } + filters = append(filters, ownedFilters...) + + result, err = ac.client.DescribeVpcs(context.TODO(), &ec2.DescribeVpcsInput{Filters: filters}) + if err != nil { + return "", errors.Wrap(err, "error describing AWS VPCs") } if len(result.Vpcs) == 0 {