diff --git a/ultraplot/axes/__init__.py b/ultraplot/axes/__init__.py index 1c8163dcf..2838f21d0 100644 --- a/ultraplot/axes/__init__.py +++ b/ultraplot/axes/__init__.py @@ -19,6 +19,10 @@ from .shared import _SharedAxes # noqa: F401 from .three import ThreeAxes # noqa: F401 +_ASTRO_AXES_CLASS = None +_ASTROPY_WCS_TYPES = () +_ASTRO_LOADED = False + # Prevent importing module names and set order of appearance for objects __all__ = [ "Axes", @@ -34,15 +38,63 @@ # NOTE: We integrate with cartopy and basemap rather than using matplotlib's # native projection system. Therefore axes names are not part of public API. _cls_dict = {} # track valid names -for _cls in (CartesianAxes, PolarAxes, _CartopyAxes, _BasemapAxes, ThreeAxes): + + +def _refresh_cls_table(): + global _cls_table + _cls_table = "\n".join( + " " + + key + + " " * (max(map(len, _cls_dict)) - len(key) + 7) + + ("GeoAxes" if cls.__name__[:1] == "_" else cls.__name__) + for key, cls in _cls_dict.items() + ) + + +def _register_projection_class(_cls): for _name in (_cls._name, *_cls._name_aliases): with context._state_context(_cls, name="ultraplot_" + _name): - mproj.register_projection(_cls) + if "ultraplot_" + _name not in mproj.get_projection_names(): + mproj.register_projection(_cls) _cls_dict[_name] = _cls -_cls_table = "\n".join( - " " - + key - + " " * (max(map(len, _cls_dict)) - len(key) + 7) - + ("GeoAxes" if cls.__name__[:1] == "_" else cls.__name__) - for key, cls in _cls_dict.items() -) + _refresh_cls_table() + + +for _cls in (CartesianAxes, PolarAxes, _CartopyAxes, _BasemapAxes, ThreeAxes): + _register_projection_class(_cls) + + +def _load_astro_axes(): + global _ASTROPY_WCS_TYPES, _ASTRO_AXES_CLASS, _ASTRO_LOADED + if _ASTRO_LOADED: + return _ASTRO_AXES_CLASS + from .astro import ASTROPY_WCS_TYPES as _types, AstroAxes as _astro_axes + + _ASTRO_LOADED = True + _ASTROPY_WCS_TYPES = _types + _ASTRO_AXES_CLASS = _astro_axes + if _ASTRO_AXES_CLASS is not None: + if "AstroAxes" not in __all__: + __all__.append("AstroAxes") + _register_projection_class(_ASTRO_AXES_CLASS) + return _ASTRO_AXES_CLASS + + +def get_astro_axes_class(*, load=False): + if load: + _load_astro_axes() + return _ASTRO_AXES_CLASS + + +def get_astropy_wcs_types(*, load=False): + if load: + _load_astro_axes() + return _ASTROPY_WCS_TYPES + + +def __getattr__(name): + if name == "AstroAxes": + return get_astro_axes_class(load=True) + if name == "ASTROPY_WCS_TYPES": + return get_astropy_wcs_types(load=True) + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/ultraplot/axes/astro.py b/ultraplot/axes/astro.py new file mode 100644 index 000000000..8f124d102 --- /dev/null +++ b/ultraplot/axes/astro.py @@ -0,0 +1,381 @@ +#!/usr/bin/env python3 +""" +Astropy WCS axes integration. +""" + +import inspect +import numbers +from collections.abc import Iterable + +from ..config import rc +from ..utils import _not_none +from . import base +from .cartesian import CartesianAxes + +try: + from astropy.visualization.wcsaxes.core import WCSAxes + from astropy.wcs.wcsapi import BaseHighLevelWCS, BaseLowLevelWCS +except ImportError: # pragma: no cover + WCSAxes = None + ASTROPY_WCS_TYPES = () +else: + ASTROPY_WCS_TYPES = (BaseLowLevelWCS, BaseHighLevelWCS) + + +if WCSAxes is not None: + + class AstroAxes(base.Axes, WCSAxes): + """ + Native UltraPlot wrapper for Astropy WCS axes. + """ + + _name = "astro" + _name_aliases = ("astropy", "wcs") + + def _update_background(self, **kwargs): + kw_face, kw_edge = rc._get_background_props(**kwargs) + self.patch.update(kw_face) + self.patch.update(kw_edge) + + def _get_coord_helper(self, axis): + index = {"x": 0, "y": 1}[axis] + try: + return self.coords[index] + except IndexError: + return None + + def _share_coord_signature(self, axis): + coord = self._get_coord_helper(axis) + if coord is None: + return None + unit = getattr(coord, "coord_unit", None) + if unit is not None and hasattr(unit, "to_string"): + unit = unit.to_string() + return ( + getattr(coord, "coord_type", None), + unit, + getattr(coord, "default_label", None), + ) + + def _update_coord_locator(self, axis, locator): + coord = self._get_coord_helper(axis) + if coord is None or locator is None: + return + if isinstance(locator, numbers.Real) and not isinstance(locator, bool): + coord.set_ticks(number=locator) + return + if isinstance(locator, Iterable) and not isinstance(locator, (str, bytes)): + coord.set_ticks(values=locator) + return + raise TypeError( + "AstroAxes.format only supports numeric or iterable tick locators. " + f"Received {locator!r}. Use ax.coords[...] for advanced locator setup." + ) + + def _update_coord_formatter(self, axis, formatter): + coord = self._get_coord_helper(axis) + if coord is None or formatter is None: + return + if isinstance(formatter, str) or callable(formatter): + coord.set_major_formatter(formatter) + return + raise TypeError( + "AstroAxes.format only supports string or callable tick formatters. " + f"Received {formatter!r}. Use ax.coords[...] for advanced formatter setup." + ) + + def _update_coord_ticks( + self, + axis, + *, + grid=None, + gridcolor=None, + tickcolor=None, + ticklen=None, + tickwidth=None, + tickdir=None, + ticklabelpad=None, + ticklabelcolor=None, + ticklabelsize=None, + ticklabelweight=None, + tickminor=None, + ): + coord = self._get_coord_helper(axis) + if coord is None: + return + if tickminor is not None: + coord.display_minor_ticks(bool(tickminor)) + major = {} + if ticklen is not None: + major["length"] = ticklen + if tickwidth is not None: + major["width"] = tickwidth + if tickcolor is not None: + major["color"] = tickcolor + if tickdir is not None: + major["direction"] = tickdir + if ticklabelpad is not None: + major["pad"] = ticklabelpad + if ticklabelcolor is not None: + major["labelcolor"] = ticklabelcolor + if ticklabelsize is not None: + major["labelsize"] = ticklabelsize + if major: + coord.tick_params(**major) + if ticklabelweight is not None: + coord.set_ticklabel(weight=ticklabelweight) + if grid is not None or gridcolor is not None: + kw = {} + if gridcolor is not None: + kw["color"] = gridcolor + coord.grid(draw_grid=grid, **kw) + + def _update_axis_label( + self, + axis, + *, + label=None, + labelpad=None, + labelcolor=None, + labelsize=None, + labelweight=None, + label_kw=None, + ): + coord = self._get_coord_helper(axis) + if coord is None: + return + if label is None and not any( + value is not None + for value in (labelpad, labelcolor, labelsize, labelweight) + ): + return + setter = getattr(self, f"set_{axis}label") + getter = getattr(self, f"get_{axis}label") + kw = dict(label_kw or {}) + if labelcolor is not None: + kw["color"] = labelcolor + if labelsize is not None: + kw["size"] = labelsize + if labelweight is not None: + kw["weight"] = labelweight + if labelpad is not None: + kw["labelpad"] = labelpad + setter(getter() if label is None else label, **kw) + + def _update_limits(self, axis, *, lim=None, min_=None, max_=None, reverse=None): + lo = hi = None + if lim is not None: + lo, hi = lim + lo = _not_none(min_=min_, lim_0=lo) + hi = _not_none(max_=max_, lim_1=hi) + if lo is not None or hi is not None: + get_lim = getattr(self, f"get_{axis}lim") + set_lim = getattr(self, f"set_{axis}lim") + cur_lo, cur_hi = get_lim() + set_lim((_not_none(lo, cur_lo), _not_none(hi, cur_hi))) + if reverse is not None: + inverted = getattr(self, f"{axis}axis_inverted")() + if bool(reverse) != bool(inverted): + getattr(self, f"invert_{axis}axis")() + + def _share_axis_limits(self, other, which): + self._shared_axes[which].join(self, other) + axis = getattr(self, f"{which}axis") + other_axis = getattr(other, f"{which}axis") + setattr(self, f"_share{which}", other) + axis.major = other_axis.major + axis.minor = other_axis.minor + get_lim = getattr(other, f"get_{which}lim") + set_lim = getattr(self, f"set_{which}lim") + get_auto = getattr(other, f"get_autoscale{which}_on") + set_lim(*get_lim(), emit=False, auto=get_auto()) + axis._scale = other_axis._scale + + def _sharex_setup(self, sharex, *, labels=True, limits=True): + super()._sharex_setup(sharex) + level = ( + 3 + if self._panel_sharex_group and self._is_panel_group_member(sharex) + else self.figure._sharex + ) + if level not in range(5): + raise ValueError(f"Invalid sharing level sharex={level!r}.") + if sharex in (None, self) or not isinstance(sharex, AstroAxes): + return + if level > 0 and labels: + self._sharex = sharex + if level > 1 and limits: + self._share_axis_limits(sharex, "x") + + def _sharey_setup(self, sharey, *, labels=True, limits=True): + super()._sharey_setup(sharey) + level = ( + 3 + if self._panel_sharey_group and self._is_panel_group_member(sharey) + else self.figure._sharey + ) + if level not in range(5): + raise ValueError(f"Invalid sharing level sharey={level!r}.") + if sharey in (None, self) or not isinstance(sharey, AstroAxes): + return + if level > 0 and labels: + self._sharey = sharey + if level > 1 and limits: + self._share_axis_limits(sharey, "y") + + def _is_ticklabel_on(self, side: str) -> bool: + axis = "x" if side in ("labelbottom", "labeltop") else "y" + coord = self._get_coord_helper(axis) + if coord is None or not coord.get_ticklabel_visible(): + return False + positions = coord.get_ticklabel_position() + tokens = { + "labelbottom": "b", + "labeltop": "t", + "labelleft": "l", + "labelright": "r", + "bottom": "b", + "top": "t", + "left": "l", + "right": "r", + } + token = tokens.get(side, side) + if token in positions: + return True + if "#" in positions: + return token == ("b" if axis == "x" else "l") + return False + + def _get_ticklabel_state(self, axis: str) -> dict[str, bool]: + sides = ("top", "bottom") if axis == "x" else ("left", "right") + return { + f"label{side}": self._is_ticklabel_on(f"label{side}") for side in sides + } + + def _set_ticklabel_state(self, axis: str, state: dict): + coord = self._get_coord_helper(axis) + if coord is None: + return + positions = [] + for side in ("bottom", "top") if axis == "x" else ("left", "right"): + if state.get(f"label{side}", False): + positions.append(side[0]) + position = "".join(positions) + coord.set_ticklabel_position(position) + coord.set_axislabel_position(position) + coord.set_ticklabel_visible(bool(positions)) + + def _apply_ticklabel_state(self, axis: str, state: dict): + self._set_ticklabel_state(axis, state) + + def format( + self, + *, + aspect=None, + xreverse=None, + yreverse=None, + xlim=None, + ylim=None, + xmin=None, + ymin=None, + xmax=None, + ymax=None, + xformatter=None, + yformatter=None, + xlocator=None, + ylocator=None, + xtickminor=None, + ytickminor=None, + xtickcolor=None, + ytickcolor=None, + xticklen=None, + yticklen=None, + xtickwidth=None, + ytickwidth=None, + xtickdir=None, + ytickdir=None, + xticklabelpad=None, + yticklabelpad=None, + xticklabelcolor=None, + yticklabelcolor=None, + xticklabelsize=None, + yticklabelsize=None, + xticklabelweight=None, + yticklabelweight=None, + xlabel=None, + ylabel=None, + xlabelpad=None, + ylabelpad=None, + xlabelcolor=None, + ylabelcolor=None, + xlabelsize=None, + ylabelsize=None, + xlabelweight=None, + ylabelweight=None, + xgrid=None, + ygrid=None, + xgridcolor=None, + ygridcolor=None, + xlabel_kw=None, + ylabel_kw=None, + **kwargs, + ): + if aspect is not None: + self.set_aspect(aspect) + self._update_limits("x", lim=xlim, min_=xmin, max_=xmax, reverse=xreverse) + self._update_limits("y", lim=ylim, min_=ymin, max_=ymax, reverse=yreverse) + self._update_coord_locator("x", xlocator) + self._update_coord_locator("y", ylocator) + self._update_coord_formatter("x", xformatter) + self._update_coord_formatter("y", yformatter) + self._update_coord_ticks( + "x", + grid=xgrid, + gridcolor=xgridcolor, + tickcolor=xtickcolor, + ticklen=xticklen, + tickwidth=xtickwidth, + tickdir=xtickdir, + ticklabelpad=xticklabelpad, + ticklabelcolor=xticklabelcolor, + ticklabelsize=xticklabelsize, + ticklabelweight=xticklabelweight, + tickminor=xtickminor, + ) + self._update_coord_ticks( + "y", + grid=ygrid, + gridcolor=ygridcolor, + tickcolor=ytickcolor, + ticklen=yticklen, + tickwidth=ytickwidth, + tickdir=ytickdir, + ticklabelpad=yticklabelpad, + ticklabelcolor=yticklabelcolor, + ticklabelsize=yticklabelsize, + ticklabelweight=yticklabelweight, + tickminor=ytickminor, + ) + self._update_axis_label( + "x", + label=xlabel, + labelpad=xlabelpad, + labelcolor=xlabelcolor, + labelsize=xlabelsize, + labelweight=xlabelweight, + label_kw=xlabel_kw, + ) + self._update_axis_label( + "y", + label=ylabel, + labelpad=ylabelpad, + labelcolor=ylabelcolor, + labelsize=ylabelsize, + labelweight=ylabelweight, + label_kw=ylabel_kw, + ) + return base.Axes.format(self, **kwargs) + + AstroAxes._format_signatures[AstroAxes] = inspect.signature(CartesianAxes.format) +else: # pragma: no cover + AstroAxes = None diff --git a/ultraplot/axes/base.py b/ultraplot/axes/base.py index 7a0f045bf..695faace2 100644 --- a/ultraplot/axes/base.py +++ b/ultraplot/axes/base.py @@ -929,6 +929,8 @@ def _add_inset_axes( if proj is None: if self._name in ("cartopy", "basemap"): proj = copy.copy(self.projection) + elif self._name == "astro" and getattr(self, "wcs", None) is not None: + proj = self.wcs else: proj = self._name kwargs = self.figure._parse_proj(proj, **kwargs) @@ -3286,6 +3288,26 @@ def _is_ticklabel_on(self, side: str) -> bool: return axis.get_tick_params().get(self._label_key(side), False) + def _get_ticklabel_state(self, axis: str) -> dict[str, bool]: + """ + Return visible ticklabel sides for one logical axis. + """ + sides = ("top", "bottom") if axis == "x" else ("left", "right") + return {f"label{side}": self._is_ticklabel_on(f"label{side}") for side in sides} + + def _set_ticklabel_state(self, axis: str, state: dict) -> None: + """ + Apply logical ticklabel visibility to one logical axis. + """ + cleaned = {k: (True if v in ("x", "y") else v) for k, v in state.items()} + mapped = { + self._label_key(key): value + for key, value in cleaned.items() + if key.startswith("label") + } + if mapped: + getattr(self, f"{axis}axis").set_tick_params(**mapped) + @docstring._snippet_manager def inset(self, *args, **kwargs): """ diff --git a/ultraplot/axes/container.py b/ultraplot/axes/container.py index 028d98b80..297098ef4 100644 --- a/ultraplot/axes/container.py +++ b/ultraplot/axes/container.py @@ -686,6 +686,23 @@ def get_external_child(self): """ return self.get_external_axes() + def get_transform(self, *args, **kwargs): + """ + Delegate projection-specific transform lookups to the external axes. + + Some external axes classes (for example WCSAxes) accept extra arguments + like ``frame`` on ``get_transform()``. Without an explicit override here, + the container inherits ``Artist.get_transform()`` and masks that API. + """ + if self._external_axes is not None: + ext_get_transform = getattr( + type(self._external_axes), "get_transform", None + ) + base_get_transform = getattr(maxes.Axes, "get_transform", None) + if args or kwargs or ext_get_transform is not base_get_transform: + return self._external_axes.get_transform(*args, **kwargs) + return super().get_transform() + def clear(self): """Clear the container and mark external axes as stale.""" # Mark external axes as stale before clearing diff --git a/ultraplot/axes/geo.py b/ultraplot/axes/geo.py index ce13b41cd..d45f0269a 100644 --- a/ultraplot/axes/geo.py +++ b/ultraplot/axes/geo.py @@ -1479,6 +1479,16 @@ def _is_ticklabel_on(self, side: str) -> bool: return False return adapter.is_label_on(side) + def _get_ticklabel_state(self, axis: str) -> dict[str, bool]: + sides = ("top", "bottom") if axis == "x" else ("left", "right") + return {f"label{side}": self._is_ticklabel_on(f"label{side}") for side in sides} + + def _set_ticklabel_state(self, axis: str, state: dict) -> None: + sides = ("top", "bottom") if axis == "x" else ("left", "right") + self._toggle_gridliner_labels( + **{f"label{side}": state.get(f"label{side}", False) for side in sides} + ) + def _clear_edge_lon_labels(self) -> None: for label in self._edge_lon_labels: try: diff --git a/ultraplot/figure.py b/ultraplot/figure.py index 4edab717d..f0412cd8e 100644 --- a/ultraplot/figure.py +++ b/ultraplot/figure.py @@ -18,7 +18,6 @@ import matplotlib.axes as maxes import matplotlib.figure as mfigure import matplotlib.gridspec as mgridspec -import matplotlib.projections as mproj import matplotlib.text as mtext import matplotlib.transforms as mtransforms import numpy as np @@ -32,6 +31,7 @@ from . import constructor from . import gridspec as pgridspec from .config import rc, rc_matplotlib +from .internals.projections import finalize_projection_kwargs, resolve_projection_kwargs from .internals import ( _not_none, _pop_params, @@ -974,6 +974,15 @@ def _share_axes_compatible(self, ref, other, which: str): ): return False, "different Geo projection classes" + astro_cls = paxes.get_astro_axes_class() + ref_astro = astro_cls is not None and isinstance(ref, astro_cls) + other_astro = astro_cls is not None and isinstance(other, astro_cls) + if ref_astro or other_astro: + if not (ref_astro and other_astro): + return False, "astro and non-astro axes cannot be shared" + if ref._share_coord_signature(which) != other._share_coord_signature(which): + return False, "different Astro coordinate families" + # Polar and non-polar should not share. ref_polar = isinstance(ref, paxes.PolarAxes) other_polar = isinstance(other, paxes.PolarAxes) @@ -981,7 +990,12 @@ def _share_axes_compatible(self, ref, other, which: str): return False, "polar and non-polar axes cannot be shared" # Non-geo external axes are generally Cartesian-like in UltraPlot. - if not ref_geo and not other_geo and not (ref_external or other_external): + if ( + not ref_geo + and not other_geo + and not (ref_external or other_external) + and not (ref_astro or other_astro) + ): if not ( isinstance(ref, paxes.CartesianAxes) and isinstance(other, paxes.CartesianAxes) @@ -1180,15 +1194,10 @@ def _share_ticklabels(self, *, axis: str) -> None: axes = list(self._iter_axes(panels=True, hidden=False)) groups = self._group_axes_by_axis(axes, axis) - # Version-dependent label name mapping for reading back params - label_keys = self._label_key_map() - # Process each group independently for _, group_axes in groups.items(): # Build baseline from MAIN axes only (exclude panels) - baseline, skip_group = self._compute_baseline_tick_state( - group_axes, axis, label_keys - ) + baseline, skip_group = self._compute_baseline_tick_state(group_axes, axis) if skip_group: continue @@ -1202,27 +1211,10 @@ def _share_ticklabels(self, *, axis: str) -> None: continue # Apply to geo/cartesian appropriately - self._set_ticklabel_state(axi, axis, masked) + axi._set_ticklabel_state(axis, masked) self.stale = True - def _label_key_map(self): - """ - Return a mapping for version-dependent label keys for Matplotlib tick params. - """ - first_axi = next(self._iter_axes(panels=True), None) - if first_axi is None: - return { - "labelleft": "labelleft", - "labelright": "labelright", - "labeltop": "labeltop", - "labelbottom": "labelbottom", - } - return { - name: first_axi._label_key(name) - for name in ("labelleft", "labelright", "labeltop", "labelbottom") - } - def _group_axes_by_axis(self, axes, axis: str): """ Group axes by row (x) or column (y). Panels included; invalid subplotspec skipped. @@ -1243,7 +1235,7 @@ def _group_key(ax): groups[key].append(axi) return groups - def _compute_baseline_tick_state(self, group_axes, axis: str, label_keys): + def _compute_baseline_tick_state(self, group_axes, axis: str): """ Build a baseline ticklabel visibility dict from MAIN axes (panels excluded). Returns (baseline_dict, skip_group: bool). Emits warnings when encountering @@ -1264,9 +1256,11 @@ def _compute_baseline_tick_state(self, group_axes, axis: str, label_keys): return {}, True # Supported axes types - if not isinstance( - axi, (paxes.CartesianAxes, paxes._CartopyAxes, paxes._BasemapAxes) - ): + supported = (paxes.CartesianAxes, paxes._CartopyAxes, paxes._BasemapAxes) + astro_cls = paxes.get_astro_axes_class() + if astro_cls is not None: + supported = (*supported, astro_cls) + if not isinstance(axi, supported): warnings._warn_ultraplot( f"Tick label sharing not implemented for {type(axi)} subplots." ) @@ -1276,17 +1270,9 @@ def _compute_baseline_tick_state(self, group_axes, axis: str, label_keys): subplot_types.add(type(axi)) # Collect label visibility state - if isinstance(axi, paxes.CartesianAxes): - params = getattr(axi, f"{axis}axis").get_tick_params() - for side in sides: - key = label_keys[f"label{side}"] - if params.get(key): - baseline[key] = params[key] - elif isinstance(axi, paxes.GeoAxes): - for side in sides: - key = f"label{side}" - if axi._is_ticklabel_on(key): - baseline[key] = axi._is_ticklabel_on(key) + for key, value in axi._get_ticklabel_state(axis).items(): + if value: + baseline[key] = value if unsupported_found: return {}, True @@ -1305,16 +1291,12 @@ def _apply_border_mask( ): """ Apply figure-border constraints and panel opposite-side suppression. - Keeps label key mapping per-axis for cartesian. """ from .axes.cartesian import OPPOSITE_SIDE masked = baseline.copy() for side in sides: label = f"label{side}" - if isinstance(axi, paxes.CartesianAxes): - # Use per-axis version-mapped key when writing - label = axi._label_key(label) # Only keep labels on true figure borders if axi not in outer_axes[side]: @@ -1356,16 +1338,6 @@ def _effective_share_level(self, axi, axis: str, sides: tuple[str, str]) -> int: return level - def _set_ticklabel_state(self, axi, axis: str, state: dict): - """Apply the computed ticklabel state to cartesian or geo axes.""" - if state: - # Normalize "x"/"y" values to booleans for both Geo and Cartesian axes - cleaned = {k: (True if v in ("x", "y") else v) for k, v in state.items()} - if isinstance(axi, paxes.GeoAxes): - axi._toggle_gridliner_labels(**cleaned) - else: - getattr(axi, f"{axis}axis").set_tick_params(**cleaned) - def _context_adjusting(self, cache=True): """ Prevent re-running auto layout steps due to draws triggered by figure @@ -1410,82 +1382,22 @@ def _parse_proj( axes class. Input projection can be a string, `matplotlib.axes.Axes`, `cartopy.crs.Projection`, or `mpl_toolkits.basemap.Basemap`. """ - # Parse arguments proj = _not_none(proj=proj, projection=projection, default="cartesian") proj_kw = _not_none(proj_kw=proj_kw, projection_kw=projection_kw, default={}) backend = self._parse_backend(backend, basemap) - if isinstance(proj, str): - proj = proj.lower() - if isinstance(self, paxes.Axes): - proj = self._name - elif isinstance(self, maxes.Axes): - raise ValueError("Matplotlib axes cannot be added to ultraplot figures.") - - # Search axes projections - name = None - - # Handle cartopy/basemap Projection objects directly - # These should be converted to Ultraplot GeoAxes - if not isinstance(proj, str): - # Check if it's a cartopy or basemap projection object - if constructor.Projection is not object and isinstance( - proj, constructor.Projection - ): - # It's a cartopy projection - use cartopy backend - name = "ultraplot_cartopy" - kwargs["map_projection"] = proj - elif constructor.Basemap is not object and isinstance( - proj, constructor.Basemap - ): - # It's a basemap projection - name = "ultraplot_basemap" - kwargs["map_projection"] = proj - # If not recognized, leave name as None and it will pass through - - if name is None and isinstance(proj, str): - try: - mproj.get_projection_class("ultraplot_" + proj) - except (KeyError, ValueError): - pass - else: - name = "ultraplot_" + proj - if name is None and isinstance(proj, str): - # Try geographic projections first if cartopy/basemap available - if ( - constructor.Projection is not object - or constructor.Basemap is not object - ): - try: - proj_obj = constructor.Proj( - proj, backend=backend, include_axes=True, **proj_kw - ) - name = "ultraplot_" + proj_obj._proj_backend - kwargs["map_projection"] = proj_obj - except ValueError: - # Not a geographic projection, will try matplotlib registry below - pass - - # If not geographic, check if registered globally in Matplotlib (e.g., 'ternary', 'polar', '3d') - if name is None and proj in mproj.get_projection_names(): - name = proj - - # Helpful error message if still not found - if name is None and isinstance(proj, str): - raise ValueError( - f"Invalid projection name {proj!r}. If you are trying to generate a " - "GeoAxes with a cartopy.crs.Projection or mpl_toolkits.basemap.Basemap " - "then cartopy or basemap must be installed. Otherwise the known axes " - f"subclasses are:\n{paxes._cls_table}" - ) + return resolve_projection_kwargs( + self, + proj, + proj_kw=proj_kw, + backend=backend, + kwargs=kwargs, + ) - # Only set projection if we found a named projection - # Otherwise preserve the original projection (e.g., cartopy Projection objects) - if name is not None: - kwargs["projection"] = name - # If name is None and proj is not a string, it means we have a non-string - # projection (e.g., cartopy.crs.Projection object) that should be passed through - # The original projection kwarg is already in kwargs, so no action needed - return kwargs + def _wrap_external_projection(self, **kwargs): + """ + Wrap non-ultraplot projection classes in an external container. + """ + return finalize_projection_kwargs(self, kwargs) def _get_align_axes(self, side): """ @@ -1807,31 +1719,24 @@ def _add_axes_panel( *getattr(ax, f"get_{'y' if side in ('left','right') else 'x'}lim")(), auto=True, ) + filled = kw.get("filled", False) + shared_state = None # Push main axes tick labels to the outside relative to the added panel # Skip this for filled panels (colorbars/legends) - if not kw.get("filled", False) and share: - if isinstance(ax, paxes.GeoAxes): - if side == "top": - ax._toggle_gridliner_labels(labeltop=False) - elif side == "bottom": - ax._toggle_gridliner_labels(labelbottom=False) - elif side == "left": - ax._toggle_gridliner_labels(labelleft=False) - elif side == "right": - ax._toggle_gridliner_labels(labelright=False) - else: - if side == "top": - ax.xaxis.set_tick_params(**{ax._label_key("labeltop"): False}) - elif side == "bottom": - ax.xaxis.set_tick_params(**{ax._label_key("labelbottom"): False}) - elif side == "left": - ax.yaxis.set_tick_params(**{ax._label_key("labelleft"): False}) - elif side == "right": - ax.yaxis.set_tick_params(**{ax._label_key("labelright"): False}) - - # Panel labels: prefer outside only for non-sharing top/right; otherwise keep off + if not filled and share: + shared_axis = "y" if side in ("left", "right") else "x" + shared_state = ax._get_ticklabel_state(shared_axis) + main_state = shared_state.copy() + main_state[f"label{side}"] = False + ax._set_ticklabel_state(shared_axis, main_state) + + # Panel labels: for non-sharing panels, keep labels on the outer edges of the + # full stack. For shared panels, only propagate the panel-side labels where + # the existing sharing logic expects them (top/right). if side == "top": - if not share: + if not share and not filled: + ax.xaxis.tick_bottom() + ax.xaxis.set_label_position("bottom") pax.xaxis.set_tick_params( **{ pax._label_key("labeltop"): True, @@ -1839,11 +1744,17 @@ def _add_axes_panel( } ) else: - on = ax.xaxis.get_tick_params()[ax._label_key("labeltop")] - pax.xaxis.set_tick_params(**{pax._label_key("labeltop"): on}) - ax.yaxis.set_tick_params(labeltop=False) + on = shared_state is not None and shared_state.get("labeltop", False) + pax.xaxis.set_tick_params( + **{ + pax._label_key("labeltop"): on, + pax._label_key("labelbottom"): False, + } + ) elif side == "right": - if not share: + if not share and not filled: + ax.yaxis.tick_left() + ax.yaxis.set_label_position("left") pax.yaxis.set_tick_params( **{ pax._label_key("labelright"): True, @@ -1851,9 +1762,47 @@ def _add_axes_panel( } ) else: - on = ax.yaxis.get_tick_params()[ax._label_key("labelright")] - pax.yaxis.set_tick_params(**{pax._label_key("labelright"): on}) - ax.yaxis.set_tick_params(**{ax._label_key("labelright"): False}) + on = shared_state is not None and shared_state.get("labelright", False) + pax.yaxis.set_tick_params( + **{ + pax._label_key("labelright"): on, + pax._label_key("labelleft"): False, + } + ) + elif side == "left" and not share and not filled: + ax.yaxis.tick_right() + ax.yaxis.set_label_position("right") + ax.yaxis.set_tick_params( + **{ + ax._label_key("labelleft"): False, + ax._label_key("labelright"): True, + } + ) + pax.yaxis.tick_left() + pax.yaxis.set_label_position("left") + pax.yaxis.set_tick_params( + **{ + pax._label_key("labelleft"): True, + pax._label_key("labelright"): False, + } + ) + elif side == "bottom" and not share and not filled: + ax.xaxis.tick_top() + ax.xaxis.set_label_position("top") + ax.xaxis.set_tick_params( + **{ + ax._label_key("labelbottom"): False, + ax._label_key("labeltop"): True, + } + ) + pax.xaxis.tick_bottom() + pax.xaxis.set_label_position("bottom") + pax.xaxis.set_tick_params( + **{ + pax._label_key("labelbottom"): True, + pax._label_key("labeltop"): False, + } + ) return pax @@ -1986,47 +1935,6 @@ def _add_subplot(self, *args, **kwargs): kwargs.setdefault("number", 1 + max(self._subplot_dict, default=0)) kwargs.pop("refwidth", None) # TODO: remove this - # Use container approach for external projections to make them ultraplot-compatible - projection_name = kwargs.get("projection") - external_axes_class = None - external_axes_kwargs = {} - - if projection_name and isinstance(projection_name, str): - # Check if this is an external (non-ultraplot) projection - # Skip external wrapping for projections that start with "ultraplot_" prefix - # as these are already Ultraplot axes classes - if not projection_name.startswith("ultraplot_"): - try: - # Get the projection class - proj_class = mproj.get_projection_class(projection_name) - - # Check if it's not a built-in ultraplot axes - # Only wrap if it's NOT a subclass of Ultraplot's Axes - if not issubclass(proj_class, paxes.Axes): - # Store the external axes class and original projection name - external_axes_class = proj_class - external_axes_kwargs["projection"] = projection_name - - # Create or get the container class for this external axes type - from .axes.container import create_external_axes_container - - container_name = f"_ultraplot_container_{projection_name}" - - # Check if container is already registered - if container_name not in mproj.get_projection_names(): - container_class = create_external_axes_container( - proj_class, projection_name=container_name - ) - mproj.register_projection(container_class) - - # Use the container projection and pass external axes info - kwargs["projection"] = container_name - kwargs["external_axes_class"] = external_axes_class - kwargs["external_axes_kwargs"] = external_axes_kwargs - except (KeyError, ValueError): - # Projection not found, let matplotlib handle the error - pass - # Remove _subplot_spec from kwargs if present to prevent it from being passed # to .set() or other methods that don't accept it. kwargs.pop("_subplot_spec", None) diff --git a/ultraplot/gridspec.py b/ultraplot/gridspec.py index 5c4ac4066..2f39667cf 100644 --- a/ultraplot/gridspec.py +++ b/ultraplot/gridspec.py @@ -37,6 +37,15 @@ __all__ = ["GridSpec", "SubplotGrid"] +class _GridCommandResult(tuple): + """ + Tuple subclass marking one result per axes from `SubplotGrid` dispatch. + """ + + def __new__(cls, values): + return super().__new__(cls, values) + + # Gridspec vector arguments # Valid for figure() and GridSpec() _shared_docstring = """ @@ -1890,11 +1899,20 @@ def __getattr__(self, attr): return objs[0] if len(self) == 1 else objs elif all(map(callable, objs)): + def _dispatch_value(obj, idx): + if isinstance(obj, _GridCommandResult): + return obj[idx] + return obj + @functools.wraps(objs[0]) def _iterate_subplots(*args, **kwargs): result = [] - for func in objs: - result.append(func(*args, **kwargs)) + for idx, func in enumerate(objs): + iargs = tuple(_dispatch_value(arg, idx) for arg in args) + ikwargs = { + key: _dispatch_value(val, idx) for key, val in kwargs.items() + } + result.append(func(*iargs, **ikwargs)) if len(self) == 1: return result[0] elif all(res is None for res in result): @@ -1902,7 +1920,7 @@ def _iterate_subplots(*args, **kwargs): elif all(isinstance(res, paxes.Axes) for res in result): return SubplotGrid(result, n=self._n, order=self._order) else: - return tuple(result) + return _GridCommandResult(result) _iterate_subplots.__doc__ = inspect.getdoc(objs[0]) return _iterate_subplots diff --git a/ultraplot/internals/projections.py b/ultraplot/internals/projections.py new file mode 100644 index 000000000..61eb45f01 --- /dev/null +++ b/ultraplot/internals/projections.py @@ -0,0 +1,305 @@ +#!/usr/bin/env python3 +""" +Projection binding registry used by figure axes creation. +""" + +from dataclasses import dataclass, field + +import matplotlib.projections as mproj + +from .. import constructor + + +@dataclass(frozen=True) +class ProjectionContext: + """ + Context passed to projection bindings. + """ + + figure: object + proj_kw: dict + backend: str | None + + +@dataclass(frozen=True) +class ProjectionResolution: + """ + Resolved projection plus any injected keyword arguments. + """ + + projection: object | str | None = None + kwargs: dict = field(default_factory=dict) + + def as_kwargs(self, kwargs=None): + merged = dict(kwargs or {}) + if self.projection is not None: + merged["projection"] = self.projection + merged.update(self.kwargs) + return merged + + +@dataclass(frozen=True) +class ProjectionBinding: + """ + Projection matcher and resolver pair. + """ + + name: str + matcher: object + resolver: object + + +_PROJECTION_BINDINGS = [] + + +def register_projection_binding(name, matcher, resolver=None): + """ + Register a projection binding. Can be used as a decorator. + """ + if resolver is None: + + def decorator(func): + _PROJECTION_BINDINGS.append(ProjectionBinding(name, matcher, func)) + return func + + return decorator + + _PROJECTION_BINDINGS.append(ProjectionBinding(name, matcher, resolver)) + return resolver + + +def iter_projection_bindings(): + """ + Return the registered projection bindings. + """ + return tuple(_PROJECTION_BINDINGS) + + +def _get_axes_module(): + from .. import axes as paxes + + return paxes + + +def _looks_like_astropy_projection(proj): + module = getattr(type(proj), "__module__", "") + return module.startswith("astropy.") + + +def _prefixed_projection_name(name): + if name.startswith("ultraplot_"): + return name if name in mproj.get_projection_names() else None + prefixed = "ultraplot_" + name + try: + mproj.get_projection_class(prefixed) + except (KeyError, ValueError): + return None + return prefixed + + +def _container_projection_name(external_axes_class): + token = f"{external_axes_class.__module__}_{external_axes_class.__name__}" + return "_ultraplot_container_" + token.replace(".", "_").replace("-", "_").lower() + + +def _wrap_external_projection(figure, projection): + if projection is None: + return ProjectionResolution() + + external_axes_class = None + external_axes_kwargs = {} + if isinstance(projection, str): + if projection.startswith("ultraplot_") or projection.startswith( + "_ultraplot_container_" + ): + return ProjectionResolution(projection=projection) + try: + external_axes_class = mproj.get_projection_class(projection) + except (KeyError, ValueError): + return ProjectionResolution(projection=projection) + elif hasattr(projection, "_as_mpl_axes"): + try: + external_axes_class, external_axes_kwargs = ( + figure._process_projection_requirements(projection=projection) + ) + except Exception: + return ProjectionResolution(projection=projection) + else: + return ProjectionResolution(projection=projection) + + paxes = _get_axes_module() + if issubclass(external_axes_class, paxes.Axes): + return ProjectionResolution( + projection=projection, + kwargs=dict(external_axes_kwargs), + ) + + from ..axes.container import create_external_axes_container + + container_name = _container_projection_name(external_axes_class) + if container_name not in mproj.get_projection_names(): + container_class = create_external_axes_container( + external_axes_class, projection_name=container_name + ) + mproj.register_projection(container_class) + + return ProjectionResolution( + projection=container_name, + kwargs={ + "external_axes_class": external_axes_class, + "external_axes_kwargs": dict(external_axes_kwargs), + }, + ) + + +@register_projection_binding( + "astropy_wcs_string", + lambda proj, context: isinstance(proj, str) + and proj in ("astro", "astropy", "wcs", "ultraplot_astro"), +) +def _resolve_astropy_wcs_string(proj, context): + if _get_axes_module().get_astro_axes_class(load=True) is None: + return ProjectionResolution() + return ProjectionResolution(projection="ultraplot_astro") + + +@register_projection_binding( + "native_ultraplot_string", + lambda proj, context: isinstance(proj, str) + and _prefixed_projection_name(proj) is not None, +) +def _resolve_native_ultraplot_string(proj, context): + return ProjectionResolution(projection=_prefixed_projection_name(proj)) + + +@register_projection_binding( + "astropy_wcs_object", + lambda proj, context: ( + not isinstance(proj, str) + and _looks_like_astropy_projection(proj) + and bool(_get_axes_module().get_astropy_wcs_types(load=True)) + and isinstance(proj, _get_axes_module().get_astropy_wcs_types()) + ), +) +def _resolve_astropy_wcs_object(proj, context): + return ProjectionResolution(projection="ultraplot_astro", kwargs={"wcs": proj}) + + +@register_projection_binding( + "cartopy_projection_object", + lambda proj, context: ( + not isinstance(proj, str) + and constructor.Projection is not object + and isinstance(proj, constructor.Projection) + ), +) +def _resolve_cartopy_projection_object(proj, context): + return ProjectionResolution( + projection="ultraplot_cartopy", + kwargs={"map_projection": proj}, + ) + + +@register_projection_binding( + "basemap_projection_object", + lambda proj, context: ( + not isinstance(proj, str) + and constructor.Basemap is not object + and isinstance(proj, constructor.Basemap) + ), +) +def _resolve_basemap_projection_object(proj, context): + return ProjectionResolution( + projection="ultraplot_basemap", + kwargs={"map_projection": proj}, + ) + + +@register_projection_binding( + "geographic_projection_name", + lambda proj, context: isinstance(proj, str) + and (constructor.Projection is not object or constructor.Basemap is not object), +) +def _resolve_geographic_projection_name(proj, context): + try: + proj_obj = constructor.Proj( + proj, + backend=context.backend, + include_axes=True, + **context.proj_kw, + ) + except ValueError: + return ProjectionResolution() + return ProjectionResolution( + projection="ultraplot_" + proj_obj._proj_backend, + kwargs={"map_projection": proj_obj}, + ) + + +@register_projection_binding( + "registered_matplotlib_string", + lambda proj, context: isinstance(proj, str) + and proj in mproj.get_projection_names(), +) +def _resolve_registered_matplotlib_string(proj, context): + return ProjectionResolution(projection=proj) + + +def resolve_projection(proj, *, figure, proj_kw=None, backend=None): + """ + Resolve a user projection spec to a final projection and kwargs. + """ + proj_kw = proj_kw or {} + if isinstance(proj, str): + proj = proj.lower() + context = ProjectionContext(figure=figure, proj_kw=proj_kw, backend=backend) + + resolution = None + for binding in _PROJECTION_BINDINGS: + if binding.matcher(proj, context): + resolution = binding.resolver(proj, context) + if resolution.projection is not None or resolution.kwargs: + break + + if resolution is None or (resolution.projection is None and not resolution.kwargs): + if isinstance(proj, str): + paxes = _get_axes_module() + raise ValueError( + f"Invalid projection name {proj!r}. If you are trying to generate a " + "GeoAxes with a cartopy.crs.Projection or mpl_toolkits.basemap.Basemap " + "then cartopy or basemap must be installed. Otherwise the known axes " + f"subclasses are:\n{paxes._cls_table}" + ) + resolution = ProjectionResolution(projection=proj) + + final = _wrap_external_projection(figure, resolution.projection) + merged_kwargs = dict(resolution.kwargs) + merged_kwargs.update(final.kwargs) + projection = ( + final.projection if final.projection is not None else resolution.projection + ) + return ProjectionResolution(projection=projection, kwargs=merged_kwargs) + + +def resolve_projection_kwargs(figure, proj, *, proj_kw=None, backend=None, kwargs=None): + """ + Resolve a projection and merge the result into an existing keyword dictionary. + """ + resolution = resolve_projection( + proj, + figure=figure, + proj_kw=proj_kw, + backend=backend, + ) + return resolution.as_kwargs(kwargs) + + +def finalize_projection_kwargs(figure, kwargs): + """ + Finalize an already-parsed projection dictionary. + """ + projection = kwargs.get("projection") + if projection is None: + return kwargs + final = _wrap_external_projection(figure, projection) + return final.as_kwargs(kwargs) diff --git a/ultraplot/tests/test_astro_axes.py b/ultraplot/tests/test_astro_axes.py new file mode 100644 index 000000000..8403263ff --- /dev/null +++ b/ultraplot/tests/test_astro_axes.py @@ -0,0 +1,192 @@ +import warnings + +import numpy as np +import pytest + +import ultraplot as uplt +from ultraplot import axes as paxes + +pytest.importorskip("astropy.visualization.wcsaxes") +from astropy.wcs import WCS + + +def _make_test_wcs(): + wcs = WCS(naxis=2) + wcs.wcs.crpix = [50.0, 50.0] + wcs.wcs.cdelt = [-0.066667, 0.066667] + wcs.wcs.crval = [0.0, -90.0] + wcs.wcs.ctype = ["RA---TAN", "DEC--TAN"] + wcs.wcs.cunit = ["deg", "deg"] + return wcs + + +def test_add_subplot_with_wcs_projection_returns_native_astro_axes(): + fig = uplt.figure() + ax = fig.add_subplot(111, projection=_make_test_wcs()) + + assert paxes.AstroAxes is not None + assert isinstance(ax, paxes.AstroAxes) + assert not (hasattr(ax, "has_external_axes") and ax.has_external_axes()) + assert ax.get_transform("icrs") is not None + + fig.canvas.draw() + bbox = ax.get_tightbbox(fig.canvas.get_renderer()) + assert bbox.width > 0 + assert bbox.height > 0 + + +def test_add_axes_with_wcs_projection_supports_basic_formatting(): + fig = uplt.figure() + ax = fig.add_axes([0.1, 0.1, 0.8, 0.8], projection=_make_test_wcs()) + + ax.format(xlabel="RA", ylabel="Dec", title="Sky", xgrid=True, ygrid=True) + + assert isinstance(ax, paxes.AstroAxes) + assert ax.get_xlabel() == "RA" + assert ax.get_ylabel() == "Dec" + assert ax.get_title() == "Sky" + + fig.canvas.draw() + bbox = ax.get_tightbbox(fig.canvas.get_renderer()) + assert bbox.width > 0 + assert bbox.height > 0 + + +def test_string_wcs_projection_uses_native_astro_axes(): + fig = uplt.figure() + ax = fig.add_subplot(111, projection="wcs", wcs=_make_test_wcs()) + + assert isinstance(ax, paxes.AstroAxes) + + +def test_same_family_astro_axes_can_share_without_warning(): + with warnings.catch_warnings(record=True) as caught: + warnings.simplefilter("always") + fig, (ax1, ax2) = uplt.subplots( + nrows=2, + proj=[_make_test_wcs(), _make_test_wcs()], + sharex=2, + ) + + messages = [str(item.message) for item in caught] + assert not any("Skipping incompatible x-axis sharing" in msg for msg in messages) + assert ax1.get_shared_x_axes().joined(ax1, ax2) + + +def test_different_astro_coordinate_families_do_not_share(): + galactic = _make_test_wcs() + galactic.wcs.ctype = ["GLON-TAN", "GLAT-TAN"] + + with warnings.catch_warnings(record=True) as caught: + warnings.simplefilter("always") + fig, (ax1, ax2) = uplt.subplots( + nrows=2, + proj=[_make_test_wcs(), galactic], + sharex=2, + ) + + messages = [str(item.message) for item in caught] + assert any("different Astro coordinate families" in msg for msg in messages) + assert not ax1.get_shared_x_axes().joined(ax1, ax2) + + +def test_subplot_grid_arrow_dispatches_per_axes_transforms(): + fig, axs = uplt.subplots( + ncols=2, + proj=[_make_test_wcs(), _make_test_wcs()], + share=0, + ) + axs.imshow(np.zeros((16, 16)), origin="lower") + + arrows = axs.arrow( + 0.0, + -89.95, + 0.0, + 0.02, + head_width=0, + head_length=0, + width=0.01, + transform=axs.get_transform("icrs"), + ) + + assert len(arrows) == 2 + fig.canvas.draw() + + +def test_astro_axes_share_ticklabels_without_hiding_outer_wcs_labels(): + fig, axs = uplt.subplots( + ncols=2, + proj=[_make_test_wcs(), _make_test_wcs()], + ) + axs.imshow(np.zeros((16, 16)), origin="lower") + + fig.canvas.draw() + + assert axs[0].coords[1].get_ticklabel_visible() + assert axs[0].coords[1].get_axislabel_position() + assert not axs[1].coords[1].get_ticklabel_visible() + assert not axs[1].coords[1].get_axislabel_position() + + +def test_astro_axes_preserve_shared_top_labels(): + fig, axs = uplt.subplots( + nrows=2, + proj=[_make_test_wcs(), _make_test_wcs()], + ) + axs.imshow(np.zeros((16, 16)), origin="lower") + for ax in axs: + ax.coords[0].set_ticklabel_position("t") + ax.coords[0].set_axislabel_position("t") + + fig.canvas.draw() + + assert axs[0].coords[0].get_ticklabel_position() == ["t"] + assert axs[0].coords[0].get_axislabel_position() == ["t"] + assert not axs[1].coords[0].get_ticklabel_position() + assert not axs[1].coords[0].get_axislabel_position() + + +def test_astro_axes_preserve_shared_right_labels(): + fig, axs = uplt.subplots( + ncols=2, + proj=[_make_test_wcs(), _make_test_wcs()], + ) + axs.imshow(np.zeros((16, 16)), origin="lower") + for ax in axs: + ax.coords[1].set_ticklabel_position("r") + ax.coords[1].set_axislabel_position("r") + + fig.canvas.draw() + + assert not axs[0].coords[1].get_ticklabel_position() + assert not axs[0].coords[1].get_axislabel_position() + assert axs[1].coords[1].get_ticklabel_position() == ["r"] + assert axs[1].coords[1].get_axislabel_position() == ["r"] + + +def test_astro_axes_panels_preserve_explicit_top_right_labels(): + fig, axs = uplt.subplots( + nrows=2, + ncols=2, + proj=[_make_test_wcs() for _ in range(4)], + ) + axs.imshow(np.zeros((16, 16)), origin="lower") + for ax in axs: + ax.coords[0].set_ticklabel_position("t") + ax.coords[0].set_axislabel_position("t") + ax.coords[1].set_ticklabel_position("r") + ax.coords[1].set_axislabel_position("r") + + pax_top = axs[0].panel("top") + pax_right = axs[1].panel("right") + fig.canvas.draw() + + assert not axs[0].coords[0].get_ticklabel_position() + assert not axs[0].coords[0].get_axislabel_position() + assert pax_top._is_ticklabel_on("labeltop") + assert not pax_top._is_ticklabel_on("labelbottom") + + assert not axs[1].coords[1].get_ticklabel_position() + assert not axs[1].coords[1].get_axislabel_position() + assert pax_right._is_ticklabel_on("labelright") + assert not pax_right._is_ticklabel_on("labelleft") diff --git a/ultraplot/tests/test_external_container_mocked.py b/ultraplot/tests/test_external_container_mocked.py index bb2c30305..5e3bef78c 100644 --- a/ultraplot/tests/test_external_container_mocked.py +++ b/ultraplot/tests/test_external_container_mocked.py @@ -233,6 +233,29 @@ def get_tightbbox(self, renderer): return super().get_tightbbox(renderer) +class MockProjectionTransformAxes(MockExternalAxes): + """Mock external axes with a projection-aware get_transform API.""" + + def __init__(self, fig, *args, transform_id=None, **kwargs): + self.transform_id = transform_id + self.transform_calls = [] + super().__init__(fig, *args, **kwargs) + + def get_transform(self, frame=None): + self.transform_calls.append(frame) + return (self.transform_id, frame) + + +class MockProjectionObject: + """Projection-like object resolved by Matplotlib via _as_mpl_axes.""" + + def __init__(self, transform_id="mock"): + self.transform_id = transform_id + + def _as_mpl_axes(self): + return MockProjectionTransformAxes, {"transform_id": self.transform_id} + + # Tests @@ -261,6 +284,34 @@ def test_container_creation_with_external_axes(): assert isinstance(ax.get_external_child(), MockExternalAxes) +def test_add_axes_wraps_projection_object_and_delegates_get_transform(): + """Projection objects should be wrapped and keep custom transform APIs.""" + fig = uplt.figure() + ax = fig.add_axes([0.1, 0.1, 0.8, 0.8], projection=MockProjectionObject("mock-wcs")) + + assert ax.has_external_child() + child = ax.get_external_child() + assert isinstance(child, MockProjectionTransformAxes) + + transform = ax.get_transform("icrs") + assert transform == ("mock-wcs", "icrs") + assert child.transform_calls == ["icrs"] + + +def test_add_subplot_wraps_projection_object_and_delegates_get_transform(): + """Subplots should also wrap projection objects via the external container.""" + fig = uplt.figure() + ax = fig.add_subplot(111, projection=MockProjectionObject("subplot-wcs")) + + assert ax.has_external_child() + child = ax.get_external_child() + assert isinstance(child, MockProjectionTransformAxes) + + transform = ax.get_transform("fk5") + assert transform == ("subplot-wcs", "fk5") + assert child.transform_calls == ["fk5"] + + def test_external_axes_removed_from_figure_axes(): """Test that external axes is removed from figure axes list.""" fig = uplt.figure() diff --git a/ultraplot/tests/test_imports.py b/ultraplot/tests/test_imports.py index f7ba6e2e0..e51a9d377 100644 --- a/ultraplot/tests/test_imports.py +++ b/ultraplot/tests/test_imports.py @@ -34,6 +34,42 @@ def test_import_is_lightweight(): assert out == "[]" +def test_loading_axes_does_not_import_astropy(): + code = """ +import json +import sys +import ultraplot as uplt +uplt.subplots() +mods = [name for name in sys.modules if name == "astropy" or name.startswith("astropy.")] +print(json.dumps(sorted(mods))) +""" + out = _run(code) + assert out == "[]" + + +def test_axes_astro_attr_is_lazy_optional(): + code = """ +import importlib.util +import json +import sys +import ultraplot.axes as paxes +spec = importlib.util.find_spec("astropy.visualization.wcsaxes") +astro = paxes.AstroAxes +mods = [name for name in sys.modules if name == "astropy" or name.startswith("astropy.")] +print(json.dumps({ + "available": bool(spec), + "astro_is_none": astro is None, + "loaded": bool(mods), +})) +""" + out = json.loads(_run(code)) + if out["available"]: + assert not out["astro_is_none"] + assert out["loaded"] + else: + assert out["astro_is_none"] + + def test_star_import_exposes_public_api(): code = """ from ultraplot import * # noqa: F403 diff --git a/ultraplot/tests/test_subplots.py b/ultraplot/tests/test_subplots.py index c4d7c6d96..c10547e80 100644 --- a/ultraplot/tests/test_subplots.py +++ b/ultraplot/tests/test_subplots.py @@ -691,6 +691,30 @@ def assert_panel(axi_panel, side, share_flag): assert_panel(pax_bottom, "bottom", share_panels) +def test_shared_panels_preserve_explicit_top_right_labels(): + fig, axs = uplt.subplots(nrows=2, ncols=2) + for ax in axs: + ax.imshow(np.zeros((10, 10))) + ax.xaxis.tick_top() + ax.xaxis.set_label_position("top") + ax.xaxis.set_tick_params(labeltop=True, labelbottom=False) + ax.yaxis.tick_right() + ax.yaxis.set_label_position("right") + ax.yaxis.set_tick_params(labelright=True, labelleft=False) + + pax_top = axs[0].panel("top") + pax_right = axs[1].panel("right") + fig.canvas.draw() + + assert not axs[0]._is_ticklabel_on("labeltop") + assert pax_top._is_ticklabel_on("labeltop") + assert not pax_top._is_ticklabel_on("labelbottom") + + assert not axs[1]._is_ticklabel_on("labelright") + assert pax_right._is_ticklabel_on("labelright") + assert not pax_right._is_ticklabel_on("labelleft") + + def test_non_rectangular_outside_labels_top(): """ Check that non-rectangular layouts work with outside labels. @@ -775,6 +799,59 @@ def test_panel_share_flag_controls_group_membership(): assert ax2[0]._panel_sharex_group is False +def test_nonsharing_left_panel_moves_main_labels_outside(): + fig, axs = uplt.subplots() + ax = axs[0] + ax.format(ylabel="main ylabel") + pax = ax.panel("left", share=False) + pax.format(ylabel="panel ylabel") + + fig.canvas.draw() + + assert not ax._is_ticklabel_on("labelleft") + assert ax._is_ticklabel_on("labelright") + assert pax._is_ticklabel_on("labelleft") + assert not pax._is_ticklabel_on("labelright") + assert ax.yaxis.get_label_position() == "right" + assert pax.yaxis.get_label_position() == "left" + + +def test_nonsharing_bottom_panel_moves_main_labels_outside(): + fig, axs = uplt.subplots() + ax = axs[0] + ax.format(xlabel="main xlabel") + pax = ax.panel("bottom", share=False) + pax.format(xlabel="panel xlabel") + + fig.canvas.draw() + + assert not ax._is_ticklabel_on("labelbottom") + assert ax._is_ticklabel_on("labeltop") + assert pax._is_ticklabel_on("labelbottom") + assert not pax._is_ticklabel_on("labeltop") + assert ax.xaxis.get_label_position() == "top" + assert pax.xaxis.get_label_position() == "bottom" + + +def test_nonsharing_left_panel_gap_matches_right_panel(): + def _panel_gap(side): + fig, axs = uplt.subplots() + ax = axs[0] + ax.format(ylabel="main ylabel") + pax = ax.panel(side, share=False) + pax.format(xlabel="panel xlabel", ylabel="panel ylabel") + fig.canvas.draw() + main = ax.get_position().bounds + panel = pax.get_position().bounds + if side == "left": + return main[0] - (panel[0] + panel[2]) + return panel[0] - (main[0] + main[2]) + + gap_left = _panel_gap("left") + gap_right = _panel_gap("right") + assert abs(gap_left - gap_right) < 1e-3 + + def test_ticklabels_with_guides_share_true_cartesian(): """ With share=True, tick labels should only appear on bottom row and left column