Skip to content

Commit

Permalink
[poc] Import Connexion source in pcluster
Browse files Browse the repository at this point in the history
  • Loading branch information
demartinofra committed Dec 2, 2024
1 parent bfe0773 commit b5f7d65
Show file tree
Hide file tree
Showing 44 changed files with 9,272 additions and 47 deletions.
8 changes: 5 additions & 3 deletions cli/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,13 @@ def readme():
"aws-cdk.aws-ssm~=" + CDK_VERSION,
"aws-cdk.aws-sqs~=" + CDK_VERSION,
"aws-cdk.aws-cloudformation~=" + CDK_VERSION,
"werkzeug~=2.0",
"connexion~=2.13.0",
"flask>=2.2.5,<2.3",
# "werkzeug~=2.0",
# "connexion~=2.13.0",
# "flask>=2.2.5,<2.3",
"jmespath~=0.10",
"jsii==1.85.0",
"werkzeug~=3.0",
"flask~=3.0",
]

LAMBDA_REQUIRES = [
Expand Down
110 changes: 76 additions & 34 deletions cli/src/pcluster/api/awslambda/serverless_wsgi.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,22 @@
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
"""
This module converts an AWS API Gateway proxied request to a WSGI request.
Inspired by: https://github.com/miserlou/zappa
Author: Logan Raarup <[email protected]>
"""
import base64
import io
import json
import os
import sys
from urllib.parse import urlencode, unquote, unquote_plus

from werkzeug.datastructures import Headers, MultiDict, iter_multi_items
from werkzeug.datastructures import Headers, iter_multi_items
from werkzeug.http import HTTP_STATUS_CODES
from werkzeug.urls import url_encode, url_unquote, url_unquote_plus
from werkzeug.wrappers import Response

# List of MIME types that should not be base64 encoded. MIME types within `text/*`
Expand Down Expand Up @@ -95,8 +102,11 @@ def encode_query_string(event):
if not params:
params = ""
if is_alb_event(event):
params = MultiDict((url_unquote_plus(k), url_unquote_plus(v)) for k, v in iter_multi_items(params))
return url_encode(params)
params = [
(unquote_plus(k), unquote_plus(v))
for k, v in iter_multi_items(params)
]
return urlencode(params, doseq=True)


def get_script_name(headers, request_context):
Expand All @@ -108,7 +118,7 @@ def get_script_name(headers, request_context):
"1",
]

if headers.get("Host", "").endswith(".amazonaws.com") and not strip_stage_path:
if "amazonaws.com" in headers.get("Host", "") and not strip_stage_path:
script_name = "/{}".format(request_context.get("stage", ""))
else:
script_name = ""
Expand Down Expand Up @@ -138,7 +148,7 @@ def setup_environ_items(environ, headers):
def generate_response(response, event):
returndict = {"statusCode": response.status_code}

if "multiValueHeaders" in event:
if "multiValueHeaders" in event and event["multiValueHeaders"]:
returndict["multiValueHeaders"] = group_headers(response.headers)
else:
returndict["headers"] = split_headers(response.headers)
Expand All @@ -152,24 +162,40 @@ def generate_response(response, event):

if response.data:
mimetype = response.mimetype or "text/plain"
if (mimetype.startswith("text/") or mimetype in TEXT_MIME_TYPES) and not response.headers.get(
"Content-Encoding", ""
):
if (
mimetype.startswith("text/") or mimetype in TEXT_MIME_TYPES
) and not response.headers.get("Content-Encoding", ""):
returndict["body"] = response.get_data(as_text=True)
returndict["isBase64Encoded"] = False
else:
returndict["body"] = base64.b64encode(response.data).decode("utf-8")
returndict["body"] = base64.b64encode(
response.data).decode("utf-8")
returndict["isBase64Encoded"] = True

return returndict


def strip_express_gateway_query_params(path):
"""Contrary to regular AWS lambda HTTP events, Express Gateway
(https://github.com/ExpressGateway/express-gateway-plugin-lambda)
adds query parameters to the path, which we need to strip.
"""
if "?" in path:
path = path.split("?")[0]
return path


