diff --git a/verticapy/tests_new/machine_learning/metrics/test_classification_metrics.py b/verticapy/tests_new/machine_learning/metrics/test_classification_metrics.py index 5e5ffa8f3..46a5d0015 100644 --- a/verticapy/tests_new/machine_learning/metrics/test_classification_metrics.py +++ b/verticapy/tests_new/machine_learning/metrics/test_classification_metrics.py @@ -252,6 +252,7 @@ def python_metrics(y_true, y_pred, average="binary", metric_name=""): ("n", "negative_likelihood_ratio"), ("n", "positive_likelihood_ratio"), ("y", "precision_score"), + ("y", "average_precision_score"), ("n", "prevalence_threshold"), ("y", "recall_score"), ("n", "specificity_score"), @@ -321,18 +322,30 @@ def get_vertica_metrics(): "prc_auc_score", "best_cutoff", "log_loss", + "average_precision_score", ]: if compute_method == "binary": - _vpy_res = getattr(vpy_metrics, metric_name)( - "y_true", "y_prob", vdf, average=compute_method, **func_args - ) + if metric_name == "average_precision_score": + _vpy_res = getattr(vpy_metrics, metric_name)( + "y_true_num", + "y_pred_num", + vdf, + average=compute_method, + pos_label="1", + ) + else: + _vpy_res = getattr(vpy_metrics, metric_name)( + "y_true", "y_prob", vdf, average=compute_method, **func_args + ) else: _vpy_res = getattr(vpy_metrics, metric_name)( "y_true_num", ["y_prob0", "y_prob1", "y_prob2"], vdf, average=compute_method, - labels=labels_num, + labels=[str(label_num) for label_num in labels_num] + if metric_name == "average_precision_score" + else labels_num, ) # rounding as best_cutoff metrics value precisions are upto 2/3 decimals _vpy_res = ( @@ -418,6 +431,18 @@ def get_python_metrics(): _y_true_num.ravel(), y_prob.ravel() ) _skl_res = skl_metrics.auc(fpr, tpr) + elif metric_name in ["average_precision_score"]: + if compute_method == "binary": + _skl_res = getattr(skl_metrics, metric_name)( + y_true_num, y_pred_num, pos_label=1 + ) + else: + _skl_res = getattr(skl_metrics, metric_name)( + y_true_num, + y_prob, + average=compute_method, + pos_label=1, + ) else: _skl_res = getattr(skl_metrics, metric_name)( y_true, y_pred, labels=labels @@ -435,8 +460,6 @@ def get_python_metrics(): _skl_res = skl_metrics.auc(recall, precision) else: y_true_num = label_binarize(y_true, classes=[0, 1, 2]) - print() - print(y_true_num) fpr, tpr, thresholds = skl_metrics.roc_curve( y_true_num, y_prob, pos_label="b" ) @@ -456,6 +479,13 @@ def get_python_metrics(): # rounding as best_cutoff metrics value precisions are upto 2/3 decimals _skl_res = round(_skl_res, 2) + elif metric_name in ["average_precision_score"]: + _skl_res = getattr(skl_metrics, metric_name)( + y_true_num, + y_prob, + average=None, + pos_label=1, + ) else: _skl_res = python_metrics( y_true, y_pred, average=compute_method, metric_name=metric_name