Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implementation of case function. #695

Merged
merged 2 commits into from
Sep 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -548,4 +548,27 @@ trait FlintSparkSuite extends QueryTest with FlintSuite with OpenSearchSuite wit
| )
|)""".stripMargin)
}

protected def createTableHttpLog(testTable: String): Unit = {
sql(s"""
| CREATE TABLE $testTable
|(
| id INT,
| status_code INT,
| request_path STRING,
| timestamp STRING
|)
| USING $tableType $tableOptions
|""".stripMargin)

sql(s"""
| INSERT INTO $testTable
| VALUES (1, 200, '/home', '2023-10-01 10:00:00'),
| (2, null, '/about', '2023-10-01 10:05:00'),
| (3, 500, '/contact', '2023-10-01 10:10:00'),
| (4, 301, '/home', '2023-10-01 10:15:00'),
| (5, 200, '/services', '2023-10-01 10:20:00'),
| (6, 403, '/home', '2023-10-01 10:25:00')
| """.stripMargin)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import org.opensearch.sql.ppl.utils.DataTypeTransformer.seq

import org.apache.spark.sql.{AnalysisException, QueryTest, Row}
import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar}
import org.apache.spark.sql.catalyst.expressions.{Alias, Ascending, Descending, LessThan, Literal, SortOrder}
import org.apache.spark.sql.catalyst.expressions.{Alias, And, Ascending, CaseWhen, Descending, EqualTo, GreaterThanOrEqual, LessThan, Literal, SortOrder}
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, LogicalPlan, Project, Sort}
import org.apache.spark.sql.streaming.StreamTest

Expand All @@ -21,12 +21,14 @@ class FlintSparkPPLEvalITSuite

/** Test table and index name */
private val testTable = "spark_catalog.default.flint_ppl_test"
private val testTableHttpLog = "spark_catalog.default.flint_ppl_test_http_log"

override def beforeAll(): Unit = {
super.beforeAll()

// Create test table
createPartitionedStateCountryTable(testTable)
createTableHttpLog(testTableHttpLog)
}

protected override def afterEach(): Unit = {
Expand Down Expand Up @@ -504,7 +506,134 @@ class FlintSparkPPLEvalITSuite
comparePlans(logicalPlan, expectedPlan, checkAnalysis = false)
}

test("eval case function") {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you also add more complex tests use case including composition of additional commands such as where conditions and stats ?
thanks

val frame = sql(s"""
| source = $testTableHttpLog |
| eval status_category =
| case(status_code >= 200 AND status_code < 300, 'Success',
| status_code >= 300 AND status_code < 400, 'Redirection',
| status_code >= 400 AND status_code < 500, 'Client Error',
| status_code >= 500, 'Server Error'
| else concat('Incorrect HTTP status code for request ', request_path)
| )
| """.stripMargin)

val results: Array[Row] = frame.collect()
val expectedResults: Array[Row] =
Array(
Row(1, 200, "/home", "2023-10-01 10:00:00", "Success"),
Row(
2,
null,
"/about",
"2023-10-01 10:05:00",
"Incorrect HTTP status code for request /about"),
Row(3, 500, "/contact", "2023-10-01 10:10:00", "Server Error"),
Row(4, 301, "/home", "2023-10-01 10:15:00", "Redirection"),
Row(5, 200, "/services", "2023-10-01 10:20:00", "Success"),
Row(6, 403, "/home", "2023-10-01 10:25:00", "Client Error"))
implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Int](_.getInt(0))
assert(results.sorted.sameElements(expectedResults.sorted))
val expectedColumns =
Array[String]("id", "status_code", "request_path", "timestamp", "status_category")
assert(frame.columns.sameElements(expectedColumns))

val logicalPlan: LogicalPlan = frame.queryExecution.logical

val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test_http_log"))
val conditionValueSequence = Seq(
(graterOrEqualAndLessThan("status_code", 200, 300), Literal("Success")),
(graterOrEqualAndLessThan("status_code", 300, 400), Literal("Redirection")),
(graterOrEqualAndLessThan("status_code", 400, 500), Literal("Client Error")),
(
EqualTo(
Literal(true),
GreaterThanOrEqual(UnresolvedAttribute("status_code"), Literal(500))),
Literal("Server Error")))
val elseValue = UnresolvedFunction(
"concat",
Seq(
Literal("Incorrect HTTP status code for request "),
UnresolvedAttribute("request_path")),
isDistinct = false)
val caseFunction = CaseWhen(conditionValueSequence, elseValue)
val aliasStatusCategory = Alias(caseFunction, "status_category")()
val evalProjectList = Seq(UnresolvedStar(None), aliasStatusCategory)
val evalProject = Project(evalProjectList, table)
val expectedPlan = Project(Seq(UnresolvedStar(None)), evalProject)
comparePlans(logicalPlan, expectedPlan, checkAnalysis = false)
}

test("eval case function in complex pipeline") {
val frame = sql(s"""
| source = $testTableHttpLog
| | where ispresent(status_code)
| | eval status_category =
| case(status_code >= 200 AND status_code < 300, 'Success',
| status_code >= 300 AND status_code < 400, 'Redirection',
| status_code >= 400 AND status_code < 500, 'Client Error',
| status_code >= 500, 'Server Error'
| else 'Unknown'
| )
| | stats count() by status_category
| """.stripMargin)

val results: Array[Row] = frame.collect()
val expectedResults: Array[Row] =
Array(
Row(1L, "Redirection"),
Row(1L, "Client Error"),
Row(1L, "Server Error"),
Row(2L, "Success"))
implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getString(1))
assert(results.sorted.sameElements(expectedResults.sorted))
val expectedColumns = Array[String]("count()", "status_category")
assert(frame.columns.sameElements(expectedColumns))

val logicalPlan: LogicalPlan = frame.queryExecution.logical

val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test_http_log"))
val filter = Filter(
UnresolvedFunction(
"isnotnull",
Seq(UnresolvedAttribute("status_code")),
isDistinct = false),
table)
val conditionValueSequence = Seq(
(graterOrEqualAndLessThan("status_code", 200, 300), Literal("Success")),
(graterOrEqualAndLessThan("status_code", 300, 400), Literal("Redirection")),
(graterOrEqualAndLessThan("status_code", 400, 500), Literal("Client Error")),
(
EqualTo(
Literal(true),
GreaterThanOrEqual(UnresolvedAttribute("status_code"), Literal(500))),
Literal("Server Error")))
val elseValue = Literal("Unknown")
val caseFunction = CaseWhen(conditionValueSequence, elseValue)
val aliasStatusCategory = Alias(caseFunction, "status_category")()
val evalProjectList = Seq(UnresolvedStar(None), aliasStatusCategory)
val evalProject = Project(evalProjectList, filter)
val aggregation = Aggregate(
Seq(Alias(UnresolvedAttribute("status_category"), "status_category")()),
Seq(
Alias(
UnresolvedFunction("COUNT", Seq(UnresolvedStar(None)), isDistinct = false),
"count()")(),
Alias(UnresolvedAttribute("status_category"), "status_category")()),
evalProject)
val expectedPlan = Project(Seq(UnresolvedStar(None)), aggregation)
comparePlans(logicalPlan, expectedPlan, checkAnalysis = false)
}

private def graterOrEqualAndLessThan(fieldName: String, min: Int, max: Int) = {
val and = And(
GreaterThanOrEqual(UnresolvedAttribute("status_code"), Literal(min)),
LessThan(UnresolvedAttribute(fieldName), Literal(max)))
EqualTo(Literal(true), and)
}

// Todo excluded fields not support yet

ignore("test single eval expression with excluded fields") {
val frame = sql(s"""
| source = $testTable | eval new_field = "New Field" | fields - age
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ package org.opensearch.flint.spark.ppl

