-
Notifications
You must be signed in to change notification settings - Fork 0
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Update API to deploy separate endpoints by run ID #2
Changes from 12 commits
eb1f53e
472c528
7518222
d5f943a
abe5403
d8bf46f
67e36ab
f23e58b
bc8c7f5
2a770c0
ec7c692
723e9ef
e6a69a9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1 @@ | ||
source("renv/activate.R") | ||
source("renv/activate.R") |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,101 +1,232 @@ | ||
# Setup ------------------------------------------------------------------------ | ||
library(arrow) | ||
library(assertthat) | ||
library(aws.s3) | ||
library(ccao) | ||
library(dplyr) | ||
library(lightsnip) | ||
library(tibble) | ||
library(plumber) | ||
library(purrr) | ||
library(rapidoc) | ||
library(vetiver) | ||
source("generics.R") | ||
|
||
# Define constants | ||
dvc_bucket_pre_2024 <- "s3://ccao-data-dvc-us-east-1" | ||
dvc_bucket_post_2024 <- "s3://ccao-data-dvc-us-east-1/files/md5" | ||
base_url_prefix <- "/predict" | ||
|
||
# Read AWS creds from Docker secrets | ||
if (file.exists("/run/secrets/ENV_FILE")) { | ||
readRenviron("/run/secrets/ENV_FILE") | ||
} else { | ||
} else if (file.exists("secrets/ENV_FILE")) { | ||
readRenviron("secrets/ENV_FILE") | ||
} | ||
readRenviron(".env") | ||
|
||
# Get the model run attributes at runtime from env vars | ||
dvc_bucket <- Sys.getenv("AWS_S3_DVC_BUCKET") | ||
run_bucket <- Sys.getenv("AWS_S3_MODEL_BUCKET") | ||
run_id <- Sys.getenv("AWS_S3_MODEL_RUN_ID") | ||
run_year <- Sys.getenv("AWS_S3_MODEL_YEAR") | ||
dfsnow marked this conversation as resolved.
Show resolved
Hide resolved
|
||
api_port <- as.numeric(Sys.getenv("API_PORT", unset = "3636")) | ||
|
||
|
||
# Download Files --------------------------------------------------------------- | ||
|
||
# Grab model fit and recipe objects | ||
temp_file_fit <- tempfile(fileext = ".zip") | ||
aws.s3::save_object( | ||
object = file.path( | ||
run_bucket, "workflow/fit", | ||
paste0("year=", run_year), | ||
paste0(run_id, ".zip") | ||
default_run_id_var_name <- "AWS_S3_DEFAULT_MODEL_RUN_ID" | ||
default_run_id <- Sys.getenv(default_run_id_var_name) | ||
|
||
# The list of runs that will be deployed as possible model endpoints | ||
valid_runs <- rbind( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. suggestion (non-blocking): It might be worth just updating the If this sounds like a good idea, let's push it to another separate PR. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good call, issue opened here: ccao-data/data-architecture#428 |
||
c( | ||
run_id = "2022-04-27-keen-gabe", | ||
year = "2022", | ||
dvc_bucket = dvc_bucket_pre_2024, | ||
predictors_only = FALSE | ||
), | ||
file = temp_file_fit | ||
) | ||
|
||
temp_file_recipe <- tempfile(fileext = ".rds") | ||
aws.s3::save_object( | ||
object = file.path( | ||
run_bucket, "workflow/recipe", | ||
paste0("year=", run_year), | ||
paste0(run_id, ".rds") | ||
c( | ||
run_id = "2023-03-14-clever-damani", | ||
year = "2023", | ||
dvc_bucket = dvc_bucket_pre_2024, | ||
predictors_only = FALSE | ||
), | ||
c( | ||
run_id = "2024-02-06-relaxed-tristan", | ||
year = "2024", | ||
dvc_bucket = dvc_bucket_post_2024, | ||
predictors_only = TRUE | ||
), | ||
file = temp_file_recipe | ||
c( | ||
run_id = "2024-03-17-stupefied-maya", | ||
year = "2024", | ||
dvc_bucket = dvc_bucket_post_2024, | ||
predictors_only = TRUE | ||
) | ||
) %>% | ||
as_tibble() | ||
|
||
assert_that( | ||
default_run_id %in% valid_runs$run_id, | ||
msg = sprintf( | ||
"%s must be a valid run_id - got '%s', expected one of: %s", | ||
default_run_id_var_name, | ||
default_run_id, | ||
paste(valid_runs$run_id, collapse = ", ") | ||
) | ||
) | ||
|
||
# Grab metadata file for the specified run | ||
metadata <- read_parquet( | ||
file.path( | ||
run_bucket, "metadata", | ||
paste0("year=", run_year), | ||
paste0(run_id, ".parquet") | ||
# Given a run ID and year, return a model object that can be used to power a | ||
# Plumber/vetiver API endpoint | ||
get_model_from_run <- function(run_id, year, dvc_bucket, predictors_only) { | ||
# Download Files ------------------------------------------------------------- | ||
|
||
# Grab model fit and recipe objects | ||
temp_file_fit <- tempfile(fileext = ".zip") | ||
aws.s3::save_object( | ||
object = file.path( | ||
run_bucket, "workflow/fit", | ||
paste0("year=", year), | ||
paste0(run_id, ".zip") | ||
), | ||
file = temp_file_fit | ||
) | ||
) | ||
|
||
# Load the training data used for this model | ||
training_data_md5 <- metadata$dvc_md5_training_data | ||
training_data <- read_parquet( | ||
file.path( | ||
dvc_bucket, | ||
substr(training_data_md5, 1, 2), | ||
substr(training_data_md5, 3, nchar(training_data_md5)) | ||
temp_file_recipe <- tempfile(fileext = ".rds") | ||
aws.s3::save_object( | ||
object = file.path( | ||
run_bucket, "workflow/recipe", | ||
paste0("year=", year), | ||
paste0(run_id, ".rds") | ||
), | ||
file = temp_file_recipe | ||
) | ||
|
||
# Grab metadata file for the specified run | ||
metadata <- read_parquet( | ||
file.path( | ||
run_bucket, "metadata", | ||
paste0("year=", year), | ||
paste0(run_id, ".parquet") | ||
) | ||
) | ||
) | ||
|
||
# Load the training data used for this model | ||
training_data_md5 <- metadata$dvc_md5_training_data | ||
training_data <- read_parquet( | ||
file.path( | ||
dvc_bucket, | ||
substr(training_data_md5, 1, 2), | ||
substr(training_data_md5, 3, nchar(training_data_md5)) | ||
) | ||
) | ||
|
||
# Load Model ------------------------------------------------------------------- | ||
|
||
# Load fit and recipe from file | ||
fit <- lightsnip::lgbm_load(temp_file_fit) | ||
recipe <- readRDS(temp_file_recipe) | ||
# Load Model ----------------------------------------------------------------- | ||
|
||
# Extract a sample row of predictors to use for the API docs | ||
predictors <- recipe$var_info %>% | ||
filter(role == "predictor") %>% | ||
pull(variable) | ||
ptype_tbl <- training_data %>% | ||
filter(meta_pin == "15251030220000") %>% | ||
select(all_of(predictors)) | ||
ptype <- vetiver_create_ptype(model = fit, save_prototype = ptype_tbl) | ||
# Load fit and recipe from file | ||
fit <- lightsnip::lgbm_load(temp_file_fit) | ||
recipe <- readRDS(temp_file_recipe) | ||
|
||
# Extract a sample row of data to use for the API docs | ||
ptype_tbl <- training_data %>% | ||
filter(meta_pin == "15251030220000") | ||
|
||
# Create API ------------------------------------------------------------------- | ||
# If the model recipe is configured to allow it, strip all chars except | ||
# for the predictors from the example row | ||
if (predictors_only) { | ||
predictors <- recipe$var_info %>% | ||
filter(role == "predictor") %>% | ||
pull(variable) | ||
ptype_tbl <- ptype_tbl %>% | ||
filter(meta_pin == "15251030220000") %>% | ||
select(all_of(predictors)) | ||
} | ||
|
||
# Create model object and populate metadata | ||
model <- vetiver_model(fit, "LightGBM", save_prototype = ptype) | ||
model$recipe <- recipe | ||
model$pv$round_type <- metadata$pv_round_type | ||
model$pv$round_break <- metadata$pv_round_break[[1]] | ||
model$pv$round_to_nearest <- metadata$pv_round_to_nearest[[1]] | ||
ptype <- vetiver_create_ptype(model = fit, save_prototype = ptype_tbl) | ||
|
||
# Start API | ||
pr() %>% | ||
vetiver_api(model) %>% | ||
pr_run( | ||
host = "0.0.0.0", | ||
port = api_port | ||
|
||
# Create API ----------------------------------------------------------------- | ||
|
||
# Create model object and populate metadata | ||
model <- vetiver_model(fit, "LightGBM", save_prototype = ptype) | ||
model$recipe <- recipe | ||
model$pv$round_type <- metadata$pv_round_type | ||
model$pv$round_break <- metadata$pv_round_break[[1]] | ||
model$pv$round_to_nearest <- metadata$pv_round_to_nearest[[1]] | ||
|
||
return(model) | ||
} | ||
|
||
# Filter the valid runs for the run marked as default | ||
default_run <- valid_runs %>% | ||
dplyr::filter(run_id == default_run_id) %>% | ||
dplyr::slice_head() | ||
|
||
# Retrieve paths and model objects for all endpoints to be deployed | ||
all_endpoints <- list() | ||
for (i in seq_len(nrow(valid_runs))) { | ||
run <- valid_runs[i, ] | ||
model <- get_model_from_run( | ||
run$run_id, run$year, run$dvc_bucket, run$predictors_only | ||
) | ||
all_endpoints <- append(all_endpoints, list(list( | ||
path = glue::glue("{base_url_prefix}/{run$run_id}"), | ||
model = model | ||
))) | ||
# If this is the default endpoint, add an extra entry for it | ||
if (run$run_id == default_run$run_id) { | ||
all_endpoints <- append(all_endpoints, list(list( | ||
path = glue::glue("{base_url_prefix}"), | ||
model = model | ||
))) | ||
} | ||
} | ||
|
||
# Instantiate a Plumber router for the API. Note that we have to use direct | ||
# Plumber calls instead of using vetiver to define the API since vetiver | ||
# currently has bad support for deploying multiple models on the same API | ||
router <- pr() %>% | ||
plumber::pr_set_debug(rlang::is_interactive()) %>% | ||
plumber::pr_set_serializer(plumber::serializer_unboxed_json(null = "null")) | ||
jeancochrane marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
# Add Plumber POST enpdoints for each model | ||
for (i in seq_along(all_endpoints)) { | ||
endpoint <- all_endpoints[[i]] | ||
router <- plumber::pr_post( | ||
router, endpoint$path, handler_predict(endpoint$model) | ||
) | ||
} | ||
|
||
# Define a function to override the openapi spec for the API, using | ||
# each model's prototype for docs and examples | ||
modify_spec <- function(spec) { | ||
spec$info$title <- "CCAO Residential AVM API" | ||
spec$info$description <- ( | ||
"API for returning predicted values using CCAO residential AVMs" | ||
) | ||
|
||
for (i in seq_along(all_endpoints)) { | ||
endpoint <- all_endpoints[[i]] | ||
ptype <- endpoint$model$prototype | ||
path <- endpoint$path | ||
orig_post <- pluck(spec, "paths", path, "post") | ||
spec$paths[[path]]$post <- list( | ||
summary = glue_spec_summary(ptype), | ||
requestBody = map_request_body(ptype), | ||
responses = orig_post$responses | ||
) | ||
} | ||
|
||
return(spec) | ||
} | ||
|
||
router <- plumber::pr_set_api_spec(router, api = modify_spec) %>% | ||
plumber::pr_set_docs( | ||
"rapidoc", | ||
header_color = "#F2C6AC", | ||
primary_color = "#8C2D2D", | ||
use_path_in_nav_bar = TRUE, | ||
show_method_in_nav_bar = "as-plain-text" | ||
Comment on lines
+223
to
+224
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It seems like the rapidoc UI layout changed with the version update, so we need these two params in order to properly display all of the available endpoints (docs). |
||
) | ||
|
||
# Start API | ||
pr_run( | ||
router, | ||
host = "0.0.0.0", | ||
port = api_port | ||
) |
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't remember why we're running There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done in d8bf46f, let's see what happens! |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -10,11 +10,6 @@ vetiver_create_meta._lgb.Booster <- function(model, metadata) { | |
} | ||
|
||
|
||
handler_startup._lgb.Booster <- function(vetiver_model) { | ||
attach_pkgs(vetiver_model$metadata$required_pkgs) | ||
} | ||
Comment on lines
-13
to
-15
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Turns out |
||
|
||
|
||
handler_predict._lgb.Booster <- function(vetiver_model, ...) { | ||
ptype <- vetiver_model$ptype | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Whoa, I didn't realize that DVC had changed the file locations :/