diff --git a/R/arg_check_ml.R b/R/arg_check_ml.R index 41f01c8..f8ab172 100644 --- a/R/arg_check_ml.R +++ b/R/arg_check_ml.R @@ -702,7 +702,7 @@ NULL #' @keywords internal #' @param y_default_eval [chr] y value of default evaluation plot. It can be #' "avg_runtime_sec" or one of the following performance metrics: -#' "avg_f1_score", "avg_log2_apop", "avg_bal_acc", or "avg_nmcc" +#' "avg_f1_score", "avg_log2_apop", "avg_bal_acc", "avg_mcc", or "avg_nmcc" #' .checkArgYDefaultEval <- function(y_default_eval) { if (!is.character(y_default_eval)) { @@ -710,11 +710,11 @@ NULL } if (!(y_default_eval %in% - c("avg_f1_score", "avg_log2_apop", "avg_bal_acc", "avg_nmcc")) + c("avg_f1_score", "avg_log2_apop", "avg_bal_acc", "avg_mcc", "avg_nmcc")) ) { stop(paste( "`y_default_eval` must be one of:", - "'avg_f1_score', 'avg_log2_apop', 'avg_bal_acc', 'avg_nmcc'." + "'avg_f1_score', 'avg_log2_apop', 'avg_bal_acc', 'avg_mcc', 'avg_nmcc'." )) } } diff --git a/R/core_ml.R b/R/core_ml.R index 5846b40..3f5fc12 100644 --- a/R/core_ml.R +++ b/R/core_ml.R @@ -391,16 +391,15 @@ getConfusionMatrix <- function(test_data_plus_predictions) { return(CM) } -#' .calculatenMCC() +#' .calculateMCC() #' -#' Returns the normalized (to a 0 to 1 scale instead of -1 to 1) Matthews -#' correlation coefficient (nMCC) based on the AMR phenotype predictions by an +#' Returns the Matthews correlation coefficient (MCC) +#' based on the AMR phenotype predictions by an #' ML model compared against the actual values. #' #' @inheritParams getConfusionMatrix -#' @return Normalized (to a 0 to 1 scale instead of -1 to 1) Matthews -#' correlation coefficient (nMCC) -.calculatenMCC <- function(test_data_plus_predictions) { +#' @return Matthews correlation coefficient (MCC), range -1 to 1 +.calculateMCC <- function(test_data_plus_predictions) { .checkArgTestDataPlusPredictions(test_data_plus_predictions) target_var <- .getTargetVarName(test_data_plus_predictions) @@ -408,11 +407,24 @@ getConfusionMatrix <- function(test_data_plus_predictions) { mcc <- test_data_plus_predictions |> yardstick::mcc(truth = !!target_var, estimate = .pred_class) |> dplyr::select(.estimate) |> - as.numeric() + as.numeric() |> + round(2) - nmcc <- (mcc + 1) / 2 + return(mcc) +} - return(round(nmcc, 2)) +#' .calculatenMCC() +#' +#' Returns the normalized (0 to 1) Matthews correlation coefficient (nMCC) +#' based on the AMR phenotype predictions by an ML model compared against +#' the actual values. +#' +#' @inheritParams getConfusionMatrix +#' @return Normalized Matthews correlation coefficient (nMCC), range 0 to 1 +.calculatenMCC <- function(test_data_plus_predictions) { + mcc <- .calculateMCC(test_data_plus_predictions) + nmcc <- round((mcc + 1) / 2, 2) + return(nmcc) } #' .calculateF1() @@ -607,12 +619,12 @@ getConfusionMatrix <- function(test_data_plus_predictions) { #' calculateEvalMets() #' #' Returns the F1 score, area under the precision-recall curve (AUPRC), balanced -#' accuracy, normalized (to a 0 to 1 scale instead of -1 to 1) Matthews -#' correlation coefficient (nMCC), and log2(AUPRC/prior) based on the AMR +#' accuracy, Matthews correlation coefficient (MCC), normalized MCC (nMCC), +#' and log2(AUPRC/prior) based on the AMR #' phenotype predictions by an ML model compared against the actual values. #' #' @inheritParams getConfusionMatrix -#' @return F1 score, AUPRC, balanced accuracy, nMCC, and log2(AUPRC/prior) +#' @return F1 score, AUPRC, balanced accuracy, MCC, nMCC, and log2(AUPRC/prior) #' @export calculateEvalMets <- function(test_data_plus_predictions) { .checkArgTestDataPlusPredictions(test_data_plus_predictions) @@ -622,10 +634,11 @@ calculateEvalMets <- function(test_data_plus_predictions) { bal_acc <- .calculateBalAcc(test_data_plus_predictions) sens <- .calculateSensitivity(test_data_plus_predictions) spec <- .calculateSpecificity(test_data_plus_predictions) + mcc <- .calculateMCC(test_data_plus_predictions) nmcc <- .calculatenMCC(test_data_plus_predictions) log2_apop <- .calculateLog2APOP(test_data_plus_predictions) - return(c(f1, auprc, bal_acc, nmcc, log2_apop)) + return(c(f1, auprc, bal_acc, mcc, nmcc, log2_apop)) } #' extractTopFeats() diff --git a/R/generate_matrices_ml.R b/R/generate_matrices_ml.R index 67ba14c..5baac2f 100644 --- a/R/generate_matrices_ml.R +++ b/R/generate_matrices_ml.R @@ -288,14 +288,14 @@ skipImbalancedMatrix <- function(genome_ids, if (group_type %in% c("drug_class", "drug_class_year", "drug_class_country")) { genome_ids <- DBI::dbGetQuery(con, sprintf(" WITH class_phenotypes AS ( - SELECT \"genome_drug.genome_id\" AS genome_id, + SELECT \"genome.genome_id\" AS genome_id, MAX(CASE WHEN \"genome_drug.resistant_phenotype\" = 'Resistant' THEN 1 ELSE 0 END) AS any_resistant, MIN(CASE WHEN \"genome_drug.resistant_phenotype\" = 'Susceptible' THEN 1 ELSE 0 END) AS all_susceptible FROM metadata WHERE %s - GROUP BY \"genome_drug.genome_id\" + GROUP BY \"genome.genome_id\" ) SELECT genome_id FROM class_phenotypes @@ -303,7 +303,7 @@ skipImbalancedMatrix <- function(genome_ids, ", condition_string))[[1]] } else { genome_ids <- DBI::dbGetQuery(con, sprintf(" - SELECT DISTINCT \"genome_drug.genome_id\" + SELECT DISTINCT \"genome.genome_id\" FROM metadata WHERE %s AND \"genome_drug.resistant_phenotype\" IN ('Resistant','Susceptible') @@ -499,7 +499,7 @@ skipImbalancedMatrix <- function(genome_ids, " COPY ( SELECT - f.\"genome_drug.genome_id\" AS genome_id, + f.\"genome.genome_id\" AS genome_id, %s AS feature_id, MAX(CAST(%s AS DOUBLE)) AS value, %s AS \"genome_drug.resistant_phenotype\" @@ -507,12 +507,12 @@ skipImbalancedMatrix <- function(genome_ids, FROM %s JOIN selected_genomes USING (genome_id) JOIN keep_features kf ON %s = kf.feature_id - JOIN metadata f ON genome_id = f.\"genome_drug.genome_id\" + JOIN metadata f ON genome_id = f.\"genome.genome_id\" WHERE %s AND f.\"genome_drug.resistant_phenotype\" IN ('Resistant','Susceptible') %s - GROUP BY f.\"genome_drug.genome_id\", %s %s - ORDER BY f.\"genome_drug.genome_id\", %s + GROUP BY f.\"genome.genome_id\", %s %s + ORDER BY f.\"genome.genome_id\", %s ) TO '%s' (FORMAT 'parquet', COMPRESSION 'zstd') @@ -733,7 +733,7 @@ skipImbalancedMatrix <- function(genome_ids, DBI::dbDisconnect(con0, shutdown = FALSE) classes <- metadata_all |> - dplyr::select(genome_drug.genome_id, resistant_classes) |> + dplyr::select(genome.genome_id, resistant_classes) |> dplyr::distinct() |> dplyr::group_by(resistant_classes) |> dplyr::count() |> @@ -745,7 +745,7 @@ skipImbalancedMatrix <- function(genome_ids, genomes_to_keep <- metadata_all |> dplyr::filter(resistant_classes %in% classes) |> - dplyr::pull(genome_drug.genome_id) + dplyr::pull(genome.genome_id) # Build one matrix per feature type and matrix type for (ftype in names(feature_types)) { @@ -825,21 +825,21 @@ skipImbalancedMatrix <- function(genome_ids, " COPY ( SELECT - f.\"genome_drug.genome_id\" AS genome_id, + f.\"genome.genome_id\" AS genome_id, %s AS feature_id, MAX(CAST(%s AS DOUBLE)) AS value, resistant_classes FROM %s JOIN selected_genomes USING (genome_id) JOIN keep_features kf ON %s = kf.feature_id - JOIN metadata f ON genome_id = f.\"genome_drug.genome_id\" + JOIN metadata f ON genome_id = f.\"genome.genome_id\" WHERE resistant_classes <> 'Intermediate' GROUP BY - f.\"genome_drug.genome_id\", + f.\"genome.genome_id\", %s, resistant_classes ORDER BY - f.\"genome_drug.genome_id\", + f.\"genome.genome_id\", %s ) TO '%s' diff --git a/R/globals.R b/R/globals.R index 131d016..e1e3ab5 100644 --- a/R/globals.R +++ b/R/globals.R @@ -44,8 +44,12 @@ utils::globalVariables(c( "idx_strat", "model", "neg_log10_adj_p", + "mcc", "nmcc", "num_obs", + "seed", + "sens", + "spec", "output_prefix", "p_value", "pair_id", diff --git a/R/ife_ml.R b/R/ife_ml.R index 3d72a6d..110a17d 100644 --- a/R/ife_ml.R +++ b/R/ife_ml.R @@ -38,7 +38,7 @@ removeTopFeats <- function(ml_input_tibble, top_feat_tibble) { #' runIFE #' Removes top features identified by ML models and retrains iteratively; -#' returns nMCC at each iteration. +#' returns MCC at each iteration. #' #' @param ml_input_tibble An ML-ready tibble generated by `loadMLInputTibble()` #' @param by_num [bool] Set to `TRUE` if removing top features as a percentage @@ -83,6 +83,7 @@ runIFE <- function( num_obs_vec <- c() res_prop_vec <- c() fit_mixture_vec <- c() + mcc_vec <- c() nmcc_vec <- c() n_feats_removed_vec <- c() total_feats_removed_vec <- c() @@ -154,6 +155,9 @@ runIFE <- function( fit_mixture_vec[i] <- ml_res$performance_tibble |> dplyr::select(fit_mixture) |> as.numeric() + mcc_vec[i] <- ml_res$performance_tibble |> + dplyr::select(mcc) |> + as.numeric() nmcc_vec[i] <- ml_res$performance_tibble |> dplyr::select(nmcc) |> as.numeric() @@ -247,7 +251,7 @@ runIFE <- function( percent_removed = c(0, percent_removal_vec), removal_type = rep(removal_type, length(num_obs_vec)), num_obs = num_obs_vec, res_prop = res_prop_vec, - fit_mixture = fit_mixture_vec, nmcc = nmcc_vec, + fit_mixture = fit_mixture_vec, mcc = mcc_vec, nmcc = nmcc_vec, n_feats_removed = n_feats_removed_vec, total_feats_removed = total_feats_removed_vec, run_time_sec = run_time_vec ) diff --git a/R/plot_ml.R b/R/plot_ml.R index 90121f3..538c0e2 100644 --- a/R/plot_ml.R +++ b/R/plot_ml.R @@ -78,7 +78,7 @@ plotTopFeatsVI <- function(fit, n_top_feats = 10) { #' or "n_fold" #' @param y_default_eval [chr] y value of default evaluation plot. It can be #' "avg_runtime_sec" or one of the following performance metrics: -#' "avg_f1_score", "avg_log2_apop", "avg_bal_acc", or "avg_nmcc" +#' "avg_f1_score", "avg_log2_apop", "avg_bal_acc", "avg_mcc", or "avg_nmcc" #' @param xlab [chr] Label for x axis #' @param ylab [chr] Label for y axis #' @return A `ggplot2` scatterplot (performance metric or runtime vs. diff --git a/R/run_ml_pipeline.R b/R/run_ml_pipeline.R index ad9f691..a21aa8b 100644 --- a/R/run_ml_pipeline.R +++ b/R/run_ml_pipeline.R @@ -297,10 +297,11 @@ runMLPipeline <- function( log2_apop <- .calculateLog2APOP(test_data_plus_predictions) } + mcc <- .calculateMCC(test_data_plus_predictions) nmcc <- .calculatenMCC(test_data_plus_predictions) if (verbose) { - message(paste("Normalized Matthews correlation coefficient:", nmcc)) + message(paste("Matthews correlation coefficient:", mcc, "| nMCC:", nmcc)) } top_feat_tibble <- extractTopFeats(fit, @@ -362,7 +363,7 @@ runMLPipeline <- function( performance_tibble <- tibble::tibble( num_obs = num_obs_ml_input_tibble, n_feat = getNumFeat(ml_input_tibble), model, train_prop = split[1], - val_prop = split[2], n_fold, nmcc, run_time_sec, + val_prop = split[2], n_fold, mcc, nmcc, run_time_sec, seed, date = as.character(Sys.Date()) ) @@ -383,18 +384,20 @@ runMLPipeline <- function( ) |> tibble::add_column(bal_acc, .after = "nmcc") |> tibble::add_column(f1, .after = "nmcc") |> - tibble::add_column(log2_apop, .after = "nmcc") + tibble::add_column(log2_apop, .after = "nmcc") |> + tibble::add_column(sens, .after = "nmcc") |> + tibble::add_column(spec, .after = "nmcc") } if (model == "LR") { performance_tibble <- performance_tibble |> - tibble::add_column(fit_penalty, .before = "nmcc") |> - tibble::add_column(fit_mixture, .before = "nmcc") + tibble::add_column(fit_penalty, .before = "mcc") |> + tibble::add_column(fit_mixture, .before = "mcc") } else if (model == "RF" || model == "BT") { performance_tibble <- performance_tibble |> - tibble::add_column(fit_trees, .before = "nmcc") |> - tibble::add_column(fit_mtry, .before = "nmcc") |> - tibble::add_column(fit_min_n, .before = "nmcc") + tibble::add_column(fit_trees, .before = "mcc") |> + tibble::add_column(fit_mtry, .before = "mcc") |> + tibble::add_column(fit_min_n, .before = "mcc") } if (external_test_data) { diff --git a/README.Rmd b/README.Rmd index 0783647..68a00f0 100644 --- a/README.Rmd +++ b/README.Rmd @@ -103,7 +103,7 @@ This uses specific matrices to test whether ML models can predict resistance aga - **Data preparation**: Load Parquet files and prepare ML-ready datasets - **Model training**: User-customizable logistic regression via tidymodels -- **Evaluation**: nMCC, F1, balanced accuracy, AuPRC, and confusion matrices +- **Evaluation**: MCC, nMCC, F1, balanced accuracy, AuPRC, and confusion matrices - **Feature importance**: Extract and rank predictive features See the [package vignette](https://jravilab.github.io/amRml/articles/intro.html) for detailed usage. diff --git a/README.md b/README.md index 0099d07..3940387 100644 --- a/README.md +++ b/README.md @@ -129,7 +129,7 @@ associated with MDR. - **Data preparation**: Load Parquet files and prepare ML-ready datasets - **Model training**: User-customizable logistic regression via tidymodels -- **Evaluation**: nMCC, F1, balanced accuracy, AuPRC, and confusion +- **Evaluation**: MCC, nMCC, F1, balanced accuracy, AuPRC, and confusion matrices - **Feature importance**: Extract and rank predictive features diff --git a/doc/intro.R b/doc/intro.R index 3782033..d7ca536 100644 --- a/doc/intro.R +++ b/doc/intro.R @@ -112,7 +112,8 @@ knitr::opts_chunk$set( ## ----metrics------------------------------------------------------------------ # # Individual metrics -# nmcc <- calculatenMCC(predictions) # Normalized MCC (0-1 scale) +# mcc <- calculateMCC(predictions) # Matthews correlation coefficient (-1 to 1) +# nmcc <- calculatenMCC(predictions) # Normalized MCC (0 to 1) # f1 <- calculateF1(predictions) # F1 score # bal_acc <- calculateBalAcc(predictions) # Balanced accuracy # auprc <- calculateAUPRC(predictions) # Area under PR curve @@ -150,7 +151,7 @@ knitr::opts_chunk$set( # verbose = TRUE # ) # -# # Results include nMCC at each iteration +# # Results include MCC and nMCC at each iteration # ife_results$ife_performance_tibble # ife_results$feats_removed # If return_feats = TRUE # @@ -233,8 +234,8 @@ knitr::opts_chunk$set( # ) # # # 5. Compare real vs baseline performance -# cat("Real nMCC:", results$performance_tibble$nmcc, "\n") -# cat("Baseline nMCC:", baseline_results$performance_tibble$nmcc, "\n") +# cat("Real MCC:", results$performance_tibble$mcc, "| nMCC:", results$performance_tibble$nmcc, "\n") +# cat("Baseline MCC:", baseline_results$performance_tibble$mcc, "| nMCC:", baseline_results$performance_tibble$nmcc, "\n") # # # 6. Run iterative feature elimination # ife_results <- runIFE( diff --git a/vignettes/intro.Rmd b/vignettes/intro.Rmd index af5bc8e..c40caba 100644 --- a/vignettes/intro.Rmd +++ b/vignettes/intro.Rmd @@ -31,7 +31,7 @@ After data curation with `amRdata`, use `generateMLInputs()` to create ML-ready The DuckDB with the parquet views is created with `cleanData()` ```{r generate-matrices} -# Generate all ML input matrices from curated data +# Example uses Shigella flexneri (Sfl); replace paths with your own species and drug. generateMLInputs( parquet_duckdb_path = "results/Sfl_parquet.duckdb", out_path = "results/", @@ -52,7 +52,6 @@ This generates: For classical train/validation/test splits instead of cross-validation: ```{r generate-splits} -# 70% train, 15% validation, 15% test generateMLInputs( parquet_duckdb_path = "results/Sfl_parquet.duckdb", out_path = "results/", @@ -88,7 +87,6 @@ The package expects Parquet files in long (sparse) format with these columns: `loadMLInputTibble()` converts this to wide format (one row per genome, one column per feature) for ML. ```{r load-data} -# Load a generated matrix ml_tibble <- loadMLInputTibble( parquet_path = "results/matrix/Sfl_drug_AMP_domains_binary_sparse.parquet" ) @@ -134,7 +132,7 @@ results$top_feat_tibble | `model` | Model type (LR) | | `train_prop`, `val_prop` | Train/validation split proportions | | `fit_penalty`, `fit_mixture` | Fitted hyperparameters | -| `nmcc`, `f1`, `bal_acc`, `log2_apop` | Performance metrics | +| `mcc`, `nmcc`, `f1`, `bal_acc`, `log2_apop` | Performance metrics | | `run_time_sec` | Runtime in seconds | **`top_feat_tibble`** - ranked feature importance: @@ -201,7 +199,8 @@ predictions <- predict(fit, test_data) ```{r metrics} # Individual metrics -nmcc <- calculatenMCC(predictions) # Normalized MCC (0-1 scale) +mcc <- calculateMCC(predictions) # Matthews correlation coefficient (-1 to 1) +nmcc <- calculatenMCC(predictions) # Normalized MCC (0 to 1) f1 <- calculateF1(predictions) # F1 score bal_acc <- calculateBalAcc(predictions) # Balanced accuracy auprc <- calculateAUPRC(predictions) # Area under PR curve @@ -216,7 +215,8 @@ conf_mat <- getConfusionMatrix(predictions) | Metric | Function | Description | |--------|----------|-------------| -| nMCC | `calculatenMCC()` | Normalized Matthews correlation coefficient (0-1) | +| MCC | `calculateMCC()` | Matthews correlation coefficient (-1 to 1) | +| nMCC | `calculatenMCC()` | Normalized Matthews correlation coefficient (0 to 1) | | F1 | `calculateF1()` | Harmonic mean of precision and recall | | Balanced accuracy | `calculateBalAcc()` | Average of sensitivity and specificity | | AUPRC | `calculateAUPRC()` | Area under precision-recall curve | @@ -251,7 +251,7 @@ ife_results <- runIFE( verbose = TRUE ) -# Results include nMCC at each iteration +# Results include MCC and nMCC at each iteration ife_results$ife_performance_tibble ife_results$feats_removed @@ -264,19 +264,19 @@ ml_tibble_reduced <- removeTopFeats(ml_tibble, top_features) ### Precision-recall curve ```{r plot-prc} -test_data_plus_predictions <- readr::read_tsv(results / ML_pred / Sfl_drug_AMP_domains_binary_prediction.tsv) +test_data_plus_predictions <- readr::read_tsv("results/ML_pred/Sfl_drug_AMP_domains_binary_prediction.tsv") plotPRC(test_data_plus_predictions) ``` ### ROC curve ```{r plot-roc} -test_data_plus_predictions <- readr::read_tsv(results / ML_pred / Sfl_drug_AMP_domains_binary_prediction.tsv) +test_data_plus_predictions <- readr::read_tsv("results/ML_pred/Sfl_drug_AMP_domains_binary_prediction.tsv") plotROC(test_data_plus_predictions) ``` ### Variable importance plot ```{r plot-vi} -topfeat <- readr::read_tsv(results / ML_top_features / Sfl_drug_AMP_domains_binary_top_features.tsv) +topfeat <- readr::read_tsv("results/ML_top_features/Sfl_drug_AMP_domains_binary_top_features.tsv") plotTopFeatsVI(topfeat) ``` ### Baseline comparison barplot @@ -306,7 +306,7 @@ As a non-ML baseline, run Fisher's exact tests with multiple testing correction: ```{r fisher} # Complete Fisher pipeline fisher_results <- runFishers( - matrix_path = "results/matrix/Cje_drug_CIP_genes_binary_sparse.parquet", + matrix_path = "results/matrix/Sfl_drug_AMP_genes_binary_sparse.parquet", Q = 0.05, alternative = "two.sided", susceptible_label = "Susceptible",