-
Notifications
You must be signed in to change notification settings - Fork 34
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implementation of case function. #695
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -7,7 +7,7 @@ package org.opensearch.flint.spark.ppl | |
|
||
import org.apache.spark.sql.{QueryTest, Row} | ||
import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar} | ||
import org.apache.spark.sql.catalyst.expressions.{Alias, And, Ascending, Descending, Divide, EqualTo, Floor, GreaterThan, LessThanOrEqual, Literal, Multiply, Not, Or, SortOrder} | ||
import org.apache.spark.sql.catalyst.expressions.{Alias, And, Ascending, CaseWhen, Descending, Divide, EqualTo, Floor, GreaterThan, LessThan, LessThanOrEqual, Literal, Multiply, Not, Or, SortOrder} | ||
import org.apache.spark.sql.catalyst.plans.logical._ | ||
import org.apache.spark.sql.streaming.StreamTest | ||
|
||
|
@@ -19,11 +19,13 @@ class FlintSparkPPLFiltersITSuite | |
|
||
/** Test table and index name */ | ||
private val testTable = "spark_catalog.default.flint_ppl_test" | ||
private val duplicationTable = "spark_catalog.default.flint_ppl_test_duplication_table" | ||
|
||
override def beforeAll(): Unit = { | ||
super.beforeAll() | ||
// Create test table | ||
createPartitionedStateCountryTable(testTable) | ||
createDuplicationNullableTable(duplicationTable) | ||
} | ||
|
||
protected override def afterEach(): Unit = { | ||
|
@@ -348,4 +350,107 @@ class FlintSparkPPLFiltersITSuite | |
assert(compareByString(expectedPlan) === compareByString(logicalPlan)) | ||
} | ||
|
||
test("case function used as filter") { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can u also add here some additional tests use cases of more complicated composition ppl commands ? |
||
val frame = sql(s""" | ||
| source = $testTable case(country = 'USA', 'The United States of America' else 'Other country') = 'The United States of America' | ||
| """.stripMargin) | ||
|
||
// Retrieve the results | ||
val results: Array[Row] = frame.collect() | ||
// Define the expected results | ||
val expectedResults: Array[Row] = Array( | ||
Row("Jake", 70, "California", "USA", 2023, 4), | ||
Row("Hello", 30, "New York", "USA", 2023, 4)) | ||
|
||
// Compare the results | ||
implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) | ||
|
||
assert(results.sorted.sameElements(expectedResults.sorted)) | ||
|
||
// Retrieve the logical plan | ||
val logicalPlan: LogicalPlan = frame.queryExecution.logical | ||
// Define the expected logical plan | ||
val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) | ||
val conditionValueSequence = Seq( | ||
( | ||
EqualTo(Literal(true), EqualTo(UnresolvedAttribute("country"), Literal("USA"))), | ||
Literal("The United States of America"))) | ||
val elseValue = Literal("Other country") | ||
val caseFunction = CaseWhen(conditionValueSequence, elseValue) | ||
val filterExpr = EqualTo(caseFunction, Literal("The United States of America")) | ||
val filterPlan = Filter(filterExpr, table) | ||
val projectList = Seq(UnresolvedStar(None)) | ||
val expectedPlan = Project(projectList, filterPlan) | ||
// Compare the two plans | ||
assert(compareByString(expectedPlan) === compareByString(logicalPlan)) | ||
} | ||
|
||
test("case function used as filter complex filter") { | ||
val frame = sql(s""" | ||
| source = $duplicationTable | ||
| | eval factor = case(id > 15, id - 14, isnull(name), id - 7, id < 3, id + 1 else 1) | ||
| | where case(factor = 2, 'even', factor = 4, 'even', factor = 6, 'even', factor = 8, 'even' else 'odd') = 'even' | ||
| | stats count() by factor | ||
| """.stripMargin) | ||
|
||
// Retrieve the results | ||
val results: Array[Row] = frame.collect() // count(), factor | ||
// Define the expected results | ||
val expectedResults: Array[Row] = Array(Row(1, 4), Row(1, 6), Row(2, 2)) | ||
|
||
// Compare the results | ||
implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Int](_.getAs[Int](1)) | ||
assert(results.sorted.sameElements(expectedResults.sorted)) | ||
|
||
// Retrieve the logical plan | ||
val logicalPlan: LogicalPlan = frame.queryExecution.logical | ||
// Define the expected logical plan | ||
val table = | ||
UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test_duplication_table")) | ||
|
||
// case function used in eval command | ||
val conditionValueEval = Seq( | ||
( | ||
EqualTo(Literal(true), GreaterThan(UnresolvedAttribute("id"), Literal(15))), | ||
UnresolvedFunction("-", Seq(UnresolvedAttribute("id"), Literal(14)), isDistinct = false)), | ||
( | ||
EqualTo( | ||
Literal(true), | ||
UnresolvedFunction("isnull", Seq(UnresolvedAttribute("name")), isDistinct = false)), | ||
UnresolvedFunction("-", Seq(UnresolvedAttribute("id"), Literal(7)), isDistinct = false)), | ||
( | ||
EqualTo(Literal(true), LessThan(UnresolvedAttribute("id"), Literal(3))), | ||
UnresolvedFunction("+", Seq(UnresolvedAttribute("id"), Literal(1)), isDistinct = false))) | ||
val aliasCaseFactor = Alias(CaseWhen(conditionValueEval, Literal(1)), "factor")() | ||
val evalProject = Project(Seq(UnresolvedStar(None), aliasCaseFactor), table) | ||
|
||
// case in where clause | ||
val conditionValueWhere = Seq( | ||
( | ||
EqualTo(Literal(true), EqualTo(UnresolvedAttribute("factor"), Literal(2))), | ||
Literal("even")), | ||
( | ||
EqualTo(Literal(true), EqualTo(UnresolvedAttribute("factor"), Literal(4))), | ||
Literal("even")), | ||
( | ||
EqualTo(Literal(true), EqualTo(UnresolvedAttribute("factor"), Literal(6))), | ||
Literal("even")), | ||
( | ||
EqualTo(Literal(true), EqualTo(UnresolvedAttribute("factor"), Literal(8))), | ||
Literal("even"))) | ||
val caseFunctionWhere = CaseWhen(conditionValueWhere, Literal("odd")) | ||
val filterPlan = Filter(EqualTo(caseFunctionWhere, Literal("even")), evalProject) | ||
|
||
val aggregation = Aggregate( | ||
Seq(Alias(UnresolvedAttribute("factor"), "factor")()), | ||
Seq( | ||
Alias( | ||
UnresolvedFunction("COUNT", Seq(UnresolvedStar(None)), isDistinct = false), | ||
"count()")(), | ||
Alias(UnresolvedAttribute("factor"), "factor")()), | ||
filterPlan) | ||
val expectedPlan = Project(Seq(UnresolvedStar(None)), aggregation) | ||
// Compare the two plans | ||
comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -245,6 +245,29 @@ See the next samples of PPL queries : | |
- `source = table | where ispresent(b)` | ||
- `source = table | where isnull(coalesce(a, b)) | fields a,b,c | head 3` | ||
- `source = table | where isempty(a)` | ||
- `source = table | where case(length(a) > 6, 'True' else 'False') = 'True'`; | ||
- | ||
``` | ||
source = table | eval status_category = | ||
case(a >= 200 AND a < 300, 'Success', | ||
a >= 300 AND a < 400, 'Redirection', | ||
a >= 400 AND a < 500, 'Client Error', | ||
a >= 500, 'Server Error' | ||
else 'Incorrect HTTP status code') | ||
| where case(a >= 200 AND a < 300, 'Success', | ||
a >= 300 AND a < 400, 'Redirection', | ||
a >= 400 AND a < 500, 'Client Error', | ||
a >= 500, 'Server Error' | ||
else 'Incorrect HTTP status code' | ||
) = 'Incorrect HTTP status code' | ||
``` | ||
- | ||
``` | ||
source = table | ||
| eval factor = case(a > 15, a - 14, isnull(b), a - 7, a < 3, a + 1 else 1) | ||
| where case(factor = 2, 'even', factor = 4, 'even', factor = 6, 'even', factor = 8, 'even' else 'odd') = 'even' | ||
| stats count() by factor | ||
``` | ||
|
||
**Filters With Logical Conditions** | ||
- `source = table | where c = 'test' AND a = 1 | fields a,b,c` | ||
|
@@ -265,6 +288,31 @@ Assumptions: `a`, `b`, `c` are existing fields in `table` | |
- `source = table | eval f = ispresent(a)` | ||
- `source = table | eval r = coalesce(a, b, c) | fields r` | ||
- `source = table | eval e = isempty(a) | fields e` | ||
- `source = table | eval f = case(a = 0, 'zero', a = 1, 'one', a = 2, 'two', a = 3, 'three', a = 4, 'four', a = 5, 'five', a = 6, 'six', a = 7, 'se7en', a = 8, 'eight', a = 9, 'nine')` | ||
- `source = table | eval f = case(a = 0, 'zero', a = 1, 'one' else 'unknown')` | ||
- `source = table | eval f = case(a = 0, 'zero', a = 1, 'one' else concat(a, ' is an incorrect binary digit'))` | ||
- | ||
``` | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. plz add more examples for using case in other use cases that include where clause / stats commands |
||
source = table | eval e = eval status_category = | ||
case(a >= 200 AND a < 300, 'Success', | ||
a >= 300 AND a < 400, 'Redirection', | ||
a >= 400 AND a < 500, 'Client Error', | ||
a >= 500, 'Server Error' | ||
else 'Unknown' | ||
) | ||
``` | ||
- | ||
``` | ||
source = table | where ispresent(a) | | ||
eval status_category = | ||
case(a >= 200 AND a < 300, 'Success', | ||
a >= 300 AND a < 400, 'Redirection', | ||
a >= 400 AND a < 500, 'Client Error', | ||
a >= 500, 'Server Error' | ||
else 'Incorrect HTTP status code' | ||
) | ||
| stats count() by status_category | ||
``` | ||
|
||
Limitation: Overriding existing field is unsupported, following queries throw exceptions with "Reference 'a' is ambiguous" | ||
- `source = table | eval a = 10 | fields a,b,c` | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you also add more complex tests use case including composition of additional commands such as where conditions and stats ?
thanks