diff --git a/CHANGELOG.rst b/CHANGELOG.rst index fafb74a..27b8bb4 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -1,6 +1,15 @@ Changelog ========= +Version 0.5.1 +------------- + +Improvements +~~~~~~~~~~~~ + +- Save relative paths and omit filenames when saving simulation paths in the simulation campaign config. + + Version 0.5.0 ------------- diff --git a/src/blueetl/campaign/config.py b/src/blueetl/campaign/config.py index 4dc1136..c620426 100644 --- a/src/blueetl/campaign/config.py +++ b/src/blueetl/campaign/config.py @@ -2,6 +2,7 @@ import logging from collections.abc import Iterator +from copy import deepcopy from dataclasses import dataclass from pathlib import Path from typing import Any, Optional @@ -32,20 +33,6 @@ def _resolve_simulation_path(path_prefix: str, path: str, filename: str) -> str: return str(path_obj) -def _reduce_simulation_path(path_prefix: str, path: str, filename: str) -> str: - """Reduce to shorter path, without path_prefix and filename.""" - if not path or path.startswith("https://"): - # do not convert excluded simulations or nexus urls - return path - path_obj = Path(path) - if path_obj.is_relative_to(path_prefix): - path_obj = path_obj.relative_to(path_prefix) - if path_obj.name == filename: - # the paths in the xarray simulation campaign config don't include the filename - path_obj = path_obj.parent - return str(path_obj) - - @dataclass class SimulationRow: """Simulation row in the simulation campaign.""" @@ -77,21 +64,24 @@ def __init__( and the parameters used for each simulation (for example: seed, ca...). name: name of the Simulation Campaign. attrs: dict of custom attributes. - config_dir: if specified, it's used to resolve relative paths in attrs. + config_dir: absolute directory used to resolve path_prefix, if relative. """ + self._validate(attrs, config_dir) self._name = name - self._attrs = attrs.copy() - if config_dir: - if "path_prefix" in self._attrs: - self._attrs["path_prefix"] = str( - resolve_path(config_dir, self._attrs["path_prefix"]) - ) - if "circuit_config" in self._attrs: - self._attrs["circuit_config"] = str( - resolve_path(config_dir, self._attrs["circuit_config"]) - ) + self._attrs = deepcopy(attrs) + self._path_prefix = resolve_path(config_dir or "", self._attrs.pop("path_prefix")) + self._is_sonata = self._attrs["circuit_config"].endswith(".json") self._data = data.copy() - self._data[SIMULATION_PATH] = self._resolve_paths(self._data[SIMULATION_PATH]) + self._data_resolved = data.copy() + self._data_resolved[SIMULATION_PATH] = self._resolve_paths(data[SIMULATION_PATH]) + + def _validate(self, attrs: dict, config_dir: Optional[Path]) -> None: + """Validate the attrs.""" + for key in "path_prefix", "circuit_config": + if not attrs.get(key): + raise ValueError(f"{key} is missing or empty in the simulation campaign") + if config_dir is None and not Path(attrs["path_prefix"]).is_absolute(): + raise ValueError("path_prefix must be set to an absolute path when config_dir is None") def _get_simulation_filename(self) -> str: """Return the filename of each simulation in the campaign.""" @@ -101,17 +91,7 @@ def _resolve_paths(self, simulation_paths: pd.Series) -> pd.Series: """Resolve the simulation paths.""" return simulation_paths.apply( lambda path: _resolve_simulation_path( - path_prefix=self.attrs.get("path_prefix", ""), - path=path, - filename=self._get_simulation_filename(), - ) - ) - - def _reduce_paths(self, simulation_paths: pd.Series) -> pd.Series: - """Reduce the simulation paths.""" - return simulation_paths.apply( - lambda path: _reduce_simulation_path( - path_prefix=self.attrs.get("path_prefix", ""), + path_prefix=str(self.path_prefix), path=path, filename=self._get_simulation_filename(), ) @@ -126,7 +106,8 @@ def __eq__(self, other: object) -> bool: ( self.name == other.name, self.attrs == other.attrs, - self._data.equals(other._data), + self.path_prefix == other.path_prefix, + self._data_resolved.equals(other._data_resolved), ) ) @@ -140,6 +121,11 @@ def attrs(self) -> dict: """Return the attributes dict associated with the simulations campaign.""" return self._attrs + @property + def path_prefix(self) -> Path: + """Return the path prefix, i.e. the absolute path containing the simulations.""" + return self._path_prefix + @property def condition_names(self) -> list[str]: """Return the names of the parameters used to run the simulations.""" @@ -150,16 +136,13 @@ def conditions(self) -> pd.DataFrame: """Return the DataFrame of the parameters used to run the simulations.""" return self._data[self.condition_names] - def is_coupled(self): + def is_coupled(self) -> bool: """Return True if the coords are coupled, False otherwise.""" return bool(self.attrs.get("__coupled__")) - def is_sonata(self): + def is_sonata(self) -> bool: """Return True if the simulations are in SONATA format, False otherwise.""" - circuit_config = self.attrs.get("circuit_config", "") - if not circuit_config: - raise RuntimeError("circuit_config is missing in the simulation campaign") - return circuit_config.endswith(".json") + return self._is_sonata @classmethod def load(cls, path: StrOrPath) -> "SimulationCampaign": @@ -201,7 +184,7 @@ def to_dict(self) -> dict: "format": "blueetl", "version": 1, "name": self.name, - "attrs": self.attrs, + "attrs": {**self.attrs, "path_prefix": str(self.path_prefix)}, "data": self._data.to_dict(orient="records"), } @@ -240,10 +223,12 @@ def from_xarray( def to_xarray(self) -> xr.DataArray: """Convert the configuration to xarray.DataArray.""" - s = self._data.set_index(self.condition_names)[SIMULATION_PATH] - s = self._reduce_paths(s) + data = self._data + if condition_names := self.condition_names: + data = data.set_index(condition_names) + s = data[SIMULATION_PATH] s.name = self.name - attrs = self.attrs.copy() + attrs = {**self.attrs, "path_prefix": str(self.path_prefix)} coupled = attrs.pop("__coupled__", None) if not coupled: # generated by GenerateSimulationCampaign @@ -252,7 +237,7 @@ def to_xarray(self) -> xr.DataArray: else: # generated by GenerateCoupledCoordsSimulationCampaign da = xr.DataArray( - list(s), + s.to_list(), name=s.name, dims=coupled, coords={coupled: s.index}, @@ -276,13 +261,13 @@ def __len__(self) -> int: def __iter__(self) -> Iterator[SimulationRow]: """Iterate over the simulation rows.""" - for i, (_, sim_dict) in enumerate(self._data.etl.iterdict()): + for i, (_, sim_dict) in enumerate(self._data_resolved.etl.iterdict()): path = sim_dict.pop(SIMULATION_PATH) yield SimulationRow(index=i, path=path, conditions=sim_dict) def __getitem__(self, index: int) -> SimulationRow: """Return a specific simulation row.""" - sim_dict = self._data.loc[index].to_dict() + sim_dict = self._data_resolved.loc[index].to_dict() path = sim_dict.pop(SIMULATION_PATH) return SimulationRow(index=index, path=path, conditions=sim_dict) @@ -291,7 +276,7 @@ def get(self, *args, **kwargs) -> pd.DataFrame: See ``etl.q`` for the filter syntax. """ - return self._data.copy().etl.q(*args, **kwargs) + return self._data_resolved.copy().etl.q(*args, **kwargs) def ids(self, *args, **kwargs) -> np.ndarray: """Return a numpy array with the ids of the selected simulations. diff --git a/tests/unit/campaign/conftest.py b/tests/unit/campaign/conftest.py index a2cb9d0..7a3138e 100644 --- a/tests/unit/campaign/conftest.py +++ b/tests/unit/campaign/conftest.py @@ -61,17 +61,17 @@ def blueetl_config_dict_simple(): { "ca": 1.0, "depolarization": 4.0, - "simulation_path": "/tmp/simple/uuid/1/simulation_config.json", + "simulation_path": "uuid/1", }, { "ca": 2.0, "depolarization": 3.0, - "simulation_path": "/tmp/simple/uuid/2/simulation_config.json", + "simulation_path": "uuid/2", }, { "ca": 2.0, "depolarization": 4.0, - "simulation_path": "/tmp/simple/uuid/3/simulation_config.json", + "simulation_path": "uuid/3", }, ], } @@ -98,7 +98,7 @@ def blueetl_config_dict_coupled(): { "ca": 2.0, "depolarization": 4.0, - "simulation_path": "/tmp/coupled/uuid/1/simulation_config.json", + "simulation_path": "uuid/1", }, ], } @@ -117,13 +117,25 @@ def xarray_config_obj_coupled(xarray_config_dict_coupled): @pytest.fixture def blueetl_config_dataframe_simple(blueetl_config_dict_simple): d = blueetl_config_dict_simple - return pd.DataFrame.from_records(d["data"]) + df = pd.DataFrame.from_records(d["data"]) + df["simulation_path"] = [ + "", + "/tmp/simple/uuid/1/simulation_config.json", + "/tmp/simple/uuid/2/simulation_config.json", + "/tmp/simple/uuid/3/simulation_config.json", + ] + return df @pytest.fixture def blueetl_config_dataframe_coupled(blueetl_config_dict_coupled): d = blueetl_config_dict_coupled - return pd.DataFrame.from_records(d["data"]) + df = pd.DataFrame.from_records(d["data"]) + df["simulation_path"] = [ + "", + "/tmp/coupled/uuid/1/simulation_config.json", + ] + return df @pytest.fixture diff --git a/tests/unit/campaign/test_config.py b/tests/unit/campaign/test_config.py index 3689261..0ad090a 100644 --- a/tests/unit/campaign/test_config.py +++ b/tests/unit/campaign/test_config.py @@ -1,5 +1,4 @@ import json -from pathlib import Path import pandas as pd import pytest @@ -208,38 +207,40 @@ def test_simulations_config_get_all(input_obj, expected_df, lazy_fixture): @pytest.mark.parametrize( - "input_obj", + "filename, expected", [ - "blueetl_config_obj_simple", - "blueetl_config_obj_coupled", + ("circuit_sonata.json", True), + ("CircuitConfig", False), ], ) @pytest.mark.parametrize( - "filename, expected", + "input_dict", [ - ("circuit_sonata.json", True), - ("CircuitConfig", False), + "blueetl_config_dict_simple", + "blueetl_config_dict_coupled", ], ) -def test_simulations_config_is_sonata(input_obj, filename, expected, lazy_fixture): - input_obj = lazy_fixture(input_obj) - input_obj.attrs["circuit_config"] = f"/path/to/{filename}" - result = input_obj.is_sonata() - assert result == expected +def test_simulations_config_is_sonata(input_dict, filename, expected, lazy_fixture): + input_dict = lazy_fixture(input_dict) + input_dict["attrs"]["circuit_config"] = f"/path/to/{filename}" + obj = test_module.SimulationCampaign.from_dict(input_dict) + assert obj.is_sonata() == expected @pytest.mark.parametrize( - "input_obj", + "input_dict", [ - "blueetl_config_obj_simple", - "blueetl_config_obj_coupled", + "blueetl_config_dict_simple", + "blueetl_config_dict_coupled", ], ) -def test_simulations_config_is_sonata_raises(input_obj, lazy_fixture): - input_obj = lazy_fixture(input_obj) - del input_obj.attrs["circuit_config"] - with pytest.raises(RuntimeError, match="circuit_config is missing in the simulation campaign"): - input_obj.is_sonata() +def test_simulations_config_raises_when_circuit_config_is_missing(input_dict, lazy_fixture): + input_dict = lazy_fixture(input_dict) + del input_dict["attrs"]["circuit_config"] + with pytest.raises( + ValueError, match="circuit_config is missing or empty in the simulation campaign" + ): + test_module.SimulationCampaign.from_dict(input_dict) @pytest.mark.parametrize( @@ -329,5 +330,4 @@ def test_simulations_config_with_relative_paths(config_file): result = test_module.SimulationCampaign.load(config_file) - assert Path(result.attrs["path_prefix"]).is_absolute() - assert Path(result.attrs["circuit_config"]).is_absolute() + assert result.path_prefix.is_absolute()