Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
* Fixed test in `test-ppc-distributions.R` that incorrectly used `ppc_dens()` instead of `ppd_dens()` when testing PPD functions
* New functions `mcmc_dots` and `mcmc_dots_by_chain` for dot plots of MCMC draws by @behramulukir (#402)
* Default to `quantiles=100` for all dot plots by @behramulukir (#402)
* Make diagnostic color scale helpers handle `"neff"` and `"neff_ratio"` explicitly, avoiding reliance on partial matching.
* Replace `stopifnot()` checks with `abort()` and descriptive input-validation messages in helper and diagnostics internals.
* Use `"neff_ratio"` consistently in diagnostic color scale helpers to avoid relying on partial matching of `"neff"`.

# bayesplot 1.15.0
Expand Down
56 changes: 41 additions & 15 deletions R/helpers-mcmc.R
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,9 @@ prepare_mcmc_array <- function(x,
x <- as.array(x)
}

stopifnot(is.matrix(x) || is.array(x))
if (!(is.matrix(x) || is.array(x))) {
abort("'x' must be a matrix or array.")
}
if (is.array(x) && !(length(dim(x)) %in% c(2,3))) {
abort("Arrays should have 2 or 3 dimensions. See help('MCMC-overview').")
}
Expand Down Expand Up @@ -80,9 +82,11 @@ select_parameters <-
patterns = character(),
complete_pars = character()) {

stopifnot(is.character(explicit),
is.character(patterns),
is.character(complete_pars))
if (!is.character(explicit) ||
!is.character(patterns) ||
!is.character(complete_pars)) {
abort("'explicit', 'patterns', and 'complete_pars' must be character vectors.")
}

if (!length(explicit) && !length(patterns)) {
return(complete_pars)
Expand Down Expand Up @@ -132,7 +136,9 @@ melt_mcmc.mcmc_array <- function(x,
value.name = "Value",
as.is = TRUE,
...) {
stopifnot(is_mcmc_array(x))
if (!is_mcmc_array(x)) {
abort("'x' must be an mcmc_array.")
}

long <- reshape2::melt(
data = x,
Expand Down Expand Up @@ -168,7 +174,9 @@ melt_mcmc.matrix <- function(x,
#' @param parnames Character vector of parameter names
#' @return x with a modified dimnames.
set_mcmc_dimnames <- function(x, parnames) {
stopifnot(is_3d_array(x))
if (!is_3d_array(x)) {
abort("'x' must be a 3-D array.")
}
dimnames(x) <- list(
Iteration = seq_len(nrow(x)),
Chain = seq_len(ncol(x)),
Expand Down Expand Up @@ -201,7 +209,9 @@ is_df_with_chain <- function(x) {
}

validate_df_with_chain <- function(x) {
stopifnot(is_df_with_chain(x))
if (!is_df_with_chain(x)) {
abort("'x' must be a data frame with a chain column.")
}
x <- as.data.frame(x)
if (!is.null(x$chain)) {
if (is.null(x$Chain)) {
Expand Down Expand Up @@ -311,7 +321,9 @@ parameter_names <- function(x) UseMethod("parameter_names")

#' @export
parameter_names.array <- function(x) {
stopifnot(is_3d_array(x))
if (!is_3d_array(x)) {
abort("'x' must be a 3-D array.")
}
dimnames(x)[[3]] %||% abort("No parameter names found.")
}
#' @export
Expand Down Expand Up @@ -350,12 +362,16 @@ is_mcmc_array <- function(x) {

# Check if 3-D array has multiple chains
has_multiple_chains <- function(x) {
stopifnot(is_3d_array(x))
if (!is_3d_array(x)) {
abort("'x' must be a 3-D array.")
}
isTRUE(dim(x)[2] > 1)
}
# Check if 3-D array has multiple parameters
has_multiple_params <- function(x) {
stopifnot(is_3d_array(x))
if (!is_3d_array(x)) {
abort("'x' must be a 3-D array.")
}
isTRUE(dim(x)[3] > 1)
}

Expand Down Expand Up @@ -412,7 +428,9 @@ apply_transformations.matrix <- function(x, ..., transformations = list()) {

#' @export
apply_transformations.array <- function(x, ..., transformations = list()) {
stopifnot(length(dim(x)) == 3)
if (length(dim(x)) != 3) {
abort("'x' must be a 3-D array.")
}
pars <- dimnames(x)[[3]]
x_transforms <- validate_transformations(transformations, pars)
for (p in names(x_transforms)) {
Expand All @@ -423,7 +441,9 @@ apply_transformations.array <- function(x, ..., transformations = list()) {
}

rename_transformed_pars <- function(pars, transformations) {
stopifnot(is.character(pars), is.list(transformations))
if (!is.character(pars) || !is.list(transformations)) {
abort("'pars' must be a character vector and 'transformations' must be a list.")
}
has_names <- sapply(transformations, is.character)
if (any(has_names)) {
nms <- names(which(has_names))
Expand Down Expand Up @@ -456,18 +476,24 @@ num_chains.mcmc_array <- function(x, ...) dim(x)[2]
num_iters.mcmc_array <- function(x, ...) dim(x)[1]
#' @export
num_params.data.frame <- function(x, ...) {
stopifnot("Parameter" %in% colnames(x))
if (!("Parameter" %in% colnames(x))) {
abort("'x' must contain a 'Parameter' column.")
}
length(unique(x$Parameter))
}
#' @export
num_chains.data.frame <- function(x, ...) {
stopifnot("Chain" %in% colnames(x))
if (!("Chain" %in% colnames(x))) {
abort("'x' must contain a 'Chain' column.")
}
length(unique(x$Chain))
}
#' @export
num_iters.data.frame <- function(x, ...) {
cols <- colnames(x)
stopifnot("Iteration" %in% cols || "Draws" %in% cols)
if (!("Iteration" %in% cols || "Draws" %in% cols)) {
abort("'x' must contain an 'Iteration' or 'Draws' column.")
}

if ("Iteration" %in% cols) {
n <- length(unique(x$Iteration))
Expand Down
32 changes: 24 additions & 8 deletions R/helpers-ppc.R
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,9 @@ all_counts <- function(x, ...) {
#' @return Either throws an error or returns a numeric vector.
#' @noRd
validate_y <- function(y) {
stopifnot(is.numeric(y))
if (!is.numeric(y)) {
abort("'y' must be numeric.")
}

if (!(inherits(y, "ts") && is.null(dim(y)))) {
if (!is_vector_or_1Darray(y)) {
Expand Down Expand Up @@ -67,9 +69,13 @@ validate_y <- function(y) {
#' @noRd
validate_predictions <- function(predictions, n_obs = NULL) {
# sanity checks
stopifnot(is.matrix(predictions), is.numeric(predictions))
if (!is.matrix(predictions) || !is.numeric(predictions)) {
abort("'predictions' must be a numeric matrix.")
}
if (!is.null(n_obs)) {
stopifnot(length(n_obs) == 1, n_obs == as.integer(n_obs))
if (length(n_obs) != 1 || !identical(n_obs, as.integer(n_obs))) {
abort("'n_obs' must be a single integer.")
}
}

if (is.integer(predictions)) {
Expand Down Expand Up @@ -111,7 +117,9 @@ validate_pit <- function(pit) {
abort("NAs not allowed in 'pit'.")
}

stopifnot(is.numeric(pit))
if (!is.numeric(pit)) {
abort("'pit' must be numeric.")
}

if (!is_vector_or_1Darray(pit)) {
abort("'pit' must be a vector or 1D array.")
Expand All @@ -137,8 +145,12 @@ validate_pit <- function(pit) {
#' @noRd
validate_group <- function(group, n_obs) {
# sanity checks
stopifnot(is.vector(group) || is.factor(group),
length(n_obs) == 1, n_obs == as.integer(n_obs))
if (!(is.vector(group) || is.factor(group))) {
abort("'group' must be a vector or factor.")
}
if (length(n_obs) != 1 || !identical(n_obs, as.integer(n_obs))) {
abort("'n_obs' must be a single integer.")
}

if (!is.factor(group)) {
group <- as.factor(group)
Expand Down Expand Up @@ -175,7 +187,9 @@ validate_x <- function(x = NULL, y, unique_x = FALSE) {
}
}

stopifnot(is.numeric(x))
if (!is.numeric(x)) {
abort("'x' must be numeric.")
}

if (!is_vector_or_1Darray(x)) {
abort("'x' must be a vector or 1D array.")
Expand All @@ -191,7 +205,9 @@ validate_x <- function(x = NULL, y, unique_x = FALSE) {
}

if (unique_x) {
stopifnot(identical(length(x), length(unique(x))))
if (!identical(length(x), length(unique(x)))) {
abort("'x' must contain only unique values.")
}
}

unname(x)
Expand Down
16 changes: 12 additions & 4 deletions R/mcmc-diagnostics.R
Original file line number Diff line number Diff line change
Expand Up @@ -389,7 +389,9 @@ diagnostic_factor.neff_ratio <- function(x, ..., breaks = c(0.1, 0.5)) {

diagnostic_data_frame <- function(x) {
x <- auto_name(sort(x))
stopifnot(!anyDuplicated(names(x)))
if (anyDuplicated(names(x))) {
abort("Diagnostic values must have unique parameter names.")
}
diagnostic <- class(x)[1]

d <- tibble::tibble(
Expand Down Expand Up @@ -584,7 +586,9 @@ drop_NAs_and_warn <- function(x) {
# @param x object returned by prepare_mcmc_array
# @param lags user's 'lags' argument
acf_data <- function(x, lags) {
stopifnot(is_mcmc_array(x))
if (!is_mcmc_array(x)) {
abort("'x' must be an mcmc_array.")
}
n_iter <- num_iters(x)
n_chain <- num_chains(x)
n_param <- num_params(x)
Expand Down Expand Up @@ -628,7 +632,9 @@ new_rhat <- function(x) {
}

validate_rhat <- function(x) {
stopifnot(is.numeric(x), !is.list(x), !is.array(x))
if (!is.numeric(x) || is.list(x) || is.array(x)) {
abort("'rhat' must be a numeric vector.")
}
if (any(x < 0, na.rm = TRUE)) {
abort("All 'rhat' values must be positive.")
}
Expand Down Expand Up @@ -656,7 +662,9 @@ new_neff_ratio <- function(x) {
}

validate_neff_ratio <- function(x) {
stopifnot(is.numeric(x), !is.list(x), !is.array(x))
if (!is.numeric(x) || is.list(x) || is.array(x)) {
abort("'neff_ratio' must be a numeric vector.")
}
if (any(x < 0, na.rm = TRUE)) {
abort("All neff ratios must be positive.")
}
Expand Down
Loading