diff --git a/NEWS.md b/NEWS.md index bf6aaef8..a7537c16 100644 --- a/NEWS.md +++ b/NEWS.md @@ -10,6 +10,8 @@ * Fixed test in `test-ppc-distributions.R` that incorrectly used `ppc_dens()` instead of `ppd_dens()` when testing PPD functions * New functions `mcmc_dots` and `mcmc_dots_by_chain` for dot plots of MCMC draws by @behramulukir (#402) * Default to `quantiles=100` for all dot plots by @behramulukir (#402) +* Make diagnostic color scale helpers handle `"neff"` and `"neff_ratio"` explicitly, avoiding reliance on partial matching. +* Replace `stopifnot()` checks with `abort()` and descriptive input-validation messages in helper and diagnostics internals. * Use `"neff_ratio"` consistently in diagnostic color scale helpers to avoid relying on partial matching of `"neff"`. # bayesplot 1.15.0 diff --git a/R/helpers-mcmc.R b/R/helpers-mcmc.R index 41e2c4ee..c133d35c 100644 --- a/R/helpers-mcmc.R +++ b/R/helpers-mcmc.R @@ -24,7 +24,9 @@ prepare_mcmc_array <- function(x, x <- as.array(x) } - stopifnot(is.matrix(x) || is.array(x)) + if (!(is.matrix(x) || is.array(x))) { + abort("'x' must be a matrix or array.") + } if (is.array(x) && !(length(dim(x)) %in% c(2,3))) { abort("Arrays should have 2 or 3 dimensions. See help('MCMC-overview').") } @@ -80,9 +82,11 @@ select_parameters <- patterns = character(), complete_pars = character()) { - stopifnot(is.character(explicit), - is.character(patterns), - is.character(complete_pars)) + if (!is.character(explicit) || + !is.character(patterns) || + !is.character(complete_pars)) { + abort("'explicit', 'patterns', and 'complete_pars' must be character vectors.") + } if (!length(explicit) && !length(patterns)) { return(complete_pars) @@ -132,7 +136,9 @@ melt_mcmc.mcmc_array <- function(x, value.name = "Value", as.is = TRUE, ...) { - stopifnot(is_mcmc_array(x)) + if (!is_mcmc_array(x)) { + abort("'x' must be an mcmc_array.") + } long <- reshape2::melt( data = x, @@ -168,7 +174,9 @@ melt_mcmc.matrix <- function(x, #' @param parnames Character vector of parameter names #' @return x with a modified dimnames. set_mcmc_dimnames <- function(x, parnames) { - stopifnot(is_3d_array(x)) + if (!is_3d_array(x)) { + abort("'x' must be a 3-D array.") + } dimnames(x) <- list( Iteration = seq_len(nrow(x)), Chain = seq_len(ncol(x)), @@ -201,7 +209,9 @@ is_df_with_chain <- function(x) { } validate_df_with_chain <- function(x) { - stopifnot(is_df_with_chain(x)) + if (!is_df_with_chain(x)) { + abort("'x' must be a data frame with a chain column.") + } x <- as.data.frame(x) if (!is.null(x$chain)) { if (is.null(x$Chain)) { @@ -311,7 +321,9 @@ parameter_names <- function(x) UseMethod("parameter_names") #' @export parameter_names.array <- function(x) { - stopifnot(is_3d_array(x)) + if (!is_3d_array(x)) { + abort("'x' must be a 3-D array.") + } dimnames(x)[[3]] %||% abort("No parameter names found.") } #' @export @@ -350,12 +362,16 @@ is_mcmc_array <- function(x) { # Check if 3-D array has multiple chains has_multiple_chains <- function(x) { - stopifnot(is_3d_array(x)) + if (!is_3d_array(x)) { + abort("'x' must be a 3-D array.") + } isTRUE(dim(x)[2] > 1) } # Check if 3-D array has multiple parameters has_multiple_params <- function(x) { - stopifnot(is_3d_array(x)) + if (!is_3d_array(x)) { + abort("'x' must be a 3-D array.") + } isTRUE(dim(x)[3] > 1) } @@ -412,7 +428,9 @@ apply_transformations.matrix <- function(x, ..., transformations = list()) { #' @export apply_transformations.array <- function(x, ..., transformations = list()) { - stopifnot(length(dim(x)) == 3) + if (length(dim(x)) != 3) { + abort("'x' must be a 3-D array.") + } pars <- dimnames(x)[[3]] x_transforms <- validate_transformations(transformations, pars) for (p in names(x_transforms)) { @@ -423,7 +441,9 @@ apply_transformations.array <- function(x, ..., transformations = list()) { } rename_transformed_pars <- function(pars, transformations) { - stopifnot(is.character(pars), is.list(transformations)) + if (!is.character(pars) || !is.list(transformations)) { + abort("'pars' must be a character vector and 'transformations' must be a list.") + } has_names <- sapply(transformations, is.character) if (any(has_names)) { nms <- names(which(has_names)) @@ -456,18 +476,24 @@ num_chains.mcmc_array <- function(x, ...) dim(x)[2] num_iters.mcmc_array <- function(x, ...) dim(x)[1] #' @export num_params.data.frame <- function(x, ...) { - stopifnot("Parameter" %in% colnames(x)) + if (!("Parameter" %in% colnames(x))) { + abort("'x' must contain a 'Parameter' column.") + } length(unique(x$Parameter)) } #' @export num_chains.data.frame <- function(x, ...) { - stopifnot("Chain" %in% colnames(x)) + if (!("Chain" %in% colnames(x))) { + abort("'x' must contain a 'Chain' column.") + } length(unique(x$Chain)) } #' @export num_iters.data.frame <- function(x, ...) { cols <- colnames(x) - stopifnot("Iteration" %in% cols || "Draws" %in% cols) + if (!("Iteration" %in% cols || "Draws" %in% cols)) { + abort("'x' must contain an 'Iteration' or 'Draws' column.") + } if ("Iteration" %in% cols) { n <- length(unique(x$Iteration)) diff --git a/R/helpers-ppc.R b/R/helpers-ppc.R index 5206b9af..c52764db 100644 --- a/R/helpers-ppc.R +++ b/R/helpers-ppc.R @@ -38,7 +38,9 @@ all_counts <- function(x, ...) { #' @return Either throws an error or returns a numeric vector. #' @noRd validate_y <- function(y) { - stopifnot(is.numeric(y)) + if (!is.numeric(y)) { + abort("'y' must be numeric.") + } if (!(inherits(y, "ts") && is.null(dim(y)))) { if (!is_vector_or_1Darray(y)) { @@ -67,9 +69,13 @@ validate_y <- function(y) { #' @noRd validate_predictions <- function(predictions, n_obs = NULL) { # sanity checks - stopifnot(is.matrix(predictions), is.numeric(predictions)) + if (!is.matrix(predictions) || !is.numeric(predictions)) { + abort("'predictions' must be a numeric matrix.") + } if (!is.null(n_obs)) { - stopifnot(length(n_obs) == 1, n_obs == as.integer(n_obs)) + if (length(n_obs) != 1 || !identical(n_obs, as.integer(n_obs))) { + abort("'n_obs' must be a single integer.") + } } if (is.integer(predictions)) { @@ -111,7 +117,9 @@ validate_pit <- function(pit) { abort("NAs not allowed in 'pit'.") } - stopifnot(is.numeric(pit)) + if (!is.numeric(pit)) { + abort("'pit' must be numeric.") + } if (!is_vector_or_1Darray(pit)) { abort("'pit' must be a vector or 1D array.") @@ -137,8 +145,12 @@ validate_pit <- function(pit) { #' @noRd validate_group <- function(group, n_obs) { # sanity checks - stopifnot(is.vector(group) || is.factor(group), - length(n_obs) == 1, n_obs == as.integer(n_obs)) + if (!(is.vector(group) || is.factor(group))) { + abort("'group' must be a vector or factor.") + } + if (length(n_obs) != 1 || !identical(n_obs, as.integer(n_obs))) { + abort("'n_obs' must be a single integer.") + } if (!is.factor(group)) { group <- as.factor(group) @@ -175,7 +187,9 @@ validate_x <- function(x = NULL, y, unique_x = FALSE) { } } - stopifnot(is.numeric(x)) + if (!is.numeric(x)) { + abort("'x' must be numeric.") + } if (!is_vector_or_1Darray(x)) { abort("'x' must be a vector or 1D array.") @@ -191,7 +205,9 @@ validate_x <- function(x = NULL, y, unique_x = FALSE) { } if (unique_x) { - stopifnot(identical(length(x), length(unique(x)))) + if (!identical(length(x), length(unique(x)))) { + abort("'x' must contain only unique values.") + } } unname(x) diff --git a/R/mcmc-diagnostics.R b/R/mcmc-diagnostics.R index 2a71b8e0..4692d3e8 100644 --- a/R/mcmc-diagnostics.R +++ b/R/mcmc-diagnostics.R @@ -389,7 +389,9 @@ diagnostic_factor.neff_ratio <- function(x, ..., breaks = c(0.1, 0.5)) { diagnostic_data_frame <- function(x) { x <- auto_name(sort(x)) - stopifnot(!anyDuplicated(names(x))) + if (anyDuplicated(names(x))) { + abort("Diagnostic values must have unique parameter names.") + } diagnostic <- class(x)[1] d <- tibble::tibble( @@ -584,7 +586,9 @@ drop_NAs_and_warn <- function(x) { # @param x object returned by prepare_mcmc_array # @param lags user's 'lags' argument acf_data <- function(x, lags) { - stopifnot(is_mcmc_array(x)) + if (!is_mcmc_array(x)) { + abort("'x' must be an mcmc_array.") + } n_iter <- num_iters(x) n_chain <- num_chains(x) n_param <- num_params(x) @@ -628,7 +632,9 @@ new_rhat <- function(x) { } validate_rhat <- function(x) { - stopifnot(is.numeric(x), !is.list(x), !is.array(x)) + if (!is.numeric(x) || is.list(x) || is.array(x)) { + abort("'rhat' must be a numeric vector.") + } if (any(x < 0, na.rm = TRUE)) { abort("All 'rhat' values must be positive.") } @@ -656,7 +662,9 @@ new_neff_ratio <- function(x) { } validate_neff_ratio <- function(x) { - stopifnot(is.numeric(x), !is.list(x), !is.array(x)) + if (!is.numeric(x) || is.list(x) || is.array(x)) { + abort("'neff_ratio' must be a numeric vector.") + } if (any(x < 0, na.rm = TRUE)) { abort("All neff ratios must be positive.") }