def handle_request(app, event, context):
if event.get("source") in ["aws.events", "serverless-plugin-warmup"]:
print("Lambda warming event received, skipping handler")
return {}

if event.get("version") is None and event.get("isBase64Encoded") is None and not is_alb_event(event):
if (
event.get("version") is None
and event.get("isBase64Encoded") is None
and event.get("requestPath") is not None
and not is_alb_event(event)
):
return handle_lambda_integration(app, event, context)

if event.get("version") == "2.0":
Expand All @@ -179,7 +205,7 @@ def handle_request(app, event, context):


def handle_payload_v1(app, event, context):
if "multiValueHeaders" in event:
if "multiValueHeaders" in event and event["multiValueHeaders"]:
headers = Headers(event["multiValueHeaders"])
else:
headers = Headers(event["headers"])
Expand All @@ -189,35 +215,39 @@ def handle_payload_v1(app, event, context):
# If a user is using a custom domain on API Gateway, they may have a base
# path in their URL. This allows us to strip it out via an optional
# environment variable.
path_info = event["path"]
path_info = strip_express_gateway_query_params(event["path"])
base_path = os.environ.get("API_GATEWAY_BASE_PATH")
if base_path:
script_name = "/" + base_path

if path_info.startswith(script_name):
path_info = path_info[len(script_name) :] # noqa: E203
path_info = path_info[len(script_name):]

body = event["body"] or ""
body = event.get("body") or ""
body = get_body_bytes(event, body)