import org.apache.spark.sql.{QueryTest, Row}
import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar}
import org.apache.spark.sql.catalyst.expressions.{Alias, And, Ascending, Descending, Divide, EqualTo, Floor, GreaterThan, LessThanOrEqual, Literal, Multiply, Not, Or, SortOrder}
import org.apache.spark.sql.catalyst.expressions.{Alias, And, Ascending, CaseWhen, Descending, Divide, EqualTo, Floor, GreaterThan, LessThan, LessThanOrEqual, Literal, Multiply, Not, Or, SortOrder}
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.streaming.StreamTest

Expand All @@ -19,11 +19,13 @@ class FlintSparkPPLFiltersITSuite

/** Test table and index name */
private val testTable = "spark_catalog.default.flint_ppl_test"
private val duplicationTable = "spark_catalog.default.flint_ppl_test_duplication_table"

override def beforeAll(): Unit = {
super.beforeAll()
// Create test table
createPartitionedStateCountryTable(testTable)
createDuplicationNullableTable(duplicationTable)
}

protected override def afterEach(): Unit = {
Expand Down Expand Up @@ -348,4 +350,107 @@ class FlintSparkPPLFiltersITSuite
assert(compareByString(expectedPlan) === compareByString(logicalPlan))
}

test("case function used as filter") {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can u also add here some additional tests use cases of more complicated composition ppl commands ?

val frame = sql(s"""
| source = $testTable case(country = 'USA', 'The United States of America' else 'Other country') = 'The United States of America'
| """.stripMargin)

// Retrieve the results
val results: Array[Row] = frame.collect()
// Define the expected results
val expectedResults: Array[Row] = Array(
Row("Jake", 70, "California", "USA", 2023, 4),
Row("Hello", 30, "New York", "USA", 2023, 4))

// Compare the results
implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0))

