-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathexperiment.py
41 lines (30 loc) · 854 Bytes
/
experiment.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
import os
from sacred import Experiment
from sacred.observers import MongoObserver
from sklearn import svm, datasets, model_selection
from dotenv import load_dotenv
db_name = "pydata_berlin"
mongo_uri = None
ex = Experiment('svm')
ex.observers.append(
MongoObserver.create(
url=mongo_uri,
db_name=db_name,
)
)
@ex.config
def cfg():
C = 1.0
gamma = 0.7
kernel = "rbf"
seed = 42
@ex.capture
def get_model(C, gamma, kernel):
return svm.SVC(C=C, kernel=kernel, gamma=gamma)
@ex.automain # Use automain to enable command line integration.
def run():
X, y = datasets.load_breast_cancer(return_X_y=True)
X_train, X_test, y_train, y_test = model_selection.train_test_split(X, y, test_size=0.2)
clf = get_model()
clf.fit(X_train, y_train)
return clf.score(X_test, y_test)