From 255e9f064571c538957c8d18680844e48fd4b1dc Mon Sep 17 00:00:00 2001 From: Gavin Ray Date: Tue, 3 Dec 2024 16:46:25 -0500 Subject: [PATCH] Phoenix update support for "IN" operator --- .../phoenix/NoRelationshipsQueryGenerator.kt | 3 +- .../phoenix/PhoenixDataConnectorService.kt | 2 +- .../phoenix/PhoenixJDBCSchemaGenerator.kt | 10 +- .../src/main/kotlin/io/hasura/ndc/ir/Query.kt | 2 + .../io/hasura/ndc/sqlgen/BaseGenerator.kt | 321 +++++++++++++----- 5 files changed, 244 insertions(+), 94 deletions(-) diff --git a/ndc-connector-phoenix/src/main/kotlin/io/hasura/phoenix/NoRelationshipsQueryGenerator.kt b/ndc-connector-phoenix/src/main/kotlin/io/hasura/phoenix/NoRelationshipsQueryGenerator.kt index 3029d98..4368bed 100644 --- a/ndc-connector-phoenix/src/main/kotlin/io/hasura/phoenix/NoRelationshipsQueryGenerator.kt +++ b/ndc-connector-phoenix/src/main/kotlin/io/hasura/phoenix/NoRelationshipsQueryGenerator.kt @@ -26,6 +26,7 @@ object NoRelationshipsQueryGenerator : BaseQueryGenerator() { ApplyBinaryComparisonOperator.IN -> col.`in`(value) ApplyBinaryComparisonOperator.IS_NULL -> col.isNull ApplyBinaryComparisonOperator.LIKE -> col.like(value as Field) + ApplyBinaryComparisonOperator.CONTAINS -> col.contains(value as Field) } } @@ -50,7 +51,7 @@ object NoRelationshipsQueryGenerator : BaseQueryGenerator() { DSL.table(DSL.unquotedName(request.collection)) ).apply { if (request.query.predicate != null) { - where(expressionToCondition(request.query.predicate!!, request)) + where(expressionToConditionPhoenixNoTableNameInPredicates(request.query.predicate!!, request)) } if (request.query.order_by != null) { orderBy( diff --git a/ndc-connector-phoenix/src/main/kotlin/io/hasura/phoenix/PhoenixDataConnectorService.kt b/ndc-connector-phoenix/src/main/kotlin/io/hasura/phoenix/PhoenixDataConnectorService.kt index 87399fc..e92d5e7 100644 --- a/ndc-connector-phoenix/src/main/kotlin/io/hasura/phoenix/PhoenixDataConnectorService.kt +++ b/ndc-connector-phoenix/src/main/kotlin/io/hasura/phoenix/PhoenixDataConnectorService.kt @@ -65,7 +65,7 @@ class PhoenixDataConnectorService @Inject constructor( override val jooqDialect = SQLDialect.DEFAULT override val jooqSettings = commonDSLContextSettings - .withRenderQuotedNames(RenderQuotedNames.EXPLICIT_DEFAULT_QUOTED) + .withRenderQuotedNames(RenderQuotedNames.EXPLICIT_DEFAULT_UNQUOTED) .withRenderOptionalAsKeywordForFieldAliases(RenderOptionalKeyword.ON) diff --git a/ndc-connector-phoenix/src/main/kotlin/io/hasura/phoenix/PhoenixJDBCSchemaGenerator.kt b/ndc-connector-phoenix/src/main/kotlin/io/hasura/phoenix/PhoenixJDBCSchemaGenerator.kt index f8d71cc..4773f3c 100644 --- a/ndc-connector-phoenix/src/main/kotlin/io/hasura/phoenix/PhoenixJDBCSchemaGenerator.kt +++ b/ndc-connector-phoenix/src/main/kotlin/io/hasura/phoenix/PhoenixJDBCSchemaGenerator.kt @@ -66,6 +66,7 @@ object PhoenixJDBCSchemaGenerator : JDBCSchemaGenerator() { "_eq" to ComparisonOperatorDefinition.Equal, "_contains" to ComparisonOperatorDefinition.Custom(argument_type = Type.Named(NDCScalar.STRING.name)), "_like" to ComparisonOperatorDefinition.Custom(argument_type = Type.Named(NDCScalar.STRING.name)), + "_in" to ComparisonOperatorDefinition.In ), aggregate_functions = mapOf( "min" to AggregateFunctionDefinition(result_type = Type.Named(NDCScalar.STRING.name)), @@ -78,7 +79,8 @@ object PhoenixJDBCSchemaGenerator : JDBCSchemaGenerator() { "_lt" to ComparisonOperatorDefinition.Custom(argument_type = Type.Named(NDCScalar.DATETIME.name)), "_gte" to ComparisonOperatorDefinition.Custom(argument_type = Type.Named(NDCScalar.DATETIME.name)), "_lte" to ComparisonOperatorDefinition.Custom(argument_type = Type.Named(NDCScalar.DATETIME.name)), - "_eq" to ComparisonOperatorDefinition.Equal + "_eq" to ComparisonOperatorDefinition.Equal, + "_in" to ComparisonOperatorDefinition.In ), aggregate_functions = emptyMap() ), @@ -88,7 +90,8 @@ object PhoenixJDBCSchemaGenerator : JDBCSchemaGenerator() { "_lt" to ComparisonOperatorDefinition.Custom(argument_type = Type.Named(NDCScalar.DATE.name)), "_gte" to ComparisonOperatorDefinition.Custom(argument_type = Type.Named(NDCScalar.DATE.name)), "_lte" to ComparisonOperatorDefinition.Custom(argument_type = Type.Named(NDCScalar.DATE.name)), - "_eq" to ComparisonOperatorDefinition.Equal + "_eq" to ComparisonOperatorDefinition.Equal, + "_in" to ComparisonOperatorDefinition.In ), aggregate_functions = emptyMap() ), @@ -98,7 +101,8 @@ object PhoenixJDBCSchemaGenerator : JDBCSchemaGenerator() { "_lt" to ComparisonOperatorDefinition.Custom(argument_type = Type.Named(NDCScalar.TIME.name)), "_gte" to ComparisonOperatorDefinition.Custom(argument_type = Type.Named(NDCScalar.TIME.name)), "_lte" to ComparisonOperatorDefinition.Custom(argument_type = Type.Named(NDCScalar.TIME.name)), - "_eq" to ComparisonOperatorDefinition.Equal + "_eq" to ComparisonOperatorDefinition.Equal, + "_in" to ComparisonOperatorDefinition.In ), aggregate_functions = emptyMap() ), diff --git a/ndc-ir/src/main/kotlin/io/hasura/ndc/ir/Query.kt b/ndc-ir/src/main/kotlin/io/hasura/ndc/ir/Query.kt index d75247c..f0d24b6 100644 --- a/ndc-ir/src/main/kotlin/io/hasura/ndc/ir/Query.kt +++ b/ndc-ir/src/main/kotlin/io/hasura/ndc/ir/Query.kt @@ -178,6 +178,8 @@ enum class ApplyBinaryComparisonOperator { IS_NULL, @JsonProperty("_like") LIKE, + @JsonProperty("_contains") + CONTAINS, } enum class ApplyUnaryComparisonOperator { diff --git a/ndc-sqlgen/src/main/kotlin/io/hasura/ndc/sqlgen/BaseGenerator.kt b/ndc-sqlgen/src/main/kotlin/io/hasura/ndc/sqlgen/BaseGenerator.kt index b15bf14..4135d64 100644 --- a/ndc-sqlgen/src/main/kotlin/io/hasura/ndc/sqlgen/BaseGenerator.kt +++ b/ndc-sqlgen/src/main/kotlin/io/hasura/ndc/sqlgen/BaseGenerator.kt @@ -25,7 +25,23 @@ sealed interface BaseGenerator { ) } - abstract fun buildComparison(col: Field, operator: ApplyBinaryComparisonOperator, value: Field): Condition + fun buildComparison( + col: Field, + operator: ApplyBinaryComparisonOperator, + value: Field + ): Condition { + return when (operator) { + ApplyBinaryComparisonOperator.EQ -> col.eq(value) + ApplyBinaryComparisonOperator.GT -> col.gt(value) + ApplyBinaryComparisonOperator.GTE -> col.ge(value) + ApplyBinaryComparisonOperator.LT -> col.lt(value) + ApplyBinaryComparisonOperator.LTE -> col.le(value) + ApplyBinaryComparisonOperator.IN -> col.`in`(value) + ApplyBinaryComparisonOperator.IS_NULL -> col.isNull + ApplyBinaryComparisonOperator.LIKE -> col.like(value as Field) + ApplyBinaryComparisonOperator.CONTAINS -> col.contains(value as Field) + } + } private fun getCollectionForCompCol( col: ComparisonColumn, @@ -38,7 +54,8 @@ sealed interface BaseGenerator { if (col.path.isNotEmpty()) { // Traverse the relationship path to get to the current collection name val targetCollection = col.path.fold("") { acc, pathElement -> - val rel = request.collection_relationships[pathElement.relationship] ?: throw Exception("Relationship not found") + val rel = request.collection_relationships[pathElement.relationship] + ?: throw Exception("Relationship not found") rel.target_collection } targetCollection @@ -49,9 +66,13 @@ sealed interface BaseGenerator { } } - fun argumentToCondition(request: QueryRequest, argument: Map.Entry, overrideCollection: String) - = argumentToCondition(request.copy(collection = overrideCollection), argument) - fun argumentToCondition(request: QueryRequest, argument: Map.Entry) : Condition { + fun argumentToCondition( + request: QueryRequest, + argument: Map.Entry, + overrideCollection: String + ) = argumentToCondition(request.copy(collection = overrideCollection), argument) + + fun argumentToCondition(request: QueryRequest, argument: Map.Entry): Condition { val compVal = when (val arg = argument.value) { is Argument.Variable -> ComparisonValue.VariableComp(arg.name) is Argument.Literal -> ComparisonValue.ScalarComp(arg.value) @@ -66,8 +87,8 @@ sealed interface BaseGenerator { } // override request collection for expressionToCondition evaluation - fun expressionToCondition( e: Expression, request: QueryRequest, overrideCollection: String) - = expressionToCondition(e, request.copy(collection = overrideCollection)) + fun expressionToCondition(e: Expression, request: QueryRequest, overrideCollection: String) = + expressionToCondition(e, request.copy(collection = overrideCollection)) // Convert a WHERE-like expression IR into a JOOQ Condition @@ -81,36 +102,33 @@ sealed interface BaseGenerator { request: QueryRequest ): Condition { + fun splitCollectionName(collectionName: String): List { + return collectionName.split(".") + } + return when (e) { - // The negation of a single subexpression - is Expression.Not -> DSL.not(expressionToCondition(e.expression,request)) + is Expression.Not -> DSL.not(expressionToCondition(e.expression, request)) - // A conjunction of several subexpressions is Expression.And -> when (e.expressions.size) { 0 -> DSL.trueCondition() - else -> DSL.and(e.expressions.map { expressionToCondition(it,request) }) + else -> DSL.and(e.expressions.map { expressionToCondition(it, request) }) } - // A disjunction of several subexpressions is Expression.Or -> when (e.expressions.size) { 0 -> DSL.falseCondition() - else -> DSL.or(e.expressions.map { expressionToCondition(it,request) }) + else -> DSL.or(e.expressions.map { expressionToCondition(it, request) }) } - // Test the specified column against a single value using a particular binary comparison operator is Expression.ApplyBinaryComparison -> { val column = DSL.field( DSL.name( - listOf( - getCollectionForCompCol(e.column, request), - e.column.name - ) + splitCollectionName(getCollectionForCompCol(e.column, request)) + e.column.name ) ) val comparisonValue = when (val v = e.value) { is ComparisonValue.ColumnComp -> { - val col = getCollectionForCompCol(v.column, request) - DSL.field(DSL.name(listOf(col, v.column.name))) + val col = splitCollectionName(getCollectionForCompCol(v.column, request)) + DSL.field(DSL.name(col + v.column.name)) } is ComparisonValue.ScalarComp -> DSL.inline(v.value) @@ -119,91 +137,211 @@ sealed interface BaseGenerator { return buildComparison(column, e.operator, comparisonValue) } - // Test the specified column against a particular unary comparison operator is Expression.ApplyUnaryComparison -> { - val baseCond = run { - val column = DSL.field(DSL.name(listOf(request.collection, e.column))) - when (e.operator) { - ApplyUnaryComparisonOperator.IS_NULL -> column.isNull - } + val column = DSL.field(DSL.name(splitCollectionName(request.collection) + e.column)) + when (e.operator) { + ApplyUnaryComparisonOperator.IS_NULL -> column.isNull } - baseCond } - // Test the specified column against an array of values using a particular binary comparison operator is Expression.ApplyBinaryArrayComparison -> { - val baseCond = run { - val column = DSL.field( - DSL.name( - listOf( - getCollectionForCompCol(e.column, request), - e.column.name + val column = DSL.field( + DSL.name( + splitCollectionName(getCollectionForCompCol(e.column, request)) + e.column.name + ) + ) + when (e.operator) { + ApplyBinaryArrayComparisonOperator.IN -> { + when { + e.values.isEmpty() -> column.`in`( + DSL.select(DSL.nullCondition()) + .where(DSL.inline(1).eq(DSL.inline(0))) ) + + else -> column.`in`(DSL.list(e.values.map { + when (it) { + is ComparisonValue.ScalarComp -> DSL.inline(it.value) + is ComparisonValue.VariableComp -> DSL.field(DSL.name(listOf("vars", it.name))) + is ComparisonValue.ColumnComp -> { + val col = splitCollectionName(getCollectionForCompCol(it.column, request)) + DSL.field(DSL.name(col + it.column.name)) + } + } + })) + } + } + } + } + + is Expression.Exists -> { + when (val inTable = e.in_collection) { + is ExistsInCollection.Related -> { + val relOrig = request.collection_relationships[inTable.relationship] + ?: throw Exception("Exists relationship not found") + val rel = relOrig.copy(arguments = relOrig.arguments + inTable.arguments) + DSL.exists( + DSL + .selectOne() + .from( + DSL.table(DSL.name(splitCollectionName(rel.target_collection))) + ) + .where( + DSL.and( + listOf( + expressionToCondition( + e.predicate, + request, + rel.target_collection + ) + ) + + rel.column_mapping.map { (sourceCol, targetCol) -> + DSL.field(DSL.name(splitCollectionName(request.collection) + sourceCol)) + .eq(DSL.field(DSL.name(splitCollectionName(rel.target_collection) + targetCol))) + } + rel.arguments.map { + argumentToCondition( + request, + it, + rel.target_collection + ) + } + ) + ) ) - ) - when (e.operator) { - ApplyBinaryArrayComparisonOperator.IN -> { - when { - // Generate "IN (SELECT NULL WHERE 1 = 0)" for easier debugging - e.values.isEmpty() -> column.`in`( - DSL.select(DSL.nullCondition()) - .where(DSL.inline(1).eq(DSL.inline(0))) + } + + is ExistsInCollection.Unrelated -> { + val condition = mkSQLJoin( + Relationship( + target_collection = inTable.collection, + arguments = inTable.arguments, + column_mapping = emptyMap(), + relationship_type = RelationshipType.Array + ), + request.collection + ) + DSL.exists( + DSL + .selectOne() + .from( + DSL.table(DSL.name(splitCollectionName(inTable.collection))) ) + .where( + listOf( + expressionToCondition( + e.predicate, + request, + inTable.collection + ), condition + ) + ) + ) + } + } + } + } + } + + // TODO: Fix this later + // There are 2 problems here: + // 1) We need to allow passing some function to handle how (and if) the table name is prefixed to columns + // 2) The handling of "IN" operator is monkey-patched here, and should be handled in a more general way. + // NDC spec removed "ApplyBinaryArrayComparison" from the IR + fun expressionToConditionPhoenixNoTableNameInPredicates( + e: Expression, + request: QueryRequest + ): Condition { + + fun splitCollectionName(collectionName: String): List { + return collectionName.split(".") + } - else -> { - - // TODO: swtich map to local context as map will need to be separate for the select column comparisions - // TODO: is it safe to assume that cols will all be from one collections? - column.`in`(DSL.list(e.values.map { - when (it) { - is ComparisonValue.ScalarComp -> DSL.inline(it.value) - is ComparisonValue.VariableComp -> DSL.field(DSL.name(listOf("vars", it.name))) - is ComparisonValue.ColumnComp -> { - val col = getCollectionForCompCol(it.column, request) - DSL.field(DSL.name(listOf(col, it.column.name))) - } + return when (e) { + is Expression.Not -> DSL.not(expressionToConditionPhoenixNoTableNameInPredicates(e.expression, request)) + + is Expression.And -> when (e.expressions.size) { + 0 -> DSL.trueCondition() + else -> DSL.and(e.expressions.map { expressionToConditionPhoenixNoTableNameInPredicates(it, request) }) + } + + is Expression.Or -> when (e.expressions.size) { + 0 -> DSL.falseCondition() + else -> DSL.or(e.expressions.map { expressionToConditionPhoenixNoTableNameInPredicates(it, request) }) + } + + is Expression.ApplyBinaryComparison -> { + val column = DSL.field(DSL.name(e.column.name)) + val comparisonValue = when (val v = e.value) { + is ComparisonValue.ColumnComp -> { + DSL.field(DSL.name( v.column.name)) + } + is ComparisonValue.ScalarComp -> DSL.inline(v.value) + is ComparisonValue.VariableComp -> DSL.field(DSL.name(listOf("vars", v.name))) + } + + if (e.operator == ApplyBinaryComparisonOperator.IN) { + if (e.value is ComparisonValue.ScalarComp) { + val valueList = (e.value as ComparisonValue.ScalarComp).value as List + return column.`in`(valueList.map(DSL::inline)) + } + } + + return buildComparison(column, e.operator, comparisonValue as Field) + } + + is Expression.ApplyUnaryComparison -> { + val column = DSL.field(DSL.name(e.column)) + when (e.operator) { + ApplyUnaryComparisonOperator.IS_NULL -> column.isNull + } + } + + is Expression.ApplyBinaryArrayComparison -> { + val column = DSL.field( + DSL.name(e.column.name) + ) + when (e.operator) { + ApplyBinaryArrayComparisonOperator.IN -> { + when { + e.values.isEmpty() -> column.`in`( + DSL.select(DSL.nullCondition()) + .where(DSL.inline(1).eq(DSL.inline(0))) + ) + + else -> column.`in`(DSL.list(e.values.map { + when (it) { + is ComparisonValue.ScalarComp -> { + when (it.value) { + is String -> DSL.inline(it.value, String::class.java) + is Int -> DSL.inline(it.value, Int::class.java) + is Long -> DSL.inline(it.value, Long::class.java) + is Double -> DSL.inline(it.value, Double::class.java) + is Float -> DSL.inline(it.value, Float::class.java) + is Boolean -> DSL.inline(it.value, Boolean::class.java) + else -> DSL.inline(it.value) } - })) + } + is ComparisonValue.VariableComp -> DSL.field(DSL.name(listOf("vars", it.name))) + is ComparisonValue.ColumnComp -> { + DSL.field(DSL.name( it.column.name)) + } } - } + })) } } } - baseCond } - // Test if a row exists that matches the where subexpression in the specified table (in_table) - // - // where ( - // exists ( - // select 1 "one" - // from "AwsDataCatalog"."chinook"."album" - // where ( - // "AwsDataCatalog"."chinook"."album"."artistid" = "artist_base_fields_0"."artistid" - // and "AwsDataCatalog"."chinook"."album"."title" = 'For Those About To Rock We Salute You' - // and exists ( - // select 1 "one" - // from "AwsDataCatalog"."chinook"."track" - // where ( - // "AwsDataCatalog"."chinook"."track"."albumid" = "albumid" - // and "AwsDataCatalog"."chinook"."track"."name" = 'For Those About To Rock (We Salute You)' - // ) - // ) - // ) - // ) - // ) is Expression.Exists -> { when (val inTable = e.in_collection) { - // The table is related to the current table via the relationship name specified in relationship - // (this means it should be joined to the current table via the relationship) is ExistsInCollection.Related -> { - val relOrig = request.collection_relationships[inTable.relationship] ?: throw Exception("Exists relationship not found") + val relOrig = request.collection_relationships[inTable.relationship] + ?: throw Exception("Exists relationship not found") val rel = relOrig.copy(arguments = relOrig.arguments + inTable.arguments) DSL.exists( DSL .selectOne() .from( - DSL.table(DSL.name(rel.target_collection)) + DSL.table(DSL.name(splitCollectionName(rel.target_collection))) ) .where( DSL.and( @@ -214,17 +352,21 @@ sealed interface BaseGenerator { rel.target_collection ) ) + - rel.column_mapping.map { (sourceCol, targetCol) -> - DSL.field(DSL.name(listOf(request.collection, sourceCol))) - .eq(DSL.field(DSL.name(listOf(rel.target_collection, targetCol)))) - } + rel.arguments.map {argumentToCondition(request, it, rel.target_collection) } + rel.column_mapping.map { (sourceCol, targetCol) -> + DSL.field(DSL.name(splitCollectionName(request.collection) + sourceCol)) + .eq(DSL.field(DSL.name(splitCollectionName(rel.target_collection) + targetCol))) + } + rel.arguments.map { + argumentToCondition( + request, + it, + rel.target_collection + ) + } ) ) ) } - // The table specified by table is unrelated to the current table and therefore is not explicitly joined to the current table - // (this means it should be joined to the current table via a subquery) is ExistsInCollection.Unrelated -> { val condition = mkSQLJoin( Relationship( @@ -239,7 +381,7 @@ sealed interface BaseGenerator { DSL .selectOne() .from( - DSL.table(DSL.name(inTable.collection)) + DSL.table(DSL.name(splitCollectionName(inTable.collection))) ) .where( listOf( @@ -256,4 +398,5 @@ sealed interface BaseGenerator { } } } -} + +} \ No newline at end of file