Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixed centralize logic TODOs in ec2 #378

Merged
merged 3 commits into from
Dec 13, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
147 changes: 81 additions & 66 deletions flintrock/ec2.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,13 +177,7 @@ def destroy(self):
super().destroy()
ec2 = boto3.resource(service_name='ec2', region_name=self.region)

# TODO: Centralize logic to get Flintrock base security group. (?)
flintrock_base_group = list(
ec2.security_groups.filter(
Filters=[
{'Name': 'group-name', 'Values': ['flintrock']},
{'Name': 'vpc-id', 'Values': [self.vpc_id]},
]))[0]
flintrock_base_group = get_base_security_group(vpc_id=self.vpc_id, region=self.region)

# We "unassign" the cluster security group here (i.e. the
# 'flintrock-clustername' group) so that we can immediately delete it once
Expand All @@ -196,13 +190,11 @@ def destroy(self):
Groups=[flintrock_base_group.id])
time.sleep(1)

# TODO: Centralize logic to get cluster security group name from cluster name.
cluster_group = list(
ec2.security_groups.filter(
Filters=[
{'Name': 'group-name', 'Values': ['flintrock-' + self.name]},
{'Name': 'vpc-id', 'Values': [self.vpc_id]},
]))[0]
cluster_group = get_cluster_security_group(
vpc_id=self.vpc_id,
region=self.region,
cluster_name=self.name,
)
cluster_group.delete()

(ec2.instances
Expand Down Expand Up @@ -380,13 +372,10 @@ def remove_slaves(self, *, user: str, identity_file: str, num_slaves: int):
if self.state == 'running':
super().remove_slaves(user=user, identity_file=identity_file)

# TODO: Centralize logic to get Flintrock base security group.
flintrock_base_group = list(
ec2.security_groups.filter(
Filters=[
{'Name': 'group-name', 'Values': ['flintrock']},
{'Name': 'vpc-id', 'Values': [self.vpc_id]},
]))[0]
flintrock_base_group = get_base_security_group(
vpc_id=self.vpc_id,
region=self.region,
)

# TODO: Is there a way to do this in one call for all instances?
for instance in removed_slave_instances:
Expand Down Expand Up @@ -490,37 +479,79 @@ def check_network_config(*, region_name: str, vpc_id: str, subnet_id: str):
)


def get_security_groups(
*,
vpc_id,
region,
security_group_names) -> "List[boto3.resource('ec2').SecurityGroup]":
BASE_SECURITY_GROUP_NAME = "flintrock"


def get_base_security_group(*, vpc_id, region):
"""
The base Flintrock group is common to all Flintrock clusters and authorizes client traffic
to them.
"""
ec2 = boto3.resource(service_name='ec2', region_name=region)
base_group = list(
ec2.security_groups.filter(
Filters=[
{'Name': 'group-name', 'Values': [BASE_SECURITY_GROUP_NAME]},
{'Name': 'vpc-id', 'Values': [vpc_id]},
]
)
)
return base_group[0] if base_group else None

groups = list(

def get_cluster_security_group_name(cluster_name):
return f"flintrock-{cluster_name}"


def get_cluster_security_group(*, vpc_id, region, cluster_name):
"""
The cluster group is specific to one Flintrock cluster and authorizes intra-cluster
communication.
"""
ec2 = boto3.resource(service_name='ec2', region_name=region)
cluster_group_name = get_cluster_security_group_name(cluster_name)
cluster_group = list(
ec2.security_groups.filter(
Filters=[
{'Name': 'group-name', 'Values': security_group_names},
{'Name': 'group-name', 'Values': [cluster_group_name]},
{'Name': 'vpc-id', 'Values': [vpc_id]},
]))
return cluster_group[0] if cluster_group else None


def get_security_groups(
*,
vpc_id,
region,
security_group_names,
):
ec2 = boto3.resource(service_name='ec2', region_name=region)
groups = list(
ec2.security_groups.filter(
Filters=[
{'Name': 'group-name', 'Values': security_group_names},
{'Name': 'vpc-id', 'Values': [vpc_id]},
]
)
)
found_group_names = [group.group_name for group in groups]
missing_group_names = set(security_group_names) - set(found_group_names)
if missing_group_names:
raise Error(
"Could not find the following security group{s}: {groups}"
.format(
s='' if len(missing_group_names) == 1 else 's',
groups=', '.join(list(missing_group_names))))

groups=', '.join(list(missing_group_names)),
)
)
return groups


def get_ssh_security_group_rules(
*,
flintrock_client_cidr=None,
flintrock_client_group=None,
) -> "boto3.resource('ec2').SecurityGroup":
):
return SecurityGroupRule(
ip_protocol='tcp',
from_port=22,
Expand All @@ -531,49 +562,26 @@ def get_ssh_security_group_rules(


def get_or_create_flintrock_security_groups(
*,
cluster_name,
vpc_id,
region,
services,
ec2_authorize_access_from,
) -> "List[boto3.resource('ec2').SecurityGroup]":
*,
cluster_name,
vpc_id,
region,
services,
ec2_authorize_access_from,
):
"""
If they do not already exist, create all the security groups needed for a
Flintrock cluster.
"""
ec2 = boto3.resource(service_name='ec2', region_name=region)

# TODO: Make these into methods, since we need this logic (though simple)
# in multiple places. (?)
flintrock_group_name = 'flintrock'
cluster_group_name = 'flintrock-' + cluster_name

# The Flintrock group is common to all Flintrock clusters and authorizes client traffic
# to them.
flintrock_group = list(
ec2.security_groups.filter(
Filters=[
{'Name': 'group-name', 'Values': [flintrock_group_name]},
{'Name': 'vpc-id', 'Values': [vpc_id]},
]))
flintrock_group = flintrock_group[0] if flintrock_group else None

# The cluster group is specific to one Flintrock cluster and authorizes intra-cluster
# communication.
cluster_group = list(
ec2.security_groups.filter(
Filters=[
{'Name': 'group-name', 'Values': [cluster_group_name]},
{'Name': 'vpc-id', 'Values': [vpc_id]},
]))
cluster_group = cluster_group[0] if cluster_group else None

flintrock_group = get_base_security_group(vpc_id=vpc_id, region=region)
if not flintrock_group:
flintrock_group = ec2.create_security_group(
GroupName=flintrock_group_name,
GroupName=BASE_SECURITY_GROUP_NAME,
Description="Flintrock base group",
VpcId=vpc_id)
VpcId=vpc_id,
)

# Rules for the client interacting with the cluster.
if ec2_authorize_access_from:
Expand Down Expand Up @@ -607,12 +615,19 @@ def get_or_create_flintrock_security_groups(
flintrock_client_cidr=str(IPv4Network(client_source)),
)

cluster_group_name = get_cluster_security_group_name(cluster_name)
cluster_group = get_cluster_security_group(
vpc_id=vpc_id,
region=region,
cluster_name=cluster_name,
)
# Rules for internal cluster communication.
if not cluster_group:
cluster_group = ec2.create_security_group(
GroupName=cluster_group_name,
Description="Flintrock cluster group",
VpcId=vpc_id)
VpcId=vpc_id,
)

# TODO: Don't try adding rules that already exist.
# TODO: Add rules in one shot.
Expand Down
Loading