Skip to content
This repository was archived by the owner on Feb 26, 2025. It is now read-only.
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
@@ -1,6 +1,15 @@
Changelog
=========

Version 0.5.1
-------------

Improvements
~~~~~~~~~~~~

- Save relative paths and omit filenames when saving simulation paths in the simulation campaign config.


Version 0.5.0
-------------

Expand Down
89 changes: 37 additions & 52 deletions src/blueetl/campaign/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import logging
from collections.abc import Iterator
from copy import deepcopy
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Optional
Expand Down Expand Up @@ -32,20 +33,6 @@ def _resolve_simulation_path(path_prefix: str, path: str, filename: str) -> str:
return str(path_obj)


def _reduce_simulation_path(path_prefix: str, path: str, filename: str) -> str:
"""Reduce to shorter path, without path_prefix and filename."""
if not path or path.startswith("https://"):
# do not convert excluded simulations or nexus urls
return path
path_obj = Path(path)
if path_obj.is_relative_to(path_prefix):
path_obj = path_obj.relative_to(path_prefix)
if path_obj.name == filename:
# the paths in the xarray simulation campaign config don't include the filename
path_obj = path_obj.parent
return str(path_obj)


@dataclass
class SimulationRow:
"""Simulation row in the simulation campaign."""
Expand Down Expand Up @@ -77,21 +64,24 @@ def __init__(
and the parameters used for each simulation (for example: seed, ca...).
name: name of the Simulation Campaign.
attrs: dict of custom attributes.
config_dir: if specified, it's used to resolve relative paths in attrs.
config_dir: absolute directory used to resolve path_prefix, if relative.
"""
self._validate(attrs, config_dir)
self._name = name
self._attrs = attrs.copy()
if config_dir:
if "path_prefix" in self._attrs:
self._attrs["path_prefix"] = str(
resolve_path(config_dir, self._attrs["path_prefix"])
)
if "circuit_config" in self._attrs:
self._attrs["circuit_config"] = str(
resolve_path(config_dir, self._attrs["circuit_config"])
)
self._attrs = deepcopy(attrs)
self._path_prefix = resolve_path(config_dir or "", self._attrs.pop("path_prefix"))
self._is_sonata = self._attrs["circuit_config"].endswith(".json")
self._data = data.copy()
self._data[SIMULATION_PATH] = self._resolve_paths(self._data[SIMULATION_PATH])
self._data_resolved = data.copy()
self._data_resolved[SIMULATION_PATH] = self._resolve_paths(data[SIMULATION_PATH])

def _validate(self, attrs: dict, config_dir: Optional[Path]) -> None:
"""Validate the attrs."""
for key in "path_prefix", "circuit_config":
if not attrs.get(key):
raise ValueError(f"{key} is missing or empty in the simulation campaign")
if config_dir is None and not Path(attrs["path_prefix"]).is_absolute():
raise ValueError("path_prefix must be set to an absolute path when config_dir is None")

def _get_simulation_filename(self) -> str:
"""Return the filename of each simulation in the campaign."""
Expand All @@ -101,17 +91,7 @@ def _resolve_paths(self, simulation_paths: pd.Series) -> pd.Series:
"""Resolve the simulation paths."""
return simulation_paths.apply(
lambda path: _resolve_simulation_path(
path_prefix=self.attrs.get("path_prefix", ""),
path=path,
filename=self._get_simulation_filename(),
)
)

def _reduce_paths(self, simulation_paths: pd.Series) -> pd.Series:
"""Reduce the simulation paths."""
return simulation_paths.apply(
lambda path: _reduce_simulation_path(
path_prefix=self.attrs.get("path_prefix", ""),
path_prefix=str(self.path_prefix),
path=path,
filename=self._get_simulation_filename(),
)
Expand All @@ -126,7 +106,8 @@ def __eq__(self, other: object) -> bool:
(
self.name == other.name,
self.attrs == other.attrs,
self._data.equals(other._data),
self.path_prefix == other.path_prefix,
self._data_resolved.equals(other._data_resolved),
)
)

Expand All @@ -140,6 +121,11 @@ def attrs(self) -> dict:
"""Return the attributes dict associated with the simulations campaign."""
return self._attrs

@property
def path_prefix(self) -> Path:
"""Return the path prefix, i.e. the absolute path containing the simulations."""
return self._path_prefix

@property
def condition_names(self) -> list[str]:
"""Return the names of the parameters used to run the simulations."""
Expand All @@ -150,16 +136,13 @@ def conditions(self) -> pd.DataFrame:
"""Return the DataFrame of the parameters used to run the simulations."""
return self._data[self.condition_names]

def is_coupled(self):
def is_coupled(self) -> bool:
"""Return True if the coords are coupled, False otherwise."""
return bool(self.attrs.get("__coupled__"))

def is_sonata(self):
def is_sonata(self) -> bool:
"""Return True if the simulations are in SONATA format, False otherwise."""
circuit_config = self.attrs.get("circuit_config", "")
if not circuit_config:
raise RuntimeError("circuit_config is missing in the simulation campaign")
return circuit_config.endswith(".json")
return self._is_sonata

@classmethod
def load(cls, path: StrOrPath) -> "SimulationCampaign":
Expand Down Expand Up @@ -201,7 +184,7 @@ def to_dict(self) -> dict:
"format": "blueetl",
"version": 1,
"name": self.name,
"attrs": self.attrs,
"attrs": {**self.attrs, "path_prefix": str(self.path_prefix)},
"data": self._data.to_dict(orient="records"),
}

Expand Down Expand Up @@ -240,10 +223,12 @@ def from_xarray(

def to_xarray(self) -> xr.DataArray:
"""Convert the configuration to xarray.DataArray."""
s = self._data.set_index(self.condition_names)[SIMULATION_PATH]
s = self._reduce_paths(s)
data = self._data
if condition_names := self.condition_names:
data = data.set_index(condition_names)
s = data[SIMULATION_PATH]
s.name = self.name
attrs = self.attrs.copy()
attrs = {**self.attrs, "path_prefix": str(self.path_prefix)}
coupled = attrs.pop("__coupled__", None)
if not coupled:
# generated by GenerateSimulationCampaign
Expand All @@ -252,7 +237,7 @@ def to_xarray(self) -> xr.DataArray:
else:
# generated by GenerateCoupledCoordsSimulationCampaign
da = xr.DataArray(
list(s),
s.to_list(),
name=s.name,
dims=coupled,
coords={coupled: s.index},
Expand All @@ -276,13 +261,13 @@ def __len__(self) -> int:

def __iter__(self) -> Iterator[SimulationRow]:
"""Iterate over the simulation rows."""
for i, (_, sim_dict) in enumerate(self._data.etl.iterdict()):
for i, (_, sim_dict) in enumerate(self._data_resolved.etl.iterdict()):
path = sim_dict.pop(SIMULATION_PATH)
yield SimulationRow(index=i, path=path, conditions=sim_dict)

def __getitem__(self, index: int) -> SimulationRow:
"""Return a specific simulation row."""
sim_dict = self._data.loc[index].to_dict()
sim_dict = self._data_resolved.loc[index].to_dict()
path = sim_dict.pop(SIMULATION_PATH)
return SimulationRow(index=index, path=path, conditions=sim_dict)

Expand All @@ -291,7 +276,7 @@ def get(self, *args, **kwargs) -> pd.DataFrame:

See ``etl.q`` for the filter syntax.
"""
return self._data.copy().etl.q(*args, **kwargs)
return self._data_resolved.copy().etl.q(*args, **kwargs)

def ids(self, *args, **kwargs) -> np.ndarray:
"""Return a numpy array with the ids of the selected simulations.
Expand Down
24 changes: 18 additions & 6 deletions tests/unit/campaign/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,17 +61,17 @@ def blueetl_config_dict_simple():
{
"ca": 1.0,
"depolarization": 4.0,
"simulation_path": "/tmp/simple/uuid/1/simulation_config.json",
"simulation_path": "uuid/1",
},
{
"ca": 2.0,
"depolarization": 3.0,
"simulation_path": "/tmp/simple/uuid/2/simulation_config.json",
"simulation_path": "uuid/2",
},
{
"ca": 2.0,
"depolarization": 4.0,
"simulation_path": "/tmp/simple/uuid/3/simulation_config.json",
"simulation_path": "uuid/3",
},
],
}
Expand All @@ -98,7 +98,7 @@ def blueetl_config_dict_coupled():
{
"ca": 2.0,
"depolarization": 4.0,
"simulation_path": "/tmp/coupled/uuid/1/simulation_config.json",
"simulation_path": "uuid/1",
},
],
}
Expand All @@ -117,13 +117,25 @@ def xarray_config_obj_coupled(xarray_config_dict_coupled):
@pytest.fixture
def blueetl_config_dataframe_simple(blueetl_config_dict_simple):
d = blueetl_config_dict_simple
return pd.DataFrame.from_records(d["data"])
df = pd.DataFrame.from_records(d["data"])
df["simulation_path"] = [
"",
"/tmp/simple/uuid/1/simulation_config.json",
"/tmp/simple/uuid/2/simulation_config.json",
"/tmp/simple/uuid/3/simulation_config.json",
]
return df


@pytest.fixture
def blueetl_config_dataframe_coupled(blueetl_config_dict_coupled):
d = blueetl_config_dict_coupled
return pd.DataFrame.from_records(d["data"])
df = pd.DataFrame.from_records(d["data"])
df["simulation_path"] = [
"",
"/tmp/coupled/uuid/1/simulation_config.json",
]
return df


@pytest.fixture
Expand Down
44 changes: 22 additions & 22 deletions tests/unit/campaign/test_config.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import json
from pathlib import Path

import pandas as pd
import pytest
Expand Down Expand Up @@ -208,38 +207,40 @@ def test_simulations_config_get_all(input_obj, expected_df, lazy_fixture):


@pytest.mark.parametrize(
"input_obj",
"filename, expected",
[
"blueetl_config_obj_simple",
"blueetl_config_obj_coupled",
("circuit_sonata.json", True),
("CircuitConfig", False),
],
)
@pytest.mark.parametrize(
"filename, expected",
"input_dict",
[
("circuit_sonata.json", True),
("CircuitConfig", False),
"blueetl_config_dict_simple",
"blueetl_config_dict_coupled",
],
)
def test_simulations_config_is_sonata(input_obj, filename, expected, lazy_fixture):
input_obj = lazy_fixture(input_obj)
input_obj.attrs["circuit_config"] = f"/path/to/{filename}"
result = input_obj.is_sonata()
assert result == expected
def test_simulations_config_is_sonata(input_dict, filename, expected, lazy_fixture):
input_dict = lazy_fixture(input_dict)
input_dict["attrs"]["circuit_config"] = f"/path/to/{filename}"
obj = test_module.SimulationCampaign.from_dict(input_dict)
assert obj.is_sonata() == expected


@pytest.mark.parametrize(
"input_obj",
"input_dict",
[
"blueetl_config_obj_simple",
"blueetl_config_obj_coupled",
"blueetl_config_dict_simple",
"blueetl_config_dict_coupled",
],
)
def test_simulations_config_is_sonata_raises(input_obj, lazy_fixture):
input_obj = lazy_fixture(input_obj)
del input_obj.attrs["circuit_config"]
with pytest.raises(RuntimeError, match="circuit_config is missing in the simulation campaign"):
input_obj.is_sonata()
def test_simulations_config_raises_when_circuit_config_is_missing(input_dict, lazy_fixture):
input_dict = lazy_fixture(input_dict)
del input_dict["attrs"]["circuit_config"]
with pytest.raises(
ValueError, match="circuit_config is missing or empty in the simulation campaign"
):
test_module.SimulationCampaign.from_dict(input_dict)


@pytest.mark.parametrize(
Expand Down Expand Up @@ -329,5 +330,4 @@ def test_simulations_config_with_relative_paths(config_file):

result = test_module.SimulationCampaign.load(config_file)

assert Path(result.attrs["path_prefix"]).is_absolute()
assert Path(result.attrs["circuit_config"]).is_absolute()
assert result.path_prefix.is_absolute()