diff --git a/verticapy/_utils/_sql/_check.py b/verticapy/_utils/_sql/_check.py index 85d77364d..fd8be2d19 100755 --- a/verticapy/_utils/_sql/_check.py +++ b/verticapy/_utils/_sql/_check.py @@ -35,6 +35,39 @@ def is_dql(query: str) -> bool: query = erase_comment(query) for idx, q in enumerate(query): if q not in (" ", "("): - result = query[idx:].lower().startswith(("select ", "with ")) + result = ( + query[idx:] + .lower() + .startswith( + ( + "select ", + "with ", + ) + ) + ) + break + return result + + +def is_procedure(query: str) -> bool: + """ + Returns True if the input SQL query + is a procedure. + """ + result = False + query = clean_query(query) + query = erase_comment(query) + for idx, q in enumerate(query): + if q not in (" ", "("): + result = ( + query[idx:] + .lower() + .startswith( + ( + "create procedure ", + "create or replace procedure ", + ) + ) + ) break return result diff --git a/verticapy/jupyter/extensions/sql_magic.py b/verticapy/jupyter/extensions/sql_magic.py index 88ab08974..7ecd84bd1 100644 --- a/verticapy/jupyter/extensions/sql_magic.py +++ b/verticapy/jupyter/extensions/sql_magic.py @@ -34,12 +34,14 @@ import verticapy._config.config as conf from verticapy._utils._object import create_new_vdf from verticapy._utils._sql._collect import save_verticapy_logs +from verticapy._utils._sql._check import is_procedure from verticapy._utils._sql._dblink import replace_external_queries from verticapy._utils._sql._format import ( clean_query, replace_vars_in_query, ) from verticapy._utils._sql._sys import _executeSQL +from verticapy.connection import current_cursor from verticapy.connection.global_connection import get_global_connection from verticapy.errors import QueryError @@ -48,6 +50,42 @@ if TYPE_CHECKING: from verticapy.core.vdataframe.base import vDataFrame +SPECIAL_WORDS = ( + # ML Algos + "ARIMA", + "AUTOREGRESSOR", + "BALANCE", + "BISECTING_KMEANS", + "CROSS_VALIDATE", + "DETECT_OUTLIERS", + "IFOREST", + "IMPUTE", + "KMEANS", + "KPROTOTYPES", + "LINEAR_REG", + "LOGISTIC_REG", + "MOVING_AVERAGE", + "NAIVE_BAYES", + "NORMALIZE", + "NORMALIZE_FIT", + "ONE_HOT_ENCODER_FIT", + "PCA", + "POISSON_REG", + "RF_CLASSIFIER", + "RF_REGRESSOR", + "SVD", + "SVM_CLASSIFIER", + "SVM_REGRESSOR", + "XGB_CLASSIFIER", + "XGB_REGRESSOR", + # ML Management + "CHANGE_MODEL_STATUS", + "EXPORT_MODELS", + "IMPORT_MODELS", + "REGISTER_MODEL", + "UPGRADE_MODEL", +) + @save_verticapy_logs @needs_local_scope @@ -743,6 +781,12 @@ def sql_magic( elif "-c" in options: queries = options["-c"] + # Case when it is a procedure + if is_procedure(queries): + current_cursor().execute(queries) + print("CREATE") + return + # Cleaning the Query queries = clean_query(queries) queries = replace_vars_in_query(queries, locals()["local_ns"]) @@ -816,11 +860,14 @@ def sql_magic( for i in range(n): query = queries[i] - if query.split(" ")[0]: - query_type = query.split(" ")[0].upper().replace("(", "") + query_words = query.split(" ") + idx = 0 if query_words[0] else 1 + query_type = query_words[idx].upper().replace("(", "") + if len(query_words) > 1: + query_subtype = query_words[idx + 1].upper() else: - query_type = query.split(" ")[1].upper().replace("(", "") + query_subtype = "UNDEFINED" if len(query_type) > 1 and query_type.startswith(("/*", "--")): query_type = "undefined" @@ -843,7 +890,7 @@ def sql_magic( elif (i < n - 1) or ( (i == n - 1) - and (query_type.lower() not in ("select", "with", "undefined")) + and (query_type.lower() not in ("select", "show", "with", "undefined")) ): error = "" @@ -869,25 +916,45 @@ def sql_magic( else: error = "" - try: + if query_type.lower() in ("show",): + final_result = _executeSQL( + query, method="fetchall", print_time_sql=False + ) + columns = [d.name for d in current_cursor().description] result = create_new_vdf( - query, - _is_sql_magic=True, + final_result, + usecols=columns, ) - result._vars["sql_magic_result"] = True - # Display parameters - if "-nrows" in options: - result._vars["max_rows"] = options["-nrows"] - if "-ncols" in options: - result._vars["max_columns"] = options["-ncols"] + continue - except: + is_vdf = False + if not (query_subtype.upper().startswith(SPECIAL_WORDS)): + try: + result = create_new_vdf( + query, + _is_sql_magic=True, + ) + result._vars["sql_magic_result"] = True + # Display parameters + if "-nrows" in options: + result._vars["max_rows"] = options["-nrows"] + if "-ncols" in options: + result._vars["max_columns"] = options["-ncols"] + is_vdf = True + except: + pass # we could not create a vDataFrame out of the query. + + if not (is_vdf): try: final_result = _executeSQL( query, method="fetchfirstelem", print_time_sql=False ) if final_result and conf.get_option("print_info"): print(final_result) + elif ( + query_subtype.upper().startswith(SPECIAL_WORDS) + ) and conf.get_option("print_info"): + print(query_subtype.upper()) elif conf.get_option("print_info"): print(query_type)