Skip to content

Commit

Permalink
New trendline ppl command (SMA only) (#833)
Browse files Browse the repository at this point in the history
* WIP trendline command

Signed-off-by: Kacper Trochimiak <[email protected]>

* wip

Signed-off-by: Kacper Trochimiak <[email protected]>

* trendline supports sorting

Signed-off-by: Kacper Trochimiak <[email protected]>

* run scalafmtAll

Signed-off-by: Kacper Trochimiak <[email protected]>

* return null when there are too few data points

Signed-off-by: Kacper Trochimiak <[email protected]>

* sbt scalafmtAll

Signed-off-by: Kacper Trochimiak <[email protected]>

* Remove WMA references

Signed-off-by: Hendrik Saly <[email protected]>

* trendline - sortByField as Optional<Field>

Signed-off-by: Kacper Trochimiak <[email protected]>

* introduce TrendlineStrategy

Signed-off-by: Kacper Trochimiak <[email protected]>

* keywordsCanBeId -> replace SMA with trendlineType

Signed-off-by: Kacper Trochimiak <[email protected]>

* handle trendline alias as qualifiedName instead of fieldExpression

Signed-off-by: Kacper Trochimiak <[email protected]>

* Add docs

Signed-off-by: Hendrik Saly <[email protected]>

* Make alias optional

Signed-off-by: Hendrik Saly <[email protected]>

* Adapt tests for optional alias

Signed-off-by: Hendrik Saly <[email protected]>

* Adden logical plan unittests

Signed-off-by: Hendrik Saly <[email protected]>

* Add missing license headers

Signed-off-by: Hendrik Saly <[email protected]>

* Fix docs

Signed-off-by: Hendrik Saly <[email protected]>

* numberOfDataPoints must be 1 or greater

Signed-off-by: Hendrik Saly <[email protected]>

* Rename TrendlineStrategy to  TrendlineCatalystUtils

Signed-off-by: Hendrik Saly <[email protected]>

* Validate TrendlineType early and pass around enum type

Signed-off-by: Hendrik Saly <[email protected]>

* Add trendline chaining test

Signed-off-by: Hendrik Saly <[email protected]>

* Fix compile errors

Signed-off-by: Hendrik Saly <[email protected]>

* Fix imports

Signed-off-by: Hendrik Saly <[email protected]>

* Fix imports

Signed-off-by: Hendrik Saly <[email protected]>

---------

Signed-off-by: Kacper Trochimiak <[email protected]>
Signed-off-by: Hendrik Saly <[email protected]>
Co-authored-by: Kacper Trochimiak <[email protected]>
  • Loading branch information
salyh and kt-eliatra authored Nov 1, 2024
1 parent faae818 commit bdb4848
Show file tree
Hide file tree
Showing 16 changed files with 698 additions and 44 deletions.
1 change: 1 addition & 0 deletions docs/ppl-lang/PPL-Example-Commands.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ _- **Limitation: new field added by eval command with a function cannot be dropp
- `source = table | where b not between '2024-09-10' and '2025-09-10'` - Note: This returns b >= '2024-09-10' and b <= '2025-09-10'
- `source = table | where cidrmatch(ip, '192.169.1.0/24')`
- `source = table | where cidrmatch(ipv6, '2003:db8::/32')`
- `source = table | trendline sma(2, temperature) as temp_trend`

```sql
source = table | eval status_category =
Expand Down
3 changes: 2 additions & 1 deletion docs/ppl-lang/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,8 @@ For additional examples see the next [documentation](PPL-Example-Commands.md).
- [`subquery commands`](ppl-subquery-command.md)

- [`correlation commands`](ppl-correlation-command.md)


- [`trendline commands`](ppl-trendline-command.md)

* **Functions**

Expand Down
60 changes: 60 additions & 0 deletions docs/ppl-lang/ppl-trendline-command.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
## PPL trendline Command

**Description**
Using ``trendline`` command to calculate moving averages of fields.


### Syntax
`TRENDLINE [sort <[+|-] sort-field>] SMA(number-of-datapoints, field) [AS alias] [SMA(number-of-datapoints, field) [AS alias]]...`

* [+|-]: optional. The plus [+] stands for ascending order and NULL/MISSING first and a minus [-] stands for descending order and NULL/MISSING last. **Default:** ascending order and NULL/MISSING first.
* sort-field: mandatory when sorting is used. The field used to sort.
* number-of-datapoints: mandatory. number of datapoints to calculate the moving average (must be greater than zero).
* field: mandatory. the name of the field the moving average should be calculated for.
* alias: optional. the name of the resulting column containing the moving average.

And the moment only the Simple Moving Average (SMA) type is supported.

It is calculated like

f[i]: The value of field 'f' in the i-th data-point
n: The number of data-points in the moving window (period)
t: The current time index

SMA(t) = (1/n) * Σ(f[i]), where i = t-n+1 to t

### Example 1: Calculate simple moving average for a timeseries of temperatures

The example calculates the simple moving average over temperatures using two datapoints.

PPL query:

os> source=t | trendline sma(2, temperature) as temp_trend;
fetched rows / total rows = 5/5
+-----------+---------+--------------------+----------+
|temperature|device-id| timestamp|temp_trend|
+-----------+---------+--------------------+----------+
| 12| 1492|2023-04-06 17:07:...| NULL|
| 12| 1492|2023-04-06 17:07:...| 12.0|
| 13| 256|2023-04-06 17:07:...| 12.5|
| 14| 257|2023-04-06 17:07:...| 13.5|
| 15| 258|2023-04-06 17:07:...| 14.5|
+-----------+---------+--------------------+----------+

### Example 2: Calculate simple moving averages for a timeseries of temperatures with sorting

The example calculates two simple moving average over temperatures using two and three datapoints sorted descending by device-id.

PPL query:

os> source=t | trendline sort - device-id sma(2, temperature) as temp_trend_2 sma(3, temperature) as temp_trend_3;
fetched rows / total rows = 5/5
+-----------+---------+--------------------+------------+------------------+
|temperature|device-id| timestamp|temp_trend_2| temp_trend_3|
+-----------+---------+--------------------+------------+------------------+
| 15| 258|2023-04-06 17:07:...| NULL| NULL|
| 14| 257|2023-04-06 17:07:...| 14.5| NULL|
| 13| 256|2023-04-06 17:07:...| 13.5| 14.0|
| 12| 1492|2023-04-06 17:07:...| 12.5| 13.0|
| 12| 1492|2023-04-06 17:07:...| 12.0|12.333333333333334|
+-----------+---------+--------------------+------------+------------------+
Original file line number Diff line number Diff line change
@@ -0,0 +1,247 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.flint.spark.ppl

import org.apache.spark.sql.{QueryTest, Row}
import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar}
import org.apache.spark.sql.catalyst.expressions.{Alias, Ascending, CaseWhen, CurrentRow, Descending, LessThan, Literal, RowFrame, SortOrder, SpecifiedWindowFrame, WindowExpression, WindowSpecDefinition}
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.streaming.StreamTest

class FlintSparkPPLTrendlineITSuite
extends QueryTest
with LogicalPlanTestUtils
with FlintPPLSuite
with StreamTest {

/** Test table and index name */
private val testTable = "spark_catalog.default.flint_ppl_test"

override def beforeAll(): Unit = {
super.beforeAll()

// Create test table
createPartitionedStateCountryTable(testTable)
}

protected override def afterEach(): Unit = {
super.afterEach()
// Stop all streaming jobs if any
spark.streams.active.foreach { job =>
job.stop()
job.awaitTermination()
}
}

test("test trendline sma command without fields command and without alias") {
val frame = sql(s"""
| source = $testTable | sort - age | trendline sma(2, age)
| """.stripMargin)

assert(
frame.columns.sameElements(
Array("name", "age", "state", "country", "year", "month", "age_trendline")))
// Retrieve the results
val results: Array[Row] = frame.collect()
val expectedResults: Array[Row] =
Array(
Row("Jake", 70, "California", "USA", 2023, 4, null),
Row("Hello", 30, "New York", "USA", 2023, 4, 50.0),
Row("John", 25, "Ontario", "Canada", 2023, 4, 27.5),
Row("Jane", 20, "Quebec", "Canada", 2023, 4, 22.5))
// Compare the results
implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0))
assert(results.sorted.sameElements(expectedResults.sorted))

// Retrieve the logical plan
val logicalPlan: LogicalPlan = frame.queryExecution.logical
val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test"))
val ageField = UnresolvedAttribute("age")
val sort = Sort(Seq(SortOrder(ageField, Descending)), global = true, table)
val countWindow = new WindowExpression(
UnresolvedFunction("COUNT", Seq(Literal(1)), isDistinct = false),
WindowSpecDefinition(Seq(), Seq(), SpecifiedWindowFrame(RowFrame, Literal(-1), CurrentRow)))
val smaWindow = WindowExpression(
UnresolvedFunction("AVG", Seq(ageField), isDistinct = false),
WindowSpecDefinition(Seq(), Seq(), SpecifiedWindowFrame(RowFrame, Literal(-1), CurrentRow)))
val caseWhen = CaseWhen(Seq((LessThan(countWindow, Literal(2)), Literal(null))), smaWindow)
val trendlineProjectList = Seq(UnresolvedStar(None), Alias(caseWhen, "age_trendline")())
val expectedPlan = Project(Seq(UnresolvedStar(None)), Project(trendlineProjectList, sort))
comparePlans(logicalPlan, expectedPlan, checkAnalysis = false)
}

test("test trendline sma command with fields command") {
val frame = sql(s"""
| source = $testTable | trendline sort - age sma(3, age) as age_sma | fields name, age, age_sma
| """.stripMargin)

assert(frame.columns.sameElements(Array("name", "age", "age_sma")))
// Retrieve the results
val results: Array[Row] = frame.collect()
val expectedResults: Array[Row] =
Array(
Row("Jake", 70, null),
Row("Hello", 30, null),
Row("John", 25, 41.666666666666664),
Row("Jane", 20, 25))
// Compare the results
implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0))
assert(results.sorted.sameElements(expectedResults.sorted))

// Retrieve the logical plan
val logicalPlan: LogicalPlan = frame.queryExecution.logical
val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test"))
val nameField = UnresolvedAttribute("name")
val ageField = UnresolvedAttribute("age")
val ageSmaField = UnresolvedAttribute("age_sma")
val sort = Sort(Seq(SortOrder(ageField, Descending)), global = true, table)
val countWindow = new WindowExpression(
UnresolvedFunction("COUNT", Seq(Literal(1)), isDistinct = false),
WindowSpecDefinition(Seq(), Seq(), SpecifiedWindowFrame(RowFrame, Literal(-2), CurrentRow)))
val smaWindow = WindowExpression(
UnresolvedFunction("AVG", Seq(ageField), isDistinct = false),
WindowSpecDefinition(Seq(), Seq(), SpecifiedWindowFrame(RowFrame, Literal(-2), CurrentRow)))
val caseWhen = CaseWhen(Seq((LessThan(countWindow, Literal(3)), Literal(null))), smaWindow)
val trendlineProjectList = Seq(UnresolvedStar(None), Alias(caseWhen, "age_sma")())
val expectedPlan =
Project(Seq(nameField, ageField, ageSmaField), Project(trendlineProjectList, sort))
comparePlans(logicalPlan, expectedPlan, checkAnalysis = false)
}

test("test multiple trendline sma commands") {
val frame = sql(s"""
| source = $testTable | trendline sort + age sma(2, age) as two_points_sma sma(3, age) as three_points_sma | fields name, age, two_points_sma, three_points_sma
| """.stripMargin)

assert(frame.columns.sameElements(Array("name", "age", "two_points_sma", "three_points_sma")))
// Retrieve the results
val results: Array[Row] = frame.collect()
val expectedResults: Array[Row] =
Array(
Row("Jane", 20, null, null),
Row("John", 25, 22.5, null),
Row("Hello", 30, 27.5, 25.0),
Row("Jake", 70, 50.0, 41.666666666666664))
// Compare the results
implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0))
assert(results.sorted.sameElements(expectedResults.sorted))

// Retrieve the logical plan
val logicalPlan: LogicalPlan = frame.queryExecution.logical
val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test"))
val nameField = UnresolvedAttribute("name")
val ageField = UnresolvedAttribute("age")
val ageTwoPointsSmaField = UnresolvedAttribute("two_points_sma")
val ageThreePointsSmaField = UnresolvedAttribute("three_points_sma")
val sort = Sort(Seq(SortOrder(ageField, Ascending)), global = true, table)
val twoPointsCountWindow = new WindowExpression(
UnresolvedFunction("COUNT", Seq(Literal(1)), isDistinct = false),
WindowSpecDefinition(Seq(), Seq(), SpecifiedWindowFrame(RowFrame, Literal(-1), CurrentRow)))
val twoPointsSmaWindow = WindowExpression(
UnresolvedFunction("AVG", Seq(ageField), isDistinct = false),
WindowSpecDefinition(Seq(), Seq(), SpecifiedWindowFrame(RowFrame, Literal(-1), CurrentRow)))
val threePointsCountWindow = new WindowExpression(
UnresolvedFunction("COUNT", Seq(Literal(1)), isDistinct = false),
WindowSpecDefinition(Seq(), Seq(), SpecifiedWindowFrame(RowFrame, Literal(-2), CurrentRow)))
val threePointsSmaWindow = WindowExpression(
UnresolvedFunction("AVG", Seq(ageField), isDistinct = false),
WindowSpecDefinition(Seq(), Seq(), SpecifiedWindowFrame(RowFrame, Literal(-2), CurrentRow)))
val twoPointsCaseWhen = CaseWhen(
Seq((LessThan(twoPointsCountWindow, Literal(2)), Literal(null))),
twoPointsSmaWindow)
val threePointsCaseWhen = CaseWhen(
Seq((LessThan(threePointsCountWindow, Literal(3)), Literal(null))),
threePointsSmaWindow)
val trendlineProjectList = Seq(
UnresolvedStar(None),
Alias(twoPointsCaseWhen, "two_points_sma")(),
Alias(threePointsCaseWhen, "three_points_sma")())
val expectedPlan = Project(
Seq(nameField, ageField, ageTwoPointsSmaField, ageThreePointsSmaField),
Project(trendlineProjectList, sort))
comparePlans(logicalPlan, expectedPlan, checkAnalysis = false)
}

test("test trendline sma command on evaluated column") {
val frame = sql(s"""
| source = $testTable | eval doubled_age = age * 2 | trendline sort + age sma(2, doubled_age) as doubled_age_sma | fields name, doubled_age, doubled_age_sma
| """.stripMargin)

assert(frame.columns.sameElements(Array("name", "doubled_age", "doubled_age_sma")))
// Retrieve the results
val results: Array[Row] = frame.collect()
val expectedResults: Array[Row] =
Array(
Row("Jane", 40, null),
Row("John", 50, 45.0),
Row("Hello", 60, 55.0),
Row("Jake", 140, 100.0))
// Compare the results
implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0))
assert(results.sorted.sameElements(expectedResults.sorted))

// Retrieve the logical plan
val logicalPlan: LogicalPlan = frame.queryExecution.logical
val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test"))
val nameField = UnresolvedAttribute("name")
val ageField = UnresolvedAttribute("age")
val doubledAgeField = UnresolvedAttribute("doubled_age")
val doubledAgeSmaField = UnresolvedAttribute("doubled_age_sma")
val evalProject = Project(
Seq(
UnresolvedStar(None),
Alias(
UnresolvedFunction("*", Seq(ageField, Literal(2)), isDistinct = false),
"doubled_age")()),
table)
val sort = Sort(Seq(SortOrder(ageField, Ascending)), global = true, evalProject)
val countWindow = new WindowExpression(
UnresolvedFunction("COUNT", Seq(Literal(1)), isDistinct = false),
WindowSpecDefinition(Seq(), Seq(), SpecifiedWindowFrame(RowFrame, Literal(-1), CurrentRow)))
val doubleAgeSmaWindow = WindowExpression(
UnresolvedFunction("AVG", Seq(doubledAgeField), isDistinct = false),
WindowSpecDefinition(Seq(), Seq(), SpecifiedWindowFrame(RowFrame, Literal(-1), CurrentRow)))
val caseWhen =
CaseWhen(Seq((LessThan(countWindow, Literal(2)), Literal(null))), doubleAgeSmaWindow)
val trendlineProjectList =
Seq(UnresolvedStar(None), Alias(caseWhen, "doubled_age_sma")())
val expectedPlan = Project(
Seq(nameField, doubledAgeField, doubledAgeSmaField),
Project(trendlineProjectList, sort))
comparePlans(logicalPlan, expectedPlan, checkAnalysis = false)
}

test("test trendline sma command chaining") {
val frame = sql(s"""
| source = $testTable | eval age_1 = age, age_2 = age | trendline sort - age_1 sma(3, age_1) | trendline sort + age_2 sma(3, age_2)
| """.stripMargin)

assert(
frame.columns.sameElements(
Array(
"name",
"age",
"state",
"country",
"year",
"month",
"age_1",
"age_2",
"age_1_trendline",
"age_2_trendline")))
// Retrieve the results
val results: Array[Row] = frame.collect()
val expectedResults: Array[Row] =
Array(
Row("Hello", 30, "New York", "USA", 2023, 4, 30, 30, null, 25.0),
Row("Jake", 70, "California", "USA", 2023, 4, 70, 70, null, 41.666666666666664),
Row("Jane", 20, "Quebec", "Canada", 2023, 4, 20, 20, 25.0, null),
Row("John", 25, "Ontario", "Canada", 2023, 4, 25, 25, 41.666666666666664, null))
// Compare the results
implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0))
assert(results.sorted.sameElements(expectedResults.sorted))
}
}
4 changes: 4 additions & 0 deletions ppl-spark-integration/src/main/antlr4/OpenSearchPPLLexer.g4
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ AD: 'AD';
ML: 'ML';
FILLNULL: 'FILLNULL';
FLATTEN: 'FLATTEN';
TRENDLINE: 'TRENDLINE';

//Native JOIN KEYWORDS
JOIN: 'JOIN';
Expand Down Expand Up @@ -90,6 +91,9 @@ FIELDSUMMARY: 'FIELDSUMMARY';
INCLUDEFIELDS: 'INCLUDEFIELDS';
NULLS: 'NULLS';

//TRENDLINE KEYWORDS
SMA: 'SMA';

// ARGUMENT KEYWORDS
KEEPEMPTY: 'KEEPEMPTY';
CONSECUTIVE: 'CONSECUTIVE';
Expand Down
14 changes: 14 additions & 0 deletions ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ commands
| fillnullCommand
| fieldsummaryCommand
| flattenCommand
| trendlineCommand
;

commandName
Expand Down Expand Up @@ -84,6 +85,7 @@ commandName
| FILLNULL
| FIELDSUMMARY
| FLATTEN
| TRENDLINE
;

searchCommand
Expand Down Expand Up @@ -252,6 +254,17 @@ flattenCommand
: FLATTEN fieldExpression
;

trendlineCommand
: TRENDLINE (SORT sortField)? trendlineClause (trendlineClause)*
;

trendlineClause
: trendlineType LT_PRTHS numberOfDataPoints = integerLiteral COMMA field = fieldExpression RT_PRTHS (AS alias = qualifiedName)?
;

trendlineType
: SMA
;

kmeansCommand
: KMEANS (kmeansParameter)*
Expand Down Expand Up @@ -1131,4 +1144,5 @@ keywordsCanBeId
| ANTI
| BETWEEN
| CIDRMATCH
| trendlineType
;
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,10 @@ public T visitLookup(Lookup node, C context) {
return visitChildren(node, context);
}

public T visitTrendline(Trendline node, C context) {
return visitChildren(node, context);
}

public T visitCorrelation(Correlation node, C context) {
return visitChildren(node, context);
}
Expand Down
Loading

0 comments on commit bdb4848

Please sign in to comment.