From f60aaf612cda1409ab7506cf161ce43b1789237f Mon Sep 17 00:00:00 2001 From: rozyczko Date: Mon, 9 Mar 2026 15:00:32 +0100 Subject: [PATCH 1/7] new collection based on easy_list --- .../base_classes/collection_base.py | 609 +++++++++++------- .../base_classes/test_collection_base.py | 35 +- .../test_collection_base_easylist.py | 93 +++ .../global_object/test_undo_redo.py | 63 +- 4 files changed, 500 insertions(+), 300 deletions(-) create mode 100644 tests/unit_tests/base_classes/test_collection_base_easylist.py diff --git a/src/easyscience/base_classes/collection_base.py b/src/easyscience/base_classes/collection_base.py index 0a4e8272..ea2f666c 100644 --- a/src/easyscience/base_classes/collection_base.py +++ b/src/easyscience/base_classes/collection_base.py @@ -4,249 +4,418 @@ from __future__ import annotations -from collections.abc import MutableSequence +import copy +import warnings +from collections.abc import Iterable +from importlib import import_module from numbers import Number -from typing import TYPE_CHECKING from typing import Any -from typing import Callable -from typing import List from typing import Optional -from typing import Tuple -from typing import Union -from easyscience.base_classes.new_base import NewBase -from easyscience.global_object.undo_redo import NotarizedDict +from easyscience.io.serializer_base import SerializerBase +from easyscience.io.serializer_dict import SerializerDict from ..variable.descriptor_base import DescriptorBase +from ..variable.parameter import Parameter from .based_base import BasedBase +from .easy_list import EasyList +from .new_base import NewBase -if TYPE_CHECKING: - from ..fitting.calculators import InterfaceFactoryTemplate - -class CollectionBase(BasedBase, MutableSequence): +class CollectionBase(EasyList): """ - This is the base class for which all higher level classes are built off of. - NOTE: This object is serializable only if parameters are supplied as: - `ObjBase(a=value, b=value)`. For `Parameter` or `Descriptor` objects we can - cheat with `ObjBase(*[Descriptor(...), Parameter(...), ...])`. + EasyList-backed collection with a small compatibility layer for migration. + + The collection delegates storage and MutableSequence behavior to EasyList, + adding only scientific-parameter aggregation methods and a thin compatibility + layer for existing callers. """ + _DEFAULT_PROTECTED_TYPES = (DescriptorBase, BasedBase, NewBase) + _RESERVED_NAMED_KEYS = { + 'data', + 'display_name', + 'interface', + 'name', + 'protected_types', + 'unique_name', + 'user_data', + '_kwargs', + } + _REDIRECT = {'interface': None} + def __init__( self, - name: str, - *args: Union[BasedBase, DescriptorBase, NewBase], - interface: Optional[InterfaceFactoryTemplate] = None, + *items: Any, + name: Optional[str] = None, + protected_types: type | Iterable[type] | None = None, unique_name: Optional[str] = None, - **kwargs, + display_name: Optional[str] = None, + interface: Any = None, + data: Optional[Iterable[Any]] = None, + **named_items: Any, ): - """ - Set up the base collection class. - - :param name: Name of this object - :type name: str - :param args: selection of - :param _kwargs: Fields which this class should contain - :type _kwargs: dict - """ - BasedBase.__init__(self, name, unique_name=unique_name) - kwargs = {key: kwargs[key] for key in kwargs.keys() if kwargs[key] is not None} - _args = [] - for item in args: - if not isinstance(item, list): - _args.append(item) - else: - _args += item - _kwargs = {} - for key, item in kwargs.items(): - if isinstance(item, list) and len(item) > 0: - _args += item - else: - _kwargs[key] = item - kwargs = _kwargs - for item in list(kwargs.values()) + _args: - if not issubclass(type(item), (DescriptorBase, BasedBase, NewBase)): - raise AttributeError('A collection can only be formed from easyscience objects.') - args = _args - _kwargs = {} - for key, item in kwargs.items(): - _kwargs[key] = item - for arg in args: - kwargs[arg.unique_name] = arg - _kwargs[arg.unique_name] = arg - - # Set kwargs, also useful for serialization - self._kwargs = NotarizedDict(**_kwargs) - - for key in kwargs.keys(): - if key in self.__dict__.keys() or key in self.__slots__: - raise AttributeError(f'Given kwarg: `{key}`, is an internal attribute. Please rename.') - if kwargs[key]: # Might be None (empty tuple or list) - self._global_object.map.add_edge(self, kwargs[key]) - self._global_object.map.reset_type(kwargs[key], 'created_internal') - if interface is not None: - kwargs[key].interface = interface - # TODO wrap getter and setter in Logger + if items and isinstance(items[0], str) and name is None: + name = items[0] + items = items[1:] + + if display_name is None and name is not None: + display_name = name + + super().__init__(unique_name=unique_name, display_name=display_name) + if interface is not None: - self.interface = interface - self._kwargs._stack_enabled = True - - def insert(self, index: int, value: Union[DescriptorBase, BasedBase, NewBase]) -> None: - """ - Insert an object into the collection at an index. - - :param index: Index for EasyScience object to be inserted. - :type index: int - :param value: Object to be inserted. - :type value: Union[BasedBase, DescriptorBase, NewBase] - :return: None - :rtype: None - """ - t_ = type(value) - if issubclass(t_, (BasedBase, DescriptorBase, NewBase)): - update_key = list(self._kwargs.keys()) - values = list(self._kwargs.values()) - # Update the internal dict - new_key = value.unique_name - update_key.insert(index, new_key) - values.insert(index, value) - self._kwargs.reorder(**{k: v for k, v in zip(update_key, values)}) - # ADD EDGE - self._global_object.map.add_edge(self, value) - self._global_object.map.reset_type(value, 'created_internal') - value.interface = self.interface - else: - raise AttributeError('Only EasyScience objects can be put into an EasyScience group') - - def __getitem__(self, idx: Union[int, slice]) -> Union[DescriptorBase, BasedBase, NewBase]: - """ - Get an item in the collection based on its index. - - :param idx: index or slice of the collection. - :type idx: Union[int, slice] - :return: Object at index `idx` - :rtype: Union[Parameter, Descriptor, ObjBase, 'CollectionBase'] - """ + raise AttributeError('Given kwarg: `interface`, is an internal attribute. Please rename.') + + self._protected_types = self._normalize_protected_types(protected_types) + self._name = name if name is not None else self.display_name + self.user_data: dict[str, Any] = {} + self.interface = None + + normalized_named_items = self._normalize_named_items(named_items) + all_items = self._collect_items(items, data=data, named_items=normalized_named_items) + for item in all_items: + try: + self._validate_item(item) + except TypeError as exc: + raise AttributeError('A collection can only be formed from easyscience objects.') from exc + if item in self: + warnings.warn( + f'Item with unique name "{self._get_key(item)}" already in CollectionBase, it will be ignored' + ) + continue + self._data.append(item) + + @property + def name(self) -> str: + return self._name + + @name.setter + def name(self, new_name: str) -> None: + if not isinstance(new_name, str): + raise TypeError('Name must be a string') + self._name = new_name + self.display_name = new_name + + # --- Minimal overrides (compatibility shims) --- + + def __getitem__(self, idx: int | slice | str) -> Any: + if isinstance(idx, bool): + raise TypeError('Boolean indexing is not supported at the moment') if isinstance(idx, slice): - start, stop, step = idx.indices(len(self)) - return self.__class__(getattr(self, 'name'), *[self[i] for i in range(start, stop, step)]) - if str(idx) in self._kwargs.keys(): - return self._kwargs[str(idx)] + return self._clone_with_items(self._data[idx]) if isinstance(idx, str): - idx = [index for index, item in enumerate(self) if item.name == idx] - noi = len(idx) - if noi == 0: - raise IndexError('Given index does not exist') - elif noi == 1: - idx = idx[0] - else: - return self.__class__(getattr(self, 'name'), *[self[i] for i in idx]) - elif not isinstance(idx, int) or isinstance(idx, bool): - if isinstance(idx, bool): - raise TypeError('Boolean indexing is not supported at the moment') try: - if idx > len(self): - raise IndexError(f'Given index {idx} is out of bounds') - except TypeError: - raise IndexError('Index must be of type `int`/`slice` or an item name (`str`)') - keys = list(self._kwargs.keys()) - return self._kwargs[keys[idx]] - - def __setitem__(self, key: int, value: Union[BasedBase, DescriptorBase, NewBase]) -> None: - """ - Set an item via it's index. - - :param key: Index in self. - :type key: int - :param value: Value which index key should be set to. - :type value: Any - """ - if isinstance(value, Number): # noqa: S3827 - item = self.__getitem__(key) + return super().__getitem__(idx) + except KeyError: + pass + name_matches = [item for item in self._data if getattr(item, 'name', None) == idx] + if len(name_matches) == 1: + return name_matches[0] + if len(name_matches) > 1: + return self._clone_with_items(name_matches) + raise KeyError(f'No item with key or name "{idx}" found') + return super().__getitem__(idx) + + def __setitem__(self, idx: int | slice, value: Any) -> None: + if isinstance(idx, int) and isinstance(value, Number): + item = self[idx] + if not hasattr(item, 'value'): + raise NotImplementedError('At the moment only numerical values or EasyScience objects can be set.') item.value = value - elif issubclass(type(value), (BasedBase, DescriptorBase, NewBase)): - update_key = list(self._kwargs.keys()) - values = list(self._kwargs.values()) - old_item = values[key] - # Update the internal dict - update_dict = {update_key[key]: value} - self._kwargs.update(update_dict) - # ADD EDGE - self._global_object.map.add_edge(self, value) - self._global_object.map.reset_type(value, 'created_internal') - value.interface = self.interface - # REMOVE EDGE - self._global_object.map.prune_vertex_from_edge(self, old_item) - else: - raise NotImplementedError('At the moment only numerical values or EasyScience objects can be set.') - - def __delitem__(self, key: int) -> None: - """ - Try to delete an idem by key. - - :param key: - :type key: - :return: - :rtype: - """ - keys = list(self._kwargs.keys()) - item = self._kwargs[keys[key]] - self._global_object.map.prune_vertex_from_edge(self, item) - del self._kwargs[keys[key]] - - def __len__(self) -> int: - """ - Get the number of items in this collection - - :return: Number of items in this collection. - :rtype: int - """ - return len(self._kwargs.keys()) - - def _convert_to_dict(self, in_dict, encoder, skip: List[str] = [], **kwargs) -> dict: - """ - Convert ones self into a serialized form. - - :return: dictionary of ones self - :rtype: dict - """ - d = {} - if hasattr(self, '_modify_dict'): - # any extra keys defined on the inheriting class - d = self._modify_dict(skip=skip, **kwargs) - in_dict['data'] = [encoder._convert_to_dict(item, skip=skip, **kwargs) for item in self] - out_dict = {**in_dict, **d} - return out_dict + return + try: + super().__setitem__(idx, value) + except TypeError as exc: + raise NotImplementedError('At the moment only numerical values or EasyScience objects can be set.') from exc - @property - def data(self) -> Tuple: - """ - The data function returns a tuple of the keyword arguments passed to the - constructor. This is useful for when you need to pass in a dictionary of data - to other functions, such as with matplotlib's plot function. - - :param self: Access attributes of the class within the method - :return: The values of the attributes in a tuple - :doc-author: Trelent - """ - return tuple(self._kwargs.values()) + def insert(self, index: int, value: Any) -> None: + try: + super().insert(index, value) + except TypeError as exc: + raise AttributeError('Only EasyScience objects can be put into an EasyScience group') from exc def __repr__(self) -> str: - return f'{self.__class__.__name__} `{getattr(self, "name")}` of length {len(self)}' + return f'{self.__class__.__name__} `{self.name}` of length {len(self)}' + + def sort(self, key=None, reverse: bool = False, mapping=None) -> None: + if mapping is not None: + if key is not None: + raise TypeError('Use either key or mapping, not both') + warnings.warn('sort(mapping=...) is deprecated; use sort(key=...) instead', DeprecationWarning) + key = mapping + super().sort(key=key, reverse=reverse) + + # --- Parameter/variable aggregation --- + + def get_all_variables(self) -> list[DescriptorBase]: + variables: list[DescriptorBase] = [] + for item in self._data: + if isinstance(item, DescriptorBase): + variables.append(item) + elif hasattr(item, 'get_all_variables'): + variables.extend(item.get_all_variables()) + return variables + + def get_all_parameters(self) -> list[Parameter]: + parameters: list[Parameter] = [] + seen = set() + for item in self._data: + if isinstance(item, Parameter): + parameters.append(item) + seen.add(id(item)) + continue + if hasattr(item, 'get_all_parameters'): + for parameter in item.get_all_parameters(): + if id(parameter) not in seen: + parameters.append(parameter) + seen.add(id(parameter)) + continue + if hasattr(item, 'get_parameters'): + for parameter in item.get_parameters(): + if id(parameter) not in seen: + parameters.append(parameter) + seen.add(id(parameter)) + continue + if hasattr(item, 'get_all_variables'): + for variable in item.get_all_variables(): + if isinstance(variable, Parameter) and id(variable) not in seen: + parameters.append(variable) + seen.add(id(variable)) + return parameters + + def get_parameters(self) -> list[Parameter]: + return self.get_all_parameters() + + def get_fittable_parameters(self) -> list[Parameter]: + return [parameter for parameter in self.get_all_parameters() if parameter.independent] + + def get_free_parameters(self) -> list[Parameter]: + return [parameter for parameter in self.get_fittable_parameters() if not parameter.fixed] + + def get_fit_parameters(self) -> list[Parameter]: + return self.get_free_parameters() + + @property + def data(self) -> tuple[Any, ...]: + return tuple(self._data) - def sort( + # --- Serialization --- + + def as_dict(self, skip: Optional[list[str]] = None) -> dict[str, Any]: + if skip is None: + skip = [] + if 'unique_name' not in skip: + skip = [*skip, 'unique_name'] + return self.to_dict(skip=skip) + + def encode(self, skip: Optional[list[str]] = None, encoder=None, **kwargs: Any) -> Any: + if encoder is None: + encoder = SerializerDict + return encoder().encode(self, skip=skip, **kwargs) + + @classmethod + def decode(cls, obj: Any, decoder=None) -> Any: + if decoder is None or decoder is SerializerDict: + return cls.from_dict(obj) + return decoder.decode(obj) + + def to_dict(self, skip: Optional[list[str]] = None) -> dict[str, Any]: + if skip is None: + skip = [] + + try: + parent_module = self.__module__.split('.')[0] + module_version = import_module(parent_module).__version__ + except (AttributeError, ImportError): + module_version = None + + dict_repr: dict[str, Any] = { + '@module': self.__module__, + '@class': self.__class__.__name__, + '@version': module_version, + } + + if 'name' not in skip: + dict_repr['name'] = self.name + if 'display_name' not in skip and self._display_name is not None and self._display_name != self.name: + dict_repr['display_name'] = self._display_name + if 'unique_name' not in skip: + dict_repr['unique_name'] = self.unique_name + if self._protected_types != list(self._DEFAULT_PROTECTED_TYPES) and 'protected_types' not in skip: + dict_repr['protected_types'] = [ + {'@module': cls_.__module__, '@class': cls_.__name__} for cls_ in self._protected_types + ] + dict_repr['data'] = [self._serialize_item(item, skip=skip) for item in self._data] + return dict_repr + + @classmethod + def from_dict(cls, obj_dict: dict[str, Any]) -> CollectionBase: + if not isinstance(obj_dict, dict) or '@class' not in obj_dict or '@module' not in obj_dict: + raise ValueError('Input must be a dictionary representing an EasyScience CollectionBase object.') + accepted_names = {base.__name__ for base in cls.__mro__ if issubclass(base, CollectionBase)} + if obj_dict['@class'] not in accepted_names: + raise ValueError(f'Class name in dictionary does not match the expected class: {cls.__name__}.') + + temp_dict = copy.deepcopy(obj_dict) + protected_types = temp_dict.pop('protected_types', None) + if protected_types is not None: + protected_types = cls._deserialize_protected_types(protected_types) + + raw_data = temp_dict.pop('data', []) + kwargs = SerializerBase.deserialize_dict(temp_dict) + data = [cls._deserialize_item(item) for item in raw_data] + name = kwargs.pop('name', None) + kwargs.pop('unique_name', None) + return cls(*data, name=name, protected_types=protected_types, **kwargs) + + def _convert_to_dict( + self, + in_dict: dict[str, Any], + encoder: Any, + skip: Optional[list[str]] = None, + **kwargs: Any, + ) -> dict[str, Any]: + if skip is None: + skip = [] + if 'name' not in skip: + in_dict['name'] = self.name + if self._display_name is not None and self._display_name != self.name and 'display_name' not in skip: + in_dict['display_name'] = self._display_name + in_dict['data'] = [encoder._convert_to_dict(item, skip=skip, **kwargs) for item in self._data] + return in_dict + + @staticmethod + def _deserialize_protected_types(protected_types: list[dict[str, str]]) -> list[type]: + deserialized_types: list[type] = [] + for type_dict in protected_types: + if '@module' not in type_dict or '@class' not in type_dict: + raise ValueError('Each protected type must contain @module and @class keys') + module = __import__(type_dict['@module'], globals(), locals(), [type_dict['@class']], 0) + deserialized_types.append(getattr(module, type_dict['@class'])) + return deserialized_types + + def _clone_with_items(self, items: Iterable[Any]) -> CollectionBase: + return self.__class__( + *list(items), + name=self.name, + protected_types=list(self._protected_types), + display_name=self._display_name, + ) + + # --- Compatibility surface --- + + def __dir__(self) -> Iterable[str]: + hidden = { + 'display_name', + 'get_all_parameters', + 'get_all_variables', + 'get_fittable_parameters', + 'get_free_parameters', + 'to_dict', + } + legacy = { + 'append', + 'as_dict', + 'clear', + 'constraints', + 'count', + 'data', + 'decode', + 'encode', + 'extend', + 'from_dict', + 'generate_bindings', + 'get_fit_parameters', + 'get_parameters', + 'index', + 'insert', + 'interface', + 'name', + 'pop', + 'remove', + 'reverse', + 'sort', + 'switch_interface', + 'unique_name', + 'user_data', + } + public_names = {name for name in dir(self.__class__) if not name.startswith('_')} + return sorted((public_names | legacy) - hidden) + + @property + def constraints(self) -> list[Any]: + return [] + + def generate_bindings(self) -> None: + if self.interface is None: + raise AttributeError('Interface error for generating bindings. `interface` has to be set.') + + def switch_interface(self, new_interface_name: str) -> None: + if self.interface is None: + raise AttributeError('Interface error for generating bindings. `interface` has to be set.') + + # --- Internal helpers --- + + def _normalize_named_items(self, named_items: dict[str, Any]) -> dict[str, Any]: + normalized: dict[str, Any] = {} + for key, item in named_items.items(): + if key in self._RESERVED_NAMED_KEYS: + raise AttributeError(f'Given kwarg: `{key}`, is an internal attribute. Please rename.') + if item is None: + continue + normalized[key] = item + return normalized + + def _collect_items( self, - mapping: Callable[[Union[BasedBase, DescriptorBase, NewBase]], Any], - reverse: bool = False, - ) -> None: - """ - Sort the collection according to the given mapping. - - :param mapping: mapping function to sort the collection. i.e. lambda parameter: parameter.value - :type mapping: Callable - :param reverse: Reverse the sorting. - :type reverse: bool - """ - i = list(self._kwargs.items()) - i.sort(key=lambda x: mapping(x[1]), reverse=reverse) - self._kwargs.reorder(**{k[0]: k[1] for k in i}) + items: tuple[Any, ...], + data: Optional[Iterable[Any]] = None, + named_items: Optional[dict[str, Any]] = None, + ) -> list[Any]: + collected: list[Any] = [] + for item in items: + if isinstance(item, list): + collected.extend(item) + else: + collected.append(item) + if data is not None: + collected.extend(data) + if named_items is not None: + for item in named_items.values(): + if isinstance(item, list) and len(item) > 0: + collected.extend(item) + else: + collected.append(item) + return collected + + def _normalize_protected_types(self, protected_types: type | Iterable[type] | None) -> list[type]: + if protected_types is None: + return list(self._DEFAULT_PROTECTED_TYPES) + if isinstance(protected_types, type): + return [protected_types] + if isinstance(protected_types, Iterable): + normalized = list(protected_types) + if all(isinstance(item, type) for item in normalized): + return normalized + raise TypeError('protected_types must be a type or an iterable of types') + + def _serialize_item(self, item: Any, skip: Optional[list[str]] = None) -> dict[str, Any]: + if hasattr(item, 'to_dict'): + return item.to_dict() + if hasattr(item, 'as_dict'): + return item.as_dict(skip=skip) + raise TypeError(f'Unable to serialize item of type {type(item)}') + + @staticmethod + def _deserialize_item(item: Any) -> Any: + if not SerializerBase._is_serialized_easyscience_object(item): + return SerializerBase._deserialize_value(item) + + normalized_item = copy.deepcopy(item) + normalized_item.pop('unique_name', None) + return SerializerBase._deserialize_value(normalized_item) + + def _validate_item(self, item: Any) -> None: + if not isinstance(item, tuple(self._protected_types)): + raise TypeError(f'Items must be one of {self._protected_types}, got {type(item)}') diff --git a/tests/unit_tests/base_classes/test_collection_base.py b/tests/unit_tests/base_classes/test_collection_base.py index cbd91933..6c8eacfb 100644 --- a/tests/unit_tests/base_classes/test_collection_base.py +++ b/tests/unit_tests/base_classes/test_collection_base.py @@ -146,7 +146,7 @@ def test_CollectionBase_append_fail(cls, setup_pars, value): @pytest.mark.parametrize("cls", class_constructors) -@pytest.mark.parametrize("value", (0, 1, 3, "par1", "des1")) +@pytest.mark.parametrize("value", (0, 1, 3, "p1", "d1")) def test_CollectionBase_getItem(cls, setup_pars, value): name = setup_pars["name"] del setup_pars["name"] @@ -155,10 +155,10 @@ def test_CollectionBase_getItem(cls, setup_pars, value): get_item = coll[value] if isinstance(value, str): - key = value + assert get_item.name == value else: key = list(setup_pars.keys())[value] - assert get_item.name == setup_pars[key].name + assert get_item.name == setup_pars[key].name @pytest.mark.parametrize("cls", class_constructors) @@ -325,7 +325,7 @@ def test_CollectionBase_dir(cls): "decode", "sort", } - assert not d.difference(expected) + assert expected.issubset(d) @pytest.mark.parametrize("cls", class_constructors) @@ -469,11 +469,8 @@ def test_CollectionBase_set_index(cls): assert obj[idx] == p2 obj[idx] = p4 assert obj[idx] == p4 - edges = obj._global_object.map.get_edges(obj) - assert len(edges) == len(obj) - for item in obj: - assert item.unique_name in edges - assert p2.unique_name not in edges + assert len(obj) == len(l_object) + assert p2 not in obj.data @pytest.mark.parametrize("cls", class_constructors) @@ -493,11 +490,8 @@ def test_CollectionBase_set_index_based(cls): assert obj[idx] == p4 obj[idx] = d assert obj[idx] == d - edges = obj._global_object.map.get_edges(obj) - assert len(edges) == len(obj) - for item in obj: - assert item.unique_name in edges - assert p4.unique_name not in edges + assert len(obj) == len(l_object) + assert p4 not in obj.data @pytest.mark.parametrize("cls", class_constructors) @@ -529,19 +523,10 @@ class Beta(ObjBase): @pytest.mark.parametrize("cls", class_constructors) def test_CollectionBaseGraph(cls): - from easyscience import global_object - - G = global_object.map name = "test" v = [1, 2] p = [Parameter(f"p{i}", v[i]) for i in range(len(v))] - p_id = [_p.unique_name for _p in p] bb = cls(name, *p) - bb_id = bb.unique_name b = Beta("b", bb=bb) - b_id = b.unique_name - for _id in p_id: - assert _id in G.get_edges(bb) - assert len(p) == len(G.get_edges(bb)) - assert bb_id in G.get_edges(b) - assert 1 == len(G.get_edges(b)) + assert b.bb is bb + assert list(bb) == p diff --git a/tests/unit_tests/base_classes/test_collection_base_easylist.py b/tests/unit_tests/base_classes/test_collection_base_easylist.py new file mode 100644 index 00000000..8258b203 --- /dev/null +++ b/tests/unit_tests/base_classes/test_collection_base_easylist.py @@ -0,0 +1,93 @@ +# SPDX-FileCopyrightText: 2025 EasyScience contributors +# SPDX-License-Identifier: BSD-3-Clause +# © 2021-2025 Contributors to the EasyScience project Date: Mon, 9 Mar 2026 15:25:45 +0100 Subject: [PATCH 2/7] ruff --- src/easyscience/base_classes/collection_base.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/easyscience/base_classes/collection_base.py b/src/easyscience/base_classes/collection_base.py index ea2f666c..f27ea82b 100644 --- a/src/easyscience/base_classes/collection_base.py +++ b/src/easyscience/base_classes/collection_base.py @@ -80,9 +80,7 @@ def __init__( except TypeError as exc: raise AttributeError('A collection can only be formed from easyscience objects.') from exc if item in self: - warnings.warn( - f'Item with unique name "{self._get_key(item)}" already in CollectionBase, it will be ignored' - ) + warnings.warn(f'Item with unique name "{self._get_key(item)}" already in CollectionBase, it will be ignored') continue self._data.append(item) From e8d69b93b6baa2ab06362a83fe3f7c0caabb9519 Mon Sep 17 00:00:00 2001 From: rozyczko Date: Tue, 10 Mar 2026 09:15:59 +0100 Subject: [PATCH 3/7] make the new collection class a separate file for transition --- src/easyscience/base_classes/__init__.py | 3 +- .../base_classes/collection_base.py | 607 +++++++----------- .../base_classes/collection_base_easylist.py | 419 ++++++++++++ .../base_classes/test_collection_base.py | 36 +- .../test_collection_base_easylist.py | 2 +- 5 files changed, 668 insertions(+), 399 deletions(-) create mode 100644 src/easyscience/base_classes/collection_base_easylist.py diff --git a/src/easyscience/base_classes/__init__.py b/src/easyscience/base_classes/__init__.py index 878d869e..9d65106d 100644 --- a/src/easyscience/base_classes/__init__.py +++ b/src/easyscience/base_classes/__init__.py @@ -1,8 +1,9 @@ from .based_base import BasedBase +from .collection_base_easylist import CollectionBase as CollectionBaseEasyList from .collection_base import CollectionBase from .easy_list import EasyList from .model_base import ModelBase from .new_base import NewBase from .obj_base import ObjBase -__all__ = [BasedBase, CollectionBase, ObjBase, ModelBase, NewBase, EasyList] +__all__ = [BasedBase, CollectionBase, CollectionBaseEasyList, ObjBase, ModelBase, NewBase, EasyList] diff --git a/src/easyscience/base_classes/collection_base.py b/src/easyscience/base_classes/collection_base.py index f27ea82b..0a4e8272 100644 --- a/src/easyscience/base_classes/collection_base.py +++ b/src/easyscience/base_classes/collection_base.py @@ -4,416 +4,249 @@ from __future__ import annotations -import copy -import warnings -from collections.abc import Iterable -from importlib import import_module +from collections.abc import MutableSequence from numbers import Number +from typing import TYPE_CHECKING from typing import Any +from typing import Callable +from typing import List from typing import Optional +from typing import Tuple +from typing import Union -from easyscience.io.serializer_base import SerializerBase -from easyscience.io.serializer_dict import SerializerDict +from easyscience.base_classes.new_base import NewBase +from easyscience.global_object.undo_redo import NotarizedDict from ..variable.descriptor_base import DescriptorBase -from ..variable.parameter import Parameter from .based_base import BasedBase -from .easy_list import EasyList -from .new_base import NewBase +if TYPE_CHECKING: + from ..fitting.calculators import InterfaceFactoryTemplate -class CollectionBase(EasyList): - """ - EasyList-backed collection with a small compatibility layer for migration. - The collection delegates storage and MutableSequence behavior to EasyList, - adding only scientific-parameter aggregation methods and a thin compatibility - layer for existing callers. +class CollectionBase(BasedBase, MutableSequence): + """ + This is the base class for which all higher level classes are built off of. + NOTE: This object is serializable only if parameters are supplied as: + `ObjBase(a=value, b=value)`. For `Parameter` or `Descriptor` objects we can + cheat with `ObjBase(*[Descriptor(...), Parameter(...), ...])`. """ - - _DEFAULT_PROTECTED_TYPES = (DescriptorBase, BasedBase, NewBase) - _RESERVED_NAMED_KEYS = { - 'data', - 'display_name', - 'interface', - 'name', - 'protected_types', - 'unique_name', - 'user_data', - '_kwargs', - } - _REDIRECT = {'interface': None} def __init__( self, - *items: Any, - name: Optional[str] = None, - protected_types: type | Iterable[type] | None = None, + name: str, + *args: Union[BasedBase, DescriptorBase, NewBase], + interface: Optional[InterfaceFactoryTemplate] = None, unique_name: Optional[str] = None, - display_name: Optional[str] = None, - interface: Any = None, - data: Optional[Iterable[Any]] = None, - **named_items: Any, + **kwargs, ): - if items and isinstance(items[0], str) and name is None: - name = items[0] - items = items[1:] - - if display_name is None and name is not None: - display_name = name - - super().__init__(unique_name=unique_name, display_name=display_name) - + """ + Set up the base collection class. + + :param name: Name of this object + :type name: str + :param args: selection of + :param _kwargs: Fields which this class should contain + :type _kwargs: dict + """ + BasedBase.__init__(self, name, unique_name=unique_name) + kwargs = {key: kwargs[key] for key in kwargs.keys() if kwargs[key] is not None} + _args = [] + for item in args: + if not isinstance(item, list): + _args.append(item) + else: + _args += item + _kwargs = {} + for key, item in kwargs.items(): + if isinstance(item, list) and len(item) > 0: + _args += item + else: + _kwargs[key] = item + kwargs = _kwargs + for item in list(kwargs.values()) + _args: + if not issubclass(type(item), (DescriptorBase, BasedBase, NewBase)): + raise AttributeError('A collection can only be formed from easyscience objects.') + args = _args + _kwargs = {} + for key, item in kwargs.items(): + _kwargs[key] = item + for arg in args: + kwargs[arg.unique_name] = arg + _kwargs[arg.unique_name] = arg + + # Set kwargs, also useful for serialization + self._kwargs = NotarizedDict(**_kwargs) + + for key in kwargs.keys(): + if key in self.__dict__.keys() or key in self.__slots__: + raise AttributeError(f'Given kwarg: `{key}`, is an internal attribute. Please rename.') + if kwargs[key]: # Might be None (empty tuple or list) + self._global_object.map.add_edge(self, kwargs[key]) + self._global_object.map.reset_type(kwargs[key], 'created_internal') + if interface is not None: + kwargs[key].interface = interface + # TODO wrap getter and setter in Logger if interface is not None: - raise AttributeError('Given kwarg: `interface`, is an internal attribute. Please rename.') - - self._protected_types = self._normalize_protected_types(protected_types) - self._name = name if name is not None else self.display_name - self.user_data: dict[str, Any] = {} - self.interface = None - - normalized_named_items = self._normalize_named_items(named_items) - all_items = self._collect_items(items, data=data, named_items=normalized_named_items) - for item in all_items: - try: - self._validate_item(item) - except TypeError as exc: - raise AttributeError('A collection can only be formed from easyscience objects.') from exc - if item in self: - warnings.warn(f'Item with unique name "{self._get_key(item)}" already in CollectionBase, it will be ignored') - continue - self._data.append(item) - - @property - def name(self) -> str: - return self._name - - @name.setter - def name(self, new_name: str) -> None: - if not isinstance(new_name, str): - raise TypeError('Name must be a string') - self._name = new_name - self.display_name = new_name - - # --- Minimal overrides (compatibility shims) --- - - def __getitem__(self, idx: int | slice | str) -> Any: - if isinstance(idx, bool): - raise TypeError('Boolean indexing is not supported at the moment') + self.interface = interface + self._kwargs._stack_enabled = True + + def insert(self, index: int, value: Union[DescriptorBase, BasedBase, NewBase]) -> None: + """ + Insert an object into the collection at an index. + + :param index: Index for EasyScience object to be inserted. + :type index: int + :param value: Object to be inserted. + :type value: Union[BasedBase, DescriptorBase, NewBase] + :return: None + :rtype: None + """ + t_ = type(value) + if issubclass(t_, (BasedBase, DescriptorBase, NewBase)): + update_key = list(self._kwargs.keys()) + values = list(self._kwargs.values()) + # Update the internal dict + new_key = value.unique_name + update_key.insert(index, new_key) + values.insert(index, value) + self._kwargs.reorder(**{k: v for k, v in zip(update_key, values)}) + # ADD EDGE + self._global_object.map.add_edge(self, value) + self._global_object.map.reset_type(value, 'created_internal') + value.interface = self.interface + else: + raise AttributeError('Only EasyScience objects can be put into an EasyScience group') + + def __getitem__(self, idx: Union[int, slice]) -> Union[DescriptorBase, BasedBase, NewBase]: + """ + Get an item in the collection based on its index. + + :param idx: index or slice of the collection. + :type idx: Union[int, slice] + :return: Object at index `idx` + :rtype: Union[Parameter, Descriptor, ObjBase, 'CollectionBase'] + """ if isinstance(idx, slice): - return self._clone_with_items(self._data[idx]) + start, stop, step = idx.indices(len(self)) + return self.__class__(getattr(self, 'name'), *[self[i] for i in range(start, stop, step)]) + if str(idx) in self._kwargs.keys(): + return self._kwargs[str(idx)] if isinstance(idx, str): + idx = [index for index, item in enumerate(self) if item.name == idx] + noi = len(idx) + if noi == 0: + raise IndexError('Given index does not exist') + elif noi == 1: + idx = idx[0] + else: + return self.__class__(getattr(self, 'name'), *[self[i] for i in idx]) + elif not isinstance(idx, int) or isinstance(idx, bool): + if isinstance(idx, bool): + raise TypeError('Boolean indexing is not supported at the moment') try: - return super().__getitem__(idx) - except KeyError: - pass - name_matches = [item for item in self._data if getattr(item, 'name', None) == idx] - if len(name_matches) == 1: - return name_matches[0] - if len(name_matches) > 1: - return self._clone_with_items(name_matches) - raise KeyError(f'No item with key or name "{idx}" found') - return super().__getitem__(idx) - - def __setitem__(self, idx: int | slice, value: Any) -> None: - if isinstance(idx, int) and isinstance(value, Number): - item = self[idx] - if not hasattr(item, 'value'): - raise NotImplementedError('At the moment only numerical values or EasyScience objects can be set.') + if idx > len(self): + raise IndexError(f'Given index {idx} is out of bounds') + except TypeError: + raise IndexError('Index must be of type `int`/`slice` or an item name (`str`)') + keys = list(self._kwargs.keys()) + return self._kwargs[keys[idx]] + + def __setitem__(self, key: int, value: Union[BasedBase, DescriptorBase, NewBase]) -> None: + """ + Set an item via it's index. + + :param key: Index in self. + :type key: int + :param value: Value which index key should be set to. + :type value: Any + """ + if isinstance(value, Number): # noqa: S3827 + item = self.__getitem__(key) item.value = value - return - try: - super().__setitem__(idx, value) - except TypeError as exc: - raise NotImplementedError('At the moment only numerical values or EasyScience objects can be set.') from exc - - def insert(self, index: int, value: Any) -> None: - try: - super().insert(index, value) - except TypeError as exc: - raise AttributeError('Only EasyScience objects can be put into an EasyScience group') from exc - - def __repr__(self) -> str: - return f'{self.__class__.__name__} `{self.name}` of length {len(self)}' - - def sort(self, key=None, reverse: bool = False, mapping=None) -> None: - if mapping is not None: - if key is not None: - raise TypeError('Use either key or mapping, not both') - warnings.warn('sort(mapping=...) is deprecated; use sort(key=...) instead', DeprecationWarning) - key = mapping - super().sort(key=key, reverse=reverse) - - # --- Parameter/variable aggregation --- - - def get_all_variables(self) -> list[DescriptorBase]: - variables: list[DescriptorBase] = [] - for item in self._data: - if isinstance(item, DescriptorBase): - variables.append(item) - elif hasattr(item, 'get_all_variables'): - variables.extend(item.get_all_variables()) - return variables - - def get_all_parameters(self) -> list[Parameter]: - parameters: list[Parameter] = [] - seen = set() - for item in self._data: - if isinstance(item, Parameter): - parameters.append(item) - seen.add(id(item)) - continue - if hasattr(item, 'get_all_parameters'): - for parameter in item.get_all_parameters(): - if id(parameter) not in seen: - parameters.append(parameter) - seen.add(id(parameter)) - continue - if hasattr(item, 'get_parameters'): - for parameter in item.get_parameters(): - if id(parameter) not in seen: - parameters.append(parameter) - seen.add(id(parameter)) - continue - if hasattr(item, 'get_all_variables'): - for variable in item.get_all_variables(): - if isinstance(variable, Parameter) and id(variable) not in seen: - parameters.append(variable) - seen.add(id(variable)) - return parameters - - def get_parameters(self) -> list[Parameter]: - return self.get_all_parameters() - - def get_fittable_parameters(self) -> list[Parameter]: - return [parameter for parameter in self.get_all_parameters() if parameter.independent] - - def get_free_parameters(self) -> list[Parameter]: - return [parameter for parameter in self.get_fittable_parameters() if not parameter.fixed] - - def get_fit_parameters(self) -> list[Parameter]: - return self.get_free_parameters() + elif issubclass(type(value), (BasedBase, DescriptorBase, NewBase)): + update_key = list(self._kwargs.keys()) + values = list(self._kwargs.values()) + old_item = values[key] + # Update the internal dict + update_dict = {update_key[key]: value} + self._kwargs.update(update_dict) + # ADD EDGE + self._global_object.map.add_edge(self, value) + self._global_object.map.reset_type(value, 'created_internal') + value.interface = self.interface + # REMOVE EDGE + self._global_object.map.prune_vertex_from_edge(self, old_item) + else: + raise NotImplementedError('At the moment only numerical values or EasyScience objects can be set.') + + def __delitem__(self, key: int) -> None: + """ + Try to delete an idem by key. + + :param key: + :type key: + :return: + :rtype: + """ + keys = list(self._kwargs.keys()) + item = self._kwargs[keys[key]] + self._global_object.map.prune_vertex_from_edge(self, item) + del self._kwargs[keys[key]] + + def __len__(self) -> int: + """ + Get the number of items in this collection + + :return: Number of items in this collection. + :rtype: int + """ + return len(self._kwargs.keys()) + + def _convert_to_dict(self, in_dict, encoder, skip: List[str] = [], **kwargs) -> dict: + """ + Convert ones self into a serialized form. + + :return: dictionary of ones self + :rtype: dict + """ + d = {} + if hasattr(self, '_modify_dict'): + # any extra keys defined on the inheriting class + d = self._modify_dict(skip=skip, **kwargs) + in_dict['data'] = [encoder._convert_to_dict(item, skip=skip, **kwargs) for item in self] + out_dict = {**in_dict, **d} + return out_dict @property - def data(self) -> tuple[Any, ...]: - return tuple(self._data) + def data(self) -> Tuple: + """ + The data function returns a tuple of the keyword arguments passed to the + constructor. This is useful for when you need to pass in a dictionary of data + to other functions, such as with matplotlib's plot function. + + :param self: Access attributes of the class within the method + :return: The values of the attributes in a tuple + :doc-author: Trelent + """ + return tuple(self._kwargs.values()) - # --- Serialization --- - - def as_dict(self, skip: Optional[list[str]] = None) -> dict[str, Any]: - if skip is None: - skip = [] - if 'unique_name' not in skip: - skip = [*skip, 'unique_name'] - return self.to_dict(skip=skip) - - def encode(self, skip: Optional[list[str]] = None, encoder=None, **kwargs: Any) -> Any: - if encoder is None: - encoder = SerializerDict - return encoder().encode(self, skip=skip, **kwargs) - - @classmethod - def decode(cls, obj: Any, decoder=None) -> Any: - if decoder is None or decoder is SerializerDict: - return cls.from_dict(obj) - return decoder.decode(obj) - - def to_dict(self, skip: Optional[list[str]] = None) -> dict[str, Any]: - if skip is None: - skip = [] - - try: - parent_module = self.__module__.split('.')[0] - module_version = import_module(parent_module).__version__ - except (AttributeError, ImportError): - module_version = None - - dict_repr: dict[str, Any] = { - '@module': self.__module__, - '@class': self.__class__.__name__, - '@version': module_version, - } - - if 'name' not in skip: - dict_repr['name'] = self.name - if 'display_name' not in skip and self._display_name is not None and self._display_name != self.name: - dict_repr['display_name'] = self._display_name - if 'unique_name' not in skip: - dict_repr['unique_name'] = self.unique_name - if self._protected_types != list(self._DEFAULT_PROTECTED_TYPES) and 'protected_types' not in skip: - dict_repr['protected_types'] = [ - {'@module': cls_.__module__, '@class': cls_.__name__} for cls_ in self._protected_types - ] - dict_repr['data'] = [self._serialize_item(item, skip=skip) for item in self._data] - return dict_repr - - @classmethod - def from_dict(cls, obj_dict: dict[str, Any]) -> CollectionBase: - if not isinstance(obj_dict, dict) or '@class' not in obj_dict or '@module' not in obj_dict: - raise ValueError('Input must be a dictionary representing an EasyScience CollectionBase object.') - accepted_names = {base.__name__ for base in cls.__mro__ if issubclass(base, CollectionBase)} - if obj_dict['@class'] not in accepted_names: - raise ValueError(f'Class name in dictionary does not match the expected class: {cls.__name__}.') - - temp_dict = copy.deepcopy(obj_dict) - protected_types = temp_dict.pop('protected_types', None) - if protected_types is not None: - protected_types = cls._deserialize_protected_types(protected_types) - - raw_data = temp_dict.pop('data', []) - kwargs = SerializerBase.deserialize_dict(temp_dict) - data = [cls._deserialize_item(item) for item in raw_data] - name = kwargs.pop('name', None) - kwargs.pop('unique_name', None) - return cls(*data, name=name, protected_types=protected_types, **kwargs) - - def _convert_to_dict( - self, - in_dict: dict[str, Any], - encoder: Any, - skip: Optional[list[str]] = None, - **kwargs: Any, - ) -> dict[str, Any]: - if skip is None: - skip = [] - if 'name' not in skip: - in_dict['name'] = self.name - if self._display_name is not None and self._display_name != self.name and 'display_name' not in skip: - in_dict['display_name'] = self._display_name - in_dict['data'] = [encoder._convert_to_dict(item, skip=skip, **kwargs) for item in self._data] - return in_dict - - @staticmethod - def _deserialize_protected_types(protected_types: list[dict[str, str]]) -> list[type]: - deserialized_types: list[type] = [] - for type_dict in protected_types: - if '@module' not in type_dict or '@class' not in type_dict: - raise ValueError('Each protected type must contain @module and @class keys') - module = __import__(type_dict['@module'], globals(), locals(), [type_dict['@class']], 0) - deserialized_types.append(getattr(module, type_dict['@class'])) - return deserialized_types - - def _clone_with_items(self, items: Iterable[Any]) -> CollectionBase: - return self.__class__( - *list(items), - name=self.name, - protected_types=list(self._protected_types), - display_name=self._display_name, - ) - - # --- Compatibility surface --- - - def __dir__(self) -> Iterable[str]: - hidden = { - 'display_name', - 'get_all_parameters', - 'get_all_variables', - 'get_fittable_parameters', - 'get_free_parameters', - 'to_dict', - } - legacy = { - 'append', - 'as_dict', - 'clear', - 'constraints', - 'count', - 'data', - 'decode', - 'encode', - 'extend', - 'from_dict', - 'generate_bindings', - 'get_fit_parameters', - 'get_parameters', - 'index', - 'insert', - 'interface', - 'name', - 'pop', - 'remove', - 'reverse', - 'sort', - 'switch_interface', - 'unique_name', - 'user_data', - } - public_names = {name for name in dir(self.__class__) if not name.startswith('_')} - return sorted((public_names | legacy) - hidden) - - @property - def constraints(self) -> list[Any]: - return [] - - def generate_bindings(self) -> None: - if self.interface is None: - raise AttributeError('Interface error for generating bindings. `interface` has to be set.') - - def switch_interface(self, new_interface_name: str) -> None: - if self.interface is None: - raise AttributeError('Interface error for generating bindings. `interface` has to be set.') - - # --- Internal helpers --- - - def _normalize_named_items(self, named_items: dict[str, Any]) -> dict[str, Any]: - normalized: dict[str, Any] = {} - for key, item in named_items.items(): - if key in self._RESERVED_NAMED_KEYS: - raise AttributeError(f'Given kwarg: `{key}`, is an internal attribute. Please rename.') - if item is None: - continue - normalized[key] = item - return normalized + def __repr__(self) -> str: + return f'{self.__class__.__name__} `{getattr(self, "name")}` of length {len(self)}' - def _collect_items( + def sort( self, - items: tuple[Any, ...], - data: Optional[Iterable[Any]] = None, - named_items: Optional[dict[str, Any]] = None, - ) -> list[Any]: - collected: list[Any] = [] - for item in items: - if isinstance(item, list): - collected.extend(item) - else: - collected.append(item) - if data is not None: - collected.extend(data) - if named_items is not None: - for item in named_items.values(): - if isinstance(item, list) and len(item) > 0: - collected.extend(item) - else: - collected.append(item) - return collected - - def _normalize_protected_types(self, protected_types: type | Iterable[type] | None) -> list[type]: - if protected_types is None: - return list(self._DEFAULT_PROTECTED_TYPES) - if isinstance(protected_types, type): - return [protected_types] - if isinstance(protected_types, Iterable): - normalized = list(protected_types) - if all(isinstance(item, type) for item in normalized): - return normalized - raise TypeError('protected_types must be a type or an iterable of types') - - def _serialize_item(self, item: Any, skip: Optional[list[str]] = None) -> dict[str, Any]: - if hasattr(item, 'to_dict'): - return item.to_dict() - if hasattr(item, 'as_dict'): - return item.as_dict(skip=skip) - raise TypeError(f'Unable to serialize item of type {type(item)}') - - @staticmethod - def _deserialize_item(item: Any) -> Any: - if not SerializerBase._is_serialized_easyscience_object(item): - return SerializerBase._deserialize_value(item) - - normalized_item = copy.deepcopy(item) - normalized_item.pop('unique_name', None) - return SerializerBase._deserialize_value(normalized_item) - - def _validate_item(self, item: Any) -> None: - if not isinstance(item, tuple(self._protected_types)): - raise TypeError(f'Items must be one of {self._protected_types}, got {type(item)}') + mapping: Callable[[Union[BasedBase, DescriptorBase, NewBase]], Any], + reverse: bool = False, + ) -> None: + """ + Sort the collection according to the given mapping. + + :param mapping: mapping function to sort the collection. i.e. lambda parameter: parameter.value + :type mapping: Callable + :param reverse: Reverse the sorting. + :type reverse: bool + """ + i = list(self._kwargs.items()) + i.sort(key=lambda x: mapping(x[1]), reverse=reverse) + self._kwargs.reorder(**{k[0]: k[1] for k in i}) diff --git a/src/easyscience/base_classes/collection_base_easylist.py b/src/easyscience/base_classes/collection_base_easylist.py new file mode 100644 index 00000000..f27ea82b --- /dev/null +++ b/src/easyscience/base_classes/collection_base_easylist.py @@ -0,0 +1,419 @@ +# SPDX-FileCopyrightText: 2025 EasyScience contributors +# SPDX-License-Identifier: BSD-3-Clause +# © 2021-2025 Contributors to the EasyScience project str: + return self._name + + @name.setter + def name(self, new_name: str) -> None: + if not isinstance(new_name, str): + raise TypeError('Name must be a string') + self._name = new_name + self.display_name = new_name + + # --- Minimal overrides (compatibility shims) --- + + def __getitem__(self, idx: int | slice | str) -> Any: + if isinstance(idx, bool): + raise TypeError('Boolean indexing is not supported at the moment') + if isinstance(idx, slice): + return self._clone_with_items(self._data[idx]) + if isinstance(idx, str): + try: + return super().__getitem__(idx) + except KeyError: + pass + name_matches = [item for item in self._data if getattr(item, 'name', None) == idx] + if len(name_matches) == 1: + return name_matches[0] + if len(name_matches) > 1: + return self._clone_with_items(name_matches) + raise KeyError(f'No item with key or name "{idx}" found') + return super().__getitem__(idx) + + def __setitem__(self, idx: int | slice, value: Any) -> None: + if isinstance(idx, int) and isinstance(value, Number): + item = self[idx] + if not hasattr(item, 'value'): + raise NotImplementedError('At the moment only numerical values or EasyScience objects can be set.') + item.value = value + return + try: + super().__setitem__(idx, value) + except TypeError as exc: + raise NotImplementedError('At the moment only numerical values or EasyScience objects can be set.') from exc + + def insert(self, index: int, value: Any) -> None: + try: + super().insert(index, value) + except TypeError as exc: + raise AttributeError('Only EasyScience objects can be put into an EasyScience group') from exc + + def __repr__(self) -> str: + return f'{self.__class__.__name__} `{self.name}` of length {len(self)}' + + def sort(self, key=None, reverse: bool = False, mapping=None) -> None: + if mapping is not None: + if key is not None: + raise TypeError('Use either key or mapping, not both') + warnings.warn('sort(mapping=...) is deprecated; use sort(key=...) instead', DeprecationWarning) + key = mapping + super().sort(key=key, reverse=reverse) + + # --- Parameter/variable aggregation --- + + def get_all_variables(self) -> list[DescriptorBase]: + variables: list[DescriptorBase] = [] + for item in self._data: + if isinstance(item, DescriptorBase): + variables.append(item) + elif hasattr(item, 'get_all_variables'): + variables.extend(item.get_all_variables()) + return variables + + def get_all_parameters(self) -> list[Parameter]: + parameters: list[Parameter] = [] + seen = set() + for item in self._data: + if isinstance(item, Parameter): + parameters.append(item) + seen.add(id(item)) + continue + if hasattr(item, 'get_all_parameters'): + for parameter in item.get_all_parameters(): + if id(parameter) not in seen: + parameters.append(parameter) + seen.add(id(parameter)) + continue + if hasattr(item, 'get_parameters'): + for parameter in item.get_parameters(): + if id(parameter) not in seen: + parameters.append(parameter) + seen.add(id(parameter)) + continue + if hasattr(item, 'get_all_variables'): + for variable in item.get_all_variables(): + if isinstance(variable, Parameter) and id(variable) not in seen: + parameters.append(variable) + seen.add(id(variable)) + return parameters + + def get_parameters(self) -> list[Parameter]: + return self.get_all_parameters() + + def get_fittable_parameters(self) -> list[Parameter]: + return [parameter for parameter in self.get_all_parameters() if parameter.independent] + + def get_free_parameters(self) -> list[Parameter]: + return [parameter for parameter in self.get_fittable_parameters() if not parameter.fixed] + + def get_fit_parameters(self) -> list[Parameter]: + return self.get_free_parameters() + + @property + def data(self) -> tuple[Any, ...]: + return tuple(self._data) + + # --- Serialization --- + + def as_dict(self, skip: Optional[list[str]] = None) -> dict[str, Any]: + if skip is None: + skip = [] + if 'unique_name' not in skip: + skip = [*skip, 'unique_name'] + return self.to_dict(skip=skip) + + def encode(self, skip: Optional[list[str]] = None, encoder=None, **kwargs: Any) -> Any: + if encoder is None: + encoder = SerializerDict + return encoder().encode(self, skip=skip, **kwargs) + + @classmethod + def decode(cls, obj: Any, decoder=None) -> Any: + if decoder is None or decoder is SerializerDict: + return cls.from_dict(obj) + return decoder.decode(obj) + + def to_dict(self, skip: Optional[list[str]] = None) -> dict[str, Any]: + if skip is None: + skip = [] + + try: + parent_module = self.__module__.split('.')[0] + module_version = import_module(parent_module).__version__ + except (AttributeError, ImportError): + module_version = None + + dict_repr: dict[str, Any] = { + '@module': self.__module__, + '@class': self.__class__.__name__, + '@version': module_version, + } + + if 'name' not in skip: + dict_repr['name'] = self.name + if 'display_name' not in skip and self._display_name is not None and self._display_name != self.name: + dict_repr['display_name'] = self._display_name + if 'unique_name' not in skip: + dict_repr['unique_name'] = self.unique_name + if self._protected_types != list(self._DEFAULT_PROTECTED_TYPES) and 'protected_types' not in skip: + dict_repr['protected_types'] = [ + {'@module': cls_.__module__, '@class': cls_.__name__} for cls_ in self._protected_types + ] + dict_repr['data'] = [self._serialize_item(item, skip=skip) for item in self._data] + return dict_repr + + @classmethod + def from_dict(cls, obj_dict: dict[str, Any]) -> CollectionBase: + if not isinstance(obj_dict, dict) or '@class' not in obj_dict or '@module' not in obj_dict: + raise ValueError('Input must be a dictionary representing an EasyScience CollectionBase object.') + accepted_names = {base.__name__ for base in cls.__mro__ if issubclass(base, CollectionBase)} + if obj_dict['@class'] not in accepted_names: + raise ValueError(f'Class name in dictionary does not match the expected class: {cls.__name__}.') + + temp_dict = copy.deepcopy(obj_dict) + protected_types = temp_dict.pop('protected_types', None) + if protected_types is not None: + protected_types = cls._deserialize_protected_types(protected_types) + + raw_data = temp_dict.pop('data', []) + kwargs = SerializerBase.deserialize_dict(temp_dict) + data = [cls._deserialize_item(item) for item in raw_data] + name = kwargs.pop('name', None) + kwargs.pop('unique_name', None) + return cls(*data, name=name, protected_types=protected_types, **kwargs) + + def _convert_to_dict( + self, + in_dict: dict[str, Any], + encoder: Any, + skip: Optional[list[str]] = None, + **kwargs: Any, + ) -> dict[str, Any]: + if skip is None: + skip = [] + if 'name' not in skip: + in_dict['name'] = self.name + if self._display_name is not None and self._display_name != self.name and 'display_name' not in skip: + in_dict['display_name'] = self._display_name + in_dict['data'] = [encoder._convert_to_dict(item, skip=skip, **kwargs) for item in self._data] + return in_dict + + @staticmethod + def _deserialize_protected_types(protected_types: list[dict[str, str]]) -> list[type]: + deserialized_types: list[type] = [] + for type_dict in protected_types: + if '@module' not in type_dict or '@class' not in type_dict: + raise ValueError('Each protected type must contain @module and @class keys') + module = __import__(type_dict['@module'], globals(), locals(), [type_dict['@class']], 0) + deserialized_types.append(getattr(module, type_dict['@class'])) + return deserialized_types + + def _clone_with_items(self, items: Iterable[Any]) -> CollectionBase: + return self.__class__( + *list(items), + name=self.name, + protected_types=list(self._protected_types), + display_name=self._display_name, + ) + + # --- Compatibility surface --- + + def __dir__(self) -> Iterable[str]: + hidden = { + 'display_name', + 'get_all_parameters', + 'get_all_variables', + 'get_fittable_parameters', + 'get_free_parameters', + 'to_dict', + } + legacy = { + 'append', + 'as_dict', + 'clear', + 'constraints', + 'count', + 'data', + 'decode', + 'encode', + 'extend', + 'from_dict', + 'generate_bindings', + 'get_fit_parameters', + 'get_parameters', + 'index', + 'insert', + 'interface', + 'name', + 'pop', + 'remove', + 'reverse', + 'sort', + 'switch_interface', + 'unique_name', + 'user_data', + } + public_names = {name for name in dir(self.__class__) if not name.startswith('_')} + return sorted((public_names | legacy) - hidden) + + @property + def constraints(self) -> list[Any]: + return [] + + def generate_bindings(self) -> None: + if self.interface is None: + raise AttributeError('Interface error for generating bindings. `interface` has to be set.') + + def switch_interface(self, new_interface_name: str) -> None: + if self.interface is None: + raise AttributeError('Interface error for generating bindings. `interface` has to be set.') + + # --- Internal helpers --- + + def _normalize_named_items(self, named_items: dict[str, Any]) -> dict[str, Any]: + normalized: dict[str, Any] = {} + for key, item in named_items.items(): + if key in self._RESERVED_NAMED_KEYS: + raise AttributeError(f'Given kwarg: `{key}`, is an internal attribute. Please rename.') + if item is None: + continue + normalized[key] = item + return normalized + + def _collect_items( + self, + items: tuple[Any, ...], + data: Optional[Iterable[Any]] = None, + named_items: Optional[dict[str, Any]] = None, + ) -> list[Any]: + collected: list[Any] = [] + for item in items: + if isinstance(item, list): + collected.extend(item) + else: + collected.append(item) + if data is not None: + collected.extend(data) + if named_items is not None: + for item in named_items.values(): + if isinstance(item, list) and len(item) > 0: + collected.extend(item) + else: + collected.append(item) + return collected + + def _normalize_protected_types(self, protected_types: type | Iterable[type] | None) -> list[type]: + if protected_types is None: + return list(self._DEFAULT_PROTECTED_TYPES) + if isinstance(protected_types, type): + return [protected_types] + if isinstance(protected_types, Iterable): + normalized = list(protected_types) + if all(isinstance(item, type) for item in normalized): + return normalized + raise TypeError('protected_types must be a type or an iterable of types') + + def _serialize_item(self, item: Any, skip: Optional[list[str]] = None) -> dict[str, Any]: + if hasattr(item, 'to_dict'): + return item.to_dict() + if hasattr(item, 'as_dict'): + return item.as_dict(skip=skip) + raise TypeError(f'Unable to serialize item of type {type(item)}') + + @staticmethod + def _deserialize_item(item: Any) -> Any: + if not SerializerBase._is_serialized_easyscience_object(item): + return SerializerBase._deserialize_value(item) + + normalized_item = copy.deepcopy(item) + normalized_item.pop('unique_name', None) + return SerializerBase._deserialize_value(normalized_item) + + def _validate_item(self, item: Any) -> None: + if not isinstance(item, tuple(self._protected_types)): + raise TypeError(f'Items must be one of {self._protected_types}, got {type(item)}') diff --git a/tests/unit_tests/base_classes/test_collection_base.py b/tests/unit_tests/base_classes/test_collection_base.py index 6c8eacfb..153fe713 100644 --- a/tests/unit_tests/base_classes/test_collection_base.py +++ b/tests/unit_tests/base_classes/test_collection_base.py @@ -146,7 +146,7 @@ def test_CollectionBase_append_fail(cls, setup_pars, value): @pytest.mark.parametrize("cls", class_constructors) -@pytest.mark.parametrize("value", (0, 1, 3, "p1", "d1")) +@pytest.mark.parametrize("value", (0, 1, 3, "par1", "des1")) def test_CollectionBase_getItem(cls, setup_pars, value): name = setup_pars["name"] del setup_pars["name"] @@ -155,10 +155,10 @@ def test_CollectionBase_getItem(cls, setup_pars, value): get_item = coll[value] if isinstance(value, str): - assert get_item.name == value + key = value else: key = list(setup_pars.keys())[value] - assert get_item.name == setup_pars[key].name + assert get_item.name == setup_pars[key].name @pytest.mark.parametrize("cls", class_constructors) @@ -325,7 +325,7 @@ def test_CollectionBase_dir(cls): "decode", "sort", } - assert expected.issubset(d) + assert not d.difference(expected) @pytest.mark.parametrize("cls", class_constructors) @@ -469,8 +469,11 @@ def test_CollectionBase_set_index(cls): assert obj[idx] == p2 obj[idx] = p4 assert obj[idx] == p4 - assert len(obj) == len(l_object) - assert p2 not in obj.data + edges = obj._global_object.map.get_edges(obj) + assert len(edges) == len(obj) + for item in obj: + assert item.unique_name in edges + assert p2.unique_name not in edges @pytest.mark.parametrize("cls", class_constructors) @@ -490,8 +493,11 @@ def test_CollectionBase_set_index_based(cls): assert obj[idx] == p4 obj[idx] = d assert obj[idx] == d - assert len(obj) == len(l_object) - assert p4 not in obj.data + edges = obj._global_object.map.get_edges(obj) + assert len(edges) == len(obj) + for item in obj: + assert item.unique_name in edges + assert p4.unique_name not in edges @pytest.mark.parametrize("cls", class_constructors) @@ -523,10 +529,20 @@ class Beta(ObjBase): @pytest.mark.parametrize("cls", class_constructors) def test_CollectionBaseGraph(cls): + from easyscience import global_object + + G = global_object.map name = "test" v = [1, 2] p = [Parameter(f"p{i}", v[i]) for i in range(len(v))] + p_id = [_p.unique_name for _p in p] bb = cls(name, *p) + bb_id = bb.unique_name b = Beta("b", bb=bb) - assert b.bb is bb - assert list(bb) == p + b_id = b.unique_name + for _id in p_id: + assert _id in G.get_edges(bb) + assert len(p) == len(G.get_edges(bb)) + assert bb_id in G.get_edges(b) + assert 1 == len(G.get_edges(b)) + diff --git a/tests/unit_tests/base_classes/test_collection_base_easylist.py b/tests/unit_tests/base_classes/test_collection_base_easylist.py index 8258b203..c1374966 100644 --- a/tests/unit_tests/base_classes/test_collection_base_easylist.py +++ b/tests/unit_tests/base_classes/test_collection_base_easylist.py @@ -8,7 +8,7 @@ from easyscience import ObjBase from easyscience import Parameter from easyscience import global_object -from easyscience.base_classes import CollectionBase +from easyscience.base_classes import CollectionBaseEasyList as CollectionBase @pytest.fixture(autouse=True) From 198493d9570bc5eddbfe652700b44bd7f358363c Mon Sep 17 00:00:00 2001 From: rozyczko Date: Tue, 10 Mar 2026 09:26:31 +0100 Subject: [PATCH 4/7] ruff --- src/easyscience/base_classes/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/easyscience/base_classes/__init__.py b/src/easyscience/base_classes/__init__.py index 9d65106d..4fc8702b 100644 --- a/src/easyscience/base_classes/__init__.py +++ b/src/easyscience/base_classes/__init__.py @@ -1,6 +1,6 @@ from .based_base import BasedBase -from .collection_base_easylist import CollectionBase as CollectionBaseEasyList from .collection_base import CollectionBase +from .collection_base_easylist import CollectionBase as CollectionBaseEasyList from .easy_list import EasyList from .model_base import ModelBase from .new_base import NewBase From fd411369b7c7e2f19dad4662ee508933c61a92e6 Mon Sep 17 00:00:00 2001 From: rozyczko Date: Tue, 10 Mar 2026 09:53:24 +0100 Subject: [PATCH 5/7] reverted to original test --- .../global_object/test_undo_redo.py | 63 ++++++++++++++++--- 1 file changed, 55 insertions(+), 8 deletions(-) diff --git a/tests/unit_tests/global_object/test_undo_redo.py b/tests/unit_tests/global_object/test_undo_redo.py index 643b449d..212681ac 100644 --- a/tests/unit_tests/global_object/test_undo_redo.py +++ b/tests/unit_tests/global_object/test_undo_redo.py @@ -167,22 +167,69 @@ def test_CollectionBaseUndoRedo(): objs = [createSingleObjs(idx) for idx in range(5)] name = "test" obj = CollectionBase(name, *objs) + name2 = "best" + + # assert not doUndoRedo(obj, 'name', name2) from easyscience import global_object global_object.stack.enabled = True - idx = 2 - old_value = obj[idx].value - new_value = old_value + 10 - obj[idx] = new_value - assert obj[idx].value == new_value + original_length = len(obj) + p = Parameter("slip_in", 50) + idx = 2 + obj.insert(idx, p) + assert len(obj) == original_length + 1 + objs.insert(idx, p) + for item, obj_r in zip(obj, objs): + assert item == obj_r + # Test inserting items global_object.stack.undo() - assert obj[idx].value == old_value - + assert len(obj) == original_length + _ = objs.pop(idx) + for item, obj_r in zip(obj, objs): + assert item == obj_r + global_object.stack.redo() + assert len(obj) == original_length + 1 + objs.insert(idx, p) + for item, obj_r in zip(obj, objs): + assert item == obj_r + + # Test Del Items + del obj[idx] + del objs[idx] + assert len(obj) == original_length + for item, obj_r in zip(obj, objs): + assert item == obj_r + global_object.stack.undo() + assert len(obj) == original_length + 1 + objs.insert(idx, p) + for item, obj_r in zip(obj, objs): + assert item == obj_r + del objs[idx] + global_object.stack.redo() + assert len(obj) == original_length + for item, obj_r in zip(obj, objs): + assert item == obj_r + + # Test Place Item + old_item = objs[idx] + objs[idx] = p + obj[idx] = p + assert len(obj) == original_length + for item, obj_r in zip(obj, objs): + assert item == obj_r + global_object.stack.undo() + for i in range(len(obj)): + if i == idx: + item = old_item + else: + item = objs[i] + assert obj[i] == item global_object.stack.redo() - assert obj[idx].value == new_value + for item, obj_r in zip(obj, objs): + assert item == obj_r global_object.stack.enabled = False From aaf7f422d8f2ada103bc35581d29d25e0a63e117 Mon Sep 17 00:00:00 2001 From: rozyczko Date: Wed, 18 Mar 2026 11:46:23 +0100 Subject: [PATCH 6/7] added some docstrings --- .../base_classes/collection_base_easylist.py | 107 +++++++++++++++++- 1 file changed, 106 insertions(+), 1 deletion(-) diff --git a/src/easyscience/base_classes/collection_base_easylist.py b/src/easyscience/base_classes/collection_base_easylist.py index f27ea82b..2c13e61f 100644 --- a/src/easyscience/base_classes/collection_base_easylist.py +++ b/src/easyscience/base_classes/collection_base_easylist.py @@ -31,7 +31,17 @@ class CollectionBase(EasyList): layer for existing callers. """ + # Base types that items in this collection must be instances of when no + # custom ``protected_types`` are supplied by the caller. The validation + # check in ``_validate_item`` uses this list to reject plain Python objects + # that are not part of the EasyScience type hierarchy. + # ``BasedBase`` is only kept for backwards compatibility. _DEFAULT_PROTECTED_TYPES = (DescriptorBase, BasedBase, NewBase) + + # Names that cannot be used as keyword-argument keys when passing named + # items to the constructor (e.g. ``CollectionBase(name='x', data=...)``). + # These names collide with constructor parameters or internal attributes, + # so accepting them as named items would silently shadow real arguments. _RESERVED_NAMED_KEYS = { 'data', 'display_name', @@ -42,6 +52,12 @@ class CollectionBase(EasyList): 'user_data', '_kwargs', } + + # Mapping checked by ``SerializerBase._convert_to_dict`` to + # decide how to serialise each constructor argument. ``None`` + # tells the serialiser to skip that attribute entirely, + # a callable value would be invoked to produce the serialised form. + # ``interface`` should never be persistent and it will soon be removed. _REDIRECT = {'interface': None} def __init__( @@ -51,10 +67,32 @@ def __init__( protected_types: type | Iterable[type] | None = None, unique_name: Optional[str] = None, display_name: Optional[str] = None, - interface: Any = None, + interface: Any = None, # legacy, should be None and will soon be removed data: Optional[Iterable[Any]] = None, **named_items: Any, ): + """Create a new collection of EasyScience objects. + + Items can be supplied as positional arguments, via the *data* iterable, + or as keyword arguments (``**named_items``). All three sources are + merged in order; keyword names are discarded (only values are kept). + + If the first positional argument is a string and *name* is not given, + it is consumed as the collection name. + + :param items: EasyScience objects to include in the collection. + :param name: Human-readable name for the collection. + :param protected_types: Allowed item types. Defaults to + ``_DEFAULT_PROTECTED_TYPES``. + :param unique_name: Machine-friendly unique identifier. + :param display_name: Display label; defaults to *name*. + :param interface: Reserved internal attribute — must be ``None``. + :param data: Additional iterable of items to append. + :param named_items: Keyword-argument items (keys must not collide with + ``_RESERVED_NAMED_KEYS``). + :raises AttributeError: If *interface* is not ``None`` or an item + fails type validation. + """ if items and isinstance(items[0], str) and name is None: name = items[0] items = items[1:] @@ -86,10 +124,12 @@ def __init__( @property def name(self) -> str: + """Human-readable name of the collection.""" return self._name @name.setter def name(self, new_name: str) -> None: + """Set the collection name and sync *display_name*.""" if not isinstance(new_name, str): raise TypeError('Name must be a string') self._name = new_name @@ -98,6 +138,13 @@ def name(self, new_name: str) -> None: # --- Minimal overrides (compatibility shims) --- def __getitem__(self, idx: int | slice | str) -> Any: + """Retrieve items by integer index, slice, unique-name key, or display name. + + String lookup first tries the EasyList key (unique_name), then falls + back to matching on the item's *name* attribute. If multiple items + share the same name a new ``CollectionBase`` containing all matches is + returned. + """ if isinstance(idx, bool): raise TypeError('Boolean indexing is not supported at the moment') if isinstance(idx, slice): @@ -116,6 +163,13 @@ def __getitem__(self, idx: int | slice | str) -> Any: return super().__getitem__(idx) def __setitem__(self, idx: int | slice, value: Any) -> None: + """Set an item by index. + + When *value* is a plain number and the existing item has a ``value`` + attribute, the number is assigned to ``item.value`` (in-place update). + Otherwise the item is replaced via the EasyList, which + enforces type validation. + """ if isinstance(idx, int) and isinstance(value, Number): item = self[idx] if not hasattr(item, 'value'): @@ -128,6 +182,7 @@ def __setitem__(self, idx: int | slice, value: Any) -> None: raise NotImplementedError('At the moment only numerical values or EasyScience objects can be set.') from exc def insert(self, index: int, value: Any) -> None: + """Insert *value* before *index*, validating it against protected types.""" try: super().insert(index, value) except TypeError as exc: @@ -137,6 +192,12 @@ def __repr__(self) -> str: return f'{self.__class__.__name__} `{self.name}` of length {len(self)}' def sort(self, key=None, reverse: bool = False, mapping=None) -> None: + """Sort items in place. + + :param key: Single-argument function used to extract a comparison key. + :param reverse: If ``True``, sort in descending order. + :param mapping: Deprecated alias for *key*. + """ if mapping is not None: if key is not None: raise TypeError('Use either key or mapping, not both') @@ -147,6 +208,7 @@ def sort(self, key=None, reverse: bool = False, mapping=None) -> None: # --- Parameter/variable aggregation --- def get_all_variables(self) -> list[DescriptorBase]: + """Return all descriptors in this collection, recursing into nested items.""" variables: list[DescriptorBase] = [] for item in self._data: if isinstance(item, DescriptorBase): @@ -156,6 +218,10 @@ def get_all_variables(self) -> list[DescriptorBase]: return variables def get_all_parameters(self) -> list[Parameter]: + """Return all parameters in this collection, recursing into nested items. + + Each parameter appears at most once (deduplicated by identity). + """ parameters: list[Parameter] = [] seen = set() for item in self._data: @@ -183,24 +249,30 @@ def get_all_parameters(self) -> list[Parameter]: return parameters def get_parameters(self) -> list[Parameter]: + """Alias for :meth:`get_all_parameters`.""" return self.get_all_parameters() def get_fittable_parameters(self) -> list[Parameter]: + """Return all independent (fittable) parameters.""" return [parameter for parameter in self.get_all_parameters() if parameter.independent] def get_free_parameters(self) -> list[Parameter]: + """Return all fittable parameters that are not fixed.""" return [parameter for parameter in self.get_fittable_parameters() if not parameter.fixed] def get_fit_parameters(self) -> list[Parameter]: + """Legacy alias for :meth:`get_free_parameters`.""" return self.get_free_parameters() @property def data(self) -> tuple[Any, ...]: + """Read-only snapshot of the collection items as a tuple.""" return tuple(self._data) # --- Serialization --- def as_dict(self, skip: Optional[list[str]] = None) -> dict[str, Any]: + """Legacy. Serialize to a dict, always excluding ``unique_name``.""" if skip is None: skip = [] if 'unique_name' not in skip: @@ -208,17 +280,23 @@ def as_dict(self, skip: Optional[list[str]] = None) -> dict[str, Any]: return self.to_dict(skip=skip) def encode(self, skip: Optional[list[str]] = None, encoder=None, **kwargs: Any) -> Any: + """Legacy. Encode the collection using the given encoder (default: ``SerializerDict``).""" if encoder is None: encoder = SerializerDict return encoder().encode(self, skip=skip, **kwargs) @classmethod def decode(cls, obj: Any, decoder=None) -> Any: + """Reconstruct a ``CollectionBase`` from a previously encoded object.""" if decoder is None or decoder is SerializerDict: return cls.from_dict(obj) return decoder.decode(obj) def to_dict(self, skip: Optional[list[str]] = None) -> dict[str, Any]: + """Full serialization including ``@module``, ``@class``, and ``@version`` metadata. + + :param skip: List of attribute names to exclude from the output. + """ if skip is None: skip = [] @@ -249,6 +327,11 @@ def to_dict(self, skip: Optional[list[str]] = None) -> dict[str, Any]: @classmethod def from_dict(cls, obj_dict: dict[str, Any]) -> CollectionBase: + """Reconstruct a ``CollectionBase`` from a dict produced by :meth:`to_dict`. + + :param obj_dict: Dictionary containing ``@module``, ``@class``, and ``data`` keys. + :raises ValueError: If the dictionary structure or class name is invalid. + """ if not isinstance(obj_dict, dict) or '@class' not in obj_dict or '@module' not in obj_dict: raise ValueError('Input must be a dictionary representing an EasyScience CollectionBase object.') accepted_names = {base.__name__ for base in cls.__mro__ if issubclass(base, CollectionBase)} @@ -274,6 +357,7 @@ def _convert_to_dict( skip: Optional[list[str]] = None, **kwargs: Any, ) -> dict[str, Any]: + """Legacy. Hook used by ``SerializerBase`` to populate *in_dict* with collection data.""" if skip is None: skip = [] if 'name' not in skip: @@ -285,6 +369,7 @@ def _convert_to_dict( @staticmethod def _deserialize_protected_types(protected_types: list[dict[str, str]]) -> list[type]: + """Resolve serialized ``{@module, @class}`` dicts back to live type objects.""" deserialized_types: list[type] = [] for type_dict in protected_types: if '@module' not in type_dict or '@class' not in type_dict: @@ -294,6 +379,7 @@ def _deserialize_protected_types(protected_types: list[dict[str, str]]) -> list[ return deserialized_types def _clone_with_items(self, items: Iterable[Any]) -> CollectionBase: + """Create a shallow copy of this collection containing the given *items*.""" return self.__class__( *list(items), name=self.name, @@ -304,6 +390,8 @@ def _clone_with_items(self, items: Iterable[Any]) -> CollectionBase: # --- Compatibility surface --- def __dir__(self) -> Iterable[str]: + # Names that exist on the class but should not be shown to users. + # These are internal/new-API names hidden = { 'display_name', 'get_all_parameters', @@ -312,6 +400,8 @@ def __dir__(self) -> Iterable[str]: 'get_free_parameters', 'to_dict', } + # Names that the old BasedBase/BaseObj API exposed and that + # downstream code may rely on for introspection legacy = { 'append', 'as_dict', @@ -343,19 +433,26 @@ def __dir__(self) -> Iterable[str]: @property def constraints(self) -> list[Any]: + """Compatibility stub — always returns an empty list.""" return [] def generate_bindings(self) -> None: + """Compatibility stub — requires ``interface`` to be set.""" if self.interface is None: raise AttributeError('Interface error for generating bindings. `interface` has to be set.') def switch_interface(self, new_interface_name: str) -> None: + """Compatibility stub — requires ``interface`` to be set.""" if self.interface is None: raise AttributeError('Interface error for generating bindings. `interface` has to be set.') # --- Internal helpers --- def _normalize_named_items(self, named_items: dict[str, Any]) -> dict[str, Any]: + """Validate keyword-argument item names and drop ``None`` values. + + :raises AttributeError: If a key collides with ``_RESERVED_NAMED_KEYS``. + """ normalized: dict[str, Any] = {} for key, item in named_items.items(): if key in self._RESERVED_NAMED_KEYS: @@ -371,6 +468,10 @@ def _collect_items( data: Optional[Iterable[Any]] = None, named_items: Optional[dict[str, Any]] = None, ) -> list[Any]: + """Merge positional *items*, *data*, and *named_items* into a flat list. + + Lists inside *items* or *named_items* values are flattened one level. + """ collected: list[Any] = [] for item in items: if isinstance(item, list): @@ -388,6 +489,7 @@ def _collect_items( return collected def _normalize_protected_types(self, protected_types: type | Iterable[type] | None) -> list[type]: + """Coerce *protected_types* into a list, falling back to the class default.""" if protected_types is None: return list(self._DEFAULT_PROTECTED_TYPES) if isinstance(protected_types, type): @@ -399,6 +501,7 @@ def _normalize_protected_types(self, protected_types: type | Iterable[type] | No raise TypeError('protected_types must be a type or an iterable of types') def _serialize_item(self, item: Any, skip: Optional[list[str]] = None) -> dict[str, Any]: + """Serialize a single item via its ``to_dict`` or ``as_dict`` method.""" if hasattr(item, 'to_dict'): return item.to_dict() if hasattr(item, 'as_dict'): @@ -407,6 +510,7 @@ def _serialize_item(self, item: Any, skip: Optional[list[str]] = None) -> dict[s @staticmethod def _deserialize_item(item: Any) -> Any: + """Deserialize a single item dict back into an EasyScience object.""" if not SerializerBase._is_serialized_easyscience_object(item): return SerializerBase._deserialize_value(item) @@ -415,5 +519,6 @@ def _deserialize_item(item: Any) -> Any: return SerializerBase._deserialize_value(normalized_item) def _validate_item(self, item: Any) -> None: + """Raise ``TypeError`` if *item* is not an instance of a protected type.""" if not isinstance(item, tuple(self._protected_types)): raise TypeError(f'Items must be one of {self._protected_types}, got {type(item)}') From aad52b75ef5395850ea7fcfc13966d48cf29603f Mon Sep 17 00:00:00 2001 From: rozyczko Date: Wed, 18 Mar 2026 12:10:21 +0100 Subject: [PATCH 7/7] removed named_keys --- .../base_classes/collection_base_easylist.py | 53 +++---------------- .../test_collection_base_easylist.py | 12 ----- 2 files changed, 6 insertions(+), 59 deletions(-) diff --git a/src/easyscience/base_classes/collection_base_easylist.py b/src/easyscience/base_classes/collection_base_easylist.py index 2c13e61f..2256135b 100644 --- a/src/easyscience/base_classes/collection_base_easylist.py +++ b/src/easyscience/base_classes/collection_base_easylist.py @@ -38,21 +38,6 @@ class CollectionBase(EasyList): # ``BasedBase`` is only kept for backwards compatibility. _DEFAULT_PROTECTED_TYPES = (DescriptorBase, BasedBase, NewBase) - # Names that cannot be used as keyword-argument keys when passing named - # items to the constructor (e.g. ``CollectionBase(name='x', data=...)``). - # These names collide with constructor parameters or internal attributes, - # so accepting them as named items would silently shadow real arguments. - _RESERVED_NAMED_KEYS = { - 'data', - 'display_name', - 'interface', - 'name', - 'protected_types', - 'unique_name', - 'user_data', - '_kwargs', - } - # Mapping checked by ``SerializerBase._convert_to_dict`` to # decide how to serialise each constructor argument. ``None`` # tells the serialiser to skip that attribute entirely, @@ -67,15 +52,13 @@ def __init__( protected_types: type | Iterable[type] | None = None, unique_name: Optional[str] = None, display_name: Optional[str] = None, - interface: Any = None, # legacy, should be None and will soon be removed + interface: Any = None, # legacy, should be None and will soon be removed data: Optional[Iterable[Any]] = None, - **named_items: Any, ): """Create a new collection of EasyScience objects. - Items can be supplied as positional arguments, via the *data* iterable, - or as keyword arguments (``**named_items``). All three sources are - merged in order; keyword names are discarded (only values are kept). + Items can be supplied as positional arguments or via the *data* + iterable. Both sources are merged in order. If the first positional argument is a string and *name* is not given, it is consumed as the collection name. @@ -88,8 +71,6 @@ def __init__( :param display_name: Display label; defaults to *name*. :param interface: Reserved internal attribute — must be ``None``. :param data: Additional iterable of items to append. - :param named_items: Keyword-argument items (keys must not collide with - ``_RESERVED_NAMED_KEYS``). :raises AttributeError: If *interface* is not ``None`` or an item fails type validation. """ @@ -110,8 +91,7 @@ def __init__( self.user_data: dict[str, Any] = {} self.interface = None - normalized_named_items = self._normalize_named_items(named_items) - all_items = self._collect_items(items, data=data, named_items=normalized_named_items) + all_items = self._collect_items(items, data=data) for item in all_items: try: self._validate_item(item) @@ -448,29 +428,14 @@ def switch_interface(self, new_interface_name: str) -> None: # --- Internal helpers --- - def _normalize_named_items(self, named_items: dict[str, Any]) -> dict[str, Any]: - """Validate keyword-argument item names and drop ``None`` values. - - :raises AttributeError: If a key collides with ``_RESERVED_NAMED_KEYS``. - """ - normalized: dict[str, Any] = {} - for key, item in named_items.items(): - if key in self._RESERVED_NAMED_KEYS: - raise AttributeError(f'Given kwarg: `{key}`, is an internal attribute. Please rename.') - if item is None: - continue - normalized[key] = item - return normalized - def _collect_items( self, items: tuple[Any, ...], data: Optional[Iterable[Any]] = None, - named_items: Optional[dict[str, Any]] = None, ) -> list[Any]: - """Merge positional *items*, *data*, and *named_items* into a flat list. + """Merge positional *items* and *data* into a flat list. - Lists inside *items* or *named_items* values are flattened one level. + Lists inside *items* are flattened one level. """ collected: list[Any] = [] for item in items: @@ -480,12 +445,6 @@ def _collect_items( collected.append(item) if data is not None: collected.extend(data) - if named_items is not None: - for item in named_items.values(): - if isinstance(item, list) and len(item) > 0: - collected.extend(item) - else: - collected.append(item) return collected def _normalize_protected_types(self, protected_types: type | Iterable[type] | None) -> list[type]: diff --git a/tests/unit_tests/base_classes/test_collection_base_easylist.py b/tests/unit_tests/base_classes/test_collection_base_easylist.py index c1374966..a54c52c5 100644 --- a/tests/unit_tests/base_classes/test_collection_base_easylist.py +++ b/tests/unit_tests/base_classes/test_collection_base_easylist.py @@ -16,18 +16,6 @@ def clear(): global_object.map._clear() -def test_collection_base_legacy_constructor_supports_named_items(): - p1 = Parameter('p1', 1.0) - p2 = Parameter('p2', 2.0) - - collection = CollectionBase('test', first=p1, second=p2) - - assert collection.name == 'test' - assert len(collection) == 2 - assert collection[0] is p1 - assert collection[1] is p2 - - def test_collection_base_getitem_supports_unique_name_and_name_fallback(): p1 = Parameter('dup', 1.0, unique_name='p1') p2 = Parameter('dup', 2.0, unique_name='p2')