Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions R/arg_check_ml.R
Original file line number Diff line number Diff line change
Expand Up @@ -702,19 +702,19 @@ NULL
#' @keywords internal
#' @param y_default_eval [chr] y value of default evaluation plot. It can be
#' "avg_runtime_sec" or one of the following performance metrics:
#' "avg_f1_score", "avg_log2_apop", "avg_bal_acc", or "avg_nmcc"
#' "avg_f1_score", "avg_log2_apop", "avg_bal_acc", "avg_mcc", or "avg_nmcc"
#'
.checkArgYDefaultEval <- function(y_default_eval) {
if (!is.character(y_default_eval)) {
stop("The `y_default_eval` argument can only take character values.")
}

if (!(y_default_eval %in%
c("avg_f1_score", "avg_log2_apop", "avg_bal_acc", "avg_nmcc"))
c("avg_f1_score", "avg_log2_apop", "avg_bal_acc", "avg_mcc", "avg_nmcc"))
) {
stop(paste(
"`y_default_eval` must be one of:",
"'avg_f1_score', 'avg_log2_apop', 'avg_bal_acc', 'avg_nmcc'."
"'avg_f1_score', 'avg_log2_apop', 'avg_bal_acc', 'avg_mcc', 'avg_nmcc'."
))
}
}
Expand Down
39 changes: 26 additions & 13 deletions R/core_ml.R
Original file line number Diff line number Diff line change
Expand Up @@ -391,28 +391,40 @@ getConfusionMatrix <- function(test_data_plus_predictions) {
return(CM)
}

