-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathProphet.py
64 lines (51 loc) · 1.98 KB
/
Prophet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import pandas as pd
from prophet import Prophet
import pickle
from sklearn.metrics import mean_squared_error, mean_absolute_error
from GetData import get_data
def train_prophet_model(data):
# Create a new Prophet model
model = Prophet()
# Fit the model to the data
model.fit(data)
return model
def evaluate_prophet_model(model, data, target):
# Make predictions using the Prophet model
forecast = model.predict(data)
y_true = data['y'].values
y_pred = forecast['yhat'].values
# Compute mean squared error (MSE) and mean absolute error (MAE)
mse = mean_squared_error(y_true, y_pred)
mae = mean_absolute_error(y_true, y_pred)
# Print the evaluation metrics
print(f"Target variable: {target}")
print(f"MSE: {mse:.4f}")
print(f"MAE: {mae:.4f}")
return mse, mae
def main():
# Get the stock data and preprocess it
df = get_data()
processed_df = df[['Date', 'High', 'Low', 'Close']]
# Prepare the data for Prophet
data = pd.DataFrame({
'ds': processed_df['Date'],
'y_high': processed_df['High'],
'y_low': processed_df['Low'],
'y_close': processed_df['Close']
})
# Train the Prophet model for each target variable
prophet_models = {}
for target in ['high', 'low', 'close']:
target_data = data[['ds', f'y_{target}']]
target_data = target_data.rename(columns={f'y_{target}': 'y'})
prophet_models[target] = train_prophet_model(target_data)
# Evaluate the Prophet model using MSE and MAE metrics
mse, mae = evaluate_prophet_model(prophet_models[target], target_data, target)
# Save the trained Prophet models to pickle files
for target, model in prophet_models.items():
filename = f'prophet_model_{target}.pkl'
with open(filename, 'wb') as file:
pickle.dump(model, file)
print(f"Saved Prophet model for {target} as {filename}")
if __name__ == '__main__':
main()