From 4a707e62c096eaea93969df08b478ed835ea42e6 Mon Sep 17 00:00:00 2001 From: Abhirupa Ghosh <100681585+AbhirupaGhosh@users.noreply.github.com> Date: Tue, 17 Feb 2026 11:09:23 -0700 Subject: [PATCH 1/7] Rename and update MCC calculation functions Some necessary updates --- R/core_ml.R | 27 +++++++++++++-------------- 1 file changed, 13 insertions(+), 14 deletions(-) diff --git a/R/core_ml.R b/R/core_ml.R index db874db..38e8d2e 100644 --- a/R/core_ml.R +++ b/R/core_ml.R @@ -380,27 +380,26 @@ getConfusionMatrix <- function(test_data_plus_predictions) { return(CM) } -#' .calculatenMCC() +#' .calculateMCC() #' -#' Returns the normalized (to a 0 to 1 scale instead of -1 to 1) Matthews -#' correlation coefficient (nMCC) based on the AMR phenotype predictions by an +#' Returns the Matthews correlation coefficient (MCC) +#' based on the AMR phenotype predictions by an #' ML model compared against the actual values. #' #' @inheritParams getConfusionMatrix -#' @return Normalized (to a 0 to 1 scale instead of -1 to 1) Matthews -#' correlation coefficient (nMCC) -.calculatenMCC <- function(test_data_plus_predictions) { +#' @return Matthews correlation coefficient (MCC) +.calculateMCC <- function(test_data_plus_predictions) { .checkArgTestDataPlusPredictions(test_data_plus_predictions) target_var <- .getTargetVarName(test_data_plus_predictions) mcc <- test_data_plus_predictions |> yardstick::mcc(truth = !!target_var, estimate = .pred_class) |> - dplyr::select(.estimate) |> as.numeric() + dplyr::select(.estimate) |> as.numeric() |> round(2) - nmcc <- (mcc + 1) / 2 + # nmcc <- (mcc + 1) / 2 - return(round(nmcc, 2)) + return(mcc) } #' .calculateF1() @@ -560,12 +559,12 @@ getConfusionMatrix <- function(test_data_plus_predictions) { #' calculateEvalMets() #' #' Returns the F1 score, area under the precision-recall curve (AUPRC), balanced -#' accuracy, normalized (to a 0 to 1 scale instead of -1 to 1) Matthews -#' correlation coefficient (nMCC), and log2(AUPRC/prior) based on the AMR +#' accuracy, Matthews correlation coefficient (MCC), +#' and log2(AUPRC/prior) based on the AMR #' phenotype predictions by an ML model compared against the actual values. #' #' @inheritParams getConfusionMatrix -#' @return F1 score, AUPRC, balanced accuracy, nMCC, and log2(AUPRC/prior) +#' @return F1 score, AUPRC, balanced accuracy, MCC, and log2(AUPRC/prior) #' @export calculateEvalMets <- function(test_data_plus_predictions) { .checkArgTestDataPlusPredictions(test_data_plus_predictions) @@ -575,10 +574,10 @@ calculateEvalMets <- function(test_data_plus_predictions) { bal_acc <- .calculateBalAcc(test_data_plus_predictions) sens <- .calculateSensitivity(test_data_plus_predictions) spec <- .calculateSpecificity(test_data_plus_predictions) - nmcc <- .calculatenMCC(test_data_plus_predictions) + mcc <- .calculateMCC(test_data_plus_predictions) log2_apop <- .calculateLog2APOP(test_data_plus_predictions) - return(c(f1, auprc, bal_acc, nmcc, log2_apop)) + return(c(f1, auprc, bal_acc, mcc, log2_apop)) } #' extractTopFeats() From ca6bf15de7ce3e198db243cf4958073a6341d116 Mon Sep 17 00:00:00 2001 From: AbhirupaGhosh Date: Tue, 17 Feb 2026 18:11:44 +0000 Subject: [PATCH 2/7] Style code (GHA) --- R/core_ml.R | 224 +++++++++++-------- R/generate_matrices_ml.R | 197 ++++++++++------- R/globals.R | 2 - R/plot_ml.R | 1 - R/prep_ml.R | 6 +- R/run_ML.R | 450 +++++++++++++++++++++------------------ R/run_ml_pipeline.R | 41 ++-- vignettes/intro.Rmd | 30 +-- 8 files changed, 547 insertions(+), 404 deletions(-) diff --git a/R/core_ml.R b/R/core_ml.R index 38e8d2e..8d84b47 100644 --- a/R/core_ml.R +++ b/R/core_ml.R @@ -73,7 +73,8 @@ NULL #' @return An `rsplit` object #' @export splitMLInputTibble <- function(ml_input_tibble, split = c(0.6, 0.2), seed = 5280) { - .checkArgTibble(ml_input_tibble, ml = TRUE); .checkArgSplit(split) + .checkArgTibble(ml_input_tibble, ml = TRUE) + .checkArgSplit(split) .checkArgSeed(seed) set.seed(seed) @@ -85,7 +86,7 @@ splitMLInputTibble <- function(ml_input_tibble, split = c(0.6, 0.2), seed = 5280 # If in CV mode: # Still retain a stratified testing holdout purely for final reporting metrics; # CV is only performed on the training portion. - prop_train_for_holdout <- 0.8 # 80 percent train, 20 percent reserved test + prop_train_for_holdout <- 0.8 # 80 percent train, 20 percent reserved test data_split <- rsample::initial_split( ml_input_tibble, prop = prop_train_for_holdout, @@ -115,7 +116,8 @@ splitMLInputTibble <- function(ml_input_tibble, split = c(0.6, 0.2), seed = 5280 #' @return A `recipe` object #' @export buildRecipe <- function(train_data, use_pca = FALSE, pca_threshold = 0.95) { - .checkArgTibble(train_data, ml = TRUE); .checkArgUsePCA(use_pca) + .checkArgTibble(train_data, ml = TRUE) + .checkArgUsePCA(use_pca) .checkArgPCAThreshold(pca_threshold) target_var <- .getTargetVarName(train_data) |> as.character() @@ -124,8 +126,10 @@ buildRecipe <- function(train_data, use_pca = FALSE, pca_threshold = 0.95) { nm <- names(train_data) id_cols <- setdiff(nm[grepl("^genome", nm)], target_var) - rec <- recipes::recipe(formula = stats::reformulate(".", response = target_var), - data = train_data) + rec <- recipes::recipe( + formula = stats::reformulate(".", response = target_var), + data = train_data + ) # Only update roles if we actually have ID columns to mark as metadata if (length(id_cols) > 0) { @@ -146,7 +150,6 @@ buildRecipe <- function(train_data, use_pca = FALSE, pca_threshold = 0.95) { } - #' buildLRModel() #' #' Builds a logistic regression model. @@ -158,13 +161,17 @@ buildRecipe <- function(train_data, use_pca = FALSE, pca_threshold = 0.95) { buildLRModel <- function(multi_class = FALSE) { .checkArgMultiClass(multi_class) - if(!multi_class) { - lr_mod <- parsnip::logistic_reg(penalty = hardhat::tune(), - mixture = hardhat::tune()) |> + if (!multi_class) { + lr_mod <- parsnip::logistic_reg( + penalty = hardhat::tune(), + mixture = hardhat::tune() + ) |> parsnip::set_engine(engine = "glmnet") - } else if(multi_class) { - lr_mod <- parsnip::multinom_reg(penalty = hardhat::tune(), - mixture = hardhat::tune()) |> + } else if (multi_class) { + lr_mod <- parsnip::multinom_reg( + penalty = hardhat::tune(), + mixture = hardhat::tune() + ) |> parsnip::set_engine(engine = "glmnet") } @@ -181,9 +188,11 @@ buildLRModel <- function(multi_class = FALSE) { #' @return A `workflow` object #' @export buildWflow <- function(parsnip_mod, recipe) { - .checkArgParsnipMod(parsnip_mod); .checkArgRecipe(recipe) + .checkArgParsnipMod(parsnip_mod) + .checkArgRecipe(recipe) - wflow <- workflows::workflow() |> workflows::add_model(parsnip_mod) |> + wflow <- workflows::workflow() |> + workflows::add_model(parsnip_mod) |> workflows::add_recipe(recipe) return(wflow) @@ -203,21 +212,21 @@ buildWflow <- function(parsnip_mod, recipe) { #' @return A logistic regression tuning grid as a tibble #' @export buildTuningGrid <- function( - model = "LR", - penalty_vec = 10^seq(-4, -1, length.out = 10), - mix_vec = 0:5 / 5 + model = "LR", + penalty_vec = 10^seq(-4, -1, length.out = 10), + mix_vec = 0:5 / 5 ) { .checkArgModel(model) - + if (model == "LR") { .checkArgPenaltyVec(penalty_vec) .checkArgMixVec(mix_vec) - + penalty <- rep(penalty_vec, each = length(mix_vec)) mixture <- rep(mix_vec, length(penalty_vec)) grid <- tibble::tibble(penalty, mixture) } - + return(grid) } @@ -237,13 +246,14 @@ buildTuningGrid <- function( #' @export tuneGrid <- function(wflow, data_split, grid = buildTuningGrid(model = "LR"), n_fold = 5) { - .checkArgTibble(grid); .checkArgWflow(wflow) + .checkArgTibble(grid) + .checkArgWflow(wflow) .checkArgDataSplit(data_split) split_class <- class(data_split)[1] # Always do CV on the training portion of the split - train_df <- rsample::training(data_split) + train_df <- rsample::training(data_split) target_var <- .getTargetVarName(train_df) if (identical(split_class, "initial_split")) { @@ -259,9 +269,9 @@ tuneGrid <- function(wflow, data_split, grid = buildTuningGrid(model = "LR"), tune_res <- tune::tune_grid( wflow, resamples = resamples, - grid = grid, - control = tune::control_grid(save_pred = TRUE), - metrics = yardstick::metric_set( + grid = grid, + control = tune::control_grid(save_pred = TRUE), + metrics = yardstick::metric_set( yardstick::f_meas, yardstick::pr_auc, yardstick::spec, @@ -286,7 +296,8 @@ tuneGrid <- function(wflow, data_split, grid = buildTuningGrid(model = "LR"), #' @return Best model workflow #' @export selectBestModel <- function(tune_res, wflow, select_best_metric = "mcc") { - .checkArgTuneRes(tune_res); .checkArgWflow(wflow) + .checkArgTuneRes(tune_res) + .checkArgWflow(wflow) .checkArgSelectBestMetric(select_best_metric) best_mod <- tune::select_best(tune_res, metric = select_best_metric) @@ -306,7 +317,8 @@ selectBestModel <- function(tune_res, wflow, select_best_metric = "mcc") { #' @return Best model fit #' @export fitBestModel <- function(final_mod, train_data) { - .checkArgWflow(final_mod); .checkArgTibble(train_data, ml = TRUE) + .checkArgWflow(final_mod) + .checkArgTibble(train_data, ml = TRUE) fit <- final_mod |> parsnip::fit(data = train_data) @@ -324,8 +336,7 @@ fitBestModel <- function(final_mod, train_data) { model <- class(fit$fit$actions$model$spec)[1] - if(model %in% c("logistic_reg", "multinom_reg")) { - + if (model %in% c("logistic_reg", "multinom_reg")) { penalty <- fit$fit$fit$spec$args$penalty mixture <- tryCatch( @@ -334,7 +345,6 @@ fitBestModel <- function(final_mod, train_data) { ) tibble::tibble(penalty = penalty, mixture = mixture) - } else { stop("The `fit` object provided must correspond to 'logistic_reg' or 'multinom_reg'.") } @@ -353,7 +363,8 @@ fitBestModel <- function(final_mod, train_data) { #' labels #' @export predictML <- function(fit, test_data) { - .checkArgWflow(fit); .checkArgTibble(test_data, ml = TRUE) + .checkArgWflow(fit) + .checkArgTibble(test_data, ml = TRUE) test_data_plus_predictions <- parsnip::augment(fit, test_data) @@ -382,7 +393,7 @@ getConfusionMatrix <- function(test_data_plus_predictions) { #' .calculateMCC() #' -#' Returns the Matthews correlation coefficient (MCC) +#' Returns the Matthews correlation coefficient (MCC) #' based on the AMR phenotype predictions by an #' ML model compared against the actual values. #' @@ -395,9 +406,11 @@ getConfusionMatrix <- function(test_data_plus_predictions) { mcc <- test_data_plus_predictions |> yardstick::mcc(truth = !!target_var, estimate = .pred_class) |> - dplyr::select(.estimate) |> as.numeric() |> round(2) + dplyr::select(.estimate) |> + as.numeric() |> + round(2) - # nmcc <- (mcc + 1) / 2 + # nmcc <- (mcc + 1) / 2 return(mcc) } @@ -412,15 +425,21 @@ getConfusionMatrix <- function(test_data_plus_predictions) { .calculateF1 <- function(test_data_plus_predictions) { .checkArgTestDataPlusPredictions(test_data_plus_predictions) - if(!("genome_drug.resistant_phenotype" %in% + if (!("genome_drug.resistant_phenotype" %in% colnames(test_data_plus_predictions))) { - stop(paste("`test_data_plus_predictions` does not have a column for", - "`genome_drug.resistant_phenotype`.")) + stop(paste( + "`test_data_plus_predictions` does not have a column for", + "`genome_drug.resistant_phenotype`." + )) } f1 <- test_data_plus_predictions |> - yardstick::f_meas(truth = genome_drug.resistant_phenotype, - estimate = .pred_class) |> dplyr::select(.estimate) |> as.numeric() |> + yardstick::f_meas( + truth = genome_drug.resistant_phenotype, + estimate = .pred_class + ) |> + dplyr::select(.estimate) |> + as.numeric() |> round(2) return(f1) @@ -436,16 +455,21 @@ getConfusionMatrix <- function(test_data_plus_predictions) { .calculateAUPRC <- function(test_data_plus_predictions) { .checkArgTestDataPlusPredictions(test_data_plus_predictions) - if(!("genome_drug.resistant_phenotype" %in% + if (!("genome_drug.resistant_phenotype" %in% colnames(test_data_plus_predictions))) { - stop(paste("`test_data_plus_predictions` does not have a column for", - "`genome_drug.resistant_phenotype`.")) + stop(paste( + "`test_data_plus_predictions` does not have a column for", + "`genome_drug.resistant_phenotype`." + )) } auprc <- test_data_plus_predictions |> yardstick::pr_auc( - truth = genome_drug.resistant_phenotype, .pred_Resistant) |> - dplyr::select(.estimate) |> as.numeric() |> round(2) + truth = genome_drug.resistant_phenotype, .pred_Resistant + ) |> + dplyr::select(.estimate) |> + as.numeric() |> + round(2) return(auprc) } @@ -460,26 +484,33 @@ getConfusionMatrix <- function(test_data_plus_predictions) { .calculateLog2APOP <- function(test_data_plus_predictions) { .checkArgTestDataPlusPredictions(test_data_plus_predictions) - if(!("genome_drug.resistant_phenotype" %in% + if (!("genome_drug.resistant_phenotype" %in% colnames(test_data_plus_predictions))) { - stop(paste("`test_data_plus_predictions` does not have a column for", - "`genome_drug.resistant_phenotype`.")) + stop(paste( + "`test_data_plus_predictions` does not have a column for", + "`genome_drug.resistant_phenotype`." + )) } auprc <- .calculateAUPRC(test_data_plus_predictions) prior <- sum( - test_data_plus_predictions$genome_drug.resistant_phenotype == "Resistant") / + test_data_plus_predictions$genome_drug.resistant_phenotype == "Resistant" + ) / nrow(test_data_plus_predictions) - if(prior > 0.3 && prior < 0.7) { - warning(paste("Classes are roughly balanced.", - "Calculation of log2(AUPRC/prior) may be inappropriate.")) - } else if(prior >= 0.7) { - warning(paste("Classes are imbalanced toward the resistant phenotype.", - "Calculation of log2(AUPRC/prior) may be inappropriate.")) + if (prior > 0.3 && prior < 0.7) { + warning(paste( + "Classes are roughly balanced.", + "Calculation of log2(AUPRC/prior) may be inappropriate." + )) + } else if (prior >= 0.7) { + warning(paste( + "Classes are imbalanced toward the resistant phenotype.", + "Calculation of log2(AUPRC/prior) may be inappropriate." + )) } - log2_apop <- log2(auprc/prior) |> round(2) + log2_apop <- log2(auprc / prior) |> round(2) return(log2_apop) } @@ -494,16 +525,21 @@ getConfusionMatrix <- function(test_data_plus_predictions) { .calculateBalAcc <- function(test_data_plus_predictions) { .checkArgTestDataPlusPredictions(test_data_plus_predictions) - if(!("genome_drug.resistant_phenotype" %in% + if (!("genome_drug.resistant_phenotype" %in% colnames(test_data_plus_predictions))) { - stop(paste("`test_data_plus_predictions` does not have a column for", - "`genome_drug.resistant_phenotype`.")) + stop(paste( + "`test_data_plus_predictions` does not have a column for", + "`genome_drug.resistant_phenotype`." + )) } bal_acc <- test_data_plus_predictions |> yardstick::bal_accuracy( - truth = genome_drug.resistant_phenotype, estimate = .pred_class) |> - dplyr::select(.estimate) |> as.numeric() |> round(2) + truth = genome_drug.resistant_phenotype, estimate = .pred_class + ) |> + dplyr::select(.estimate) |> + as.numeric() |> + round(2) return(bal_acc) } @@ -518,15 +554,21 @@ getConfusionMatrix <- function(test_data_plus_predictions) { .calculateSensitivity <- function(test_data_plus_predictions) { .checkArgTestDataPlusPredictions(test_data_plus_predictions) - if(!("genome_drug.resistant_phenotype" %in% + if (!("genome_drug.resistant_phenotype" %in% colnames(test_data_plus_predictions))) { - stop(paste("`test_data_plus_predictions` does not have a column for", - "`genome_drug.resistant_phenotype`.")) + stop(paste( + "`test_data_plus_predictions` does not have a column for", + "`genome_drug.resistant_phenotype`." + )) } sens <- test_data_plus_predictions |> - yardstick::sens(truth = genome_drug.resistant_phenotype, - estimate = .pred_class) |> dplyr::select(.estimate) |> as.numeric() |> + yardstick::sens( + truth = genome_drug.resistant_phenotype, + estimate = .pred_class + ) |> + dplyr::select(.estimate) |> + as.numeric() |> round(2) return(sens) @@ -542,15 +584,21 @@ getConfusionMatrix <- function(test_data_plus_predictions) { .calculateSpecificity <- function(test_data_plus_predictions) { .checkArgTestDataPlusPredictions(test_data_plus_predictions) - if(!("genome_drug.resistant_phenotype" %in% + if (!("genome_drug.resistant_phenotype" %in% colnames(test_data_plus_predictions))) { - stop(paste("`test_data_plus_predictions` does not have a column for", - "`genome_drug.resistant_phenotype`.")) + stop(paste( + "`test_data_plus_predictions` does not have a column for", + "`genome_drug.resistant_phenotype`." + )) } spec <- test_data_plus_predictions |> - yardstick::spec(truth = genome_drug.resistant_phenotype, - estimate = .pred_class) |> dplyr::select(.estimate) |> as.numeric() |> + yardstick::spec( + truth = genome_drug.resistant_phenotype, + estimate = .pred_class + ) |> + dplyr::select(.estimate) |> + as.numeric() |> round(2) return(spec) @@ -559,7 +607,7 @@ getConfusionMatrix <- function(test_data_plus_predictions) { #' calculateEvalMets() #' #' Returns the F1 score, area under the precision-recall curve (AUPRC), balanced -#' accuracy, Matthews correlation coefficient (MCC), +#' accuracy, Matthews correlation coefficient (MCC), #' and log2(AUPRC/prior) based on the AMR #' phenotype predictions by an ML model compared against the actual values. #' @@ -597,30 +645,36 @@ calculateEvalMets <- function(test_data_plus_predictions) { #' `Importance`, and a column for `Sign` (or, for multi-class, a tibble with #' per-class columns of importance scores for each `Variable`) #' @export -extractTopFeats <- function(fit, prop_vi_top_feats = c(0, 1), - n_top_feats = NA) { +extractTopFeats <- function( + fit, prop_vi_top_feats = c(0, 1), + n_top_feats = NA +) { .checkArgWflow(fit) - if(!is.na(n_top_feats)) {prop_vi_top_feats <- NA} + if (!is.na(n_top_feats)) { + prop_vi_top_feats <- NA + } # Arg checking for every permutation of `prop_vi_top_feats` and `n_top_feats` - if(is.na(n_top_feats) & any(!is.na(prop_vi_top_feats))) { + if (is.na(n_top_feats) & any(!is.na(prop_vi_top_feats))) { .checkArgPropVITopFeats(prop_vi_top_feats) - } else if(any(is.na(prop_vi_top_feats)) & !is.na(n_top_feats)) { + } else if (any(is.na(prop_vi_top_feats)) & !is.na(n_top_feats)) { .checkArgNTopFeats(n_top_feats) - } else if(any(!is.na(prop_vi_top_feats)) & !is.na(n_top_feats)) { + } else if (any(!is.na(prop_vi_top_feats)) & !is.na(n_top_feats)) { stop("Set either `n_top_feats` or `prop_vi_top_feats` to `NA` but not both.") - } else if(any(is.na(prop_vi_top_feats)) & is.na(n_top_feats)) { + } else if (any(is.na(prop_vi_top_feats)) & is.na(n_top_feats)) { stop("Please specify either `n_top_feats` or `prop_vi_top_feats`.") } - feats_arranged <- fit |> workflowsets::extract_fit_parsnip() |> vip::vi() |> + feats_arranged <- fit |> + workflowsets::extract_fit_parsnip() |> + vip::vi() |> dplyr::arrange(dplyr::desc(Importance)) - if(!is.na(n_top_feats)) { + if (!is.na(n_top_feats)) { top_feats_and_VIs <- feats_arranged |> dplyr::slice(1:n_top_feats) - } else if(any(!is.na(prop_vi_top_feats))) { + } else if (any(!is.na(prop_vi_top_feats))) { cum_vi_lower <- prop_vi_top_feats[1] * sum(feats_arranged$Importance) cum_vi_upper <- prop_vi_top_feats[2] * sum(feats_arranged$Importance) @@ -637,9 +691,11 @@ extractTopFeats <- function(fit, prop_vi_top_feats = c(0, 1), # Take a different approach if using multi-class (the previous code would give # a less meaningful result). - if(class(fit$fit$actions$model$spec)[1] == "multinom_reg") { - warning(paste("Extracting top features from a multi-class model.", - "The `prop_vi_top_feats` and `n_top_feats` arguments do not apply.")) + if (class(fit$fit$actions$model$spec)[1] == "multinom_reg") { + warning(paste( + "Extracting top features from a multi-class model.", + "The `prop_vi_top_feats` and `n_top_feats` arguments do not apply." + )) fit_penalty <- .getFitHps(fit)["penalty"] |> as.numeric() glmnet_fit <- parsnip::extract_fit_engine(fit) diff --git a/R/generate_matrices_ml.R b/R/generate_matrices_ml.R index bb1dc68..b19beef 100644 --- a/R/generate_matrices_ml.R +++ b/R/generate_matrices_ml.R @@ -156,7 +156,6 @@ skipImbalancedMatrix <- function(genome_ids, split, stratify_by = NULL, verbosity = c("minimal", "debug")) { - verbosity <- match.arg(verbosity) log <- .make_logger(verbosity) @@ -197,8 +196,10 @@ skipImbalancedMatrix <- function(genome_ids, if (!dir.exists(matrix_path)) dir.create(matrix_path, recursive = TRUE) log("info", paste0("Matrix output directory: ", matrix_path)) - log("debug", paste0("Stratification: ", - ifelse(is.null(stratify_column), "None", stratify_column))) + log("debug", paste0( + "Stratification: ", + ifelse(is.null(stratify_column), "None", stratify_column) + )) # Feature and matrix types feature_types <- list( @@ -220,9 +221,11 @@ skipImbalancedMatrix <- function(genome_ids, # Safe DBI-quoting quote_condition <- function(group_cols, group_values, con) { - ids <- vapply(group_cols, - function(col) DBI::dbQuoteIdentifier(con, col), - character(1)) + ids <- vapply( + group_cols, + function(col) DBI::dbQuoteIdentifier(con, col), + character(1) + ) vals <- vapply( group_cols, function(col) { @@ -256,7 +259,6 @@ skipImbalancedMatrix <- function(genome_ids, log("debug", paste0("Found ", nrow(all_groups), " groups for type: ", group_type)) for (i in seq_len(nrow(all_groups))) { - # New connection for this group con <- DBI::dbConnect(duckdb::duckdb(), parquet_duckdb_path) @@ -268,13 +270,14 @@ skipImbalancedMatrix <- function(genome_ids, condition_string <- quote_condition(group_cols, group_values, con) # Strat filter - strat_filter <- if (!is.null(stratify_column)) + strat_filter <- if (!is.null(stratify_column)) { sprintf("AND \"%s\" IS NOT NULL AND \"%s\" != ''", stratify_column, stratify_column) - else "" + } else { + "" + } # Genome selection logic if (group_type %in% c("drug_class", "drug_class_year", "drug_class_country")) { - genome_ids <- DBI::dbGetQuery(con, sprintf(" WITH class_phenotypes AS ( SELECT \"genome_drug.genome_id\" AS genome_id, @@ -290,7 +293,6 @@ skipImbalancedMatrix <- function(genome_ids, FROM class_phenotypes WHERE any_resistant = 1 OR all_susceptible = 1 ", condition_string))[[1]] - } else { genome_ids <- DBI::dbGetQuery(con, sprintf(" SELECT DISTINCT \"genome_drug.genome_id\" @@ -310,19 +312,24 @@ skipImbalancedMatrix <- function(genome_ids, ", condition_string)) phenotype_summary <- paste( - apply(phenotype_counts_all, 1, - function(row) paste0(row["phenotype"], "=", row["count"])), + apply( + phenotype_counts_all, 1, + function(row) paste0(row["phenotype"], "=", row["count"]) + ), collapse = "; " ) # Apply skip logic if (skipImbalancedMatrix(genome_ids, phenotype_counts_all, n_fold, split, - verbosity = verbosity)) { - + verbosity = verbosity + )) { readr::write_lines( - sprintf("%s\tToo few samples for CV/split\t%d\t%s", - group_label, length(genome_ids), phenotype_summary), - log_path, append = TRUE + sprintf( + "%s\tToo few samples for CV/split\t%d\t%s", + group_label, length(genome_ids), phenotype_summary + ), + log_path, + append = TRUE ) DBI::dbDisconnect(con, shutdown = FALSE) @@ -331,9 +338,12 @@ skipImbalancedMatrix <- function(genome_ids, if (length(genome_ids) < 40) { readr::write_lines( - sprintf("%s\tToo few observations\t%d\t%s", - group_label, length(genome_ids), phenotype_summary), - log_path, append = TRUE + sprintf( + "%s\tToo few observations\t%d\t%s", + group_label, length(genome_ids), phenotype_summary + ), + log_path, + append = TRUE ) DBI::dbDisconnect(con, shutdown = FALSE) @@ -351,9 +361,12 @@ skipImbalancedMatrix <- function(genome_ids, if (nrow(phen2) < 2) { readr::write_lines( - sprintf("%s\tOnly one phenotype class\t%d\t%s", - group_label, length(genome_ids), phenotype_summary), - log_path, append = TRUE + sprintf( + "%s\tOnly one phenotype class\t%d\t%s", + group_label, length(genome_ids), phenotype_summary + ), + log_path, + append = TRUE ) DBI::dbDisconnect(con, shutdown = FALSE) @@ -363,13 +376,14 @@ skipImbalancedMatrix <- function(genome_ids, # Create selected_genomes DBI::dbExecute(con, "CREATE OR REPLACE TEMP TABLE selected_genomes (genome_id VARCHAR)") DBI::dbWriteTable(con, "selected_genomes", - data.frame(genome_id = genome_ids), append = TRUE) + data.frame(genome_id = genome_ids), + append = TRUE + ) # Feature and matrix generation steps for (ftype in names(feature_types)) { - fview <- feature_types[[ftype]]$view - fid <- feature_types[[ftype]]$id_col + fid <- feature_types[[ftype]]$id_col # binary view DBI::dbExecute(con, sprintf(" @@ -389,13 +403,14 @@ skipImbalancedMatrix <- function(genome_ids, } for (mtype in names(matrix_types)) { - binary_only <- matrix_types[[mtype]]$binary_only if (ftype == "struct" && !binary_only) next - mview <- sprintf("%s_%s", ftype, - ifelse(grepl("binary", mtype), "binary", "counts")) - value_col <- matrix_types[[mtype]]$value_col + mview <- sprintf( + "%s_%s", ftype, + ifelse(grepl("binary", mtype), "binary", "counts") + ) + value_col <- matrix_types[[mtype]]$value_col filter_clause <- matrix_types[[mtype]]$filter # select features with non-zero variance @@ -409,29 +424,38 @@ skipImbalancedMatrix <- function(genome_ids, keep_features <- DBI::dbGetQuery(con, keep_query)[["feature_id"]] if (length(keep_features) == 0) { - log("info", paste0("All features filtered for ", - ftype, " - ", mtype, " - ", group_label)) + log("info", paste0( + "All features filtered for ", + ftype, " - ", mtype, " - ", group_label + )) next } - DBI::dbExecute(con, - "CREATE OR REPLACE TEMP TABLE keep_features (feature_id VARCHAR)") + DBI::dbExecute( + con, + "CREATE OR REPLACE TEMP TABLE keep_features (feature_id VARCHAR)" + ) DBI::dbWriteTable(con, - "keep_features", - data.frame(feature_id = keep_features), - append = TRUE) + "keep_features", + data.frame(feature_id = keep_features), + append = TRUE + ) mtype_label <- matrix_types[[mtype]]$label - long_out_path <- file.path(matrix_path, - sprintf("%s_%s_%s_%s_%s_sparse.parquet", - bug, group_type, group_label, ftype, mtype_label)) + long_out_path <- file.path( + matrix_path, + sprintf( + "%s_%s_%s_%s_%s_sparse.parquet", + bug, group_type, group_label, ftype, mtype_label + ) + ) long_out_path_sql <- gsub("\\\\", "/", long_out_path) # phenotype case phenotype_case <- if (group_type %in% - c("drug_class", "drug_class_year", "drug_class_country")) { + c("drug_class", "drug_class_year", "drug_class_country")) { " CASE WHEN MAX(CASE WHEN f.\"genome_drug.resistant_phenotype\"='Resistant' @@ -451,13 +475,20 @@ skipImbalancedMatrix <- function(genome_ids, " } - strat_col_select <- if (!is.null(stratify_by)) - sprintf(", f.\"%s\"", stratify_column) else "" + strat_col_select <- if (!is.null(stratify_by)) { + sprintf(", f.\"%s\"", stratify_column) + } else { + "" + } - strat_col_group <- if (!is.null(stratify_by)) - sprintf(", f.\"%s\"", stratify_column) else "" + strat_col_group <- if (!is.null(stratify_by)) { + sprintf(", f.\"%s\"", stratify_column) + } else { + "" + } - copy_sql <- sprintf(" + copy_sql <- sprintf( + " COPY ( SELECT f.\"genome_drug.genome_id\" AS genome_id, @@ -478,18 +509,21 @@ skipImbalancedMatrix <- function(genome_ids, TO '%s' (FORMAT 'parquet', COMPRESSION 'zstd') ", - fid, value_col, phenotype_case, strat_col_select, - mview, fid, condition_string, - strat_filter, fid, strat_col_group, fid, - long_out_path_sql) + fid, value_col, phenotype_case, strat_col_select, + mview, fid, condition_string, + strat_filter, fid, strat_col_group, fid, + long_out_path_sql + ) ok <- try(DBI::dbExecute(con, copy_sql), silent = TRUE) # On copy failure, log + continue without stopping entire pipeline if (inherits(ok, "try-error")) { readr::write_lines( - sprintf("%s\tCOPY_failed\t%d\t%s", - group_label, length(genome_ids), phenotype_summary), + sprintf( + "%s\tCOPY_failed\t%d\t%s", + group_label, length(genome_ids), phenotype_summary + ), log_path, append = TRUE ) @@ -530,7 +564,7 @@ skipImbalancedMatrix <- function(genome_ids, # Normalize paths to forward slashes for consistency matrix_path <- gsub("\\\\", "/", file.path(path, paste0("matrix_", stratify_by))) - LOO_path <- gsub("\\\\", "/", file.path(path, paste0("LOO_matrix_", stratify_by))) + LOO_path <- gsub("\\\\", "/", file.path(path, paste0("LOO_matrix_", stratify_by))) if (!dir.exists(matrix_path)) { log("info", paste0("The matrix directory ", matrix_path, " does not exist.")) @@ -626,9 +660,11 @@ skipImbalancedMatrix <- function(genome_ids, out_file <- gsub("\\\\", "/", file.path( LOO_path, - paste0(sub_prefix, "_", stratify_by, "_", - drug_class, "_leaveout_", leave_one_out, "_", - sub_feature, "_sparse.parquet") + paste0( + sub_prefix, "_", stratify_by, "_", + drug_class, "_leaveout_", leave_one_out, "_", + sub_feature, "_sparse.parquet" + ) )) arrow::write_parquet(combined, out_file) created <<- c(created, out_file) @@ -702,7 +738,7 @@ skipImbalancedMatrix <- function(genome_ids, # Build one matrix per feature type and matrix type for (ftype in names(feature_types)) { fview <- feature_types[[ftype]]$view - fid <- feature_types[[ftype]]$id_col + fid <- feature_types[[ftype]]$id_col for (mtype in names(matrix_types)) { binary_only <- matrix_types[[mtype]]$binary_only @@ -722,8 +758,9 @@ skipImbalancedMatrix <- function(genome_ids, # Selected genomes DBI::dbExecute(con, "CREATE OR REPLACE TEMP TABLE selected_genomes (genome_id VARCHAR)") DBI::dbWriteTable(con, "selected_genomes", - data.frame(genome_id = genomes_to_keep), - append = TRUE) + data.frame(genome_id = genomes_to_keep), + append = TRUE + ) # Binary view DBI::dbExecute(con, sprintf(" @@ -763,13 +800,15 @@ skipImbalancedMatrix <- function(genome_ids, DBI::dbExecute(con, "CREATE OR REPLACE TEMP TABLE keep_features (feature_id VARCHAR)") DBI::dbWriteTable(con, "keep_features", - data.frame(feature_id = keep_features), - append = TRUE) + data.frame(feature_id = keep_features), + append = TRUE + ) + - - copy_sql <- sprintf(" + copy_sql <- sprintf( + " COPY ( - SELECT + SELECT f.\"genome_drug.genome_id\" AS genome_id, %s AS feature_id, MAX(CAST(%s AS DOUBLE)) AS value, @@ -779,26 +818,26 @@ skipImbalancedMatrix <- function(genome_ids, JOIN keep_features kf ON %s = kf.feature_id JOIN metadata f ON genome_id = f.\"genome_drug.genome_id\" WHERE resistant_classes <> 'Intermediate' - GROUP BY - f.\"genome_drug.genome_id\", - %s, + GROUP BY + f.\"genome_drug.genome_id\", + %s, resistant_classes - ORDER BY - f.\"genome_drug.genome_id\", + ORDER BY + f.\"genome_drug.genome_id\", %s ) TO '%s' (FORMAT 'parquet', COMPRESSION 'zstd') - ", - fid, # %s -> feature_id expression column name - value_col, # %s -> value column to CAST - mview, # %s -> source view (binary or counts) - fid, # %s -> join to keep_features - fid, # %s -> group by feature id - fid, # %s -> order by feature id - out_file_sql # %s -> destination parquet file + ", + fid, # %s -> feature_id expression column name + value_col, # %s -> value column to CAST + mview, # %s -> source view (binary or counts) + fid, # %s -> join to keep_features + fid, # %s -> group by feature id + fid, # %s -> order by feature id + out_file_sql # %s -> destination parquet file ) - + ok <- try(DBI::dbExecute(con, copy_sql), silent = TRUE) if (inherits(ok, "try-error")) { log("info", paste0("COPY failed for MDR matrix: ", out_file)) diff --git a/R/globals.R b/R/globals.R index a6595d2..131d016 100644 --- a/R/globals.R +++ b/R/globals.R @@ -8,7 +8,6 @@ "_PACKAGE" utils::globalVariables(c( - # Prediction columns from tidymodels ".estimate", ".pred_Resistant", @@ -52,7 +51,6 @@ utils::globalVariables(c( "pair_id", "parts", "phenotype", - "precision", "prefix", "prefix_key", diff --git a/R/plot_ml.R b/R/plot_ml.R index 071e1a1..90121f3 100644 --- a/R/plot_ml.R +++ b/R/plot_ml.R @@ -214,7 +214,6 @@ plotFishers <- function( alpha = 0.05, label_top_n = 5 ) { - required_cols <- c("gene", "adj_p_value", "sig_after_bh") missing_cols <- setdiff(required_cols, colnames(fisher_df)) diff --git a/R/prep_ml.R b/R/prep_ml.R index d47c160..4a5954e 100644 --- a/R/prep_ml.R +++ b/R/prep_ml.R @@ -111,8 +111,10 @@ loadMLInputTibble <- function(parquet_path) { if (exists(".ml_logger")) { log <- .ml_logger("minimal") - log("debug", paste0("ML tibble constructed: ", nrow(ml_input_tibble), - " genomes x ", getNumFeat(ml_input_tibble), " features")) + log("debug", paste0( + "ML tibble constructed: ", nrow(ml_input_tibble), + " genomes x ", getNumFeat(ml_input_tibble), " features" + )) } if (anyDuplicated(dplyr::pull(ml_input_tibble, genome_id)) != 0) { diff --git a/R/run_ML.R b/R/run_ML.R index eba37f8..2ed07e7 100644 --- a/R/run_ML.R +++ b/R/run_ML.R @@ -4,9 +4,11 @@ #' the ML matrices with these new split/CV values instead. #' @noRd .resolveSplitParams <- function(parquet_path, - defaults = list(split = c(0.8, 0), - seed = 5280, - n_fold = 5)) { + defaults = list( + split = c(0.8, 0), + seed = 5280, + n_fold = 5 + )) { # matrix_dir is the directory that contains the parquet files matrix_dir <- normalizePath(dirname(parquet_path)) params_json <- .readMLParameters(matrix_dir) @@ -16,8 +18,8 @@ } list( - split = if (!is.null(params_json$split)) params_json$split else defaults$split, - seed = if (!is.null(params_json$seed)) params_json$seed else defaults$seed, + split = if (!is.null(params_json$split)) params_json$split else defaults$split, + seed = if (!is.null(params_json$seed)) params_json$seed else defaults$seed, n_fold = if (!is.null(params_json$n_fold)) params_json$n_fold else defaults$n_fold ) } @@ -53,8 +55,9 @@ #' #' # LOO analysis stratified by year #' paths_loo <- createMLResultDir("/path/to/results", -#' stratify_by = "year", -#' LOO = TRUE) +#' stratify_by = "year", +#' LOO = TRUE +#' ) #' #' # MDR analysis #' paths_mdr <- createMLResultDir("/path/to/results", MDR = TRUE) @@ -90,16 +93,17 @@ createMLResultDir <- function(path, ) } else { # Determine prefixes (only in non-MDR mode) - full_prefix <- paste0(ifelse(isTRUE(LOO), "LOO_", ""), - ifelse(isTRUE(cross_test), "cross_test_", "")) + full_prefix <- paste0( + ifelse(isTRUE(LOO), "LOO_", ""), + ifelse(isTRUE(cross_test), "cross_test_", "") + ) half_prefix <- ifelse(isTRUE(LOO), "LOO_", "") # Determine suffix suffix <- if (is.null(stratify_by) || identical(stratify_by, "")) { "" } else { - switch( - stratify_by, + switch(stratify_by, "country" = "_country", "year" = "_year", stop("`stratify_by` must be NULL, 'country', or 'year'.") @@ -127,20 +131,20 @@ createMLResultDir <- function(path, return(paths) } - # createAllMLResultDir <- function(path) { - # createMLResultDir(path, stratify_by = NULL, LOO = FALSE, cross_test = FALSE, MDR = FALSE) - # createMLResultDir(path, stratify_by = NULL, LOO = FALSE, cross_test = TRUE, MDR = FALSE) - # createMLResultDir(path, stratify_by = NULL, LOO = FALSE, cross_test = FALSE, MDR = TRUE) - # createMLResultDir(path, stratify_by = "year", LOO = FALSE, cross_test = FALSE, MDR = FALSE) - # createMLResultDir(path, stratify_by = "year", LOO = FALSE, cross_test = TRUE, MDR = FALSE) - # createMLResultDir(path, stratify_by = "year", LOO = TRUE, cross_test = FALSE, MDR = FALSE) - # createMLResultDir(path, stratify_by = "year", LOO = TRUE, cross_test = TRUE, MDR = FALSE) - # createMLResultDir(path, stratify_by = "country", LOO = FALSE, cross_test = FALSE, MDR = FALSE) - # createMLResultDir(path, stratify_by = "country", LOO = FALSE, cross_test = TRUE, MDR = FALSE) - # createMLResultDir(path, stratify_by = "country", LOO = TRUE, cross_test = FALSE, MDR = FALSE) - # createMLResultDir(path, stratify_by = "country", LOO = TRUE, cross_test = TRUE, MDR = FALSE) - # } - # +# createAllMLResultDir <- function(path) { +# createMLResultDir(path, stratify_by = NULL, LOO = FALSE, cross_test = FALSE, MDR = FALSE) +# createMLResultDir(path, stratify_by = NULL, LOO = FALSE, cross_test = TRUE, MDR = FALSE) +# createMLResultDir(path, stratify_by = NULL, LOO = FALSE, cross_test = FALSE, MDR = TRUE) +# createMLResultDir(path, stratify_by = "year", LOO = FALSE, cross_test = FALSE, MDR = FALSE) +# createMLResultDir(path, stratify_by = "year", LOO = FALSE, cross_test = TRUE, MDR = FALSE) +# createMLResultDir(path, stratify_by = "year", LOO = TRUE, cross_test = FALSE, MDR = FALSE) +# createMLResultDir(path, stratify_by = "year", LOO = TRUE, cross_test = TRUE, MDR = FALSE) +# createMLResultDir(path, stratify_by = "country", LOO = FALSE, cross_test = FALSE, MDR = FALSE) +# createMLResultDir(path, stratify_by = "country", LOO = FALSE, cross_test = TRUE, MDR = FALSE) +# createMLResultDir(path, stratify_by = "country", LOO = TRUE, cross_test = FALSE, MDR = FALSE) +# createMLResultDir(path, stratify_by = "country", LOO = TRUE, cross_test = TRUE, MDR = FALSE) +# } +# #' Create machine learning input list #' @@ -174,8 +178,9 @@ createMLResultDir <- function(path, #' #' # Cross-test with year stratification #' inputs_ct <- createMLinputList("/path/to/results", -#' stratify_by = "year", -#' cross_test = TRUE) +#' stratify_by = "year", +#' cross_test = TRUE +#' ) #' #' # MDR analysis #' inputs_mdr <- createMLinputList("/path/to/results", MDR = TRUE) @@ -187,10 +192,10 @@ createMLinputList <- function(path, LOO = FALSE, MDR = FALSE, cross_test = FALSE) { - # Validate inputs - if (!is.character(path) || length(path) != 1 || is.na(path)) + if (!is.character(path) || length(path) != 1 || is.na(path)) { stop("`path` must be a valid file path string.") + } path <- normalizePath(path) @@ -225,21 +230,17 @@ createMLinputList <- function(path, # Multi-drug resistance models # ============================ if (MDR) { - parsed <- tibble::tibble(ref_file = files_vec) |> dplyr::mutate( parts = stringr::str_split(basename(ref_file), "_"), - species = purrr::map_chr(parts, ~ .x[1]), - mdr_tag = purrr::map_chr(parts, ~ .x[2]), # always "MDR" + mdr_tag = purrr::map_chr(parts, ~ .x[2]), # always "MDR" phenotype = purrr::map_chr(parts, ~ paste(.x[3:4], collapse = "_")), # Feature is 5th + 6th tokens feature_type = purrr::map_chr(parts, ~ .x[5]), feature_subtype = purrr::map_chr(parts, ~ stringr::str_remove(.x[6], "_sparse.parquet")), - feature = purrr::map2_chr(feature_type, feature_subtype, paste, sep = "_"), - output_prefix = paste0("MDR_", phenotype, "_", feature) ) @@ -247,38 +248,43 @@ createMLinputList <- function(path, dplyr::mutate( test_file = NA_character_, matrix_path = paths$matrix_path, - out_perf = paths$ML_performance, - out_top = paths$ML_top_features, - out_models= paths$ML_models, - out_pred = paths$ML_prediction + out_perf = paths$ML_performance, + out_top = paths$ML_top_features, + out_models = paths$ML_models, + out_pred = paths$ML_prediction ) return(out) - # ============================ - # For all other modeling types - # ============================ + # ============================ + # For all other modeling types + # ============================ } else { - parsed <- tibble::tibble(ref_file = files_vec) |> dplyr::mutate( - parts = stringr::str_split(basename(ref_file), "_"), + parts = stringr::str_split(basename(ref_file), "_"), i_sparse = purrr::map_int(parts, ~ .get_idx(.x, "sparse.parquet")), - i_strat = purrr::map_int(parts, ~ { - if (is.null(stratify_by)) return(NA_integer_) + i_strat = purrr::map_int(parts, ~ { + if (is.null(stratify_by)) { + return(NA_integer_) + } .get_idx(.x, stratify_by) }), # Feature = last two tokens before sparse.parquet feature = purrr::map2_chr(parts, i_sparse, ~ { - i <- .y; x <- .x - if (is.na(i) || i < 3) return(NA_character_) + i <- .y + x <- .x + if (is.na(i) || i < 3) { + return(NA_character_) + } paste(x[(i - 2):(i - 1)], collapse = "_") }), # Drug or drug class extraction drug_or_class = purrr::map2_chr(parts, i_strat, ~ { - i <- .y; x <- .x + i <- .y + x <- .x # Stratified models if (!is.na(i)) { @@ -304,32 +310,40 @@ createMLinputList <- function(path, # Stratification value (if present) strat_value = purrr::map2_chr(parts, i_strat, ~ { - i <- .y; x <- .x - if (is.na(i)) return("") + i <- .y + x <- .x + if (is.na(i)) { + return("") + } # default position is two tokens after the strat label j <- i + 2 # if there's an intervening 'leaveout', skip over it if (j <= length(x) && identical(x[j], "leaveout")) j <- j + 1 - if (j <= length(x)) return(x[j]) - "" # no stratification + if (j <= length(x)) { + return(x[j]) + } + "" # no stratification }), # Prefix key for grouping prefix_key = purrr::map2_chr(parts, i_strat, ~ { - i <- .y; x <- .x + i <- .y + x <- .x # Case A: stratified -> prefix before the stratify label if (!is.na(i)) { - if (i - 1 >= 1) return(paste(x[1:(i - 1)], collapse = "_")) + if (i - 1 >= 1) { + return(paste(x[1:(i - 1)], collapse = "_")) + } return("") } # Case B: unstratified -> prefix is first two tokens - if (x[2] == "drug" && x[3] != "class"){ + if (x[2] == "drug" && x[3] != "class") { # Case A: Cje_drug_X return(paste(x[1:2], collapse = "_")) } - if (x[2] == "drug" && x[3] == "class"){ + if (x[2] == "drug" && x[3] == "class") { # Case A: Cje_drug_X return(paste(x[1:3], collapse = "_")) } @@ -345,18 +359,17 @@ createMLinputList <- function(path, test_file = NA_character_, output_prefix = gsub("_sparse\\.parquet$", "", basename(ref_file)), matrix_path = paths$matrix_path, - out_perf = paths$ML_performance, - out_top = paths$ML_top_features, - out_models= paths$ML_models, - out_pred = paths$ML_prediction + out_perf = paths$ML_performance, + out_top = paths$ML_top_features, + out_models = paths$ML_models, + out_pred = paths$ML_prediction ) return(out) - # ============================ - # Cross-test modeling, no LOO - # ============================ + # ============================ + # Cross-test modeling, no LOO + # ============================ } else if (cross_test && !LOO) { - if (is.null(stratify_by)) { # Case A: stratify_by = NULL, pair across abx within same feature + prefix pairs <- parsed |> @@ -366,8 +379,10 @@ createMLinputList <- function(path, dplyr::select(test_file = ref_file, feature, prefix_key, strat_value, test_drug = drug_or_class), by = c("feature", "prefix_key", "strat_value") ) |> - dplyr::filter(ref_file != test_file, - ref_drug != test_drug) |> + dplyr::filter( + ref_file != test_file, + ref_drug != test_drug + ) |> dplyr::distinct() |> dplyr::mutate( output_prefix = paste0( @@ -380,10 +395,10 @@ createMLinputList <- function(path, out <- pairs |> dplyr::mutate( matrix_path = paths$matrix_path, - out_perf = paths$ML_performance, - out_top = paths$ML_top_features, - out_models= paths$ML_models, - out_pred = paths$ML_prediction + out_perf = paths$ML_performance, + out_top = paths$ML_top_features, + out_models = paths$ML_models, + out_pred = paths$ML_prediction ) return(out) @@ -392,30 +407,29 @@ createMLinputList <- function(path, # Case B: stratify_by != NULL, pair same drug/class, prefix, feature, # but across different stratification groups pairs <- parsed |> - dplyr::select(ref_file, feature, prefix_key, strat_value, - drug_or_class) |> - + dplyr::select( + ref_file, feature, prefix_key, strat_value, + drug_or_class + ) |> # self-join ONLY on prefix_key, drug/class, feature dplyr::inner_join( parsed |> - dplyr::select(test_file = ref_file, - feature, prefix_key, strat_value_test = strat_value, - drug_or_class), + dplyr::select( + test_file = ref_file, + feature, prefix_key, strat_value_test = strat_value, + drug_or_class + ), by = c("prefix_key", "feature", "drug_or_class") ) |> - # do NOT test file against itself dplyr::filter(ref_file != test_file) |> - # enforce different stratification group dplyr::filter(strat_value != strat_value_test) |> - # remove symmetric duplicates (A,B == B,A) dplyr::rowwise() |> dplyr::mutate(pair_id = paste(sort(c(ref_file, test_file)), collapse = "||")) |> dplyr::ungroup() |> dplyr::distinct(pair_id, .keep_all = TRUE) |> - dplyr::mutate( output_prefix = paste0( prefix_key, "_", @@ -429,19 +443,18 @@ createMLinputList <- function(path, out <- pairs |> dplyr::mutate( matrix_path = paths$matrix_path, - out_perf = paths$ML_performance, - out_top = paths$ML_top_features, - out_models= paths$ML_models, - out_pred = paths$ML_prediction + out_perf = paths$ML_performance, + out_top = paths$ML_top_features, + out_models = paths$ML_models, + out_pred = paths$ML_prediction ) return(out) - # ============================ - # Cross-test + LOO modeling - # ============================ + # ============================ + # Cross-test + LOO modeling + # ============================ } else if (cross_test && LOO) { - # LOO requires special directory structure resolution test_path <- file.path(path, stringr::str_remove(basename(paths$matrix_path), "^LOO_")) test_path <- normalizePath(test_path) @@ -461,10 +474,10 @@ createMLinputList <- function(path, out <- loo_pairs |> dplyr::mutate( matrix_path = paths$matrix_path, - out_perf = paths$ML_performance, - out_top = paths$ML_top_features, - out_models= paths$ML_models, - out_pred = paths$ML_prediction + out_perf = paths$ML_performance, + out_top = paths$ML_top_features, + out_models = paths$ML_models, + out_pred = paths$ML_prediction ) return(out) @@ -472,9 +485,11 @@ createMLinputList <- function(path, } # If we ever get here, something wasn't covered - stop("Unhandled combination of arguments: ", - "MDR=", MDR, ", cross_test=", cross_test, ", LOO=", LOO, - ", stratify_by=", if (is.null(stratify_by)) "NULL" else stratify_by) + stop( + "Unhandled combination of arguments: ", + "MDR=", MDR, ", cross_test=", cross_test, ", LOO=", LOO, + ", stratify_by=", if (is.null(stratify_by)) "NULL" else stratify_by + ) } @@ -544,13 +559,15 @@ createMLinputList <- function(path, #' #' # Run with more threads and minimal output #' runMDRmodels("/path/to/results", -#' threads = 32, -#' verbose = FALSE) +#' threads = 32, +#' verbose = FALSE +#' ) #' #' # Run without saving model fits (save disk space) #' runMDRmodels("/path/to/results", -#' threads = 16, -#' return_fit = FALSE) +#' threads = 16, +#' return_fit = FALSE +#' ) #' } #' #' @seealso @@ -571,12 +588,12 @@ runMDRmodels <- function(path, use_saved_split = TRUE, shuffle_labels = FALSE, use_pca = FALSE) { - files <- createMLinputList(path, - stratify_by = NULL, - LOO = FALSE, - cross_test = FALSE, - MDR = TRUE) + stratify_by = NULL, + LOO = FALSE, + cross_test = FALSE, + MDR = TRUE + ) if (nrow(files) == 0) { message("No MDR files found to process. Exiting.") @@ -594,18 +611,19 @@ runMDRmodels <- function(path, # Auto tags for shuffled and PCA shuffle_tag <- if (isTRUE(shuffle_labels)) "shuffled_" else "" - pca_tag <- if (isTRUE(use_pca)) paste0("_pca", as.character(pca_threshold)) else "" + pca_tag <- if (isTRUE(use_pca)) paste0("_pca", as.character(pca_threshold)) else "" results_list <- future.apply::future_lapply( seq_len(nrow(files)), FUN = function(i) { - - ref_parquet <- files$ref_file[i] + ref_parquet <- files$ref_file[i] output_prefix <- files$output_prefix[i] if (interactive()) { - message(sprintf("[runMDRmodels] %d/%d: %s", - i, nrow(files), basename(ref_parquet))) + message(sprintf( + "[runMDRmodels] %d/%d: %s", + i, nrow(files), basename(ref_parquet) + )) } ml_input <- loadMLInputTibble(ref_parquet) @@ -619,32 +637,37 @@ runMDRmodels <- function(path, list(split = split, seed = 5280, n_fold = n_fold) } - res <- try({ - runMLPipeline( - ml_input_tibble = ml_input, - test_data = NA, - model = "LR", - split = sp$split, - n_fold = sp$n_fold, - prop_vi_top_feats = prop_vi_top_feats, - n_top_feats = NA, - use_pca = use_pca, - pca_threshold = pca_threshold, - shuffle_labels = shuffle_labels, - penalty_vec = 10^seq(-4, -1, length.out = 10), - mix_vec = 0:5 / 5, - select_best_metric = "mcc", - seed = sp$seed, - verbose = verbose, - return_tune_res = return_tune_res, - return_fit = return_fit, - return_pred = return_pred - ) - }, silent = TRUE) + res <- try( + { + runMLPipeline( + ml_input_tibble = ml_input, + test_data = NA, + model = "LR", + split = sp$split, + n_fold = sp$n_fold, + prop_vi_top_feats = prop_vi_top_feats, + n_top_feats = NA, + use_pca = use_pca, + pca_threshold = pca_threshold, + shuffle_labels = shuffle_labels, + penalty_vec = 10^seq(-4, -1, length.out = 10), + mix_vec = 0:5 / 5, + select_best_metric = "mcc", + seed = sp$seed, + verbose = verbose, + return_tune_res = return_tune_res, + return_fit = return_fit, + return_pred = return_pred + ) + }, + silent = TRUE + ) if (inherits(res, "try-error")) { - warning("Model failed for: ", output_prefix, - "\n Error: ", attr(res, "condition")$message) + warning( + "Model failed for: ", output_prefix, + "\n Error: ", attr(res, "condition")$message + ) return(NULL) } @@ -652,19 +675,25 @@ runMDRmodels <- function(path, base <- paste0(shuffle_tag, output_prefix, pca_tag) if (!is.null(res$performance_tibble)) { - readr::write_tsv(res$performance_tibble, - file.path(files$out_perf[i], paste0(base, "_performance.tsv"))) + readr::write_tsv( + res$performance_tibble, + file.path(files$out_perf[i], paste0(base, "_performance.tsv")) + ) } if (!is.null(res$top_feat_tibble)) { - readr::write_tsv(res$top_feat_tibble, - file.path(files$out_top[i], paste0(base, "_top_features.tsv"))) + readr::write_tsv( + res$top_feat_tibble, + file.path(files$out_top[i], paste0(base, "_top_features.tsv")) + ) } if (!is.null(res$fit)) { saveRDS(res$fit, file.path(files$out_models[i], paste0(base, "_model_fit.rds"))) } if (!is.null(res$pred)) { - readr::write_tsv(res$pred, - file.path(files$out_pred[i], paste0(base, "_prediction.tsv"))) + readr::write_tsv( + res$pred, + file.path(files$out_pred[i], paste0(base, "_prediction.tsv")) + ) } NULL @@ -783,21 +812,24 @@ runMDRmodels <- function(path, #' #' # Cross-test with year stratification #' runMLmodels("/path/to/results", -#' stratify_by = "year", -#' cross_test = TRUE, -#' threads = 32) +#' stratify_by = "year", +#' cross_test = TRUE, +#' threads = 32 +#' ) #' #' # LOO analysis stratified by country with cross-testing #' runMLmodels("/path/to/results", -#' stratify_by = "country", -#' LOO = TRUE, -#' cross_test = TRUE, -#' verbose = TRUE) +#' stratify_by = "country", +#' LOO = TRUE, +#' cross_test = TRUE, +#' verbose = TRUE +#' ) #' #' # Run without saving model fits (save disk space) #' runMLmodels("/path/to/results", -#' stratify_by = "year", -#' return_fit = FALSE) +#' stratify_by = "year", +#' return_fit = FALSE +#' ) #' } #' #' @seealso @@ -823,19 +855,21 @@ runMLmodels <- function(path, use_saved_split = TRUE, shuffle_labels = FALSE, use_pca = FALSE) { - if (!is.null(stratify_by)) { - if (!is.character(stratify_by) || length(stratify_by) != 1L) + if (!is.character(stratify_by) || length(stratify_by) != 1L) { stop("`stratify_by` must be NULL or a single string: 'year' or 'country'.") - if (!stratify_by %in% c("year", "country")) + } + if (!stratify_by %in% c("year", "country")) { stop("`stratify_by` must be NULL, 'year', or 'country'.") + } } files <- createMLinputList(path, - stratify_by = stratify_by, - LOO = LOO, - MDR = FALSE, - cross_test = cross_test) + stratify_by = stratify_by, + LOO = LOO, + MDR = FALSE, + cross_test = cross_test + ) if (nrow(files) == 0) { message("No files found to process. Exiting.") @@ -864,8 +898,7 @@ runMLmodels <- function(path, strat_suffix <- if (is.null(stratify_by) || identical(stratify_by, "")) { "" } else { - switch( - stratify_by, + switch(stratify_by, "country" = "_country", "year" = "_year", stop("`stratify_by` must be NULL, 'year', or 'country'.") @@ -874,18 +907,19 @@ runMLmodels <- function(path, # Auto naming for shuffled and PCA shuffle_tag <- if (isTRUE(shuffle_labels)) "shuffled_" else "" - pca_tag <- if (isTRUE(use_pca)) paste0("_pca", as.character(pca_threshold)) else "" + pca_tag <- if (isTRUE(use_pca)) paste0("_pca", as.character(pca_threshold)) else "" results_list <- future.apply::future_lapply( seq_len(nrow(files)), FUN = function(i) { - - ref_parquet <- files$ref_file[i] + ref_parquet <- files$ref_file[i] output_prefix <- files$output_prefix[i] if (interactive()) { - message(sprintf("[runMLmodels] %d/%d: %s", - i, nrow(files), basename(ref_parquet))) + message(sprintf( + "[runMLmodels] %d/%d: %s", + i, nrow(files), basename(ref_parquet) + )) } ml_input <- loadMLInputTibble(ref_parquet) @@ -910,32 +944,37 @@ runMLmodels <- function(path, list(split = split, seed = 5280, n_fold = n_fold) } - res <- try({ - runMLPipeline( - ml_input_tibble = ml_input, - test_data = test_data, - model = "LR", - split = sp$split, - n_fold = sp$n_fold, - prop_vi_top_feats = prop_vi_top_feats, - n_top_feats = NA, - use_pca = use_pca, - pca_threshold = pca_threshold, - shuffle_labels = shuffle_labels, - penalty_vec = 10^seq(-4, -1, length.out = 10), - mix_vec = 0:5 / 5, - select_best_metric = "mcc", - seed = sp$seed, - verbose = verbose, - return_tune_res = return_tune_res, - return_fit = return_fit, - return_pred = return_pred - ) - }, silent = TRUE) + res <- try( + { + runMLPipeline( + ml_input_tibble = ml_input, + test_data = test_data, + model = "LR", + split = sp$split, + n_fold = sp$n_fold, + prop_vi_top_feats = prop_vi_top_feats, + n_top_feats = NA, + use_pca = use_pca, + pca_threshold = pca_threshold, + shuffle_labels = shuffle_labels, + penalty_vec = 10^seq(-4, -1, length.out = 10), + mix_vec = 0:5 / 5, + select_best_metric = "mcc", + seed = sp$seed, + verbose = verbose, + return_tune_res = return_tune_res, + return_fit = return_fit, + return_pred = return_pred + ) + }, + silent = TRUE + ) if (inherits(res, "try-error")) { - warning("Model failed for: ", output_prefix, - "\n Error: ", attr(res, "condition")$message) + warning( + "Model failed for: ", output_prefix, + "\n Error: ", attr(res, "condition")$message + ) return(NULL) } @@ -943,19 +982,25 @@ runMLmodels <- function(path, base <- paste0(shuffle_tag, config_prefix, output_prefix, pca_tag, strat_suffix) if (!is.null(res$performance_tibble)) { - readr::write_tsv(res$performance_tibble, - file.path(files$out_perf[i], paste0(base, "_performance.tsv"))) + readr::write_tsv( + res$performance_tibble, + file.path(files$out_perf[i], paste0(base, "_performance.tsv")) + ) } if (!is.null(res$top_feat_tibble)) { - readr::write_tsv(res$top_feat_tibble, - file.path(files$out_top[i], paste0(base, "_top_features.tsv"))) + readr::write_tsv( + res$top_feat_tibble, + file.path(files$out_top[i], paste0(base, "_top_features.tsv")) + ) } if (!is.null(res$fit)) { saveRDS(res$fit, file.path(files$out_models[i], paste0(base, "_model_fit.rds"))) } if (!is.null(res$pred)) { - readr::write_tsv(res$pred, - file.path(files$out_pred[i], paste0(base, "_prediction.tsv"))) + readr::write_tsv( + res$pred, + file.path(files$out_pred[i], paste0(base, "_prediction.tsv")) + ) } NULL @@ -973,7 +1018,6 @@ runMLmodels <- function(path, } - #' Run the entire AMR ML pipeline from a parquet-backed DuckDB #' #' This function provides a complete end-to-end AMR machine learning workflow. @@ -1006,11 +1050,12 @@ runModelingPipeline <- function(parquet_duckdb_path, pca_threshold = 0.99, verbose = TRUE, use_saved_split = TRUE) { - parquet_duckdb_path <- normalizePath(parquet_duckdb_path) if (!file.exists(parquet_duckdb_path)) { - stop("Parquet-backed DuckDB at ", parquet_duckdb_path, " not found.\n", - "Are you using `{Bug}.duckdb` instead of `{Bug}_parquet.duckdb?`") + stop( + "Parquet-backed DuckDB at ", parquet_duckdb_path, " not found.\n", + "Are you using `{Bug}.duckdb` instead of `{Bug}_parquet.duckdb?`" + ) } out_root <- dirname(parquet_duckdb_path) @@ -1024,9 +1069,9 @@ runModelingPipeline <- function(parquet_duckdb_path, generateMLInputs( parquet_duckdb_path = parquet_duckdb_path, out_path = out_root, - n_fold = n_fold, - split = split, - min_n = min_n, + n_fold = n_fold, + split = split, + min_n = min_n, verbosity = if (verbose) "minimal" else "debug" ) @@ -1089,12 +1134,13 @@ runModelingPipeline <- function(parquet_duckdb_path, # All done! if (verbose) { message("\n=== AMR-ML Pipeline Complete ===") - message("All matrices, models, top feature lists, and performance metrics saved under:\n ", - out_root) + message( + "All matrices, models, top feature lists, and performance metrics saved under:\n ", + out_root + ) message("\nTo inspect model outputs, see directories such as:") message(" ML_performance/, ML_models/, ML_prediction/, ML_top_features/") } invisible(out_root) } - diff --git a/R/run_ml_pipeline.R b/R/run_ml_pipeline.R index 2a97c00..ad9f691 100644 --- a/R/run_ml_pipeline.R +++ b/R/run_ml_pipeline.R @@ -93,20 +93,21 @@ runMLPipeline <- function( .checkArgReturnPred(return_pred) - # Set `n_fold` to `NA` if not using cross-validation. if (split[2] != 0) { n_fold <- NA } # Confirm resolved split params - if (verbose) { - mode <- if (split[2] == 0) "cv" else "splits" - message(sprintf("ML split mode: %s | split = c(%.2f, %.2f) | n_fold = %s | seed = %s", - mode, split[1], split[2], - ifelse(is.na(n_fold), "NA", as.character(n_fold)), - as.character(seed))) - } + if (verbose) { + mode <- if (split[2] == 0) "cv" else "splits" + message(sprintf( + "ML split mode: %s | split = c(%.2f, %.2f) | n_fold = %s | seed = %s", + mode, split[1], split[2], + ifelse(is.na(n_fold), "NA", as.character(n_fold)), + as.character(seed) + )) + } # Create a variable indicating whether external `test_data` was provided. This # will be set to `TRUE` later if the `test_data` argument is not `NA`. @@ -116,10 +117,10 @@ runMLPipeline <- function( # Determine whether multi-class classification is to be performed. if (as.character(.getTargetVarName(ml_input_tibble)) == "resistant_classes") { - multi_class <- TRUE - } else { - multi_class <- FALSE - } + multi_class <- TRUE + } else { + multi_class <- FALSE + } if (model != "LR" & multi_class) { stop(paste( @@ -262,7 +263,7 @@ runMLPipeline <- function( mix_vec = mix_vec ) } - + recipe <- buildRecipe(train_data, use_pca = use_pca, pca_threshold = pca_threshold @@ -421,14 +422,16 @@ runMLPipeline <- function( all_results[["fit"]] <- fit } - if(return_pred) { - if(!multi_class){ + if (return_pred) { + if (!multi_class) { all_results[["pred"]] <- test_data_plus_predictions |> - dplyr::select(c(genome_id, .pred_class, .pred_Resistant, - .pred_Susceptible, genome_drug.resistant_phenotype)) - } - all_results[["pred"]] <- test_data_plus_predictions + dplyr::select(c( + genome_id, .pred_class, .pred_Resistant, + .pred_Susceptible, genome_drug.resistant_phenotype + )) } + all_results[["pred"]] <- test_data_plus_predictions + } return(all_results) } diff --git a/vignettes/intro.Rmd b/vignettes/intro.Rmd index 996eb6b..af5bc8e 100644 --- a/vignettes/intro.Rmd +++ b/vignettes/intro.Rmd @@ -264,19 +264,19 @@ ml_tibble_reduced <- removeTopFeats(ml_tibble, top_features) ### Precision-recall curve ```{r plot-prc} -test_data_plus_predictions <- readr::read_tsv(results/ML_pred/Sfl_drug_AMP_domains_binary_prediction.tsv) +test_data_plus_predictions <- readr::read_tsv(results / ML_pred / Sfl_drug_AMP_domains_binary_prediction.tsv) plotPRC(test_data_plus_predictions) ``` ### ROC curve ```{r plot-roc} -test_data_plus_predictions <- readr::read_tsv(results/ML_pred/Sfl_drug_AMP_domains_binary_prediction.tsv) +test_data_plus_predictions <- readr::read_tsv(results / ML_pred / Sfl_drug_AMP_domains_binary_prediction.tsv) plotROC(test_data_plus_predictions) ``` ### Variable importance plot ```{r plot-vi} -topfeat <- readr::read_tsv(results/ML_top_features/Sfl_drug_AMP_domains_binary_top_features.tsv) +topfeat <- readr::read_tsv(results / ML_top_features / Sfl_drug_AMP_domains_binary_top_features.tsv) plotTopFeatsVI(topfeat) ``` ### Baseline comparison barplot @@ -326,7 +326,6 @@ You can label the top N features to highlight the strongest hits (default is 5) ```{r} plotFishers(fisher_results) plotFishers(fisher_results, alpha = 0.01, label_top_n = 5) - ``` ## Wrapper to run all models @@ -338,14 +337,15 @@ Given a DuckDB file produced by `runDataProcessing()`, it: 5. saves performance metrics, fitted models, predictions, and top feature rankings ``` {r} runModelingPipeline(parquet_duckdb_path, - threads = 16, - n_fold = 5, - split = c(1, 0), - min_n = 25, - prop_vi_top_feats = c(0, 1), - pca_threshold = 0.99, - verbose = TRUE, - use_saved_split = TRUE) + threads = 16, + n_fold = 5, + split = c(1, 0), + min_n = 25, + prop_vi_top_feats = c(0, 1), + pca_threshold = 0.99, + verbose = TRUE, + use_saved_split = TRUE +) ``` Merge the performance and top features of each kind of models into a parquet that will serve as starting data for `amRshiny` package @@ -357,7 +357,7 @@ buildPerformancePq( LOO = FALSE, MDR = FALSE, cross_test = FALSE, - out_parquet = NULL, + out_parquet = NULL, compression = "zstd", verbose = TRUE ) @@ -367,8 +367,8 @@ buildTopFeatsPq( LOO = FALSE, MDR = FALSE, cross_test = FALSE, - out_parquet = NULL, + out_parquet = NULL, compression = "zstd", verbose = TRUE -) +) ``` From 6f0872e6b2995f0b773635f6c0d6b62595583651 Mon Sep 17 00:00:00 2001 From: Abhirupa Ghosh <100681585+AbhirupaGhosh@users.noreply.github.com> Date: Tue, 17 Feb 2026 11:13:19 -0700 Subject: [PATCH 3/7] Refactor ML pipeline for improved readability and logic --- R/run_ml_pipeline.R | 65 ++++++++++++++++++++++----------------------- 1 file changed, 32 insertions(+), 33 deletions(-) diff --git a/R/run_ml_pipeline.R b/R/run_ml_pipeline.R index ad9f691..9a89def 100644 --- a/R/run_ml_pipeline.R +++ b/R/run_ml_pipeline.R @@ -93,21 +93,20 @@ runMLPipeline <- function( .checkArgReturnPred(return_pred) + # Set `n_fold` to `NA` if not using cross-validation. if (split[2] != 0) { n_fold <- NA } # Confirm resolved split params - if (verbose) { - mode <- if (split[2] == 0) "cv" else "splits" - message(sprintf( - "ML split mode: %s | split = c(%.2f, %.2f) | n_fold = %s | seed = %s", - mode, split[1], split[2], - ifelse(is.na(n_fold), "NA", as.character(n_fold)), - as.character(seed) - )) - } + if (verbose) { + mode <- if (split[2] == 0) "cv" else "splits" + message(sprintf("ML split mode: %s | split = c(%.2f, %.2f) | n_fold = %s | seed = %s", + mode, split[1], split[2], + ifelse(is.na(n_fold), "NA", as.character(n_fold)), + as.character(seed))) + } # Create a variable indicating whether external `test_data` was provided. This # will be set to `TRUE` later if the `test_data` argument is not `NA`. @@ -117,10 +116,10 @@ runMLPipeline <- function( # Determine whether multi-class classification is to be performed. if (as.character(.getTargetVarName(ml_input_tibble)) == "resistant_classes") { - multi_class <- TRUE - } else { - multi_class <- FALSE - } + multi_class <- TRUE + } else { + multi_class <- FALSE + } if (model != "LR" & multi_class) { stop(paste( @@ -263,7 +262,7 @@ runMLPipeline <- function( mix_vec = mix_vec ) } - + recipe <- buildRecipe(train_data, use_pca = use_pca, pca_threshold = pca_threshold @@ -297,10 +296,10 @@ runMLPipeline <- function( log2_apop <- .calculateLog2APOP(test_data_plus_predictions) } - nmcc <- .calculatenMCC(test_data_plus_predictions) + mcc <- .calculateMCC(test_data_plus_predictions) if (verbose) { - message(paste("Normalized Matthews correlation coefficient:", nmcc)) + message(paste("Matthews correlation coefficient:", mcc)) } top_feat_tibble <- extractTopFeats(fit, @@ -362,7 +361,7 @@ runMLPipeline <- function( performance_tibble <- tibble::tibble( num_obs = num_obs_ml_input_tibble, n_feat = getNumFeat(ml_input_tibble), model, train_prop = split[1], - val_prop = split[2], n_fold, nmcc, run_time_sec, + val_prop = split[2], n_fold, mcc, run_time_sec, seed, date = as.character(Sys.Date()) ) @@ -381,20 +380,22 @@ runMLPipeline <- function( lower_prop_vi_top_feats = prop_vi_top_feats[1], .after = "val_prop" ) |> - tibble::add_column(bal_acc, .after = "nmcc") |> - tibble::add_column(f1, .after = "nmcc") |> - tibble::add_column(log2_apop, .after = "nmcc") + tibble::add_column(bal_acc, .after = "mcc") |> + tibble::add_column(f1, .after = "mcc") |> + tibble::add_column(log2_apop, .after = "mcc") |> + tibble::add_column(sens, .after = "mcc") |> + tibble::add_column(spec, .after = "mcc") } if (model == "LR") { performance_tibble <- performance_tibble |> - tibble::add_column(fit_penalty, .before = "nmcc") |> - tibble::add_column(fit_mixture, .before = "nmcc") + tibble::add_column(fit_penalty, .before = "mcc") |> + tibble::add_column(fit_mixture, .before = "mcc") } else if (model == "RF" || model == "BT") { performance_tibble <- performance_tibble |> - tibble::add_column(fit_trees, .before = "nmcc") |> - tibble::add_column(fit_mtry, .before = "nmcc") |> - tibble::add_column(fit_min_n, .before = "nmcc") + tibble::add_column(fit_trees, .before = "mcc") |> + tibble::add_column(fit_mtry, .before = "mcc") |> + tibble::add_column(fit_min_n, .before = "mcc") } if (external_test_data) { @@ -422,16 +423,14 @@ runMLPipeline <- function( all_results[["fit"]] <- fit } - if (return_pred) { - if (!multi_class) { + if(return_pred) { + if(!multi_class){ all_results[["pred"]] <- test_data_plus_predictions |> - dplyr::select(c( - genome_id, .pred_class, .pred_Resistant, - .pred_Susceptible, genome_drug.resistant_phenotype - )) + dplyr::select(c(genome_id, .pred_class, .pred_Resistant, + .pred_Susceptible, genome_drug.resistant_phenotype)) + } + all_results[["pred"]] <- test_data_plus_predictions } - all_results[["pred"]] <- test_data_plus_predictions - } return(all_results) } From 54d8799c0ffa02fb53a2539fa38fbd4a587ab047 Mon Sep 17 00:00:00 2001 From: AbhirupaGhosh Date: Tue, 17 Feb 2026 18:14:45 +0000 Subject: [PATCH 4/7] Style code (GHA) --- R/run_ml_pipeline.R | 41 ++++++++++++++++++++++------------------- 1 file changed, 22 insertions(+), 19 deletions(-) diff --git a/R/run_ml_pipeline.R b/R/run_ml_pipeline.R index 9a89def..bc497f7 100644 --- a/R/run_ml_pipeline.R +++ b/R/run_ml_pipeline.R @@ -93,20 +93,21 @@ runMLPipeline <- function( .checkArgReturnPred(return_pred) - # Set `n_fold` to `NA` if not using cross-validation. if (split[2] != 0) { n_fold <- NA } # Confirm resolved split params - if (verbose) { - mode <- if (split[2] == 0) "cv" else "splits" - message(sprintf("ML split mode: %s | split = c(%.2f, %.2f) | n_fold = %s | seed = %s", - mode, split[1], split[2], - ifelse(is.na(n_fold), "NA", as.character(n_fold)), - as.character(seed))) - } + if (verbose) { + mode <- if (split[2] == 0) "cv" else "splits" + message(sprintf( + "ML split mode: %s | split = c(%.2f, %.2f) | n_fold = %s | seed = %s", + mode, split[1], split[2], + ifelse(is.na(n_fold), "NA", as.character(n_fold)), + as.character(seed) + )) + } # Create a variable indicating whether external `test_data` was provided. This # will be set to `TRUE` later if the `test_data` argument is not `NA`. @@ -116,10 +117,10 @@ runMLPipeline <- function( # Determine whether multi-class classification is to be performed. if (as.character(.getTargetVarName(ml_input_tibble)) == "resistant_classes") { - multi_class <- TRUE - } else { - multi_class <- FALSE - } + multi_class <- TRUE + } else { + multi_class <- FALSE + } if (model != "LR" & multi_class) { stop(paste( @@ -262,7 +263,7 @@ runMLPipeline <- function( mix_vec = mix_vec ) } - + recipe <- buildRecipe(train_data, use_pca = use_pca, pca_threshold = pca_threshold @@ -423,14 +424,16 @@ runMLPipeline <- function( all_results[["fit"]] <- fit } - if(return_pred) { - if(!multi_class){ + if (return_pred) { + if (!multi_class) { all_results[["pred"]] <- test_data_plus_predictions |> - dplyr::select(c(genome_id, .pred_class, .pred_Resistant, - .pred_Susceptible, genome_drug.resistant_phenotype)) - } - all_results[["pred"]] <- test_data_plus_predictions + dplyr::select(c( + genome_id, .pred_class, .pred_Resistant, + .pred_Susceptible, genome_drug.resistant_phenotype + )) } + all_results[["pred"]] <- test_data_plus_predictions + } return(all_results) } From d40fe4da9659a9e08446279dc8937f0f969e026c Mon Sep 17 00:00:00 2001 From: Abhirupa Ghosh <100681585+AbhirupaGhosh@users.noreply.github.com> Date: Thu, 12 Mar 2026 16:05:42 -0600 Subject: [PATCH 5/7] Refactor genome_id references in SQL queries --- R/generate_matrices_ml.R | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/R/generate_matrices_ml.R b/R/generate_matrices_ml.R index b19beef..4104f38 100644 --- a/R/generate_matrices_ml.R +++ b/R/generate_matrices_ml.R @@ -280,14 +280,14 @@ skipImbalancedMatrix <- function(genome_ids, if (group_type %in% c("drug_class", "drug_class_year", "drug_class_country")) { genome_ids <- DBI::dbGetQuery(con, sprintf(" WITH class_phenotypes AS ( - SELECT \"genome_drug.genome_id\" AS genome_id, + SELECT \"genome.genome_id\" AS genome_id, MAX(CASE WHEN \"genome_drug.resistant_phenotype\" = 'Resistant' THEN 1 ELSE 0 END) AS any_resistant, MIN(CASE WHEN \"genome_drug.resistant_phenotype\" = 'Susceptible' THEN 1 ELSE 0 END) AS all_susceptible FROM metadata WHERE %s - GROUP BY \"genome_drug.genome_id\" + GROUP BY \"genome.genome_id\" ) SELECT genome_id FROM class_phenotypes @@ -295,7 +295,7 @@ skipImbalancedMatrix <- function(genome_ids, ", condition_string))[[1]] } else { genome_ids <- DBI::dbGetQuery(con, sprintf(" - SELECT DISTINCT \"genome_drug.genome_id\" + SELECT DISTINCT \"genome.genome_id\" FROM metadata WHERE %s AND \"genome_drug.resistant_phenotype\" IN ('Resistant','Susceptible') @@ -491,7 +491,7 @@ skipImbalancedMatrix <- function(genome_ids, " COPY ( SELECT - f.\"genome_drug.genome_id\" AS genome_id, + f.\"genome.genome_id\" AS genome_id, %s AS feature_id, MAX(CAST(%s AS DOUBLE)) AS value, %s AS \"genome_drug.resistant_phenotype\" @@ -499,12 +499,12 @@ skipImbalancedMatrix <- function(genome_ids, FROM %s JOIN selected_genomes USING (genome_id) JOIN keep_features kf ON %s = kf.feature_id - JOIN metadata f ON genome_id = f.\"genome_drug.genome_id\" + JOIN metadata f ON genome_id = f.\"genome.genome_id\" WHERE %s AND f.\"genome_drug.resistant_phenotype\" IN ('Resistant','Susceptible') %s - GROUP BY f.\"genome_drug.genome_id\", %s %s - ORDER BY f.\"genome_drug.genome_id\", %s + GROUP BY f.\"genome.genome_id\", %s %s + ORDER BY f.\"genome.genome_id\", %s ) TO '%s' (FORMAT 'parquet', COMPRESSION 'zstd') @@ -721,7 +721,7 @@ skipImbalancedMatrix <- function(genome_ids, DBI::dbDisconnect(con0, shutdown = FALSE) classes <- metadata_all |> - dplyr::select(genome_drug.genome_id, resistant_classes) |> + dplyr::select(genome.genome_id, resistant_classes) |> dplyr::distinct() |> dplyr::group_by(resistant_classes) |> dplyr::count() |> @@ -733,7 +733,7 @@ skipImbalancedMatrix <- function(genome_ids, genomes_to_keep <- metadata_all |> dplyr::filter(resistant_classes %in% classes) |> - dplyr::pull(genome_drug.genome_id) + dplyr::pull(genome.genome_id) # Build one matrix per feature type and matrix type for (ftype in names(feature_types)) { @@ -809,21 +809,21 @@ skipImbalancedMatrix <- function(genome_ids, " COPY ( SELECT - f.\"genome_drug.genome_id\" AS genome_id, + f.\"genome.genome_id\" AS genome_id, %s AS feature_id, MAX(CAST(%s AS DOUBLE)) AS value, resistant_classes FROM %s JOIN selected_genomes USING (genome_id) JOIN keep_features kf ON %s = kf.feature_id - JOIN metadata f ON genome_id = f.\"genome_drug.genome_id\" + JOIN metadata f ON genome_id = f.\"genome.genome_id\" WHERE resistant_classes <> 'Intermediate' GROUP BY - f.\"genome_drug.genome_id\", + f.\"genome.genome_id\", %s, resistant_classes ORDER BY - f.\"genome_drug.genome_id\", + f.\"genome.genome_id\", %s ) TO '%s' From a022c6e35fe748ae51897a5558918e50181e87df Mon Sep 17 00:00:00 2001 From: Janani Ravi Date: Mon, 11 May 2026 11:08:46 -0600 Subject: [PATCH 6/7] Add MCC and nMCC metrics throughout Added Matthews correlation coefficient (MCC) and normalized MCC (nMCC) support throughout the package. Updated argument checks and plotting docs to accept an avg_mcc option, track mcc in runIFE results, and add related globals. Also updated README, vignette, and docs (including example file paths) to mention MCC and nMCC and to reflect the updated outputs. Moved Cje to Sfl (one consistent example throughout. Minor fixes: ensure test predictions are assigned to all_results in runMLPipeline and improve function return documentation. --- R/arg_check_ml.R | 6 +++--- R/core_ml.R | 25 +++++++++++++++++++------ R/globals.R | 4 ++++ R/ife_ml.R | 8 ++++++-- R/plot_ml.R | 2 +- R/run_ml_pipeline.R | 18 ++++++++++-------- README.Rmd | 2 +- README.md | 2 +- doc/intro.R | 9 +++++---- vignettes/intro.Rmd | 22 +++++++++++----------- 10 files changed, 61 insertions(+), 37 deletions(-) diff --git a/R/arg_check_ml.R b/R/arg_check_ml.R index 41f01c8..f8ab172 100644 --- a/R/arg_check_ml.R +++ b/R/arg_check_ml.R @@ -702,7 +702,7 @@ NULL #' @keywords internal #' @param y_default_eval [chr] y value of default evaluation plot. It can be #' "avg_runtime_sec" or one of the following performance metrics: -#' "avg_f1_score", "avg_log2_apop", "avg_bal_acc", or "avg_nmcc" +#' "avg_f1_score", "avg_log2_apop", "avg_bal_acc", "avg_mcc", or "avg_nmcc" #' .checkArgYDefaultEval <- function(y_default_eval) { if (!is.character(y_default_eval)) { @@ -710,11 +710,11 @@ NULL } if (!(y_default_eval %in% - c("avg_f1_score", "avg_log2_apop", "avg_bal_acc", "avg_nmcc")) + c("avg_f1_score", "avg_log2_apop", "avg_bal_acc", "avg_mcc", "avg_nmcc")) ) { stop(paste( "`y_default_eval` must be one of:", - "'avg_f1_score', 'avg_log2_apop', 'avg_bal_acc', 'avg_nmcc'." + "'avg_f1_score', 'avg_log2_apop', 'avg_bal_acc', 'avg_mcc', 'avg_nmcc'." )) } } diff --git a/R/core_ml.R b/R/core_ml.R index 8d84b47..1d72507 100644 --- a/R/core_ml.R +++ b/R/core_ml.R @@ -398,7 +398,7 @@ getConfusionMatrix <- function(test_data_plus_predictions) { #' ML model compared against the actual values. #' #' @inheritParams getConfusionMatrix -#' @return Matthews correlation coefficient (MCC) +#' @return Matthews correlation coefficient (MCC), range -1 to 1 .calculateMCC <- function(test_data_plus_predictions) { .checkArgTestDataPlusPredictions(test_data_plus_predictions) @@ -410,11 +410,23 @@ getConfusionMatrix <- function(test_data_plus_predictions) { as.numeric() |> round(2) - # nmcc <- (mcc + 1) / 2 - return(mcc) } +#' .calculatenMCC() +#' +#' Returns the normalized (0 to 1) Matthews correlation coefficient (nMCC) +#' based on the AMR phenotype predictions by an ML model compared against +#' the actual values. +#' +#' @inheritParams getConfusionMatrix +#' @return Normalized Matthews correlation coefficient (nMCC), range 0 to 1 +.calculatenMCC <- function(test_data_plus_predictions) { + mcc <- .calculateMCC(test_data_plus_predictions) + nmcc <- round((mcc + 1) / 2, 2) + return(nmcc) +} + #' .calculateF1() #' #' Returns the F1 score based on the AMR phenotype predictions by an ML model @@ -607,12 +619,12 @@ getConfusionMatrix <- function(test_data_plus_predictions) { #' calculateEvalMets() #' #' Returns the F1 score, area under the precision-recall curve (AUPRC), balanced -#' accuracy, Matthews correlation coefficient (MCC), +#' accuracy, Matthews correlation coefficient (MCC), normalized MCC (nMCC), #' and log2(AUPRC/prior) based on the AMR #' phenotype predictions by an ML model compared against the actual values. #' #' @inheritParams getConfusionMatrix -#' @return F1 score, AUPRC, balanced accuracy, MCC, and log2(AUPRC/prior) +#' @return F1 score, AUPRC, balanced accuracy, MCC, nMCC, and log2(AUPRC/prior) #' @export calculateEvalMets <- function(test_data_plus_predictions) { .checkArgTestDataPlusPredictions(test_data_plus_predictions) @@ -623,9 +635,10 @@ calculateEvalMets <- function(test_data_plus_predictions) { sens <- .calculateSensitivity(test_data_plus_predictions) spec <- .calculateSpecificity(test_data_plus_predictions) mcc <- .calculateMCC(test_data_plus_predictions) + nmcc <- .calculatenMCC(test_data_plus_predictions) log2_apop <- .calculateLog2APOP(test_data_plus_predictions) - return(c(f1, auprc, bal_acc, mcc, log2_apop)) + return(c(f1, auprc, bal_acc, mcc, nmcc, log2_apop)) } #' extractTopFeats() diff --git a/R/globals.R b/R/globals.R index 131d016..e1e3ab5 100644 --- a/R/globals.R +++ b/R/globals.R @@ -44,8 +44,12 @@ utils::globalVariables(c( "idx_strat", "model", "neg_log10_adj_p", + "mcc", "nmcc", "num_obs", + "seed", + "sens", + "spec", "output_prefix", "p_value", "pair_id", diff --git a/R/ife_ml.R b/R/ife_ml.R index 3d72a6d..110a17d 100644 --- a/R/ife_ml.R +++ b/R/ife_ml.R @@ -38,7 +38,7 @@ removeTopFeats <- function(ml_input_tibble, top_feat_tibble) { #' runIFE #' Removes top features identified by ML models and retrains iteratively; -#' returns nMCC at each iteration. +#' returns MCC at each iteration. #' #' @param ml_input_tibble An ML-ready tibble generated by `loadMLInputTibble()` #' @param by_num [bool] Set to `TRUE` if removing top features as a percentage @@ -83,6 +83,7 @@ runIFE <- function( num_obs_vec <- c() res_prop_vec <- c() fit_mixture_vec <- c() + mcc_vec <- c() nmcc_vec <- c() n_feats_removed_vec <- c() total_feats_removed_vec <- c() @@ -154,6 +155,9 @@ runIFE <- function( fit_mixture_vec[i] <- ml_res$performance_tibble |> dplyr::select(fit_mixture) |> as.numeric() + mcc_vec[i] <- ml_res$performance_tibble |> + dplyr::select(mcc) |> + as.numeric() nmcc_vec[i] <- ml_res$performance_tibble |> dplyr::select(nmcc) |> as.numeric() @@ -247,7 +251,7 @@ runIFE <- function( percent_removed = c(0, percent_removal_vec), removal_type = rep(removal_type, length(num_obs_vec)), num_obs = num_obs_vec, res_prop = res_prop_vec, - fit_mixture = fit_mixture_vec, nmcc = nmcc_vec, + fit_mixture = fit_mixture_vec, mcc = mcc_vec, nmcc = nmcc_vec, n_feats_removed = n_feats_removed_vec, total_feats_removed = total_feats_removed_vec, run_time_sec = run_time_vec ) diff --git a/R/plot_ml.R b/R/plot_ml.R index 90121f3..538c0e2 100644 --- a/R/plot_ml.R +++ b/R/plot_ml.R @@ -78,7 +78,7 @@ plotTopFeatsVI <- function(fit, n_top_feats = 10) { #' or "n_fold" #' @param y_default_eval [chr] y value of default evaluation plot. It can be #' "avg_runtime_sec" or one of the following performance metrics: -#' "avg_f1_score", "avg_log2_apop", "avg_bal_acc", or "avg_nmcc" +#' "avg_f1_score", "avg_log2_apop", "avg_bal_acc", "avg_mcc", or "avg_nmcc" #' @param xlab [chr] Label for x axis #' @param ylab [chr] Label for y axis #' @return A `ggplot2` scatterplot (performance metric or runtime vs. diff --git a/R/run_ml_pipeline.R b/R/run_ml_pipeline.R index bc497f7..d632bfe 100644 --- a/R/run_ml_pipeline.R +++ b/R/run_ml_pipeline.R @@ -298,9 +298,10 @@ runMLPipeline <- function( } mcc <- .calculateMCC(test_data_plus_predictions) + nmcc <- .calculatenMCC(test_data_plus_predictions) if (verbose) { - message(paste("Matthews correlation coefficient:", mcc)) + message(paste("Matthews correlation coefficient:", mcc, "| nMCC:", nmcc)) } top_feat_tibble <- extractTopFeats(fit, @@ -362,7 +363,7 @@ runMLPipeline <- function( performance_tibble <- tibble::tibble( num_obs = num_obs_ml_input_tibble, n_feat = getNumFeat(ml_input_tibble), model, train_prop = split[1], - val_prop = split[2], n_fold, mcc, run_time_sec, seed, + val_prop = split[2], n_fold, mcc, nmcc, run_time_sec, seed, date = as.character(Sys.Date()) ) @@ -381,11 +382,11 @@ runMLPipeline <- function( lower_prop_vi_top_feats = prop_vi_top_feats[1], .after = "val_prop" ) |> - tibble::add_column(bal_acc, .after = "mcc") |> - tibble::add_column(f1, .after = "mcc") |> - tibble::add_column(log2_apop, .after = "mcc") |> - tibble::add_column(sens, .after = "mcc") |> - tibble::add_column(spec, .after = "mcc") + tibble::add_column(bal_acc, .after = "nmcc") |> + tibble::add_column(f1, .after = "nmcc") |> + tibble::add_column(log2_apop, .after = "nmcc") |> + tibble::add_column(sens, .after = "nmcc") |> + tibble::add_column(spec, .after = "nmcc") } if (model == "LR") { @@ -431,8 +432,9 @@ runMLPipeline <- function( genome_id, .pred_class, .pred_Resistant, .pred_Susceptible, genome_drug.resistant_phenotype )) + } else { + all_results[["pred"]] <- test_data_plus_predictions } - all_results[["pred"]] <- test_data_plus_predictions } return(all_results) diff --git a/README.Rmd b/README.Rmd index 0783647..68a00f0 100644 --- a/README.Rmd +++ b/README.Rmd @@ -103,7 +103,7 @@ This uses specific matrices to test whether ML models can predict resistance aga - **Data preparation**: Load Parquet files and prepare ML-ready datasets - **Model training**: User-customizable logistic regression via tidymodels -- **Evaluation**: nMCC, F1, balanced accuracy, AuPRC, and confusion matrices +- **Evaluation**: MCC, nMCC, F1, balanced accuracy, AuPRC, and confusion matrices - **Feature importance**: Extract and rank predictive features See the [package vignette](https://jravilab.github.io/amRml/articles/intro.html) for detailed usage. diff --git a/README.md b/README.md index 0099d07..3940387 100644 --- a/README.md +++ b/README.md @@ -129,7 +129,7 @@ associated with MDR. - **Data preparation**: Load Parquet files and prepare ML-ready datasets - **Model training**: User-customizable logistic regression via tidymodels -- **Evaluation**: nMCC, F1, balanced accuracy, AuPRC, and confusion +- **Evaluation**: MCC, nMCC, F1, balanced accuracy, AuPRC, and confusion matrices - **Feature importance**: Extract and rank predictive features diff --git a/doc/intro.R b/doc/intro.R index 3782033..d7ca536 100644 --- a/doc/intro.R +++ b/doc/intro.R @@ -112,7 +112,8 @@ knitr::opts_chunk$set( ## ----metrics------------------------------------------------------------------ # # Individual metrics -# nmcc <- calculatenMCC(predictions) # Normalized MCC (0-1 scale) +# mcc <- calculateMCC(predictions) # Matthews correlation coefficient (-1 to 1) +# nmcc <- calculatenMCC(predictions) # Normalized MCC (0 to 1) # f1 <- calculateF1(predictions) # F1 score # bal_acc <- calculateBalAcc(predictions) # Balanced accuracy # auprc <- calculateAUPRC(predictions) # Area under PR curve @@ -150,7 +151,7 @@ knitr::opts_chunk$set( # verbose = TRUE # ) # -# # Results include nMCC at each iteration +# # Results include MCC and nMCC at each iteration # ife_results$ife_performance_tibble # ife_results$feats_removed # If return_feats = TRUE # @@ -233,8 +234,8 @@ knitr::opts_chunk$set( # ) # # # 5. Compare real vs baseline performance -# cat("Real nMCC:", results$performance_tibble$nmcc, "\n") -# cat("Baseline nMCC:", baseline_results$performance_tibble$nmcc, "\n") +# cat("Real MCC:", results$performance_tibble$mcc, "| nMCC:", results$performance_tibble$nmcc, "\n") +# cat("Baseline MCC:", baseline_results$performance_tibble$mcc, "| nMCC:", baseline_results$performance_tibble$nmcc, "\n") # # # 6. Run iterative feature elimination # ife_results <- runIFE( diff --git a/vignettes/intro.Rmd b/vignettes/intro.Rmd index af5bc8e..2129f6b 100644 --- a/vignettes/intro.Rmd +++ b/vignettes/intro.Rmd @@ -31,7 +31,7 @@ After data curation with `amRdata`, use `generateMLInputs()` to create ML-ready The DuckDB with the parquet views is created with `cleanData()` ```{r generate-matrices} -# Generate all ML input matrices from curated data +# Example uses Shigella flexneri (Sfl); replace paths with your own species and drug. generateMLInputs( parquet_duckdb_path = "results/Sfl_parquet.duckdb", out_path = "results/", @@ -52,7 +52,6 @@ This generates: For classical train/validation/test splits instead of cross-validation: ```{r generate-splits} -# 70% train, 15% validation, 15% test generateMLInputs( parquet_duckdb_path = "results/Sfl_parquet.duckdb", out_path = "results/", @@ -88,7 +87,6 @@ The package expects Parquet files in long (sparse) format with these columns: `loadMLInputTibble()` converts this to wide format (one row per genome, one column per feature) for ML. ```{r load-data} -# Load a generated matrix ml_tibble <- loadMLInputTibble( parquet_path = "results/matrix/Sfl_drug_AMP_domains_binary_sparse.parquet" ) @@ -134,7 +132,7 @@ results$top_feat_tibble | `model` | Model type (LR) | | `train_prop`, `val_prop` | Train/validation split proportions | | `fit_penalty`, `fit_mixture` | Fitted hyperparameters | -| `nmcc`, `f1`, `bal_acc`, `log2_apop` | Performance metrics | +| `mcc`, `nmcc`, `f1`, `bal_acc`, `log2_apop` | Performance metrics | | `run_time_sec` | Runtime in seconds | **`top_feat_tibble`** - ranked feature importance: @@ -201,7 +199,8 @@ predictions <- predict(fit, test_data) ```{r metrics} # Individual metrics -nmcc <- calculatenMCC(predictions) # Normalized MCC (0-1 scale) +mcc <- calculateMCC(predictions) # Matthews correlation coefficient (-1 to 1) +nmcc <- calculatenMCC(predictions) # Normalized MCC (0 to 1) f1 <- calculateF1(predictions) # F1 score bal_acc <- calculateBalAcc(predictions) # Balanced accuracy auprc <- calculateAUPRC(predictions) # Area under PR curve @@ -216,7 +215,8 @@ conf_mat <- getConfusionMatrix(predictions) | Metric | Function | Description | |--------|----------|-------------| -| nMCC | `calculatenMCC()` | Normalized Matthews correlation coefficient (0-1) | +| MCC | `calculateMCC()` | Matthews correlation coefficient (-1 to 1) | +| nMCC | `calculatenMCC()` | Normalized Matthews correlation coefficient (0 to 1) | | F1 | `calculateF1()` | Harmonic mean of precision and recall | | Balanced accuracy | `calculateBalAcc()` | Average of sensitivity and specificity | | AUPRC | `calculateAUPRC()` | Area under precision-recall curve | @@ -251,7 +251,7 @@ ife_results <- runIFE( verbose = TRUE ) -# Results include nMCC at each iteration +# Results include MCC and nMCC at each iteration ife_results$ife_performance_tibble ife_results$feats_removed @@ -264,19 +264,19 @@ ml_tibble_reduced <- removeTopFeats(ml_tibble, top_features) ### Precision-recall curve ```{r plot-prc} -test_data_plus_predictions <- readr::read_tsv(results / ML_pred / Sfl_drug_AMP_domains_binary_prediction.tsv) +test_data_plus_predictions <- readr::read_tsv("results/ML_pred/Sfl_drug_AMP_domains_binary_prediction.tsv") plotPRC(test_data_plus_predictions) ``` ### ROC curve ```{r plot-roc} -test_data_plus_predictions <- readr::read_tsv(results / ML_pred / Sfl_drug_AMP_domains_binary_prediction.tsv) +test_data_plus_predictions <- readr::read_tsv("results/ML_pred/Sfl_drug_AMP_domains_binary_prediction.tsv") plotROC(test_data_plus_predictions) ``` ### Variable importance plot ```{r plot-vi} -topfeat <- readr::read_tsv(results / ML_top_features / Sfl_drug_AMP_domains_binary_top_features.tsv) +topfeat <- readr::read_tsv("results/ML_top_features/Sfl_drug_AMP_domains_binary_top_features.tsv") plotTopFeatsVI(topfeat) ``` ### Baseline comparison barplot @@ -306,7 +306,7 @@ As a non-ML baseline, run Fisher's exact tests with multiple testing correction: ```{r fisher} # Complete Fisher pipeline fisher_results <- runFishers( - matrix_path = "results/matrix/Cje_drug_CIP_genes_binary_sparse.parquet", + matrix_path = "results/matrix/Sfl_drug_AMP_genes_binary_sparse.parquet", Q = 0.05, alternative = "two.sided", susceptible_label = "Susceptible", From 8de8191aac01ee7d3f44d1c7dc4642d1ce969d5a Mon Sep 17 00:00:00 2001 From: jananiravi Date: Mon, 11 May 2026 17:13:56 +0000 Subject: [PATCH 7/7] Style code (GHA) --- vignettes/intro.Rmd | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vignettes/intro.Rmd b/vignettes/intro.Rmd index 2129f6b..c40caba 100644 --- a/vignettes/intro.Rmd +++ b/vignettes/intro.Rmd @@ -199,7 +199,7 @@ predictions <- predict(fit, test_data) ```{r metrics} # Individual metrics -mcc <- calculateMCC(predictions) # Matthews correlation coefficient (-1 to 1) +mcc <- calculateMCC(predictions) # Matthews correlation coefficient (-1 to 1) nmcc <- calculatenMCC(predictions) # Normalized MCC (0 to 1) f1 <- calculateF1(predictions) # F1 score bal_acc <- calculateBalAcc(predictions) # Balanced accuracy