Skip to content

Commit

Permalink
fix: ignore case when matching function name (#16912)
Browse files Browse the repository at this point in the history
* add method to check sugar function

* compare sugar functions using unicase ascii

* add sqllogictest

* migrate builtin function lookup to unicase

* fix issues

* fix build issue
  • Loading branch information
notauserx authored Dec 27, 2024
1 parent 1f9a4eb commit bebf2af
Show file tree
Hide file tree
Showing 7 changed files with 128 additions and 84 deletions.
8 changes: 8 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -485,6 +485,7 @@ tower = { version = "0.5.1", features = ["util"] }
tower-service = "0.3.3"
twox-hash = "1.6.3"
typetag = "0.2.3"
unicase = "2.8.0"
unicode-segmentation = "1.10.1"
unindent = "0.2"
url = "2.3.1"
Expand Down
1 change: 1 addition & 0 deletions src/query/functions/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ siphasher = { workspace = true }
strength_reduce = { workspace = true }
stringslice = { workspace = true }
twox-hash = { workspace = true }
unicase = { workspace = true }

[dev-dependencies]
comfy-table = { workspace = true }
Expand Down
81 changes: 44 additions & 37 deletions src/query/functions/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,17 @@
use aggregates::AggregateFunctionFactory;
use ctor::ctor;
use databend_common_expression::FunctionRegistry;
use unicase::Ascii;

pub mod aggregates;
mod cast_rules;
pub mod scalars;
pub mod srfs;

pub fn is_builtin_function(name: &str) -> bool {
BUILTIN_FUNCTIONS.contains(name)
|| AggregateFunctionFactory::instance().contains(name)
let name = Ascii::new(name);
BUILTIN_FUNCTIONS.contains(name.into_inner())
|| AggregateFunctionFactory::instance().contains(name.into_inner())
|| GENERAL_WINDOW_FUNCTIONS.contains(&name)
|| GENERAL_LAMBDA_FUNCTIONS.contains(&name)
|| GENERAL_SEARCH_FUNCTIONS.contains(&name)
Expand All @@ -45,56 +47,61 @@ pub fn is_builtin_function(name: &str) -> bool {
// The plan of search function, async function and udf contains some arguments defined in meta,
// which may be modified by user at any time. Those functions are not not suitable for caching.
pub fn is_cacheable_function(name: &str) -> bool {
BUILTIN_FUNCTIONS.contains(name)
|| AggregateFunctionFactory::instance().contains(name)
let name = Ascii::new(name);
BUILTIN_FUNCTIONS.contains(name.into_inner())
|| AggregateFunctionFactory::instance().contains(name.into_inner())
|| GENERAL_WINDOW_FUNCTIONS.contains(&name)
|| GENERAL_LAMBDA_FUNCTIONS.contains(&name)
}

#[ctor]
pub static BUILTIN_FUNCTIONS: FunctionRegistry = builtin_functions();

pub const ASYNC_FUNCTIONS: [&str; 2] = ["nextval", "dict_get"];
pub const ASYNC_FUNCTIONS: [Ascii<&str>; 2] = [Ascii::new("nextval"), Ascii::new("dict_get")];

pub const GENERAL_WINDOW_FUNCTIONS: [&str; 13] = [
"row_number",
"rank",
"dense_rank",
"percent_rank",
"lag",
"lead",
"first_value",
"first",
"last_value",
"last",
"nth_value",
"ntile",
"cume_dist",
pub const GENERAL_WINDOW_FUNCTIONS: [Ascii<&str>; 13] = [
Ascii::new("row_number"),
Ascii::new("rank"),
Ascii::new("dense_rank"),
Ascii::new("percent_rank"),
Ascii::new("lag"),
Ascii::new("lead"),
Ascii::new("first_value"),
Ascii::new("first"),
Ascii::new("last_value"),
Ascii::new("last"),
Ascii::new("nth_value"),
Ascii::new("ntile"),
Ascii::new("cume_dist"),
];

pub const RANK_WINDOW_FUNCTIONS: [&str; 5] =
["first_value", "first", "last_value", "last", "nth_value"];

pub const GENERAL_LAMBDA_FUNCTIONS: [&str; 16] = [
"array_transform",
"array_apply",
"array_map",
"array_filter",
"array_reduce",
"json_array_transform",
"json_array_apply",
"json_array_map",
"json_array_filter",
"json_array_reduce",
"map_filter",
"map_transform_keys",
"map_transform_values",
"json_map_filter",
"json_map_transform_keys",
"json_map_transform_values",
pub const GENERAL_LAMBDA_FUNCTIONS: [Ascii<&str>; 16] = [
Ascii::new("array_transform"),
Ascii::new("array_apply"),
Ascii::new("array_map"),
Ascii::new("array_filter"),
Ascii::new("array_reduce"),
Ascii::new("json_array_transform"),
Ascii::new("json_array_apply"),
Ascii::new("json_array_map"),
Ascii::new("json_array_filter"),
Ascii::new("json_array_reduce"),
Ascii::new("map_filter"),
Ascii::new("map_transform_keys"),
Ascii::new("map_transform_values"),
Ascii::new("json_map_filter"),
Ascii::new("json_map_transform_keys"),
Ascii::new("json_map_transform_values"),
];

pub const GENERAL_SEARCH_FUNCTIONS: [&str; 3] = ["match", "query", "score"];
pub const GENERAL_SEARCH_FUNCTIONS: [Ascii<&str>; 3] = [
Ascii::new("match"),
Ascii::new("query"),
Ascii::new("score"),
];

fn builtin_functions() -> FunctionRegistry {
let mut registry = FunctionRegistry::empty();
Expand Down
1 change: 1 addition & 0 deletions src/query/sql/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ serde = { workspace = true }
sha2 = { workspace = true }
simsearch = { workspace = true }
tokio = { workspace = true }
unicase = { workspace = true }
url = { workspace = true }

[lints]
Expand Down
117 changes: 70 additions & 47 deletions src/query/sql/src/planner/semantic/type_check.rs
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ use itertools::Itertools;
use jsonb::keypath::KeyPath;
use jsonb::keypath::KeyPaths;
use simsearch::SimSearch;
use unicase::Ascii;

use super::name_resolution::NameResolutionContext;
use super::normalize_identifier;
Expand Down Expand Up @@ -194,7 +195,7 @@ pub struct TypeChecker<'a> {
// This is used to check if there is nested aggregate function.
in_aggregate_function: bool,

// true if current expr is inside an window function.
// true if current expr is inside a window function.
// This is used to allow aggregation function in window's aggregate function.
in_window_function: bool,
forbid_udf: bool,
Expand Down Expand Up @@ -731,8 +732,9 @@ impl<'a> TypeChecker<'a> {
} => {
let func_name = normalize_identifier(name, self.name_resolution_ctx).to_string();
let func_name = func_name.as_str();
let uni_case_func_name = Ascii::new(func_name);
if !is_builtin_function(func_name)
&& !Self::all_sugar_functions().contains(&func_name)
&& !Self::all_sugar_functions().contains(&uni_case_func_name)
{
if let Some(udf) = self.resolve_udf(*span, func_name, args)? {
return Ok(udf);
Expand All @@ -743,15 +745,35 @@ impl<'a> TypeChecker<'a> {
.all_function_names()
.into_iter()
.chain(AggregateFunctionFactory::instance().registered_names())
.chain(GENERAL_WINDOW_FUNCTIONS.iter().cloned().map(str::to_string))
.chain(GENERAL_LAMBDA_FUNCTIONS.iter().cloned().map(str::to_string))
.chain(GENERAL_SEARCH_FUNCTIONS.iter().cloned().map(str::to_string))
.chain(ASYNC_FUNCTIONS.iter().cloned().map(str::to_string))
.chain(
GENERAL_WINDOW_FUNCTIONS
.iter()
.cloned()
.map(|ascii| ascii.into_inner().to_string()),
)
.chain(
GENERAL_LAMBDA_FUNCTIONS
.iter()
.cloned()
.map(|ascii| ascii.into_inner().to_string()),
)
.chain(
GENERAL_SEARCH_FUNCTIONS
.iter()
.cloned()
.map(|ascii| ascii.into_inner().to_string()),
)
.chain(
ASYNC_FUNCTIONS
.iter()
.cloned()
.map(|ascii| ascii.into_inner().to_string()),
)
.chain(
Self::all_sugar_functions()
.iter()
.cloned()
.map(str::to_string),
.map(|ascii| ascii.into_inner().to_string()),
);
let mut engine: SimSearch<String> = SimSearch::new();
for func_name in all_funcs {
Expand Down Expand Up @@ -779,15 +801,15 @@ impl<'a> TypeChecker<'a> {
// check window function legal
if window.is_some()
&& !AggregateFunctionFactory::instance().contains(func_name)
&& !GENERAL_WINDOW_FUNCTIONS.contains(&func_name)
&& !GENERAL_WINDOW_FUNCTIONS.contains(&uni_case_func_name)
{
return Err(ErrorCode::SemanticError(
"only window and aggregate functions allowed in window syntax",
)
.set_span(*span));
}
// check lambda function legal
if lambda.is_some() && !GENERAL_LAMBDA_FUNCTIONS.contains(&func_name) {
if lambda.is_some() && !GENERAL_LAMBDA_FUNCTIONS.contains(&uni_case_func_name) {
return Err(ErrorCode::SemanticError(
"only lambda functions allowed in lambda syntax",
)
Expand All @@ -796,7 +818,7 @@ impl<'a> TypeChecker<'a> {

let args: Vec<&Expr> = args.iter().collect();

if GENERAL_WINDOW_FUNCTIONS.contains(&func_name) {
if GENERAL_WINDOW_FUNCTIONS.contains(&uni_case_func_name) {
// general window function
if window.is_none() {
return Err(ErrorCode::SemanticError(format!(
Expand Down Expand Up @@ -862,7 +884,7 @@ impl<'a> TypeChecker<'a> {
// aggregate function
Box::new((new_agg_func.into(), data_type))
}
} else if GENERAL_LAMBDA_FUNCTIONS.contains(&func_name) {
} else if GENERAL_LAMBDA_FUNCTIONS.contains(&uni_case_func_name) {
if lambda.is_none() {
return Err(ErrorCode::SemanticError(format!(
"function {func_name} must have a lambda expression",
Expand All @@ -871,8 +893,8 @@ impl<'a> TypeChecker<'a> {
}
let lambda = lambda.as_ref().unwrap();
self.resolve_lambda_function(*span, func_name, &args, lambda)?
} else if GENERAL_SEARCH_FUNCTIONS.contains(&func_name) {
match func_name {
} else if GENERAL_SEARCH_FUNCTIONS.contains(&uni_case_func_name) {
match func_name.to_lowercase().as_str() {
"score" => self.resolve_score_search_function(*span, func_name, &args)?,
"match" => self.resolve_match_search_function(*span, func_name, &args)?,
"query" => self.resolve_query_search_function(*span, func_name, &args)?,
Expand All @@ -884,7 +906,7 @@ impl<'a> TypeChecker<'a> {
.set_span(*span));
}
}
} else if ASYNC_FUNCTIONS.contains(&func_name) {
} else if ASYNC_FUNCTIONS.contains(&uni_case_func_name) {
self.resolve_async_function(*span, func_name, &args)?
} else if BUILTIN_FUNCTIONS
.get_property(func_name)
Expand Down Expand Up @@ -1445,7 +1467,7 @@ impl<'a> TypeChecker<'a> {
self.in_window_function = false;

// If { IGNORE | RESPECT } NULLS is not specified, the default is RESPECT NULLS
// (i.e. a NULL value will be returned if the expression contains a NULL value and it is the first value in the expression).
// (i.e. a NULL value will be returned if the expression contains a NULL value, and it is the first value in the expression).
let ignore_null = if let Some(ignore_null) = window_ignore_null {
*ignore_null
} else {
Expand Down Expand Up @@ -2090,7 +2112,7 @@ impl<'a> TypeChecker<'a> {
param_count: usize,
span: Span,
) -> Result<()> {
// json lambda functions are casted to array or map, ignored here.
// json lambda functions are cast to array or map, ignored here.
let expected_count = if func_name == "array_reduce" {
2
} else if func_name.starts_with("array") {
Expand Down Expand Up @@ -3124,37 +3146,38 @@ impl<'a> TypeChecker<'a> {
Ok(Box::new((subquery_expr.into(), data_type)))
}

pub fn all_sugar_functions() -> &'static [&'static str] {
&[
"current_catalog",
"database",
"currentdatabase",
"current_database",
"version",
"user",
"currentuser",
"current_user",
"current_role",
"connection_id",
"timezone",
"nullif",
"ifnull",
"nvl",
"nvl2",
"is_null",
"is_error",
"error_or",
"coalesce",
"last_query_id",
"array_sort",
"array_aggregate",
"to_variant",
"try_to_variant",
"greatest",
"least",
"stream_has_data",
"getvariable",
]
pub fn all_sugar_functions() -> &'static [Ascii<&'static str>] {
static FUNCTIONS: &[Ascii<&'static str>] = &[
Ascii::new("current_catalog"),
Ascii::new("database"),
Ascii::new("currentdatabase"),
Ascii::new("current_database"),
Ascii::new("version"),
Ascii::new("user"),
Ascii::new("currentuser"),
Ascii::new("current_user"),
Ascii::new("current_role"),
Ascii::new("connection_id"),
Ascii::new("timezone"),
Ascii::new("nullif"),
Ascii::new("ifnull"),
Ascii::new("nvl"),
Ascii::new("nvl2"),
Ascii::new("is_null"),
Ascii::new("is_error"),
Ascii::new("error_or"),
Ascii::new("coalesce"),
Ascii::new("last_query_id"),
Ascii::new("array_sort"),
Ascii::new("array_aggregate"),
Ascii::new("to_variant"),
Ascii::new("try_to_variant"),
Ascii::new("greatest"),
Ascii::new("least"),
Ascii::new("stream_has_data"),
Ascii::new("getvariable"),
];
FUNCTIONS
}

fn try_rewrite_sugar_function(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@ select * from Student
statement ok
set unquoted_ident_case_sensitive = 1

statement ok
SELECT VERSION()

statement error (?s)1025,.*Unknown table `default`\.`default`\.student \.
INSERT INTO student VALUES(1)

Expand Down

0 comments on commit bebf2af

Please sign in to comment.