assert(results.sorted.sameElements(expectedResults.sorted))

// Retrieve the logical plan
val logicalPlan: LogicalPlan = frame.queryExecution.logical
// Define the expected logical plan
val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test"))
val conditionValueSequence = Seq(
(
EqualTo(Literal(true), EqualTo(UnresolvedAttribute("country"), Literal("USA"))),
Literal("The United States of America")))
val elseValue = Literal("Other country")
val caseFunction = CaseWhen(conditionValueSequence, elseValue)
val filterExpr = EqualTo(caseFunction, Literal("The United States of America"))
val filterPlan = Filter(filterExpr, table)
val projectList = Seq(UnresolvedStar(None))
val expectedPlan = Project(projectList, filterPlan)
// Compare the two plans
assert(compareByString(expectedPlan) === compareByString(logicalPlan))
}

test("case function used as filter complex filter") {
val frame = sql(s"""
| source = $duplicationTable
| | eval factor = case(id > 15, id - 14, isnull(name), id - 7, id < 3, id + 1 else 1)
| | where case(factor = 2, 'even', factor = 4, 'even', factor = 6, 'even', factor = 8, 'even' else 'odd') = 'even'
| | stats count() by factor
| """.stripMargin)

// Retrieve the results
val results: Array[Row] = frame.collect() // count(), factor
// Define the expected results
val expectedResults: Array[Row] = Array(Row(1, 4), Row(1, 6), Row(2, 2))

// Compare the results
implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Int](_.getAs[Int](1))
assert(results.sorted.sameElements(expectedResults.sorted))

// Retrieve the logical plan
val logicalPlan: LogicalPlan = frame.queryExecution.logical
// Define the expected logical plan
val table =
UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test_duplication_table"))

// case function used in eval command
val conditionValueEval = Seq(
(
EqualTo(Literal(true), GreaterThan(UnresolvedAttribute("id"), Literal(15))),
UnresolvedFunction("-", Seq(UnresolvedAttribute("id"), Literal(14)), isDistinct = false)),
(
EqualTo(
Literal(true),
UnresolvedFunction("isnull", Seq(UnresolvedAttribute("name")), isDistinct = false)),
UnresolvedFunction("-", Seq(UnresolvedAttribute("id"), Literal(7)), isDistinct = false)),
(
EqualTo(Literal(true), LessThan(UnresolvedAttribute("id"), Literal(3))),
UnresolvedFunction("+", Seq(UnresolvedAttribute("id"), Literal(1)), isDistinct = false)))
val aliasCaseFactor = Alias(CaseWhen(conditionValueEval, Literal(1)), "factor")()
val evalProject = Project(Seq(UnresolvedStar(None), aliasCaseFactor), table)