environ = {
"CONTENT_LENGTH": str(len(body)),
"CONTENT_TYPE": headers.get("Content-Type", ""),
"PATH_INFO": url_unquote(path_info),
"PATH_INFO": unquote(path_info),
"QUERY_STRING": encode_query_string(event),
"REMOTE_ADDR": event.get("requestContext", {}).get("identity", {}).get("sourceIp", ""),
"REMOTE_USER": event.get("requestContext", {}).get("authorizer", {}).get("principalId", ""),
"REMOTE_ADDR": event.get("requestContext", {})
.get("identity", {})
.get("sourceIp", ""),
"REMOTE_USER": (event.get("requestContext", {})
.get("authorizer") or {})
.get("principalId", ""),
"REQUEST_METHOD": event.get("httpMethod", {}),
"SCRIPT_NAME": script_name,
"SERVER_NAME": headers.get("Host", "lambda"),
"SERVER_PORT": headers.get("X-Forwarded-Port", "80"),
"SERVER_PORT": headers.get("X-Forwarded-Port", "443"),
"SERVER_PROTOCOL": "HTTP/1.1",
"wsgi.errors": sys.stderr,
"wsgi.input": io.BytesIO(body),
"wsgi.multiprocess": False,
"wsgi.multithread": False,
"wsgi.run_once": False,
"wsgi.url_scheme": headers.get("X-Forwarded-Proto", "http"),
"wsgi.url_scheme": headers.get("X-Forwarded-Proto", "https"),
"wsgi.version": (1, 0),
"serverless.authorizer": event.get("requestContext", {}).get("authorizer"),
"serverless.event": event,
Expand All @@ -237,31 +267,43 @@ def handle_payload_v2(app, event, context):

script_name = get_script_name(headers, event.get("requestContext", {}))

path_info = event["rawPath"]
path_info = strip_express_gateway_query_params(event["rawPath"])
base_path = os.environ.get("API_GATEWAY_BASE_PATH")
if base_path:
script_name = "/" + base_path

if path_info.startswith(script_name):
path_info = path_info[len(script_name):]

body = event.get("body", "")
body = get_body_bytes(event, body)

headers["Cookie"] = "; ".join(event.get("cookies", []))

environ = {
"CONTENT_LENGTH": str(len(body)),
"CONTENT_LENGTH": str(len(body or "")),
"CONTENT_TYPE": headers.get("Content-Type", ""),
"PATH_INFO": url_unquote(path_info),
"PATH_INFO": unquote(path_info),
"QUERY_STRING": event.get("rawQueryString", ""),
"REMOTE_ADDR": event.get("requestContext", {}).get("http", {}).get("sourceIp", ""),
"REMOTE_USER": event.get("requestContext", {}).get("authorizer", {}).get("principalId", ""),
"REQUEST_METHOD": event.get("requestContext", {}).get("http", {}).get("method", ""),
"REMOTE_ADDR": event.get("requestContext", {})
.get("http", {})
.get("sourceIp", ""),
"REMOTE_USER": event.get("requestContext", {})
.get("authorizer", {})
.get("principalId", ""),
"REQUEST_METHOD": event.get("requestContext", {})
.get("http", {})
.get("method", ""),
"SCRIPT_NAME": script_name,
"SERVER_NAME": headers.get("Host", "lambda"),
"SERVER_PORT": headers.get("X-Forwarded-Port", "80"),
"SERVER_PORT": headers.get("X-Forwarded-Port", "443"),
"SERVER_PROTOCOL": "HTTP/1.1",
"wsgi.errors": sys.stderr,
"wsgi.input": io.BytesIO(body),
"wsgi.multiprocess": False,
"wsgi.multithread": False,
"wsgi.run_once": False,
"wsgi.url_scheme": headers.get("X-Forwarded-Proto", "http"),
"wsgi.url_scheme": headers.get("X-Forwarded-Proto", "https"),
"wsgi.version": (1, 0),
"serverless.authorizer": event.get("requestContext", {}).get("authorizer"),
"serverless.event": event,
Expand All @@ -282,7 +324,7 @@ def handle_lambda_integration(app, event, context):

script_name = get_script_name(headers, event)

path_info = event["requestPath"]
path_info = strip_express_gateway_query_params(event["requestPath"])

for key, value in event.get("path", {}).items():
path_info = path_info.replace("{%s}" % key, value)
Expand All @@ -293,23 +335,23 @@ def handle_lambda_integration(app, event, context):
body = get_body_bytes(event, body)

environ = {
"CONTENT_LENGTH": str(len(body)),
"CONTENT_LENGTH": str(len(body or "")),
"CONTENT_TYPE": headers.get("Content-Type", ""),
"PATH_INFO": url_unquote(path_info),
"QUERY_STRING": url_encode(event.get("query", {})),
"PATH_INFO": unquote(path_info),
"QUERY_STRING": urlencode(event.get("query", {}), doseq=True),
"REMOTE_ADDR": event.get("identity", {}).get("sourceIp", ""),
"REMOTE_USER": event.get("principalId", ""),
"REQUEST_METHOD": event.get("method", ""),
"SCRIPT_NAME": script_name,
"SERVER_NAME": headers.get("Host", "lambda"),
"SERVER_PORT": headers.get("X-Forwarded-Port", "80"),
"SERVER_PORT": headers.get("X-Forwarded-Port", "443"),
"SERVER_PROTOCOL": "HTTP/1.1",
"wsgi.errors": sys.stderr,
"wsgi.input": io.BytesIO(body),
"wsgi.multiprocess": False,
"wsgi.multithread": False,
"wsgi.run_once": False,
"wsgi.url_scheme": headers.get("X-Forwarded-Proto", "http"),
"wsgi.url_scheme": headers.get("X-Forwarded-Proto", "https"),
"wsgi.version": (1, 0),
"serverless.authorizer": event.get("enhancedAuthContext"),
"serverless.event": event,
Expand Down
Empty file.
16 changes: 16 additions & 0 deletions cli/src/pcluster/api/connexion/apis/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
"""
This module defines Connexion APIs. A connexion API takes in an OpenAPI specification and
translates the operations defined in it to a set of Connexion Operations. This set of operations
is implemented as a framework blueprint (A Flask blueprint or framework-specific equivalent),
which can be registered on the framework application.
For each operation, the API resolves the user view function to link to the operation, wraps it
with a Connexion Operation which it configures based on the OpenAPI spec, and finally adds it as
a route on the framework blueprint.
When the API is registered on the Connexion APP, the underlying framework blueprint is registered
on the framework app.
"""


from .abstract import AbstractAPI # NOQA
Loading

0 comments on commit b5f7d65

Please sign in to comment.