-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathpredict.php
51 lines (33 loc) · 1.42 KB
/
predict.php
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
<?php
include __DIR__ . '/vendor/autoload.php';
use Rubix\ML\Loggers\Screen;
use Rubix\ML\Extractors\SQLTable;
use Rubix\ML\Extractors\ColumnPicker;
use Rubix\ML\Datasets\Unlabeled;
use Rubix\ML\PersistentModel;
use Rubix\ML\Persisters\Filesystem;
ini_set('memory_limit', '-1');
$logger = new Screen();
$logger->info('Loading data into memory');
$connection = new PDO('sqlite:database.sqlite');
$extractor = new SqlTable($connection, 'customers');
$extractor = new ColumnPicker($extractor, [
'Id', 'Gender', 'SeniorCitizen', 'Partner', 'Dependents', 'MonthsInService', 'Phone',
'MultipleLines', 'InternetService', 'OnlineSecurity', 'OnlineBackup', 'DeviceProtection',
'TechSupport', 'TV', 'Movies', 'Contract', 'PaperlessBilling', 'PaymentMethod',
'MonthlyCharges', 'TotalCharges', 'Region',
]);
$dataset = Unlabeled::fromIterator($extractor);
$logger->info('Loading model into memory');
$ids = $dataset->feature(0);
$dataset->dropFeature(0);
$estimator = PersistentModel::load(new Filesystem('model.rbx'));
$logger->info('Making predictions');
$predictions = $estimator->predict($dataset);
if (strtolower(readline('Save predictions to database? (y|[n]): ')) === 'y') {
$statement = $connection->prepare("UPDATE customers SET churn=? WHERE id=?");
foreach ($predictions as $i => $prediction) {
$statement->execute([$prediction, $ids[$i]]);
}
$logger->info('Predictions saved to database');
}