// case in where clause
val conditionValueWhere = Seq(
(
EqualTo(Literal(true), EqualTo(UnresolvedAttribute("factor"), Literal(2))),
Literal("even")),
(
EqualTo(Literal(true), EqualTo(UnresolvedAttribute("factor"), Literal(4))),
Literal("even")),
(
EqualTo(Literal(true), EqualTo(UnresolvedAttribute("factor"), Literal(6))),
Literal("even")),
(
EqualTo(Literal(true), EqualTo(UnresolvedAttribute("factor"), Literal(8))),
Literal("even")))
val caseFunctionWhere = CaseWhen(conditionValueWhere, Literal("odd"))
val filterPlan = Filter(EqualTo(caseFunctionWhere, Literal("even")), evalProject)

val aggregation = Aggregate(
Seq(Alias(UnresolvedAttribute("factor"), "factor")()),
Seq(
Alias(
UnresolvedFunction("COUNT", Seq(UnresolvedStar(None)), isDistinct = false),
"count()")(),
Alias(UnresolvedAttribute("factor"), "factor")()),
filterPlan)
val expectedPlan = Project(Seq(UnresolvedStar(None)), aggregation)
// Compare the two plans
comparePlans(logicalPlan, expectedPlan, checkAnalysis = false)
}
}
48 changes: 48 additions & 0 deletions ppl-spark-integration/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,29 @@ See the next samples of PPL queries :
- `source = table | where ispresent(b)`
- `source = table | where isnull(coalesce(a, b)) | fields a,b,c | head 3`
- `source = table | where isempty(a)`
- `source = table | where case(length(a) > 6, 'True' else 'False') = 'True'`;
-
```
source = table | eval status_category =
case(a >= 200 AND a < 300, 'Success',
a >= 300 AND a < 400, 'Redirection',
a >= 400 AND a < 500, 'Client Error',
a >= 500, 'Server Error'
else 'Incorrect HTTP status code')
| where case(a >= 200 AND a < 300, 'Success',
a >= 300 AND a < 400, 'Redirection',
a >= 400 AND a < 500, 'Client Error',
a >= 500, 'Server Error'
else 'Incorrect HTTP status code'
) = 'Incorrect HTTP status code'
```
-
```
source = table
| eval factor = case(a > 15, a - 14, isnull(b), a - 7, a < 3, a + 1 else 1)
| where case(factor = 2, 'even', factor = 4, 'even', factor = 6, 'even', factor = 8, 'even' else 'odd') = 'even'
| stats count() by factor
```

**Filters With Logical Conditions**
- `source = table | where c = 'test' AND a = 1 | fields a,b,c`
Expand All @@ -265,6 +288,31 @@ Assumptions: `a`, `b`, `c` are existing fields in `table`
- `source = table | eval f = ispresent(a)`
- `source = table | eval r = coalesce(a, b, c) | fields r`
- `source = table | eval e = isempty(a) | fields e`
- `source = table | eval f = case(a = 0, 'zero', a = 1, 'one', a = 2, 'two', a = 3, 'three', a = 4, 'four', a = 5, 'five', a = 6, 'six', a = 7, 'se7en', a = 8, 'eight', a = 9, 'nine')`
- `source = table | eval f = case(a = 0, 'zero', a = 1, 'one' else 'unknown')`
- `source = table | eval f = case(a = 0, 'zero', a = 1, 'one' else concat(a, ' is an incorrect binary digit'))`
-
```
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

plz add more examples for using case in other use cases that include where clause / stats commands

source = table | eval e = eval status_category =
case(a >= 200 AND a < 300, 'Success',
a >= 300 AND a < 400, 'Redirection',
a >= 400 AND a < 500, 'Client Error',
a >= 500, 'Server Error'
else 'Unknown'
)
```
-
```
source = table | where ispresent(a) |
eval status_category =
case(a >= 200 AND a < 300, 'Success',
a >= 300 AND a < 400, 'Redirection',
a >= 400 AND a < 500, 'Client Error',
a >= 500, 'Server Error'
else 'Incorrect HTTP status code'
)
| stats count() by status_category
```

Limitation: Overriding existing field is unsupported, following queries throw exceptions with "Reference 'a' is ambiguous"
- `source = table | eval a = 10 | fields a,b,c`
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ ANOMALY_SCORE_THRESHOLD: 'ANOMALY_SCORE_THRESHOLD';

// COMPARISON FUNCTION KEYWORDS
CASE: 'CASE';
ELSE: 'ELSE';
IN: 'IN';

// LOGICAL KEYWORDS
Expand Down
Loading
Loading