From 7f8d27f46fae0921b67eb85452b726f62e2dd78e Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Wed, 22 Apr 2026 10:13:30 -0700 Subject: [PATCH] No public description PiperOrigin-RevId: 903911142 --- learned_optimization/jax_utils.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/learned_optimization/jax_utils.py b/learned_optimization/jax_utils.py index 144a217..e6a95b3 100644 --- a/learned_optimization/jax_utils.py +++ b/learned_optimization/jax_utils.py @@ -22,6 +22,13 @@ import jax.numpy as jnp import numpy as onp +try: + # JAX v0.10.0 or newer + from jax.extend.core import unsafe_am_i_under_a_jit_DO_NOT_USE # pylint: disable=g-import-not-at-top +except ImportError: + # JAX v0.9.2 or older + from jax.core import unsafe_am_i_under_a_jit_DO_NOT_USE # pylint: disable=g-import-not-at-top + def maybe_static_cond(pred, true_fn, false_fn, val): """Conditional that first checks if pred can be determined at compile time.""" @@ -51,7 +58,7 @@ def in_jit() -> bool: jax.core.thread_local_state.trace_state.trace_stack # type: ignore ) - return jax.core.unsafe_am_i_under_a_jit_DO_NOT_USE() + return unsafe_am_i_under_a_jit_DO_NOT_USE() Carry = TypeVar("Carry")