-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathaws_mnist.py
128 lines (105 loc) · 3.81 KB
/
aws_mnist.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
from jinja2 import Environment, FileSystemLoader
from kfp import compiler
from kfp import dsl
from kfp.components import load_component_from_text
from kfp.aws import use_aws_secret
from helpers.image import component_image_name
from helpers.tmp import get_tmp_dir
env = Environment( # nosec
# https://jinja.palletsprojects.com/en/2.11.x/api/#jinja2.DictLoader
loader=FileSystemLoader("./components-templates"),
)
def downloadOp(datasets_dir):
template = env.get_template("download-datasets-s3.yaml.j2")
component = template.render(
image=component_image_name("download-datasets-s3"),
output_dir=datasets_dir,
)
return load_component_from_text(component)
def trainOp(model_dir):
template = env.get_template("mnist-train.yaml.j2")
component = template.render(
image=component_image_name("mnist-train"),
model_dir=model_dir,
log_dir=get_tmp_dir("log"),
)
return load_component_from_text(component)
def evaluateOp():
template = env.get_template("mnist-evaluate.yaml.j2")
component = template.render(
image=component_image_name("mnist-evaluate"),
)
return load_component_from_text(component)
def exportOp():
template = env.get_template("upload-s3.yaml.j2")
component = template.render(
image=component_image_name("upload"),
)
return load_component_from_text(component)
@dsl.pipeline(
name="mnist_pipeline",
description="Train an mnist fashion classification model and export to AWS S3",
)
def pipeline( # nosec
aws_secret_name: str = "aws-s3-data-secret-kfaas-demo",
model_name: str = "mnist-fashion",
model_version: str = "1",
epochs: int = 10,
bucket: str = "kfaas-demo-data-sandbox",
bucket_dir_model: str = "demo/models",
bucket_dir_tensorboard: str = "demo/tensorboard",
):
mnt_path = "/mnt"
datasets_dir = "/mnt/datasets"
model_dir = "/mnt/model"
mnist_data_s3_urls = [
"https://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz",
"https://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz",
"https://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz",
"https://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz",
]
aws_secret = use_aws_secret(
aws_secret_name, "AWS_ACCESS_KEY_ID", "AWS_SECRET_ACCESS_KEY"
)
vop = dsl.VolumeOp(
name="create_volume",
resource_name="data-volume",
size="1Gi",
modes=dsl.VOLUME_MODE_RWO,
)
download_op = downloadOp(datasets_dir)
download_task = download_op(urls=(",").join(mnist_data_s3_urls)).add_pvolumes(
{mnt_path: vop.volume}
)
train_op = trainOp(model_dir)
train_task = (
train_op(
datadir=datasets_dir,
epochs=epochs,
bucket=bucket,
bucketdir=f"{bucket_dir_tensorboard}/{model_name}/{model_version}/{dsl.RUN_ID_PLACEHOLDER}",
)
.add_pvolumes({mnt_path: download_task.pvolume})
.apply(aws_secret)
)
train_task.after(download_task)
evaluate_op = evaluateOp()
evaluate_task = evaluate_op(datadir=datasets_dir, modeldir=model_dir).add_pvolumes(
{mnt_path: train_task.pvolume}
)
evaluate_task.after(train_task)
vop.delete().after(evaluate_task)
export_op = exportOp()
export_task = (
export_op(
srcdir=model_dir,
bucket=bucket,
bucketdir=f"{bucket_dir_model}/{model_name}/{model_version}",
)
.add_pvolumes({mnt_path: evaluate_task.pvolume})
.apply(aws_secret)
)
export_task.after(evaluate_task)
if __name__ == "__main__":
package_name = __file__.replace(".py", ".zip")
compiler.Compiler().compile(pipeline, package_name)