Skip to content

Commit

Permalink
Make tests and lambdas run with same permissions
Browse files Browse the repository at this point in the history
  • Loading branch information
Rikuoja committed Jan 27, 2025
1 parent 5eed7f2 commit a08e494
Show file tree
Hide file tree
Showing 7 changed files with 45 additions and 37 deletions.
43 changes: 23 additions & 20 deletions database/db_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,26 @@ class User(enum.Enum):
READ = "DB_SECRET_R_ARN"


env_credentials = {
User.SU: {
"username": os.environ.get("SU_USER"),
"password": os.environ.get("SU_USER_PW"),
},
User.ADMIN: {
"username": os.environ.get("ADMIN_USER"),
"password": os.environ.get("ADMIN_USER_PW"),
},
User.READ_WRITE: {
"username": os.environ.get("RW_USER"),
"password": os.environ.get("RW_USER_PW"),
},
User.READ: {
"username": os.environ.get("R_USER"),
"password": os.environ.get("R_USER_PW"),
},
}


class Db(enum.Enum):
MAINTENANCE = 1
MAIN = 2
Expand All @@ -26,14 +46,14 @@ def __init__(self, user: Optional[User] = None):
If user is not specified, requires that the lambda function has *all* user
privileges and secrets specified in lambda function os.environ.
"""
# if user is not specified, iterate through all users
users: List | Type[User] = [user] if user else User
if os.environ.get("READ_FROM_AWS", "1") == "1":
session = boto3.session.Session()
client = session.client(
service_name="secretsmanager",
region_name=os.environ.get("AWS_REGION_NAME"),
)
# if user is not specified, iterate through all users
users: List | Type[User] = [user] if user else User
self._users = {
user: json.loads(
client.get_secret_value(SecretId=os.environ[user.value])[
Expand All @@ -43,24 +63,7 @@ def __init__(self, user: Optional[User] = None):
for user in users
}
else:
self._users = {
User.SU: {
"username": os.environ.get("SU_USER"),
"password": os.environ.get("SU_USER_PW"),
},
User.ADMIN: {
"username": os.environ.get("ADMIN_USER"),
"password": os.environ.get("ADMIN_USER_PW"),
},
User.READ_WRITE: {
"username": os.environ.get("RW_USER"),
"password": os.environ.get("RW_USER_PW"),
},
User.READ: {
"username": os.environ.get("R_USER"),
"password": os.environ.get("R_USER_PW"),
},
}
self._users = {user: env_credentials[user] for user in users}
self._dbs = {
Db.MAIN: os.environ["DB_MAIN_NAME"],
Db.MAINTENANCE: os.environ["DB_MAINTENANCE_NAME"],
Expand Down
2 changes: 1 addition & 1 deletion database/mml_loader/mml_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ def save_geometries(self, geoms: Dict) -> str:
def handler(event, _) -> Response:
"""Handler which is called when accessing the endpoint."""
response: Response = {"statusCode": 200, "body": json.dumps("")}
db_helper = DatabaseHelper(user=User.READ_WRITE)
db_helper = DatabaseHelper(user=User.ADMIN)
api_key = os.environ.get("MML_APIKEY")
if not api_key:
raise ValueError(
Expand Down
15 changes: 10 additions & 5 deletions database/test/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from alembic.operations import ops
from alembic.script import ScriptDirectory
from base import PROJECT_SRID
from db_helper import DatabaseHelper
from db_helper import DatabaseHelper, User
from db_manager import db_manager
from dotenv import load_dotenv
from geoalchemy2.shape import from_shape
Expand Down Expand Up @@ -474,13 +474,18 @@ def assert_database_is_alright(


@pytest.fixture(scope="module")
def connection_string(hame_database_created) -> str:
return DatabaseHelper().get_connection_string()
def admin_connection_string(hame_database_created) -> str:
return DatabaseHelper(user=User.ADMIN).get_connection_string()


@pytest.fixture(scope="module")
def session(connection_string):
engine = sqlalchemy.create_engine(connection_string)
def rw_connection_string(hame_database_created) -> str:
return DatabaseHelper(user=User.READ_WRITE).get_connection_string()


@pytest.fixture(scope="module")
def session(admin_connection_string):
engine = sqlalchemy.create_engine(admin_connection_string)
session = sessionmaker(bind=engine)
yield session()

Expand Down
8 changes: 4 additions & 4 deletions database/test/test_koodistot_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,9 +427,9 @@ def changed_mock_koodistot(requests_mock, mock_koodistot) -> None:


@pytest.fixture(scope="module")
def loader(connection_string) -> KoodistotLoader:
def loader(admin_connection_string) -> KoodistotLoader:
return KoodistotLoader(
connection_string,
admin_connection_string,
api_url="http://mock.url",
)

Expand Down Expand Up @@ -699,15 +699,15 @@ def test_save_objects(loader, koodistot_data, main_db_params):


def test_save_changed_objects(
changed_koodistot_data, connection_string, main_db_params
changed_koodistot_data, admin_connection_string, main_db_params
):
# The database is already populated in the first test. Because
# connection string (and therefore hame_database_created)
# has module scope, the database persists between tests.
assert_data_is_imported(main_db_params)
# check that a new loader adds one object to the database
loader = KoodistotLoader(
connection_string,
admin_connection_string,
api_url="http://mock.url",
)
loader.save_objects(changed_koodistot_data)
Expand Down
4 changes: 2 additions & 2 deletions database/test/test_mml_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,11 @@ def match_request_body(request: _RequestObjectProxy):

@pytest.fixture()
def loader(
connection_string: str,
admin_connection_string: str,
municipality_instance: codes.Municipality,
administrative_region_instance: codes.AdministrativeRegion,
) -> MMLLoader:
return MMLLoader(connection_string, api_key="mock_apikey")
return MMLLoader(admin_connection_string, api_key="mock_apikey")


def test_get_geometries(mock_mml: Callable, loader: MMLLoader):
Expand Down
8 changes: 4 additions & 4 deletions database/test/test_ryhti_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,7 +341,7 @@ def mock_xroad_ryhti_update_existing_plan_matter(

@pytest.fixture(scope="function")
def client_with_plan_data(
connection_string: str, complete_test_plan: models.Plan
rw_connection_string: str, complete_test_plan: models.Plan
) -> RyhtiClient:
"""
Return RyhtiClient that has plan data read in.
Expand All @@ -353,7 +353,7 @@ def client_with_plan_data(
"""
# Let's mock production x-road with gispo organization client here.
client = RyhtiClient(
connection_string,
rw_connection_string,
public_api_url="http://mock.url",
xroad_server_address="http://mock2.url",
xroad_instance="FI",
Expand All @@ -370,7 +370,7 @@ def client_with_plan_data(
@pytest.fixture(scope="function")
def client_with_plan_data_in_proposal_phase(
session: Session,
connection_string: str,
rw_connection_string: str,
complete_test_plan: models.Plan,
plan_proposal_status_instance: codes.LifeCycleStatus,
) -> RyhtiClient:
Expand All @@ -391,7 +391,7 @@ def client_with_plan_data_in_proposal_phase(

# Let's mock production x-road with gispo organization client here.
client = RyhtiClient(
connection_string,
rw_connection_string,
public_api_url="http://mock.url",
xroad_server_address="http://mock2.url",
xroad_instance="FI",
Expand Down
2 changes: 1 addition & 1 deletion infra/lambda.tf
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ resource "aws_lambda_function" "mml_loader" {
DB_MAIN_NAME = var.hame_db_name
DB_MAINTENANCE_NAME = "postgres"
READ_FROM_AWS = 1
DB_SECRET_RW_ARN = aws_secretsmanager_secret.hame-db-rw.arn
DB_SECRET_ADMIN_ARN = aws_secretsmanager_secret.hame-db-admin.arn
MML_APIKEY = var.mml_apikey
}
}
Expand Down

0 comments on commit a08e494

Please sign in to comment.