Skip to content

Commit

Permalink
Supporting estimation of "ts"
Browse files Browse the repository at this point in the history
  • Loading branch information
oualib committed Nov 5, 2023
1 parent 8347934 commit 61eb541
Showing 1 changed file with 31 additions and 5 deletions.
36 changes: 31 additions & 5 deletions verticapy/machine_learning/vertica/tsa.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,7 @@ def predict(
npredictions: int = 10,
output_standard_errors: bool = False,
output_index: bool = False,
output_estimated_ts: bool = False,
) -> vDataFrame:
"""
Predicts using the input relation.
Expand Down Expand Up @@ -394,7 +395,11 @@ def predict(
Boolean, whether to return estimates of the standard
error of each prediction.
output_index: bool, optional
Boolean, whether to return the index of each position.
Boolean, whether to return the index of each prediction.
output_estimated_ts: bool, optional
Boolean, whether to return the estimated abscissa of
each prediction. The real one is hard to obtain due to
interval computations.
Returns
-------
Expand All @@ -418,24 +423,31 @@ def predict(
y=y,
start=start,
npredictions=npredictions,
output_standard_errors=output_standard_errors,
output_standard_errors=(
output_standard_errors or output_index or output_estimated_ts
),
output_index=output_index,
)
no_relation = True
if not (isinstance(vdf, NoneType)):
sql += f" FROM {vdf}"
no_relation = False
if output_index:
if output_index or output_estimated_ts:
j = self.n_
if no_relation:
if not (isinstance(start, NoneType)):
j = j + start
elif not (isinstance(start, NoneType)):
j = start
if output_standard_errors and not (ar_ma):
if (output_standard_errors or output_estimated_ts) and not (ar_ma):
if not (output_standard_errors):
stde_out = ""
else:
stde_out = ", std_err"
output_standard_errors = ", std_err"
else:
output_standard_errors = ""
stde_out = ""
if ar_ma:
order_by = ""
else:
Expand All @@ -445,7 +457,21 @@ def predict(
ROW_NUMBER() OVER ({order_by}) + {j} - 1 AS idx,
prediction{output_standard_errors}
FROM ({sql}) VERTICAPY_SUBTABLE"""
return vDataFrame(sql)
if output_estimated_ts:
min_value = f"(SELECT MIN({self.ts}) FROM {self.input_relation})"
delta = f"""
(SELECT
AVG(delta)
FROM (SELECT
{self.ts} - LAG({self.ts}) OVER (ORDER BY {self.ts}) AS delta
FROM {self.input_relation}) VERTICAPY_SUBTABLE)"""
sql = f"""
SELECT
{delta} + idx * {min_value} AS {self.ts},
prediction{stde_out}
FROM ({sql}) VERTICAPY_SUBTABLE
"""
return vDataFrame(clean_query(sql))

# Model Evaluation Methods.

Expand Down

0 comments on commit 61eb541

Please sign in to comment.