diff --git a/pixi.lock b/pixi.lock index efafa5aa..0d173a32 100644 --- a/pixi.lock +++ b/pixi.lock @@ -3888,8 +3888,8 @@ packages: requires_python: '>=3.5' - pypi: ./ name: easydynamics - version: 0.4.0+dev5 - sha256: ff8f55922804cdb622d0eb0aecd00105aeb30e470aa79b05bf6f92556ad8ce67 + version: 0.3.0+devdirty8 + sha256: 23d0790d25938acbe8e96780e087d1b29916963a0f6b4d04b906012d906c5cfb requires_dist: - darkdetect - easyscience diff --git a/pyproject.toml b/pyproject.toml index 17452cd7..befd67f4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -256,8 +256,7 @@ select = [ # Ignore specific rules globally ignore = [ 'COM812', # https://docs.astral.sh/ruff/rules/missing-trailing-comma/ - # The following is replaced by 'D'/[tool.ruff.lint.pydocstyle] and [tool.pydoclint] - 'DOC', # https://docs.astral.sh/ruff/rules/#pydoclint-doc + # The following is replaced by 'D'/[tool.ruff.lint.pydocstyle] and [tool.pydoclint] 'DOC', # https://docs.astral.sh/ruff/rules/#pydoclint-doc # Disable, as [tool.format_docstring] split one-line docstrings into the canonical multi-line layout 'D200', # https://docs.astral.sh/ruff/rules/unnecessary-multiline-docstring/ ] diff --git a/src/easydynamics/base_classes/__init__.py b/src/easydynamics/base_classes/__init__.py new file mode 100644 index 00000000..cb15f1dd --- /dev/null +++ b/src/easydynamics/base_classes/__init__.py @@ -0,0 +1,10 @@ +# SPDX-FileCopyrightText: 2026 EasyScience contributors +# SPDX-License-Identifier: BSD-3-Clause + +from easydynamics.base_classes.easydynamics_base import EasyDynamicsBase +from easydynamics.base_classes.easydynamics_modelbase import EasyDynamicsModelBase + +__all__ = [ + 'EasyDynamicsModelBase', + 'EasyDynamicsBase', +] diff --git a/src/easydynamics/base_classes/easydynamics_base.py b/src/easydynamics/base_classes/easydynamics_base.py new file mode 100644 index 00000000..1f38c9d3 --- /dev/null +++ b/src/easydynamics/base_classes/easydynamics_base.py @@ -0,0 +1,69 @@ +# SPDX-FileCopyrightText: 2026 EasyScience contributors +# SPDX-License-Identifier: BSD-3-Clause + +from easyscience.base_classes.new_base import NewBase + + +class EasyDynamicsBase(NewBase): + """Base class for all EasyDynamics classes.""" + + def __init__( + self, + name: str | None = 'MyEasyDynamicsModel', + display_name: str | None = 'MyEasyDynamicsModel', + unique_name: str | None = None, + ) -> None: + """ + Initialize the EasyDynamicsBase. + + Parameters + ---------- + name : str | None, default='MyEasyDynamicsModel' + Name of the model. + display_name : str | None, default='MyEasyDynamicsModel' + Display name of the model. + unique_name : str | None, default=None + Unique name of the model. If None, a unique name will be generated. + + Raises + ------ + TypeError + If name is not a string or None. + """ + super().__init__(display_name=display_name, unique_name=unique_name) + + if name is not None and not isinstance(name, str): + raise TypeError('Name must be a string or None.') + self._name = name + + @property + def name(self) -> str | None: + """ + Get the name of the model. + + Returns + ------- + str | None + The name of the model. + """ + return self._name + + @name.setter + def name(self, name_str: str | None) -> None: + """ + Set the name of the model. + + Parameters + ---------- + name_str : str | None + The new name to set. + + Raises + ------ + TypeError + If name_str is not a string or None. + """ + + if name_str is not None and not isinstance(name_str, str): + raise TypeError('Name must be a string or None.') + self._name = name_str diff --git a/src/easydynamics/base_classes/easydynamics_modelbase.py b/src/easydynamics/base_classes/easydynamics_modelbase.py new file mode 100644 index 00000000..abc167c9 --- /dev/null +++ b/src/easydynamics/base_classes/easydynamics_modelbase.py @@ -0,0 +1,109 @@ +# SPDX-FileCopyrightText: 2026 EasyScience contributors +# SPDX-License-Identifier: BSD-3-Clause + +import scipp as sc +from easyscience.base_classes import ModelBase + +from easydynamics.utils.utils import _validate_unit + + +class EasyDynamicsModelBase(ModelBase): + """Base class for all EasyDynamics models.""" + + def __init__( + self, + unit: str | sc.Unit = 'meV', + name: str | None = 'MyEasyDynamicsModel', + display_name: str | None = 'MyEasyDynamicsModel', + unique_name: str | None = None, + ) -> None: + """ + Initialize the EasyDynamicsModelBase. + + Parameters + ---------- + unit : str | sc.Unit, default='meV' + Unit of the model. + name : str | None, default='MyEasyDynamicsModel' + Name of the model. + display_name : str | None, default='MyEasyDynamicsModel' + Display name of the model. + unique_name : str | None, default=None + Unique name of the model. If None, a unique name will be generated. + + Raises + ------ + TypeError + If name is not a string or None. + """ + super().__init__(display_name=display_name, unique_name=unique_name) + self._unit = _validate_unit(unit) + + if name is not None and not isinstance(name, str): + raise TypeError('Name must be a string or None.') + self._name = name + + @property + def unit(self) -> str | sc.Unit | None: + """ + Get the unit of the model. + + Returns + ------- + str | sc.Unit | None + The unit of the model. + """ + + return self._unit + + @unit.setter + def unit(self, _unit_str: str) -> None: + """ + Unit is read-only and cannot be set directly. + + Parameters + ---------- + _unit_str : str + The new unit to set (ignored). + + Raises + ------ + AttributeError + Always raised to indicate that the unit is read-only. + """ + raise AttributeError( + f'Unit is read-only. Use convert_unit to change the unit between allowed types ' + f'or create a new {self.__class__.__name__} with the desired unit.' + ) # noqa: E501 + + @property + def name(self) -> str | None: + """ + Get the name of the model. + + Returns + ------- + str | None + The name of the model. + """ + return self._name + + @name.setter + def name(self, name_str: str) -> None: + """ + Set the name of the model. + + Parameters + ---------- + name_str : str + The new name to set. + + Raises + ------ + TypeError + If name_str is not a string or None. + """ + + if name_str is not None and not isinstance(name_str, str): + raise TypeError('Name must be a string or None.') + self._name = name_str diff --git a/src/easydynamics/convolution/analytical_convolution.py b/src/easydynamics/convolution/analytical_convolution.py index a835f215..1162381e 100644 --- a/src/easydynamics/convolution/analytical_convolution.py +++ b/src/easydynamics/convolution/analytical_convolution.py @@ -38,10 +38,12 @@ class AnalyticalConvolution(ConvolutionBase): def __init__( self, energy: np.ndarray | sc.Variable, - energy_unit: str | sc.Unit = 'meV', + unit: str | sc.Unit = 'meV', sample_components: ComponentCollection | ModelComponent | None = None, resolution_components: ComponentCollection | ModelComponent | None = None, energy_offset: Numeric | Parameter = 0.0, + display_name: str | None = 'MyConvolution', + unique_name: str | None = None, ) -> None: """ Initialize an AnalyticalConvolution. @@ -50,7 +52,7 @@ def __init__( ---------- energy : np.ndarray | sc.Variable 1D array of energy values where the convolution is evaluated. - energy_unit : str | sc.Unit, default='meV' + unit : str | sc.Unit, default='meV' The unit of the energy. sample_components : ComponentCollection | ModelComponent | None, default=None The sample model to be convolved. @@ -58,13 +60,19 @@ def __init__( The resolution model to convolve with. energy_offset : Numeric | Parameter, default=0.0 An offset to shift the energy values by. + display_name : str | None, default='MyConvolution' + Display name of the model. + unique_name : str | None, default=None + Unique name of the model. If None, a unique name will be generated. """ super().__init__( energy=energy, - energy_unit=energy_unit, + unit=unit, sample_components=sample_components, resolution_components=resolution_components, energy_offset=energy_offset, + display_name=display_name, + unique_name=unique_name, ) def convolution( diff --git a/src/easydynamics/convolution/convolution.py b/src/easydynamics/convolution/convolution.py index 32d3fb96..449be1a3 100644 --- a/src/easydynamics/convolution/convolution.py +++ b/src/easydynamics/convolution/convolution.py @@ -56,8 +56,10 @@ def __init__( extension_factor: Numeric | None = 0.2, temperature: Parameter | Numeric | None = None, temperature_unit: str | sc.Unit = 'K', - energy_unit: str | sc.Unit = 'meV', + unit: str | sc.Unit = 'meV', normalize_detailed_balance: bool = True, + display_name: str | None = 'MyConvolution', + unique_name: str | None = None, ) -> None: """ Initialize the Convolution class. @@ -80,10 +82,14 @@ def __init__( The temperature to use for detailed balance correction. temperature_unit : str | sc.Unit, default='K' The unit of the temperature parameter. - energy_unit : str | sc.Unit, default='meV' + unit : str | sc.Unit, default='meV' The unit of the energy. normalize_detailed_balance : bool, default=True Whether to normalize the detailed balance correction. Default is True. + display_name : str | None, default='MyConvolution' + Display name of the model. + unique_name : str | None, default=None + Unique name of the model. If None, a unique name will be generated. """ self._convolution_plan_is_valid = False @@ -97,8 +103,10 @@ def __init__( extension_factor=extension_factor, temperature=temperature, temperature_unit=temperature_unit, - energy_unit=energy_unit, + unit=unit, normalize_detailed_balance=normalize_detailed_balance, + display_name=display_name, + unique_name=unique_name, ) self._reactions_enabled = True diff --git a/src/easydynamics/convolution/convolution_base.py b/src/easydynamics/convolution/convolution_base.py index d328dbed..dd98d5e9 100644 --- a/src/easydynamics/convolution/convolution_base.py +++ b/src/easydynamics/convolution/convolution_base.py @@ -5,12 +5,13 @@ import scipp as sc from easyscience.variable import Parameter +from easydynamics.base_classes import EasyDynamicsModelBase from easydynamics.sample_model.component_collection import ComponentCollection from easydynamics.sample_model.components.model_component import ModelComponent from easydynamics.utils.utils import Numeric -class ConvolutionBase: +class ConvolutionBase(EasyDynamicsModelBase): """ Base class for convolutions of sample and resolution models. @@ -22,8 +23,10 @@ def __init__( energy: np.ndarray | sc.Variable, sample_components: ComponentCollection | ModelComponent | None = None, resolution_components: ComponentCollection | ModelComponent | None = None, - energy_unit: str | sc.Unit = 'meV', + unit: str | sc.Unit = 'meV', energy_offset: Numeric | Parameter = 0.0, + display_name: str | None = 'MyConvolution', + unique_name: str | None = None, ) -> None: """ Initialize the ConvolutionBase. @@ -36,10 +39,14 @@ def __init__( The sample model to be convolved. resolution_components : ComponentCollection | ModelComponent | None, default=None The resolution model to convolve with. - energy_unit : str | sc.Unit, default='meV' + unit : str | sc.Unit, default='meV' The unit of the energy. energy_offset : Numeric | Parameter, default=0.0 The energy offset applied to the convolution. + display_name : str | None, default='MyConvolution' + Display name of the model. + unique_name : str | None, default=None + Unique name of the model. If None, a unique name will be generated. Raises ------ @@ -49,28 +56,29 @@ def __init__( sample_components is not a ComponentCollection or ModelComponent, or if resolution_components is not a ComponentCollection or ModelComponent. """ + + super().__init__( + unit=unit, + display_name=display_name, + unique_name=unique_name, + ) + if isinstance(energy, Numeric): energy = np.array([float(energy)]) if not isinstance(energy, (np.ndarray, sc.Variable)): raise TypeError(f'Energy must be a numpy ndarray or a scipp Variable. Got {energy}') - if not isinstance(energy_unit, (str, sc.Unit)): - raise TypeError('Energy_unit must be a string or sc.Unit.') - if isinstance(energy, np.ndarray): - energy = sc.array(dims=['energy'], values=energy, unit=energy_unit) + energy = sc.array(dims=['energy'], values=energy, unit=unit) if isinstance(energy_offset, Numeric): - energy_offset = Parameter( - name='energy_offset', value=float(energy_offset), unit=energy_unit - ) + energy_offset = Parameter(name='energy_offset', value=float(energy_offset), unit=unit) if not isinstance(energy_offset, Parameter): raise TypeError('Energy_offset must be a number or a Parameter.') self._energy = energy - self._energy_unit = energy_unit self._energy_offset = energy_offset if sample_components is not None and not ( @@ -202,62 +210,42 @@ def energy(self, energy: np.ndarray | sc.Variable) -> None: if isinstance(energy, sc.Variable): self._energy = energy - self._energy_unit = energy.unit - - @property - def energy_unit(self) -> str: - """ - Get the energy unit. - - Returns - ------- - str - The unit of the energy. - """ - return self._energy_unit - - @energy_unit.setter - def energy_unit(self, _unit_str: str) -> None: - """Energy unit.""" - raise AttributeError( - f'Unit is read-only. Use convert_unit to change the unit between allowed types ' - f'or create a new {self.__class__.__name__} with the desired unit.' - ) # noqa: E501 + self._unit = energy.unit - def convert_energy_unit(self, energy_unit: str | sc.Unit) -> None: + def convert_unit(self, unit: str | sc.Unit) -> None: """ Convert the energy and energy_offset to the specified unit. Parameters ---------- - energy_unit : str | sc.Unit + unit : str | sc.Unit The unit of the energy. Raises ------ TypeError - If energy_unit is not a string or scipp unit. + If unit is not a string or scipp unit. Exception If energy cannot be converted to the specified unit. """ - if not isinstance(energy_unit, (str, sc.Unit)): + if not isinstance(unit, (str, sc.Unit)): raise TypeError('Energy unit must be a string or scipp unit.') old_energy = self.energy.copy() try: - self.energy = sc.to_unit(self.energy, energy_unit) + self.energy = sc.to_unit(self.energy, unit) except Exception as e: self.energy = old_energy raise e old_energy_offset = self.energy_offset try: - self.energy_offset.convert_unit(energy_unit) + self.energy_offset.convert_unit(unit) except Exception as e: self.energy_offset = old_energy_offset raise e - self._energy_unit = energy_unit + self._unit = unit @property def sample_components(self) -> ComponentCollection | ModelComponent: diff --git a/src/easydynamics/convolution/numerical_convolution.py b/src/easydynamics/convolution/numerical_convolution.py index 87e6b49e..17fc3904 100644 --- a/src/easydynamics/convolution/numerical_convolution.py +++ b/src/easydynamics/convolution/numerical_convolution.py @@ -32,8 +32,10 @@ def __init__( extension_factor: Numeric | None = 0.2, temperature: Parameter | Numeric | None = None, temperature_unit: str | sc.Unit = 'K', - energy_unit: str | sc.Unit = 'meV', + unit: str | sc.Unit = 'meV', normalize_detailed_balance: bool = True, + display_name: str | None = 'MyConvolution', + unique_name: str | None = None, ) -> None: """ Initialize the NumericalConvolution object. @@ -56,10 +58,14 @@ def __init__( The temperature to use for detailed balance correction. temperature_unit : str | sc.Unit, default='K' The unit of the temperature parameter. - energy_unit : str | sc.Unit, default='meV' + unit : str | sc.Unit, default='meV' The unit of the energy. normalize_detailed_balance : bool, default=True Whether to normalize the detailed balance correction. Default is True. + display_name : str | None, default='MyConvolution' + Display name of the model. + unique_name : str | None, default=None + Unique name of the model. If None, a unique name will be generated. """ super().__init__( energy=energy, @@ -70,8 +76,10 @@ def __init__( extension_factor=extension_factor, temperature=temperature, temperature_unit=temperature_unit, - energy_unit=energy_unit, + unit=unit, normalize_detailed_balance=normalize_detailed_balance, + display_name=display_name, + unique_name=unique_name, ) def convolution( diff --git a/src/easydynamics/convolution/numerical_convolution_base.py b/src/easydynamics/convolution/numerical_convolution_base.py index e2d6f5b6..d6bdc1ca 100644 --- a/src/easydynamics/convolution/numerical_convolution_base.py +++ b/src/easydynamics/convolution/numerical_convolution_base.py @@ -42,8 +42,10 @@ def __init__( extension_factor: Numeric | None = 0.2, temperature: Parameter | Numeric | None = None, temperature_unit: str | sc.Unit = 'K', - energy_unit: str | sc.Unit = 'meV', + unit: str | sc.Unit = 'meV', normalize_detailed_balance: bool = True, + display_name: str | None = 'MyConvolution', + unique_name: str | None = None, ) -> None: """ Initialize the NumericalConvolutionBase. @@ -66,10 +68,14 @@ def __init__( The temperature to use for detailed balance correction. temperature_unit : str | sc.Unit, default='K' The unit of the temperature parameter. - energy_unit : str | sc.Unit, default='meV' + unit : str | sc.Unit, default='meV' The unit of the energy. normalize_detailed_balance : bool, default=True Whether to normalize the detailed balance correction. + display_name : str | None, default='MyConvolution' + Display name of the model. + unique_name : str | None, default=None + Unique name of the model. If None, a unique name will be generated. Raises ------ @@ -82,8 +88,10 @@ def __init__( energy=energy, sample_components=sample_components, resolution_components=resolution_components, - energy_unit=energy_unit, + unit=unit, energy_offset=energy_offset, + display_name=display_name, + unique_name=unique_name, ) if temperature is not None and not isinstance(temperature, (Numeric, Parameter)): @@ -434,7 +442,7 @@ def __repr__(self) -> str: f'energy=array of shape {self.energy.values.shape},\n ' f'sample_components={repr(self.sample_components)}, \n' f'resolution_components={repr(self.resolution_components)},\n ' - f'energy_unit={self._energy_unit}, ' + f'unit={self.unit}, ' f'upsample_factor={self.upsample_factor}, ' f'extension_factor={self.extension_factor}, ' f'temperature={self.temperature}, ' diff --git a/src/easydynamics/experiment/experiment.py b/src/easydynamics/experiment/experiment.py index 0d305400..909fb35a 100644 --- a/src/easydynamics/experiment/experiment.py +++ b/src/easydynamics/experiment/experiment.py @@ -7,15 +7,15 @@ import numpy as np import plopp as pp import scipp as sc -from easyscience.base_classes.new_base import NewBase from plopp.backends.matplotlib.figure import InteractiveFigure from scipp.io import load_hdf5 as sc_load_hdf5 from scipp.io import save_hdf5 as sc_save_hdf5 +from easydynamics.base_classes.easydynamics_base import EasyDynamicsBase from easydynamics.utils.utils import _in_notebook -class Experiment(NewBase): +class Experiment(EasyDynamicsBase): """ Holds data from an experiment as a sc.DataArray along with metadata. diff --git a/src/easydynamics/sample_model/components/model_component.py b/src/easydynamics/sample_model/components/model_component.py index f4d0c15e..33235025 100644 --- a/src/easydynamics/sample_model/components/model_component.py +++ b/src/easydynamics/sample_model/components/model_component.py @@ -8,13 +8,13 @@ import numpy as np import scipp as sc -from easyscience.base_classes.model_base import ModelBase from scipp import UnitError +from easydynamics.base_classes.easydynamics_modelbase import EasyDynamicsModelBase from easydynamics.utils.utils import Numeric -class ModelComponent(ModelBase): +class ModelComponent(EasyDynamicsModelBase): """Abstract base class for all model components.""" def __init__( @@ -28,15 +28,14 @@ def __init__( Parameters ---------- - unit : str | sc.Unit, default='meV' + unit : str | sc.Unit, default="meV" The unit of the model component. display_name : str | None, default=None A human-readable name for the component. unique_name : str | None, default=None A unique identifier for the component. """ - self.validate_unit(unit) - super().__init__(display_name=display_name, unique_name=unique_name) + super().__init__(unit=unit, display_name=display_name, unique_name=unique_name) self._unit = unit @property @@ -163,26 +162,6 @@ def _prepare_x_for_evaluate( return np.sort(x_in) - @staticmethod - def validate_unit(unit: str | sc.Unit | None) -> None: - """ - Validate that the unit is either a string or a scipp Unit. - - Parameters - ---------- - unit : str | sc.Unit | None - The unit to validate. - - Raises - ------ - TypeError - If unit is not a string or scipp Unit. - """ - if unit is not None and not isinstance(unit, (str, sc.Unit)): - raise TypeError( - f'unit must be None, a string, or a scipp Unit, got {type(unit).__name__}' - ) - def convert_unit(self, unit: str | sc.Unit) -> None: """ Convert the unit of the Parameters in the component. diff --git a/src/easydynamics/sample_model/diffusion_model/diffusion_model_base.py b/src/easydynamics/sample_model/diffusion_model/diffusion_model_base.py index 177d5e79..0e83351e 100644 --- a/src/easydynamics/sample_model/diffusion_model/diffusion_model_base.py +++ b/src/easydynamics/sample_model/diffusion_model/diffusion_model_base.py @@ -2,15 +2,15 @@ # SPDX-License-Identifier: BSD-3-Clause import scipp as sc -from easyscience.base_classes.model_base import ModelBase from easyscience.variable import DescriptorNumber from easyscience.variable import Parameter from scipp import UnitError +from easydynamics.base_classes.easydynamics_modelbase import EasyDynamicsModelBase from easydynamics.utils.utils import Numeric -class DiffusionModelBase(ModelBase): +class DiffusionModelBase(EasyDynamicsModelBase): """Base class for constructing diffusion models.""" def __init__( @@ -56,48 +56,13 @@ def __init__( scale = Parameter(name='scale', value=float(scale), fixed=False, min=0.0, unit=unit) - super().__init__(display_name=display_name, unique_name=unique_name) - self._unit = unit + super().__init__(display_name=display_name, unique_name=unique_name, unit=unit) self._scale = scale # ------------------------------------------------------------------ # Properties # ------------------------------------------------------------------ - @property - def unit(self) -> str | sc.Unit | None: - """ - Get the unit of the energy axis of the DiffusionModel. - - Returns - ------- - str | sc.Unit | None - Unit of the DiffusionModel. - """ - return str(self._unit) - - @unit.setter - def unit(self, _unit_str: str) -> None: - """ - The unit of the energy axis is read-only. - - To change the unit, use convert_unit or create a new DiffusionModel with the desired unit. - - Parameters - ---------- - _unit_str : str - The new unit to set (ignored). - - Raises - ------ - AttributeError - Always, since the unit is read-only. - """ - raise AttributeError( - f'Unit is read-only. Use convert_unit to change the unit between allowed types ' - f'or create a new {self.__class__.__name__} with the desired unit.' - ) # noqa: E501 - @property def scale(self) -> Parameter: """ diff --git a/src/easydynamics/sample_model/model_base.py b/src/easydynamics/sample_model/model_base.py index 1b5d9363..7949f217 100644 --- a/src/easydynamics/sample_model/model_base.py +++ b/src/easydynamics/sample_model/model_base.py @@ -5,18 +5,17 @@ import numpy as np import scipp as sc -from easyscience.base_classes.model_base import ModelBase as EasyScienceModelBase from easyscience.variable import Parameter +from easydynamics.base_classes.easydynamics_modelbase import EasyDynamicsModelBase from easydynamics.sample_model.component_collection import ComponentCollection from easydynamics.sample_model.components.model_component import ModelComponent from easydynamics.utils.utils import Numeric from easydynamics.utils.utils import Q_type from easydynamics.utils.utils import _validate_and_convert_Q -from easydynamics.utils.utils import _validate_unit -class ModelBase(EasyScienceModelBase): +class ModelBase(EasyDynamicsModelBase): """ Base class for Sample Models. @@ -54,10 +53,10 @@ def __init__( If components is not a ModelComponent or ComponentCollection. """ super().__init__( + unit=unit, display_name=display_name, unique_name=unique_name, ) - self._unit = _validate_unit(unit) self._Q = _validate_and_convert_Q(Q) if components is not None and not isinstance( @@ -145,12 +144,12 @@ def clear_components(self) -> None: @property def unit(self) -> str | sc.Unit | None: """ - Get the unit of the ComponentCollection. + Get the unit of the SampleModel. Returns ------- str | sc.Unit | None - The unit of the ComponentCollection. + The unit of the SampleModel. """ return self._unit @@ -175,41 +174,6 @@ def unit(self, _unit_str: str) -> None: f'or create a new {self.__class__.__name__} with the desired unit.' ) # noqa: E501 - def convert_unit(self, unit: str | sc.Unit) -> None: - """ - Convert the unit of the ComponentCollection and all its components. - - Parameters - ---------- - unit : str | sc.Unit - The new unit to convert to. - - Raises - ------ - TypeError - If the provided unit is not a string or sc.Unit. - Exception - If the provided unit is not compatible with the current unit. - """ - - old_unit = self._unit - - if not isinstance(unit, (str, sc.Unit)): - raise TypeError(f'Unit must be a string or sc.Unit, got {type(unit).__name__}') - try: - for component in self.components: - component.convert_unit(unit) - self._unit = unit - except Exception as e: - # Attempt to rollback on failure - try: - for component in self.components: - component.convert_unit(old_unit) - except Exception: # noqa: S110 - pass # Best effort rollback - raise e - self._on_components_change() - @property def components(self) -> list[ModelComponent]: """ @@ -315,6 +279,42 @@ def clear_Q(self, confirm: bool = False) -> None: # ------------------------------------------------------------------ # Other methods # ------------------------------------------------------------------ + + def convert_unit(self, unit: str | sc.Unit) -> None: + """ + Convert the unit of the ComponentCollection and all its components. + + Parameters + ---------- + unit : str | sc.Unit + The new unit to convert to. + + Raises + ------ + TypeError + If the provided unit is not a string or sc.Unit. + Exception + If the provided unit is not compatible with the current unit. + """ + + old_unit = self._unit + + if not isinstance(unit, (str, sc.Unit)): + raise TypeError(f'Unit must be a string or sc.Unit, got {type(unit).__name__}') + try: + for component in self.components: + component.convert_unit(unit) + self._unit = unit + except Exception as e: + # Attempt to rollback on failure + try: + for component in self.components: + component.convert_unit(old_unit) + except Exception: # noqa: S110 + pass # Best effort rollback + raise e + self._on_components_change() + def fix_all_parameters(self) -> None: """Fix all Parameters in all ComponentCollections.""" for par in self.get_all_variables(): diff --git a/src/easydynamics/utils/utils.py b/src/easydynamics/utils/utils.py index 3d383d9c..d9fc27c0 100644 --- a/src/easydynamics/utils/utils.py +++ b/src/easydynamics/utils/utils.py @@ -84,8 +84,11 @@ def _validate_unit(unit: str | sc.Unit | None) -> sc.Unit | None: if unit is not None and not isinstance(unit, (str, sc.Unit)): raise TypeError(f'unit must be None, a string, or a scipp Unit, got {type(unit).__name__}') - if isinstance(unit, str): - unit = sc.Unit(unit) + # if isinstance(unit, str): + # unit = sc.Unit(unit) + + if isinstance(unit, sc.Unit): + unit = str(unit) return unit diff --git a/tests/unit/easydynamics/base_classes/test_easydynamics_base.py b/tests/unit/easydynamics/base_classes/test_easydynamics_base.py new file mode 100644 index 00000000..a020fa97 --- /dev/null +++ b/tests/unit/easydynamics/base_classes/test_easydynamics_base.py @@ -0,0 +1,70 @@ +# SPDX-FileCopyrightText: 2026 EasyScience contributors +# SPDX-License-Identifier: BSD-3-Clause + +import pytest + +from easydynamics.base_classes import EasyDynamicsBase + + +class TestEasyDynamicsBase: + """Tests for the EasyDynamicsBase class.""" + + @pytest.fixture + def easy_dynamics_base(self): + """Fixture for creating an instance of EasyDynamicsBase.""" + + return EasyDynamicsBase(name='TestModel') + + def test_initialization(self, easy_dynamics_base): + """Test that the EasyDynamicsBase is initialized correctly.""" + + # WHEN THEN EXPECT + assert easy_dynamics_base.name == 'TestModel' + assert easy_dynamics_base.display_name == 'MyEasyDynamicsModel' + assert easy_dynamics_base.unique_name is not None + + def test_init_raises_type_error_for_invalid_name(self): + """Test that initializing with an invalid name raises a TypeError.""" + # WHEN THEN EXPECT + with pytest.raises(TypeError, match='Name must be a string or None.'): + EasyDynamicsBase(name=123) # Not a string + + def test_init_name_can_be_none(self): + """Test that initializing with name as None works correctly.""" + # WHEN THEN EXPECT + model = EasyDynamicsBase(name=None) + + # THEN EXPECT + assert model.name is None + + def test_name_setter_and_getter(self, easy_dynamics_base): + """Test that the name setter and getter work correctly.""" + # WHEN THEN EXPECT + assert easy_dynamics_base.name == 'TestModel' + + # THEN + easy_dynamics_base.name = 'NewName' + + # EXPECT + assert easy_dynamics_base.name == 'NewName' + + # THEN + easy_dynamics_base.name = None + + # EXPECT + assert easy_dynamics_base.name is None + + @pytest.mark.parametrize( + 'invalid_name', + [ + 123, # Not a string + [1, 2, 3], # Not a string + {'name': 'Test'}, # Not a string + ], + ids=['integer', 'list', 'dict'], + ) + def test_name_setter_invalid_type(self, easy_dynamics_base, invalid_name): + """Test that setting the name to an invalid type raises a TypeError.""" + # WHEN THEN EXPECT + with pytest.raises(TypeError, match='Name must be a string or None.'): + easy_dynamics_base.name = invalid_name diff --git a/tests/unit/easydynamics/base_classes/test_easydynamics_modelbase.py b/tests/unit/easydynamics/base_classes/test_easydynamics_modelbase.py new file mode 100644 index 00000000..8d7612e6 --- /dev/null +++ b/tests/unit/easydynamics/base_classes/test_easydynamics_modelbase.py @@ -0,0 +1,80 @@ +# SPDX-FileCopyrightText: 2026 EasyScience contributors +# SPDX-License-Identifier: BSD-3-Clause + +import pytest + +from easydynamics.base_classes import EasyDynamicsModelBase + + +class TestEasyDynamicsModelBase: + """Tests for the EasyDynamicsModelBase class.""" + + @pytest.fixture + def easy_dynamics_modelbase(self): + """Fixture for creating an instance of EasyDynamicsModelBase.""" + + return EasyDynamicsModelBase(name='TestModel', unit='meV') + + def test_initialization(self, easy_dynamics_modelbase): + """Test that the EasyDynamicsModelBase is initialized correctly.""" + + # WHEN THEN EXPECT + assert easy_dynamics_modelbase.name == 'TestModel' + assert easy_dynamics_modelbase.display_name == 'MyEasyDynamicsModel' + assert easy_dynamics_modelbase.unique_name is not None + + def test_init_raises_type_error_for_invalid_name(self): + """Test that initializing with an invalid name raises a TypeError.""" + # WHEN THEN EXPECT + with pytest.raises(TypeError, match='Name must be a string or None.'): + EasyDynamicsModelBase(name=123) # Not a string + + def test_init_name_can_be_none(self): + """Test that initializing with name as None works correctly.""" + # WHEN THEN EXPECT + model = EasyDynamicsModelBase(name=None) + + # THEN EXPECT + assert model.name is None + + def test_name_setter_and_getter(self, easy_dynamics_modelbase): + """Test that the name setter and getter work correctly.""" + # WHEN THEN EXPECT + assert easy_dynamics_modelbase.name == 'TestModel' + + # THEN + easy_dynamics_modelbase.name = 'NewName' + + # EXPECT + assert easy_dynamics_modelbase.name == 'NewName' + + # THEN + easy_dynamics_modelbase.name = None + + # EXPECT + assert easy_dynamics_modelbase.name is None + + @pytest.mark.parametrize( + 'invalid_name', + [ + 123, # Not a string + [1, 2, 3], # Not a string + {'name': 'Test'}, # Not a string + ], + ids=['integer', 'list', 'dict'], + ) + def test_name_setter_invalid_type(self, easy_dynamics_modelbase, invalid_name): + """Test that setting the name to an invalid type raises a TypeError.""" + # WHEN THEN EXPECT + with pytest.raises(TypeError, match='Name must be a string or None.'): + easy_dynamics_modelbase.name = invalid_name + + def test_unit_property(self, easy_dynamics_modelbase): + # WHEN THEN EXPECT + + assert easy_dynamics_modelbase.unit == 'meV' + + def test_unit_setter_raises(self, easy_dynamics_modelbase): + # WHEN / THEN / EXPECT + with pytest.raises(AttributeError, match='Use convert_unit to change '): + easy_dynamics_modelbase.unit = 'K' diff --git a/tests/unit/easydynamics/convolution/test_convolution.py b/tests/unit/easydynamics/convolution/test_convolution.py index 468df680..ad70ad5a 100644 --- a/tests/unit/easydynamics/convolution/test_convolution.py +++ b/tests/unit/easydynamics/convolution/test_convolution.py @@ -76,7 +76,7 @@ def test_init(self, default_convolution): assert default_convolution.upsample_factor == 5 assert default_convolution.extension_factor == 0.2 assert default_convolution.temperature is None - assert default_convolution.energy_unit == 'meV' + assert default_convolution.unit == 'meV' assert default_convolution.normalize_detailed_balance is True assert isinstance(default_convolution._energy_grid, EnergyGrid) @@ -110,7 +110,7 @@ def test_init_components(self, convolution_with_components): assert convolution_with_components.upsample_factor == 5 assert convolution_with_components.extension_factor == 0.2 assert convolution_with_components.temperature is None - assert convolution_with_components.energy_unit == 'meV' + assert convolution_with_components.unit == 'meV' assert convolution_with_components.normalize_detailed_balance is True assert isinstance(convolution_with_components._energy_grid, EnergyGrid) diff --git a/tests/unit/easydynamics/convolution/test_convolution_base.py b/tests/unit/easydynamics/convolution/test_convolution_base.py index 44ae133c..393c4af3 100644 --- a/tests/unit/easydynamics/convolution/test_convolution_base.py +++ b/tests/unit/easydynamics/convolution/test_convolution_base.py @@ -78,7 +78,7 @@ def test_init_energy_numerical_none_offset(self): 'energy': 'invalid', 'sample_components': ComponentCollection(), 'resolution_components': ComponentCollection(), - 'energy_unit': 'meV', + 'unit': 'meV', 'energy_offset': 0, }, 'Energy must be', @@ -88,7 +88,7 @@ def test_init_energy_numerical_none_offset(self): 'energy': np.linspace(-10, 10, 100), 'sample_components': 'invalid', 'resolution_components': ComponentCollection(), - 'energy_unit': 'meV', + 'unit': 'meV', 'energy_offset': 0, }, ( @@ -101,7 +101,7 @@ def test_init_energy_numerical_none_offset(self): 'energy': np.linspace(-10, 10, 100), 'sample_components': ComponentCollection(), 'resolution_components': 'invalid', - 'energy_unit': 'meV', + 'unit': 'meV', 'energy_offset': 0, }, ( @@ -114,17 +114,17 @@ def test_init_energy_numerical_none_offset(self): 'energy': np.linspace(-10, 10, 100), 'sample_components': ComponentCollection(), 'resolution_components': ComponentCollection(), - 'energy_unit': 123, + 'unit': 123, 'energy_offset': 0, }, - 'Energy_unit must be ', + 'unit must be ', ), ( { 'energy': np.linspace(-10, 10, 100), 'sample_components': ComponentCollection(), 'resolution_components': ComponentCollection(), - 'energy_unit': 'meV', + 'unit': 'meV', 'energy_offset': 'invalid', }, 'Energy_offset must be ', @@ -173,48 +173,48 @@ def test_energy_setter_invalid_type_raises(self, convolution_base): ): convolution_base.energy = 'invalid' - def test_energy_unit_property(self, convolution_base): + def test_unit_property(self, convolution_base): # WHEN THEN EXPECT assert convolution_base.energy.unit == 'meV' - def test_energy_unit_setter_raises(self, convolution_base): + def test_unit_setter_raises(self, convolution_base): # WHEN THEN EXPECT with pytest.raises( AttributeError, match='Use convert_unit to change the unit between allowed types ', ): - convolution_base.energy_unit = 'K' + convolution_base.unit = 'K' - def test_convert_energy_unit(self, convolution_base): + def test_convert_unit(self, convolution_base): # WHEN THEN - convolution_base.convert_energy_unit('eV') + convolution_base.convert_unit('eV') # EXPECT assert convolution_base.energy.unit == 'eV' - assert convolution_base.energy_unit == 'eV' + assert convolution_base.unit == 'eV' assert np.allclose(convolution_base.energy.values, np.linspace(-0.01, 0.01, 100)) - def test_convert_energy_unit_invalid_type_raises(self, convolution_base): + def test_convert_unit_invalid_type_raises(self, convolution_base): # WHEN THEN EXPECT with pytest.raises( TypeError, match='Energy unit must be a string or scipp unit.', ): - convolution_base.convert_energy_unit(123) + convolution_base.convert_unit(123) - def test_convert_energy_unit_invalid_unit_rollback(self, convolution_base): + def test_convert_unit_invalid_unit_rollback(self, convolution_base): # WHEN THEN with pytest.raises( UnitError, match='Conversion from `meV` to `s` is not valid.', ): - convolution_base.convert_energy_unit('s') + convolution_base.convert_unit('s') # EXPECT - assert convolution_base.energy_unit == 'meV' + assert convolution_base.unit == 'meV' assert np.allclose(convolution_base.energy.values, np.linspace(-10, 10, 100)) - def test_convert_energy_unit_invalid_offset_unit_rollback(self, convolution_base): + def test_convert_unit_invalid_offset_unit_rollback(self, convolution_base): # WHEN convolution_base.energy_offset = Parameter(name='energy_offset', value=5, unit='s') @@ -223,10 +223,10 @@ def test_convert_energy_unit_invalid_offset_unit_rollback(self, convolution_base UnitError, match='Conversion from `s` to `meV` is not valid.', ): - convolution_base.convert_energy_unit('meV') + convolution_base.convert_unit('meV') # EXPECT - assert convolution_base.energy_unit == 'meV' + assert convolution_base.unit == 'meV' assert convolution_base.energy_offset.unit == 's' def test_energy_offset_property(self, convolution_base): diff --git a/tests/unit/easydynamics/convolution/test_numerical_convolution.py b/tests/unit/easydynamics/convolution/test_numerical_convolution.py index 29e5dfd6..de28a2bc 100644 --- a/tests/unit/easydynamics/convolution/test_numerical_convolution.py +++ b/tests/unit/easydynamics/convolution/test_numerical_convolution.py @@ -48,7 +48,7 @@ def test_init(self, default_numerical_convolution): assert default_numerical_convolution.upsample_factor == 5 assert default_numerical_convolution.extension_factor == 0.2 assert default_numerical_convolution.temperature is None - assert default_numerical_convolution.energy_unit == 'meV' + assert default_numerical_convolution.unit == 'meV' assert default_numerical_convolution.normalize_detailed_balance is True assert isinstance(default_numerical_convolution._energy_grid, EnergyGrid) diff --git a/tests/unit/easydynamics/convolution/test_numerical_convolution_base.py b/tests/unit/easydynamics/convolution/test_numerical_convolution_base.py index 3934eefb..1c462608 100644 --- a/tests/unit/easydynamics/convolution/test_numerical_convolution_base.py +++ b/tests/unit/easydynamics/convolution/test_numerical_convolution_base.py @@ -46,7 +46,7 @@ def test_init(self, default_numerical_convolution_base): assert default_numerical_convolution_base.upsample_factor == 5 assert default_numerical_convolution_base.extension_factor == 0.2 assert default_numerical_convolution_base.temperature is None - assert default_numerical_convolution_base.energy_unit == 'meV' + assert default_numerical_convolution_base.unit == 'meV' assert default_numerical_convolution_base.normalize_detailed_balance is True assert isinstance(default_numerical_convolution_base._energy_grid, EnergyGrid) @@ -63,7 +63,7 @@ def test_init_with_custom_parameters(self): extension_factor = 0.5 temperature = 300.0 temperature_unit = 'K' - energy_unit = 'meV' + unit = 'meV' normalize_detailed_balance = False # THEN @@ -75,7 +75,7 @@ def test_init_with_custom_parameters(self): extension_factor=extension_factor, temperature=temperature, temperature_unit=temperature_unit, - energy_unit=energy_unit, + unit=unit, normalize_detailed_balance=normalize_detailed_balance, ) @@ -84,7 +84,7 @@ def test_init_with_custom_parameters(self): assert numerical_convolution_base.extension_factor == extension_factor assert numerical_convolution_base.temperature.value == temperature assert numerical_convolution_base.temperature.unit == temperature_unit - assert numerical_convolution_base.energy_unit == energy_unit + assert numerical_convolution_base.unit == unit assert numerical_convolution_base.normalize_detailed_balance == normalize_detailed_balance assert isinstance(numerical_convolution_base._energy_grid, EnergyGrid) @@ -502,7 +502,7 @@ def test_repr(self, default_numerical_convolution_base): assert 'resolution_components=' in repr_str # Important parameters - assert 'energy_unit=meV' in repr_str + assert 'unit=meV' in repr_str assert 'upsample_factor=5' in repr_str assert 'extension_factor=0.2' in repr_str assert 'temperature=None' in repr_str diff --git a/tests/unit/easydynamics/sample_model/components/test_model_component.py b/tests/unit/easydynamics/sample_model/components/test_model_component.py index 9e47d4e9..24f3ac57 100644 --- a/tests/unit/easydynamics/sample_model/components/test_model_component.py +++ b/tests/unit/easydynamics/sample_model/components/test_model_component.py @@ -8,8 +8,6 @@ from easydynamics.sample_model.components.model_component import ModelComponent -Numeric = float | int - class DummyComponent(ModelComponent): def __init__(self): diff --git a/tests/unit/easydynamics/utils/test_utils.py b/tests/unit/easydynamics/utils/test_utils.py index 7aa72a78..680d2a28 100644 --- a/tests/unit/easydynamics/utils/test_utils.py +++ b/tests/unit/easydynamics/utils/test_utils.py @@ -92,13 +92,13 @@ def test_validate_unit_valid(self, unit_input): if unit_input is None: assert unit is None else: - assert isinstance(unit, sc.Unit) + assert isinstance(unit, str) def test_validate_unit_string_conversion(self): - unit = _validate_unit('meV') + unit = _validate_unit(sc.Unit('meV')) - assert isinstance(unit, sc.Unit) - assert unit == sc.Unit('meV') + assert isinstance(unit, str) + assert unit == 'meV' @pytest.mark.parametrize( 'unit_input',