#' .calculatenMCC()
#' .calculateMCC()
#'
#' Returns the normalized (to a 0 to 1 scale instead of -1 to 1) Matthews
#' correlation coefficient (nMCC) based on the AMR phenotype predictions by an
#' Returns the Matthews correlation coefficient (MCC)
#' based on the AMR phenotype predictions by an
#' ML model compared against the actual values.
#'
#' @inheritParams getConfusionMatrix
#' @return Normalized (to a 0 to 1 scale instead of -1 to 1) Matthews
#' correlation coefficient (nMCC)
.calculatenMCC <- function(test_data_plus_predictions) {
#' @return Matthews correlation coefficient (MCC), range -1 to 1
.calculateMCC <- function(test_data_plus_predictions) {
.checkArgTestDataPlusPredictions(test_data_plus_predictions)

target_var <- .getTargetVarName(test_data_plus_predictions)

mcc <- test_data_plus_predictions |>
yardstick::mcc(truth = !!target_var, estimate = .pred_class) |>
dplyr::select(.estimate) |>
as.numeric()
as.numeric() |>
round(2)

nmcc <- (mcc + 1) / 2
return(mcc)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

might still need to round to two decimals, right?

}

return(round(nmcc, 2))
#' .calculatenMCC()
#'
#' Returns the normalized (0 to 1) Matthews correlation coefficient (nMCC)
#' based on the AMR phenotype predictions by an ML model compared against
#' the actual values.
#'
#' @inheritParams getConfusionMatrix
#' @return Normalized Matthews correlation coefficient (nMCC), range 0 to 1
.calculatenMCC <- function(test_data_plus_predictions) {
mcc <- .calculateMCC(test_data_plus_predictions)
nmcc <- round((mcc + 1) / 2, 2)
return(nmcc)
}

#' .calculateF1()
Expand Down Expand Up @@ -607,12 +619,12 @@ getConfusionMatrix <- function(test_data_plus_predictions) {
#' calculateEvalMets()
#'
#' Returns the F1 score, area under the precision-recall curve (AUPRC), balanced
#' accuracy, normalized (to a 0 to 1 scale instead of -1 to 1) Matthews
#' correlation coefficient (nMCC), and log2(AUPRC/prior) based on the AMR
#' accuracy, Matthews correlation coefficient (MCC), normalized MCC (nMCC),
#' and log2(AUPRC/prior) based on the AMR
#' phenotype predictions by an ML model compared against the actual values.
#'
#' @inheritParams getConfusionMatrix
#' @return F1 score, AUPRC, balanced accuracy, nMCC, and log2(AUPRC/prior)
#' @return F1 score, AUPRC, balanced accuracy, MCC, nMCC, and log2(AUPRC/prior)
#' @export
calculateEvalMets <- function(test_data_plus_predictions) {
.checkArgTestDataPlusPredictions(test_data_plus_predictions)
Expand All @@ -622,10 +634,11 @@ calculateEvalMets <- function(test_data_plus_predictions) {
bal_acc <- .calculateBalAcc(test_data_plus_predictions)
sens <- .calculateSensitivity(test_data_plus_predictions)
spec <- .calculateSpecificity(test_data_plus_predictions)
mcc <- .calculateMCC(test_data_plus_predictions)
nmcc <- .calculatenMCC(test_data_plus_predictions)
log2_apop <- .calculateLog2APOP(test_data_plus_predictions)

return(c(f1, auprc, bal_acc, nmcc, log2_apop))
return(c(f1, auprc, bal_acc, mcc, nmcc, log2_apop))
}

#' extractTopFeats()
Expand Down
26 changes: 13 additions & 13 deletions R/generate_matrices_ml.R
Original file line number Diff line number Diff line change
Expand Up @@ -288,22 +288,22 @@ skipImbalancedMatrix <- function(genome_ids,
if (group_type %in% c("drug_class", "drug_class_year", "drug_class_country")) {
genome_ids <- DBI::dbGetQuery(con, sprintf("
WITH class_phenotypes AS (
SELECT \"genome_drug.genome_id\" AS genome_id,
SELECT \"genome.genome_id\" AS genome_id,
MAX(CASE WHEN \"genome_drug.resistant_phenotype\" = 'Resistant'
THEN 1 ELSE 0 END) AS any_resistant,
MIN(CASE WHEN \"genome_drug.resistant_phenotype\" = 'Susceptible'
THEN 1 ELSE 0 END) AS all_susceptible
FROM metadata
WHERE %s
GROUP BY \"genome_drug.genome_id\"
GROUP BY \"genome.genome_id\"
)
SELECT genome_id
FROM class_phenotypes
WHERE any_resistant = 1 OR all_susceptible = 1
", condition_string))[[1]]
} else {
genome_ids <- DBI::dbGetQuery(con, sprintf("
SELECT DISTINCT \"genome_drug.genome_id\"
SELECT DISTINCT \"genome.genome_id\"
FROM metadata
WHERE %s
AND \"genome_drug.resistant_phenotype\" IN ('Resistant','Susceptible')
Expand Down Expand Up @@ -499,20 +499,20 @@ skipImbalancedMatrix <- function(genome_ids,
"
COPY (
SELECT
f.\"genome_drug.genome_id\" AS genome_id,
f.\"genome.genome_id\" AS genome_id,
%s AS feature_id,
MAX(CAST(%s AS DOUBLE)) AS value,
%s AS \"genome_drug.resistant_phenotype\"
%s
FROM %s
JOIN selected_genomes USING (genome_id)
JOIN keep_features kf ON %s = kf.feature_id
JOIN metadata f ON genome_id = f.\"genome_drug.genome_id\"
JOIN metadata f ON genome_id = f.\"genome.genome_id\"
WHERE %s
AND f.\"genome_drug.resistant_phenotype\" IN ('Resistant','Susceptible')
%s
GROUP BY f.\"genome_drug.genome_id\", %s %s
ORDER BY f.\"genome_drug.genome_id\", %s
GROUP BY f.\"genome.genome_id\", %s %s
ORDER BY f.\"genome.genome_id\", %s
)
TO '%s'
(FORMAT 'parquet', COMPRESSION 'zstd')
Expand Down Expand Up @@ -733,7 +733,7 @@ skipImbalancedMatrix <- function(genome_ids,
DBI::dbDisconnect(con0, shutdown = FALSE)

classes <- metadata_all |>
dplyr::select(genome_drug.genome_id, resistant_classes) |>
dplyr::select(genome.genome_id, resistant_classes) |>
dplyr::distinct() |>
dplyr::group_by(resistant_classes) |>
dplyr::count() |>
Expand All @@ -745,7 +745,7 @@ skipImbalancedMatrix <- function(genome_ids,

genomes_to_keep <- metadata_all |>
dplyr::filter(resistant_classes %in% classes) |>
dplyr::pull(genome_drug.genome_id)
dplyr::pull(genome.genome_id)

# Build one matrix per feature type and matrix type
for (ftype in names(feature_types)) {
Expand Down Expand Up @@ -825,21 +825,21 @@ skipImbalancedMatrix <- function(genome_ids,
"
COPY (
SELECT
f.\"genome_drug.genome_id\" AS genome_id,
f.\"genome.genome_id\" AS genome_id,
%s AS feature_id,
MAX(CAST(%s AS DOUBLE)) AS value,
resistant_classes
FROM %s
JOIN selected_genomes USING (genome_id)
JOIN keep_features kf ON %s = kf.feature_id
JOIN metadata f ON genome_id = f.\"genome_drug.genome_id\"
JOIN metadata f ON genome_id = f.\"genome.genome_id\"
WHERE resistant_classes <> 'Intermediate'
GROUP BY
f.\"genome_drug.genome_id\",
f.\"genome.genome_id\",
%s,
resistant_classes
ORDER BY
f.\"genome_drug.genome_id\",
f.\"genome.genome_id\",
%s
)
TO '%s'
Expand Down
4 changes: 4 additions & 0 deletions R/globals.R
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,12 @@ utils::globalVariables(c(
"idx_strat",
"model",
"neg_log10_adj_p",
"mcc",
"nmcc",
"num_obs",
"seed",
"sens",
"spec",
"output_prefix",
"p_value",
"pair_id",
Expand Down
8 changes: 6 additions & 2 deletions R/ife_ml.R
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ removeTopFeats <- function(ml_input_tibble, top_feat_tibble) {

#' runIFE
#' Removes top features identified by ML models and retrains iteratively;
#' returns nMCC at each iteration.
#' returns MCC at each iteration.
#'
#' @param ml_input_tibble An ML-ready tibble generated by `loadMLInputTibble()`
#' @param by_num [bool] Set to `TRUE` if removing top features as a percentage
Expand Down Expand Up @@ -83,6 +83,7 @@ runIFE <- function(
num_obs_vec <- c()
res_prop_vec <- c()
fit_mixture_vec <- c()
mcc_vec <- c()
nmcc_vec <- c()
n_feats_removed_vec <- c()
total_feats_removed_vec <- c()
Expand Down Expand Up @@ -154,6 +155,9 @@ runIFE <- function(
fit_mixture_vec[i] <- ml_res$performance_tibble |>
dplyr::select(fit_mixture) |>
as.numeric()
mcc_vec[i] <- ml_res$performance_tibble |>
dplyr::select(mcc) |>
as.numeric()
nmcc_vec[i] <- ml_res$performance_tibble |>
dplyr::select(nmcc) |>
as.numeric()
Expand Down Expand Up @@ -247,7 +251,7 @@ runIFE <- function(
percent_removed = c(0, percent_removal_vec),
removal_type = rep(removal_type, length(num_obs_vec)),
num_obs = num_obs_vec, res_prop = res_prop_vec,
fit_mixture = fit_mixture_vec, nmcc = nmcc_vec,
fit_mixture = fit_mixture_vec, mcc = mcc_vec, nmcc = nmcc_vec,
n_feats_removed = n_feats_removed_vec,
total_feats_removed = total_feats_removed_vec, run_time_sec = run_time_vec
)
Expand Down
2 changes: 1 addition & 1 deletion R/plot_ml.R
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ plotTopFeatsVI <- function(fit, n_top_feats = 10) {
#' or "n_fold"
#' @param y_default_eval [chr] y value of default evaluation plot. It can be
#' "avg_runtime_sec" or one of the following performance metrics:
#' "avg_f1_score", "avg_log2_apop", "avg_bal_acc", or "avg_nmcc"
#' "avg_f1_score", "avg_log2_apop", "avg_bal_acc", "avg_mcc", or "avg_nmcc"
#' @param xlab [chr] Label for x axis
#' @param ylab [chr] Label for y axis
#' @return A `ggplot2` scatterplot (performance metric or runtime vs.
Expand Down
19 changes: 11 additions & 8 deletions R/run_ml_pipeline.R
Original file line number Diff line number Diff line change
Expand Up @@ -297,10 +297,11 @@ runMLPipeline <- function(
log2_apop <- .calculateLog2APOP(test_data_plus_predictions)
}

mcc <- .calculateMCC(test_data_plus_predictions)
nmcc <- .calculatenMCC(test_data_plus_predictions)

if (verbose) {
message(paste("Normalized Matthews correlation coefficient:", nmcc))
message(paste("Matthews correlation coefficient:", mcc, "| nMCC:", nmcc))
}

top_feat_tibble <- extractTopFeats(fit,
Expand Down Expand Up @@ -362,7 +363,7 @@ runMLPipeline <- function(
performance_tibble <- tibble::tibble(
num_obs = num_obs_ml_input_tibble,
n_feat = getNumFeat(ml_input_tibble), model, train_prop = split[1],
val_prop = split[2], n_fold, nmcc, run_time_sec,
val_prop = split[2], n_fold, mcc, nmcc, run_time_sec, seed,
date = as.character(Sys.Date())
)

Expand All @@ -383,18 +384,20 @@ runMLPipeline <- function(
) |>
tibble::add_column(bal_acc, .after = "nmcc") |>
tibble::add_column(f1, .after = "nmcc") |>
tibble::add_column(log2_apop, .after = "nmcc")
tibble::add_column(log2_apop, .after = "nmcc") |>
tibble::add_column(sens, .after = "nmcc") |>
tibble::add_column(spec, .after = "nmcc")
}

if (model == "LR") {
performance_tibble <- performance_tibble |>
tibble::add_column(fit_penalty, .before = "nmcc") |>
tibble::add_column(fit_mixture, .before = "nmcc")
tibble::add_column(fit_penalty, .before = "mcc") |>
tibble::add_column(fit_mixture, .before = "mcc")
} else if (model == "RF" || model == "BT") {
performance_tibble <- performance_tibble |>
tibble::add_column(fit_trees, .before = "nmcc") |>
tibble::add_column(fit_mtry, .before = "nmcc") |>
tibble::add_column(fit_min_n, .before = "nmcc")
tibble::add_column(fit_trees, .before = "mcc") |>
tibble::add_column(fit_mtry, .before = "mcc") |>
tibble::add_column(fit_min_n, .before = "mcc")
}

if (external_test_data) {
Expand Down
2 changes: 1 addition & 1 deletion README.Rmd
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ This uses specific matrices to test whether ML models can predict resistance aga

- **Data preparation**: Load Parquet files and prepare ML-ready datasets
- **Model training**: User-customizable logistic regression via tidymodels
- **Evaluation**: nMCC, F1, balanced accuracy, AuPRC, and confusion matrices
- **Evaluation**: MCC, nMCC, F1, balanced accuracy, AuPRC, and confusion matrices
- **Feature importance**: Extract and rank predictive features

See the [package vignette](https://jravilab.github.io/amRml/articles/intro.html) for detailed usage.
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ associated with MDR.
- **Data preparation**: Load Parquet files and prepare ML-ready datasets
- **Model training**: User-customizable logistic regression via
tidymodels
- **Evaluation**: nMCC, F1, balanced accuracy, AuPRC, and confusion
- **Evaluation**: MCC, nMCC, F1, balanced accuracy, AuPRC, and confusion
matrices
- **Feature importance**: Extract and rank predictive features

Expand Down
9 changes: 5 additions & 4 deletions doc/intro.R
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,8 @@ knitr::opts_chunk$set(

## ----metrics------------------------------------------------------------------
# # Individual metrics
# nmcc <- calculatenMCC(predictions) # Normalized MCC (0-1 scale)
# mcc <- calculateMCC(predictions) # Matthews correlation coefficient (-1 to 1)
# nmcc <- calculatenMCC(predictions) # Normalized MCC (0 to 1)
# f1 <- calculateF1(predictions) # F1 score
# bal_acc <- calculateBalAcc(predictions) # Balanced accuracy
# auprc <- calculateAUPRC(predictions) # Area under PR curve
Expand Down Expand Up @@ -150,7 +151,7 @@ knitr::opts_chunk$set(
# verbose = TRUE
# )
#
# # Results include nMCC at each iteration
# # Results include MCC and nMCC at each iteration
# ife_results$ife_performance_tibble
# ife_results$feats_removed # If return_feats = TRUE
#
Expand Down Expand Up @@ -233,8 +234,8 @@ knitr::opts_chunk$set(
# )
#
# # 5. Compare real vs baseline performance
# cat("Real nMCC:", results$performance_tibble$nmcc, "\n")
# cat("Baseline nMCC:", baseline_results$performance_tibble$nmcc, "\n")
# cat("Real MCC:", results$performance_tibble$mcc, "| nMCC:", results$performance_tibble$nmcc, "\n")
# cat("Baseline MCC:", baseline_results$performance_tibble$mcc, "| nMCC:", baseline_results$performance_tibble$nmcc, "\n")
#
# # 6. Run iterative feature elimination
# ife_results <- runIFE(
Expand Down
Loading
Loading