diff --git a/src/easyscience/__init__.py b/src/easyscience/__init__.py index 167e98e0..ae709682 100644 --- a/src/easyscience/__init__.py +++ b/src/easyscience/__init__.py @@ -14,6 +14,7 @@ from .base_classes import ObjBase # noqa: E402 from .fitting import AvailableMinimizers # noqa: E402 from .fitting import Fitter # noqa: E402 +from .legacy import CollectionBase # noqa: E402 from .variable import DescriptorNumber # noqa: E402 from .variable import Parameter # noqa: E402 @@ -23,6 +24,7 @@ __version__, global_object, ObjBase, + CollectionBase, AvailableMinimizers, Fitter, DescriptorNumber, diff --git a/src/easyscience/base_classes/__init__.py b/src/easyscience/base_classes/__init__.py index b5dc0418..6c465346 100644 --- a/src/easyscience/base_classes/__init__.py +++ b/src/easyscience/base_classes/__init__.py @@ -1,11 +1,11 @@ # SPDX-FileCopyrightText: 2025 EasyScience contributors # SPDX-License-Identifier: BSD-3-Clause +from ..legacy.collection_base import CollectionBase +from ..legacy.obj_base import ObjBase from .based_base import BasedBase -from .collection_base import CollectionBase from .easy_list import EasyList from .model_base import ModelBase from .new_base import NewBase -from .obj_base import ObjBase __all__ = [BasedBase, CollectionBase, ObjBase, ModelBase, NewBase, EasyList] diff --git a/src/easyscience/base_classes/collection_base.py b/src/easyscience/base_classes/collection_base.py index 3328abc5..f53dddf1 100644 --- a/src/easyscience/base_classes/collection_base.py +++ b/src/easyscience/base_classes/collection_base.py @@ -1,251 +1,18 @@ # SPDX-FileCopyrightText: 2026 EasyScience contributors # SPDX-License-Identifier: BSD-3-Clause - -from __future__ import annotations - -from collections.abc import MutableSequence -from numbers import Number -from typing import TYPE_CHECKING -from typing import Any -from typing import Callable -from typing import List -from typing import Optional -from typing import Tuple -from typing import Union - -from easyscience.base_classes.new_base import NewBase -from easyscience.global_object.undo_redo import NotarizedDict - -from ..variable.descriptor_base import DescriptorBase -from .based_base import BasedBase - -if TYPE_CHECKING: - from ..fitting.calculators import InterfaceFactoryTemplate - - -class CollectionBase(BasedBase, MutableSequence): - """This is the base class for which all higher level classes are - built off of. - - NOTE: This object is serializable only if parameters are supplied as: - `ObjBase(a=value, b=value)`. For `Parameter` or `Descriptor` objects we can - cheat with `ObjBase(*[Descriptor(...), Parameter(...), ...])`. - """ - - def __init__( - self, - name: str, - *args: Union[BasedBase, DescriptorBase, NewBase], - interface: Optional[InterfaceFactoryTemplate] = None, - unique_name: Optional[str] = None, - **kwargs, - ): - """Set up the base collection class. - - :param name: Name of this object - :type name: str - :param args: selection of - :param _kwargs: Fields which this class should contain - :type _kwargs: dict - """ - BasedBase.__init__(self, name, unique_name=unique_name) - kwargs = {key: kwargs[key] for key in kwargs.keys() if kwargs[key] is not None} - _args = [] - for item in args: - if not isinstance(item, list): - _args.append(item) - else: - _args += item - _kwargs = {} - for key, item in kwargs.items(): - if isinstance(item, list) and len(item) > 0: - _args += item - else: - _kwargs[key] = item - kwargs = _kwargs - for item in list(kwargs.values()) + _args: - if not issubclass(type(item), (DescriptorBase, BasedBase, NewBase)): - raise AttributeError('A collection can only be formed from easyscience objects.') - args = _args - _kwargs = {} - for key, item in kwargs.items(): - _kwargs[key] = item - for arg in args: - kwargs[arg.unique_name] = arg - _kwargs[arg.unique_name] = arg - - # Set kwargs, also useful for serialization - self._kwargs = NotarizedDict(**_kwargs) - - for key in kwargs.keys(): - if key in self.__dict__.keys() or key in self.__slots__: - raise AttributeError( - f'Given kwarg: `{key}`, is an internal attribute. Please rename.' - ) - if kwargs[key]: # Might be None (empty tuple or list) - self._global_object.map.add_edge(self, kwargs[key]) - self._global_object.map.reset_type(kwargs[key], 'created_internal') - if interface is not None: - kwargs[key].interface = interface - # TODO wrap getter and setter in Logger - if interface is not None: - self.interface = interface - self._kwargs._stack_enabled = True - - def insert(self, index: int, value: Union[DescriptorBase, BasedBase, NewBase]) -> None: - """Insert an object into the collection at an index. - - :param index: Index for EasyScience object to be inserted. - :type index: int - :param value: Object to be inserted. - :type value: Union[BasedBase, DescriptorBase, NewBase] - :return: None - :rtype: None - """ - t_ = type(value) - if issubclass(t_, (BasedBase, DescriptorBase, NewBase)): - update_key = list(self._kwargs.keys()) - values = list(self._kwargs.values()) - # Update the internal dict - new_key = value.unique_name - update_key.insert(index, new_key) - values.insert(index, value) - self._kwargs.reorder(**{k: v for k, v in zip(update_key, values)}) - # ADD EDGE - self._global_object.map.add_edge(self, value) - self._global_object.map.reset_type(value, 'created_internal') - value.interface = self.interface - else: - raise AttributeError('Only EasyScience objects can be put into an EasyScience group') - - def __getitem__(self, idx: Union[int, slice]) -> Union[DescriptorBase, BasedBase, NewBase]: - """Get an item in the collection based on its index. - - :param idx: index or slice of the collection. - :type idx: Union[int, slice] - :return: Object at index `idx` - :rtype: Union[Parameter, Descriptor, ObjBase, 'CollectionBase'] - """ - if isinstance(idx, slice): - start, stop, step = idx.indices(len(self)) - return self.__class__( - getattr(self, 'name'), *[self[i] for i in range(start, stop, step)] - ) - if str(idx) in self._kwargs.keys(): - return self._kwargs[str(idx)] - if isinstance(idx, str): - idx = [index for index, item in enumerate(self) if item.name == idx] - noi = len(idx) - if noi == 0: - raise IndexError('Given index does not exist') - elif noi == 1: - idx = idx[0] - else: - return self.__class__(getattr(self, 'name'), *[self[i] for i in idx]) - elif not isinstance(idx, int) or isinstance(idx, bool): - if isinstance(idx, bool): - raise TypeError('Boolean indexing is not supported at the moment') - try: - if idx > len(self): - raise IndexError(f'Given index {idx} is out of bounds') - except TypeError: - raise IndexError('Index must be of type `int`/`slice` or an item name (`str`)') - keys = list(self._kwargs.keys()) - return self._kwargs[keys[idx]] - - def __setitem__(self, key: int, value: Union[BasedBase, DescriptorBase, NewBase]) -> None: - """Set an item via it's index. - - :param key: Index in self. - :type key: int - :param value: Value which index key should be set to. - :type value: Any - """ - if isinstance(value, Number): # noqa: S3827 - item = self.__getitem__(key) - item.value = value - elif issubclass(type(value), (BasedBase, DescriptorBase, NewBase)): - update_key = list(self._kwargs.keys()) - values = list(self._kwargs.values()) - old_item = values[key] - # Update the internal dict - update_dict = {update_key[key]: value} - self._kwargs.update(update_dict) - # ADD EDGE - self._global_object.map.add_edge(self, value) - self._global_object.map.reset_type(value, 'created_internal') - value.interface = self.interface - # REMOVE EDGE - self._global_object.map.prune_vertex_from_edge(self, old_item) - else: - raise NotImplementedError( - 'At the moment only numerical values or EasyScience objects can be set.' - ) - - def __delitem__(self, key: int) -> None: - """Try to delete an idem by key. - - :param key: - :type key: - :return: - :rtype: - """ - keys = list(self._kwargs.keys()) - item = self._kwargs[keys[key]] - self._global_object.map.prune_vertex_from_edge(self, item) - del self._kwargs[keys[key]] - - def __len__(self) -> int: - """Get the number of items in this collection. - - :return: Number of items in this collection. - :rtype: int - """ - return len(self._kwargs.keys()) - - def _convert_to_dict(self, in_dict, encoder, skip: List[str] = [], **kwargs) -> dict: - """Convert ones self into a serialized form. - - :return: dictionary of ones self - :rtype: dict - """ - d = {} - if hasattr(self, '_modify_dict'): - # any extra keys defined on the inheriting class - d = self._modify_dict(skip=skip, **kwargs) - in_dict['data'] = [encoder._convert_to_dict(item, skip=skip, **kwargs) for item in self] - out_dict = {**in_dict, **d} - return out_dict - - @property - def data(self) -> Tuple: - """The data function returns a tuple of the keyword arguments - passed to the constructor. This is useful for when you need to - pass in a dictionary of data to other functions, such as with - matplotlib's plot function. - - :param self: Access attributes of the class within the method - :return: The values of the attributes in a tuple :doc-author: - Trelent - """ - return tuple(self._kwargs.values()) - - def __repr__(self) -> str: - return f'{self.__class__.__name__} `{getattr(self, "name")}` of length {len(self)}' - - def sort( - self, - mapping: Callable[[Union[BasedBase, DescriptorBase, NewBase]], Any], - reverse: bool = False, - ) -> None: - """Sort the collection according to the given mapping. - - :param mapping: mapping function to sort the collection. i.e. - lambda parameter: parameter.value - :type mapping: Callable - :param reverse: Reverse the sorting. - :type reverse: bool - """ - i = list(self._kwargs.items()) - i.sort(key=lambda x: mapping(x[1]), reverse=reverse) - self._kwargs.reorder(**{k[0]: k[1] for k in i}) +""" +.. deprecated:: + This module has been moved to `easyscience.legacy.collection_base`. + Please update your imports. +""" + +import warnings + +from ..legacy.collection_base import CollectionBase # noqa: F401 + +warnings.warn( + 'easyscience.base_classes.collection_base is deprecated. ' + 'Please import from easyscience.legacy.collection_base instead.', + DeprecationWarning, + stacklevel=2, +) diff --git a/src/easyscience/base_classes/easy_list.py b/src/easyscience/base_classes/easy_list.py index 39f04eaa..de0cc786 100644 --- a/src/easyscience/base_classes/easy_list.py +++ b/src/easyscience/base_classes/easy_list.py @@ -17,13 +17,15 @@ from typing import overload from easyscience.io.serializer_base import SerializerBase +from easyscience.variable.descriptor_base import DescriptorBase +from .model_base import ModelBase from .new_base import NewBase ProtectedType_ = TypeVar('ProtectedType', bound=NewBase) -class EasyList(NewBase, MutableSequence[ProtectedType_]): +class EasyList(ModelBase, MutableSequence[ProtectedType_]): # If we were to inherit from List instead of MutableSequence, # we would have to overwrite "extend", "remove", "__iadd__", "count", "append", "__iter__" and "clear" def __init__( @@ -191,6 +193,24 @@ def _get_key(self, obj) -> str: """ return obj.unique_name + def get_all_variables(self) -> List[DescriptorBase]: + """Get all `Descriptor` and `Parameter` objects from all + elements that are derived from `ModelBase`. + + For each element that is a `ModelBase` instance, the element's + own `get_all_variables()` method is called and the results are + collected into a single flat list. + + :return: Flat list of all `DescriptorBase` objects from all + `ModelBase` elements. + :rtype: List[DescriptorBase] + """ + all_vars: List[DescriptorBase] = [] + for item in self._data: + if isinstance(item, ModelBase): + all_vars.extend(item.get_all_variables()) + return all_vars + # Overwriting methods def __repr__(self) -> str: diff --git a/src/easyscience/base_classes/obj_base.py b/src/easyscience/base_classes/obj_base.py index 8576b67b..d3edfffa 100644 --- a/src/easyscience/base_classes/obj_base.py +++ b/src/easyscience/base_classes/obj_base.py @@ -1,156 +1,18 @@ -# SPDX-FileCopyrightText: 2025 EasyScience contributors +# SPDX-FileCopyrightText: 2026 EasyScience contributors # SPDX-License-Identifier: BSD-3-Clause - -from __future__ import annotations - -from typing import TYPE_CHECKING -from typing import Callable -from typing import Optional - -from ..utils.classTools import addLoggedProp -from ..variable.descriptor_base import DescriptorBase -from .based_base import BasedBase - -if TYPE_CHECKING: - from ..io import SerializerComponent - - -class ObjBase(BasedBase): - """This is the base class for which all higher level classes are - built off of. - - NOTE: This object is serializable only if parameters are supplied as: - `ObjBase(a=value, b=value)`. For `Parameter` or `Descriptor` objects we can - cheat with `ObjBase(*[Descriptor(...), Parameter(...), ...])`. - """ - - def __init__( - self, - name: str, - unique_name: Optional[str] = None, - *args: Optional[SerializerComponent], - **kwargs: Optional[SerializerComponent], - ): - """Set up the base class. - - :param name: Name of this object - :param args: Any arguments? - :param kwargs: Fields which this class should contain - """ - super(ObjBase, self).__init__(name=name, unique_name=unique_name) - # If Parameter or Descriptor is given as arguments... - for arg in args: - if issubclass(type(arg), (ObjBase, DescriptorBase)): - kwargs[getattr(arg, 'name')] = arg - # Set kwargs, also useful for serialization - known_keys = self.__dict__.keys() - self._kwargs = kwargs - for key in kwargs.keys(): - if key in known_keys: - raise AttributeError('Kwargs cannot overwrite class attributes in ObjBase.') - if issubclass(type(kwargs[key]), (BasedBase, DescriptorBase)) or 'CollectionBase' in [ - c.__name__ for c in type(kwargs[key]).__bases__ - ]: - self._global_object.map.add_edge(self, kwargs[key]) - self._global_object.map.reset_type(kwargs[key], 'created_internal') - addLoggedProp( - self, - key, - self.__getter(key), - self.__setter(key), - get_id=key, - my_self=self, - test_class=ObjBase, - ) - - def _add_component(self, key: str, component: SerializerComponent) -> None: - """Dynamically add a component to the class. This is an internal - method, though can be called remotely. The recommended - alternative is to use typing, i.e. - - .. code-block:: python - - class Foo(Bar): - def __init__(self, foo: Parameter, bar: Parameter): - super(Foo, self).__init__(bar=bar) - self._add_component('foo', foo) - - :param key: Name of component to be added - :param component: Component to be added - :return: None - """ - self._kwargs[key] = component - self._global_object.map.add_edge(self, component) - self._global_object.map.reset_type(component, 'created_internal') - addLoggedProp( - self, - key, - self.__getter(key), - self.__setter(key), - get_id=key, - my_self=self, - test_class=ObjBase, - ) - - def __setattr__(self, key: str, value: SerializerComponent) -> None: - # Assume that the annotation is a ClassVar - old_obj = None - if ( - hasattr(self.__class__, '__annotations__') - and key in self.__class__.__annotations__ - and hasattr(self.__class__.__annotations__[key], '__args__') - and issubclass( - getattr(value, '__old_class__', value.__class__), - self.__class__.__annotations__[key].__args__, - ) - ): - if issubclass(type(getattr(self, key, None)), (BasedBase, DescriptorBase)): - old_obj = self.__getattribute__(key) - self._global_object.map.prune_vertex_from_edge(self, old_obj) - self._add_component(key, value) - else: - if hasattr(self, key) and issubclass(type(value), (BasedBase, DescriptorBase)): - old_obj = self.__getattribute__(key) - self._global_object.map.prune_vertex_from_edge(self, old_obj) - self._global_object.map.add_edge(self, value) - super(ObjBase, self).__setattr__(key, value) - # Update the interface bindings if something changed (BasedBase and Descriptor) - if old_obj is not None: - old_interface = getattr(self, 'interface', None) - if old_interface is not None: - self.generate_bindings() - - def __repr__(self) -> str: - return f'{self.__class__.__name__} `{getattr(self, "name")}`' - - @staticmethod - def __getter(key: str) -> Callable[[SerializerComponent], SerializerComponent]: - def getter(obj: SerializerComponent) -> SerializerComponent: - return obj._kwargs[key] - - return getter - - @staticmethod - def __setter(key: str) -> Callable[[SerializerComponent], None]: - def setter(obj: SerializerComponent, value: float) -> None: - if issubclass(obj._kwargs[key].__class__, (DescriptorBase)) and not issubclass( - value.__class__, (DescriptorBase) - ): - obj._kwargs[key].value = value - else: - obj._kwargs[key] = value - - return setter - - # @staticmethod - # def __setter(key: str) -> Callable[[Union[B, V]], None]: - # def setter(obj: Union[V, B], value: float) -> None: - # if issubclass(obj._kwargs[key].__class__, Descriptor): - # if issubclass(obj._kwargs[key].__class__, Descriptor): - # obj._kwargs[key] = value - # else: - # obj._kwargs[key].value = value - # else: - # obj._kwargs[key] = value - # - # return setter +""" +.. deprecated:: + This module has been moved to `easyscience.legacy.obj_base`. + Please update your imports. +""" + +import warnings + +from ..legacy.obj_base import ObjBase # noqa: F401 + +warnings.warn( + 'easyscience.base_classes.obj_base is deprecated. ' + 'Please import from easyscience.legacy.obj_base instead.', + DeprecationWarning, + stacklevel=2, +) diff --git a/src/easyscience/job/analysis.py b/src/easyscience/job/analysis.py index 448f1c3c..ab8544c6 100644 --- a/src/easyscience/job/analysis.py +++ b/src/easyscience/job/analysis.py @@ -6,8 +6,8 @@ import numpy as np -from ..base_classes.obj_base import ObjBase from ..fitting.minimizers import MinimizerBase +from ..legacy.obj_base import ObjBase class AnalysisBase(ObjBase, metaclass=ABCMeta): diff --git a/src/easyscience/job/experiment.py b/src/easyscience/job/experiment.py index 0f9577c3..eba9439a 100644 --- a/src/easyscience/job/experiment.py +++ b/src/easyscience/job/experiment.py @@ -1,7 +1,7 @@ # SPDX-FileCopyrightText: 2026 EasyScience contributors # SPDX-License-Identifier: BSD-3-Clause -from ..base_classes.obj_base import ObjBase +from ..legacy.obj_base import ObjBase class ExperimentBase(ObjBase): diff --git a/src/easyscience/job/job.py b/src/easyscience/job/job.py index b48cbb1d..bcddee70 100644 --- a/src/easyscience/job/job.py +++ b/src/easyscience/job/job.py @@ -4,7 +4,7 @@ from abc import ABCMeta from abc import abstractmethod -from ..base_classes.obj_base import ObjBase +from ..legacy.obj_base import ObjBase from .analysis import AnalysisBase from .experiment import ExperimentBase from .theoreticalmodel import TheoreticalModelBase diff --git a/src/easyscience/job/theoreticalmodel.py b/src/easyscience/job/theoreticalmodel.py index 609f889e..6ebc3978 100644 --- a/src/easyscience/job/theoreticalmodel.py +++ b/src/easyscience/job/theoreticalmodel.py @@ -1,7 +1,7 @@ # SPDX-FileCopyrightText: 2026 EasyScience contributors # SPDX-License-Identifier: BSD-3-Clause -from ..base_classes.obj_base import ObjBase +from ..legacy.obj_base import ObjBase class TheoreticalModelBase(ObjBase): diff --git a/src/easyscience/legacy/__init__.py b/src/easyscience/legacy/__init__.py new file mode 100644 index 00000000..482abc26 --- /dev/null +++ b/src/easyscience/legacy/__init__.py @@ -0,0 +1,7 @@ +# SPDX-FileCopyrightText: 2026 EasyScience contributors +# SPDX-License-Identifier: BSD-3-Clause + +from .collection_base import CollectionBase # noqa: F401 +from .obj_base import ObjBase # noqa: F401 + +__all__ = ['CollectionBase', 'ObjBase'] diff --git a/src/easyscience/legacy/collection_base.py b/src/easyscience/legacy/collection_base.py new file mode 100644 index 00000000..7ec88e74 --- /dev/null +++ b/src/easyscience/legacy/collection_base.py @@ -0,0 +1,262 @@ +# SPDX-FileCopyrightText: 2026 EasyScience contributors +# SPDX-License-Identifier: BSD-3-Clause + +from __future__ import annotations + +import warnings +from collections.abc import MutableSequence +from numbers import Number +from typing import TYPE_CHECKING +from typing import Any +from typing import Callable +from typing import List +from typing import Optional +from typing import Tuple +from typing import Union + +from easyscience.base_classes.new_base import NewBase +from easyscience.global_object.undo_redo import NotarizedDict + +from ..base_classes.based_base import BasedBase +from ..variable.descriptor_base import DescriptorBase + +if TYPE_CHECKING: + from ..fitting.calculators import InterfaceFactoryTemplate + + +class CollectionBase(BasedBase, MutableSequence): + """This is the base class for which all higher level classes are + built off of. + + .. deprecated:: + `CollectionBase` is deprecated and will be removed in a future version. + Please migrate to `ModelBase` or `EasyList` instead. + + NOTE: This object is serializable only if parameters are supplied as: + `ObjBase(a=value, b=value)`. For `Parameter` or `Descriptor` objects we can + cheat with `ObjBase(*[Descriptor(...), Parameter(...), ...])`. + """ + + def __init__( + self, + name: str, + *args: Union[BasedBase, DescriptorBase, NewBase], + interface: Optional[InterfaceFactoryTemplate] = None, + unique_name: Optional[str] = None, + **kwargs, + ): + """Set up the base collection class. + + :param name: Name of this object + :type name: str + :param args: selection of + :param _kwargs: Fields which this class should contain + :type _kwargs: dict + """ + warnings.warn( + 'CollectionBase is deprecated and will be removed in a future version. ' + 'Please migrate to ModelBase or EasyList.', + DeprecationWarning, + stacklevel=2, + ) + BasedBase.__init__(self, name, unique_name=unique_name) + kwargs = {key: kwargs[key] for key in kwargs.keys() if kwargs[key] is not None} + _args = [] + for item in args: + if not isinstance(item, list): + _args.append(item) + else: + _args += item + _kwargs = {} + for key, item in kwargs.items(): + if isinstance(item, list) and len(item) > 0: + _args += item + else: + _kwargs[key] = item + kwargs = _kwargs + for item in list(kwargs.values()) + _args: + if not issubclass(type(item), (DescriptorBase, BasedBase, NewBase)): + raise AttributeError('A collection can only be formed from easyscience objects.') + args = _args + _kwargs = {} + for key, item in kwargs.items(): + _kwargs[key] = item + for arg in args: + kwargs[arg.unique_name] = arg + _kwargs[arg.unique_name] = arg + + # Set kwargs, also useful for serialization + self._kwargs = NotarizedDict(**_kwargs) + + for key in kwargs.keys(): + if key in self.__dict__.keys() or key in self.__slots__: + raise AttributeError( + f'Given kwarg: `{key}`, is an internal attribute. Please rename.' + ) + if kwargs[key]: # Might be None (empty tuple or list) + self._global_object.map.add_edge(self, kwargs[key]) + self._global_object.map.reset_type(kwargs[key], 'created_internal') + if interface is not None: + kwargs[key].interface = interface + # TODO wrap getter and setter in Logger + if interface is not None: + self.interface = interface + self._kwargs._stack_enabled = True + + def insert(self, index: int, value: Union[DescriptorBase, BasedBase, NewBase]) -> None: + """Insert an object into the collection at an index. + + :param index: Index for EasyScience object to be inserted. + :type index: int + :param value: Object to be inserted. + :type value: Union[BasedBase, DescriptorBase, NewBase] + :return: None + :rtype: None + """ + t_ = type(value) + if issubclass(t_, (BasedBase, DescriptorBase, NewBase)): + update_key = list(self._kwargs.keys()) + values = list(self._kwargs.values()) + # Update the internal dict + new_key = value.unique_name + update_key.insert(index, new_key) + values.insert(index, value) + self._kwargs.reorder(**{k: v for k, v in zip(update_key, values)}) + # ADD EDGE + self._global_object.map.add_edge(self, value) + self._global_object.map.reset_type(value, 'created_internal') + value.interface = self.interface + else: + raise AttributeError('Only EasyScience objects can be put into an EasyScience group') + + def __getitem__(self, idx: Union[int, slice]) -> Union[DescriptorBase, BasedBase, NewBase]: + """Get an item in the collection based on its index. + + :param idx: index or slice of the collection. + :type idx: Union[int, slice] + :return: Object at index `idx` + :rtype: Union[Parameter, Descriptor, ObjBase, 'CollectionBase'] + """ + if isinstance(idx, slice): + start, stop, step = idx.indices(len(self)) + return self.__class__( + getattr(self, 'name'), *[self[i] for i in range(start, stop, step)] + ) + if str(idx) in self._kwargs.keys(): + return self._kwargs[str(idx)] + if isinstance(idx, str): + idx = [index for index, item in enumerate(self) if item.name == idx] + noi = len(idx) + if noi == 0: + raise IndexError('Given index does not exist') + elif noi == 1: + idx = idx[0] + else: + return self.__class__(getattr(self, 'name'), *[self[i] for i in idx]) + elif not isinstance(idx, int) or isinstance(idx, bool): + if isinstance(idx, bool): + raise TypeError('Boolean indexing is not supported at the moment') + try: + if idx > len(self): + raise IndexError(f'Given index {idx} is out of bounds') + except TypeError: + raise IndexError('Index must be of type `int`/`slice` or an item name (`str`)') + keys = list(self._kwargs.keys()) + return self._kwargs[keys[idx]] + + def __setitem__(self, key: int, value: Union[BasedBase, DescriptorBase, NewBase]) -> None: + """Set an item via it's index. + + :param key: Index in self. + :type key: int + :param value: Value which index key should be set to. + :type value: Any + """ + if isinstance(value, Number): # noqa: S3827 + item = self.__getitem__(key) + item.value = value + elif issubclass(type(value), (BasedBase, DescriptorBase, NewBase)): + update_key = list(self._kwargs.keys()) + values = list(self._kwargs.values()) + old_item = values[key] + # Update the internal dict + update_dict = {update_key[key]: value} + self._kwargs.update(update_dict) + # ADD EDGE + self._global_object.map.add_edge(self, value) + self._global_object.map.reset_type(value, 'created_internal') + value.interface = self.interface + # REMOVE EDGE + self._global_object.map.prune_vertex_from_edge(self, old_item) + else: + raise NotImplementedError( + 'At the moment only numerical values or EasyScience objects can be set.' + ) + + def __delitem__(self, key: int) -> None: + """Try to delete an idem by key. + + :param key: + :type key: + :return: + :rtype: + """ + keys = list(self._kwargs.keys()) + item = self._kwargs[keys[key]] + self._global_object.map.prune_vertex_from_edge(self, item) + del self._kwargs[keys[key]] + + def __len__(self) -> int: + """Get the number of items in this collection. + + :return: Number of items in this collection. + :rtype: int + """ + return len(self._kwargs.keys()) + + def _convert_to_dict(self, in_dict, encoder, skip: List[str] = [], **kwargs) -> dict: + """Convert ones self into a serialized form. + + :return: dictionary of ones self + :rtype: dict + """ + d = {} + if hasattr(self, '_modify_dict'): + # any extra keys defined on the inheriting class + d = self._modify_dict(skip=skip, **kwargs) + in_dict['data'] = [encoder._convert_to_dict(item, skip=skip, **kwargs) for item in self] + out_dict = {**in_dict, **d} + return out_dict + + @property + def data(self) -> Tuple: + """The data function returns a tuple of the keyword arguments + passed to the constructor. This is useful for when you need to + pass in a dictionary of data to other functions, such as with + matplotlib's plot function. + + :param self: Access attributes of the class within the method + :return: The values of the attributes in a tuple :doc-author: + Trelent + """ + return tuple(self._kwargs.values()) + + def __repr__(self) -> str: + return f'{self.__class__.__name__} `{getattr(self, "name")}` of length {len(self)}' + + def sort( + self, + mapping: Callable[[Union[BasedBase, DescriptorBase, NewBase]], Any], + reverse: bool = False, + ) -> None: + """Sort the collection according to the given mapping. + + :param mapping: mapping function to sort the collection. i.e. + lambda parameter: parameter.value + :type mapping: Callable + :param reverse: Reverse the sorting. + :type reverse: bool + """ + i = list(self._kwargs.items()) + i.sort(key=lambda x: mapping(x[1]), reverse=reverse) + self._kwargs.reorder(**{k[0]: k[1] for k in i}) diff --git a/src/easyscience/legacy/obj_base.py b/src/easyscience/legacy/obj_base.py new file mode 100644 index 00000000..7a6d611a --- /dev/null +++ b/src/easyscience/legacy/obj_base.py @@ -0,0 +1,154 @@ +# SPDX-FileCopyrightText: 2025 EasyScience contributors +# SPDX-License-Identifier: BSD-3-Clause + +from __future__ import annotations + +import warnings +from typing import TYPE_CHECKING +from typing import Callable +from typing import Optional + +from ..base_classes.based_base import BasedBase +from ..utils.classTools import addLoggedProp +from ..variable.descriptor_base import DescriptorBase + +if TYPE_CHECKING: + from ..io import SerializerComponent + + +class ObjBase(BasedBase): + """This is the base class for which all higher level classes are + built off of. + + .. deprecated:: + `ObjBase` is deprecated and will be removed in a future version. + Please migrate to `ModelBase` instead. + + NOTE: This object is serializable only if parameters are supplied as: + `ObjBase(a=value, b=value)`. For `Parameter` or `Descriptor` objects we can + cheat with `ObjBase(*[Descriptor(...), Parameter(...), ...])`. + """ + + def __init__( + self, + name: str, + unique_name: Optional[str] = None, + *args: Optional[SerializerComponent], + **kwargs: Optional[SerializerComponent], + ): + """Set up the base class. + + :param name: Name of this object + :param args: Any arguments? + :param kwargs: Fields which this class should contain + """ + warnings.warn( + 'ObjBase is deprecated and will be removed in a future version. ' + 'Please migrate to ModelBase.', + DeprecationWarning, + stacklevel=2, + ) + super(ObjBase, self).__init__(name=name, unique_name=unique_name) + # If Parameter or Descriptor is given as arguments... + for arg in args: + if issubclass(type(arg), (ObjBase, DescriptorBase)): + kwargs[getattr(arg, 'name')] = arg + # Set kwargs, also useful for serialization + known_keys = self.__dict__.keys() + self._kwargs = kwargs + for key in kwargs.keys(): + if key in known_keys: + raise AttributeError('Kwargs cannot overwrite class attributes in ObjBase.') + if issubclass(type(kwargs[key]), (BasedBase, DescriptorBase)) or 'CollectionBase' in [ + c.__name__ for c in type(kwargs[key]).__bases__ + ]: + self._global_object.map.add_edge(self, kwargs[key]) + self._global_object.map.reset_type(kwargs[key], 'created_internal') + addLoggedProp( + self, + key, + self.__getter(key), + self.__setter(key), + get_id=key, + my_self=self, + test_class=ObjBase, + ) + + def _add_component(self, key: str, component: SerializerComponent) -> None: + """Dynamically add a component to the class. This is an internal + method, though can be called remotely. The recommended + alternative is to use typing, i.e. + + .. code-block:: python + + class Foo(Bar): + def __init__(self, foo: Parameter, bar: Parameter): + super(Foo, self).__init__(bar=bar) + self._add_component('foo', foo) + + :param key: Name of component to be added + :param component: Component to be added + :return: None + """ + self._kwargs[key] = component + self._global_object.map.add_edge(self, component) + self._global_object.map.reset_type(component, 'created_internal') + addLoggedProp( + self, + key, + self.__getter(key), + self.__setter(key), + get_id=key, + my_self=self, + test_class=ObjBase, + ) + + def __setattr__(self, key: str, value: SerializerComponent) -> None: + # Assume that the annotation is a ClassVar + old_obj = None + if ( + hasattr(self.__class__, '__annotations__') + and key in self.__class__.__annotations__ + and hasattr(self.__class__.__annotations__[key], '__args__') + and issubclass( + getattr(value, '__old_class__', value.__class__), + self.__class__.__annotations__[key].__args__, + ) + ): + if issubclass(type(getattr(self, key, None)), (BasedBase, DescriptorBase)): + old_obj = self.__getattribute__(key) + self._global_object.map.prune_vertex_from_edge(self, old_obj) + self._add_component(key, value) + else: + if hasattr(self, key) and issubclass(type(value), (BasedBase, DescriptorBase)): + old_obj = self.__getattribute__(key) + self._global_object.map.prune_vertex_from_edge(self, old_obj) + self._global_object.map.add_edge(self, value) + super(ObjBase, self).__setattr__(key, value) + # Update the interface bindings if something changed (BasedBase and Descriptor) + if old_obj is not None: + old_interface = getattr(self, 'interface', None) + if old_interface is not None: + self.generate_bindings() + + def __repr__(self) -> str: + return f'{self.__class__.__name__} `{getattr(self, "name")}`' + + @staticmethod + def __getter(key: str) -> Callable[[SerializerComponent], SerializerComponent]: + def getter(obj: SerializerComponent) -> SerializerComponent: + return obj._kwargs[key] + + return getter + + @staticmethod + def __setter(key: str) -> Callable[[SerializerComponent], None]: + def setter(obj: SerializerComponent, value: float) -> None: + if issubclass(obj._kwargs[key].__class__, (DescriptorBase)) and not issubclass( + value.__class__, (DescriptorBase) + ): + obj._kwargs[key].value = value + else: + obj._kwargs[key] = value + + return setter diff --git a/tests/unit/base_classes/test_deprecated_wrappers.py b/tests/unit/base_classes/test_deprecated_wrappers.py new file mode 100644 index 00000000..01388b3f --- /dev/null +++ b/tests/unit/base_classes/test_deprecated_wrappers.py @@ -0,0 +1,88 @@ +# SPDX-FileCopyrightText: 2026 EasyScience contributors +# SPDX-License-Identifier: BSD-3-Clause + +"""Tests for the deprecated wrapper modules in easyscience.base_classes. + +These modules now only emit DeprecationWarnings and re-export from +easyscience.legacy. We test that the warnings are raised and that +the classes are still usable. +""" + +import warnings + +import pytest + +from easyscience import global_object + + +@pytest.fixture(autouse=True) +def _clear_map(): + """Clear the global object map before and after each test.""" + global_object.map._clear() + yield + global_object.map._clear() + + +# --------------------------------------------------------------------------- +# Deprecated wrapper: easyscience.base_classes.collection_base +# --------------------------------------------------------------------------- + + +def test_import_collection_base_warns(): + """Importing easyscience.base_classes.collection_base emits DeprecationWarning.""" + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter('always') + from easyscience.base_classes.collection_base import CollectionBase # noqa: F811 + + assert len(w) >= 1, 'Expected at least one DeprecationWarning' + deprecation_warnings = [x for x in w if issubclass(x.category, DeprecationWarning)] + assert len(deprecation_warnings) >= 1, 'Expected a DeprecationWarning' + msg = str(deprecation_warnings[0].message) + assert 'deprecated' in msg.lower() + assert 'legacy.collection_base' in msg + + +def test_collection_base_still_works_from_deprecated(): + """The class imported from the deprecated wrapper still works.""" + with warnings.catch_warnings(): + warnings.simplefilter('ignore', DeprecationWarning) + from easyscience.base_classes.collection_base import CollectionBase # noqa: F811 + + from easyscience import Parameter + + p = Parameter('p1', 1.0) + coll = CollectionBase('test', p) + assert len(coll) == 1 + assert coll[0].name == 'p1' + + +# --------------------------------------------------------------------------- +# Deprecated wrapper: easyscience.base_classes.obj_base +# --------------------------------------------------------------------------- + + +def test_import_obj_base_warns(): + """Importing easyscience.base_classes.obj_base emits DeprecationWarning.""" + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter('always') + from easyscience.base_classes.obj_base import ObjBase # noqa: F811 + + assert len(w) >= 1, 'Expected at least one DeprecationWarning' + deprecation_warnings = [x for x in w if issubclass(x.category, DeprecationWarning)] + assert len(deprecation_warnings) >= 1, 'Expected a DeprecationWarning' + msg = str(deprecation_warnings[0].message) + assert 'deprecated' in msg.lower() + assert 'legacy.obj_base' in msg + + +def test_obj_base_still_works_from_deprecated(): + """The class imported from the deprecated wrapper still works.""" + with warnings.catch_warnings(): + warnings.simplefilter('ignore', DeprecationWarning) + from easyscience.base_classes.obj_base import ObjBase # noqa: F811 + + from easyscience import Parameter + + p = Parameter('p1', 1.0) + obj = ObjBase('test', p1=p) + assert obj.p1.value == 1.0 diff --git a/tests/unit/base_classes/test_easy_list.py b/tests/unit/base_classes/test_easy_list.py index 32f6a60b..daee3519 100644 --- a/tests/unit/base_classes/test_easy_list.py +++ b/tests/unit/base_classes/test_easy_list.py @@ -7,7 +7,10 @@ from easyscience import global_object from easyscience.base_classes.easy_list import EasyList +from easyscience.base_classes.model_base import ModelBase from easyscience.base_classes.new_base import NewBase +from easyscience.variable import DescriptorNumber +from easyscience.variable import Parameter class Alpha(NewBase): @@ -24,6 +27,52 @@ def __init__(self, unique_name=None, display_name=None): super().__init__(unique_name=unique_name, display_name=display_name) +class MockModel(ModelBase): + """A ModelBase subclass with a Parameter and a DescriptorNumber for testing get_all_variables.""" + + def __init__(self, unique_name=None, display_name=None, temperature=25, volume=1.0): + super().__init__(unique_name=unique_name, display_name=display_name) + self._temperature = Parameter(name='temperature', value=temperature) + self._volume = DescriptorNumber(name='volume', value=volume) + + @property + def temperature(self): + return self._temperature + + @temperature.setter + def temperature(self, value): + self._temperature.value = value + + @property + def volume(self): + return self._volume + + @volume.setter + def volume(self, value): + self._volume.value = value + + +class MockModelNested(ModelBase): + """A ModelBase subclass that contains another ModelBase to test nested variable collection.""" + + def __init__(self, unique_name=None, display_name=None, component=None, pressure=0): + super().__init__(unique_name=unique_name, display_name=display_name) + self._pressure = Parameter(name='pressure', value=pressure) + self._component = component or MockModel(unique_name='inner', temperature=30, volume=2.0) + + @property + def pressure(self): + return self._pressure + + @pressure.setter + def pressure(self, value): + self._pressure.value = value + + @property + def component(self): + return self._component + + class TestEasyList: @pytest.fixture(autouse=True) def clear(self): @@ -524,3 +573,111 @@ def test_from_dict_round_trip(self): assert el2[0].unique_name == 'a1' assert el2[1].unique_name == 'a2' assert d == el2.to_dict() # The dicts should be the same after round trip + + # --- get_all_variables --- + + def test_get_all_variables_empty_list(self): + """An empty EasyList should return an empty list of variables.""" + el = EasyList(protected_types=ModelBase) + assert el.get_all_variables() == [] + + def test_get_all_variables_no_modelbase_elements(self): + """An EasyList with only plain NewBase elements (not ModelBase) should return an empty list.""" + a1 = Alpha(unique_name='a1') + a2 = NewBase(unique_name='nb1') + el = EasyList(a1, a2) + assert el.get_all_variables() == [] + + def test_get_all_variables_single_modelbase(self): + """A single ModelBase element should return its Parameter and DescriptorNumber.""" + m1 = MockModel(unique_name='m1', temperature=10, volume=5.0) + el = EasyList(m1, protected_types=ModelBase) + vars = el.get_all_variables() + assert len(vars) == 2 + names = {v.name for v in vars} + assert 'temperature' in names + assert 'volume' in names + # Verify specific values + temp_var = next(v for v in vars if v.name == 'temperature') + assert temp_var.value == 10 + vol_var = next(v for v in vars if v.name == 'volume') + assert vol_var.value == 5.0 + + def test_get_all_variables_multiple_modelbase(self): + """Multiple ModelBase elements should return all their combined variables.""" + m1 = MockModel(unique_name='m1', temperature=10, volume=5.0) + m2 = MockModel(unique_name='m2', temperature=99, volume=1.5) + el = EasyList(m1, m2, protected_types=ModelBase) + vars = el.get_all_variables() + assert len(vars) == 4 + names = {v.name for v in vars} + assert names == {'temperature', 'volume'} + + def test_get_all_variables_mixed_elements(self): + """Only ModelBase-derived elements contribute variables; plain NewBase elements are skipped.""" + m1 = MockModel(unique_name='m1', temperature=10, volume=5.0) + a1 = Alpha(unique_name='a1') + el = EasyList(m1, a1) + vars = el.get_all_variables() + assert len(vars) == 2 + names = {v.name for v in vars} + assert names == {'temperature', 'volume'} + + def test_get_all_variables_nested_model(self): + """A ModelBase containing another ModelBase should recursively collect all variables.""" + inner = MockModel(unique_name='inner', temperature=30, volume=2.0) + parent = MockModelNested(unique_name='parent', component=inner, pressure=100) + el = EasyList(parent, protected_types=ModelBase) + vars = el.get_all_variables() + # parent: pressure (Parameter), inner: temperature (Parameter), volume (DescriptorNumber) + assert len(vars) == 3 + names = {v.name for v in vars} + assert names == {'pressure', 'temperature', 'volume'} + + def test_get_all_variables_returns_descriptorbase_instances(self): + """All returned items should be instances of DescriptorBase.""" + m1 = MockModel(unique_name='m1', temperature=10, volume=5.0) + m2 = MockModel(unique_name='m2', temperature=99, volume=1.5) + el = EasyList(m1, m2, protected_types=ModelBase) + for v in el.get_all_variables(): + from easyscience.variable.descriptor_base import DescriptorBase + + assert isinstance(v, DescriptorBase) + + def test_get_all_variables_nested_easylist(self): + """An EasyList containing another EasyList with mixed NewBase/ModelBase elements + should collect variables from the inner EasyList's ModelBase items, + skipping plain NewBase items.""" + inner_model = MockModel(unique_name='inner_m', temperature=50, volume=3.0) + inner_plain = Alpha(unique_name='inner_a') + inner_list = EasyList(inner_model, inner_plain) + outer_model = MockModel(unique_name='outer_m', temperature=70, volume=4.0) + outer_list = EasyList(inner_list, outer_model) + + # --- outer_list structure --- + assert len(outer_list) == 2 + assert outer_list[0] is inner_list + assert outer_list[1] is outer_model + + # --- inner_list structure --- + assert len(inner_list) == 2 + assert inner_list[0] is inner_model + assert inner_list[1] is inner_plain + # inner_list own get_all_variables should only see inner_model (skip Alpha) + inner_vars = inner_list.get_all_variables() + assert len(inner_vars) == 2 + + # --- outer_list.get_all_variables --- + vars = outer_list.get_all_variables() + # inner_model: temperature (50), volume (3.0); outer_model: temperature (70), volume (4.0) + assert len(vars) == 4 + + # Verify all returned items are DescriptorBase instances + for v in vars: + assert isinstance(v, DescriptorNumber) + + # Collect temperatures and volumes from both models + temps = {v.value for v in vars if v.name == 'temperature'} + vols = {v.value for v in vars if v.name == 'volume'} + assert temps == {50, 70} + assert vols == {3.0, 4.0} diff --git a/tests/unit/base_classes/test_obj_base.py b/tests/unit/base_classes/test_obj_base.py index 357aa4f1..5eb1b2ea 100644 --- a/tests/unit/base_classes/test_obj_base.py +++ b/tests/unit/base_classes/test_obj_base.py @@ -148,7 +148,7 @@ def test_ObjBase_as_dict(clear, setup_pars: dict): obtained = obj.as_dict() assert isinstance(obtained, dict) expected = { - '@module': 'easyscience.base_classes.obj_base', + '@module': 'easyscience.legacy.obj_base', '@class': 'ObjBase', '@version': easyscience.__version__, 'name': 'test', diff --git a/tests/unit/legacy/test_collection_base.py b/tests/unit/legacy/test_collection_base.py new file mode 100644 index 00000000..1a1115da --- /dev/null +++ b/tests/unit/legacy/test_collection_base.py @@ -0,0 +1,155 @@ +# SPDX-FileCopyrightText: 2026 EasyScience contributors +# SPDX-License-Identifier: BSD-3-Clause + +"""Tests for easyscience.legacy.collection_base.CollectionBase. + +These tests cover methods not exercised by the existing +tests/unit/base_classes/test_collection_base.py suite. +""" + +import warnings + +import pytest + +from easyscience import DescriptorNumber +from easyscience import Parameter +from easyscience import global_object +from easyscience.legacy.collection_base import CollectionBase + + +@pytest.fixture(autouse=True) +def _clear_map(): + global_object.map._clear() + yield + global_object.map._clear() + + +@pytest.fixture +def setup_pars(): + return { + 'name': 'test', + 'par1': Parameter('p1', 0.1, fixed=True), + 'des1': DescriptorNumber('d1', 0.1), + 'par2': Parameter('p2', 0.2), + 'des2': DescriptorNumber('d2', 0.2), + 'par3': Parameter('p3', 0.3), + } + + +# --------------------------------------------------------------------------- +# Deprecation warning on instantiation +# --------------------------------------------------------------------------- + +def test_instantiation_emits_deprecation_warning(): + """Instantiating CollectionBase should emit a DeprecationWarning.""" + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter('always') + CollectionBase('test') + deprecation_warnings = [x for x in w if issubclass(x.category, DeprecationWarning)] + assert len(deprecation_warnings) >= 1 + assert 'deprecated' in str(deprecation_warnings[0].message).lower() + + +# --------------------------------------------------------------------------- +# sort +# --------------------------------------------------------------------------- + +def test_sort_ascending(setup_pars): + """Sort by parameter value in ascending order.""" + name = setup_pars.pop('name') + coll = CollectionBase(name, **setup_pars) + coll.sort(mapping=lambda item: item.value, reverse=False) + values = [item.value for item in coll] + assert values == sorted(values) + + +def test_sort_descending(setup_pars): + """Sort by parameter value in descending order.""" + name = setup_pars.pop('name') + coll = CollectionBase(name, **setup_pars) + coll.sort(mapping=lambda item: item.value, reverse=True) + values = [item.value for item in coll] + assert values == sorted(values, reverse=True) + + +# --------------------------------------------------------------------------- +# data property +# --------------------------------------------------------------------------- + +def test_data_property(setup_pars): + """The data property returns a tuple of stored items.""" + name = setup_pars.pop('name') + coll = CollectionBase(name, **setup_pars) + d = coll.data + assert isinstance(d, tuple) + assert len(d) == 5 + # Items should be the same objects + for item_from_data, item_from_iter in zip(d, coll): + assert item_from_data is item_from_iter + + +# --------------------------------------------------------------------------- +# __setitem__ with an EasyScience object (not a Number) +# --------------------------------------------------------------------------- + +def test_setitem_with_easyscience_object(setup_pars): + """Replace an item at an index with another EasyScience object.""" + name = setup_pars.pop('name') + coll = CollectionBase(name, **setup_pars) + + n_before = len(coll) + old_item = coll[0] + new_item = Parameter('replacement', 99.0) + + coll[0] = new_item + + assert len(coll) == n_before + assert coll[0].name == 'replacement' + assert coll[0].value == 99.0 + # Old item should be removed from the graph + assert old_item.unique_name not in global_object.map.get_edges(coll) + + +# --------------------------------------------------------------------------- +# __getitem__ with duplicate names returns a new CollectionBase +# --------------------------------------------------------------------------- + +def test_getitem_duplicate_names_returns_collection(setup_pars): + """When multiple items share the same name, __getitem__ returns a sub-collection.""" + name = setup_pars.pop('name') + # Add two items with the same display name + p1 = Parameter('same_name', 1.0) + p2 = Parameter('same_name', 2.0) + coll = CollectionBase(name, p1, p2) + + result = coll['same_name'] + assert isinstance(result, CollectionBase) + assert len(result) == 2 + + +def test_getitem_nonexistent_name_raises(setup_pars): + """Looking up a nonexistent name raises IndexError.""" + name = setup_pars.pop('name') + coll = CollectionBase(name, **setup_pars) + + with pytest.raises(IndexError, match='Given index does not exist'): + _ = coll['nonexistent'] + + +# --------------------------------------------------------------------------- +# insert method (called by MutableSequence.append, but test edge case) +# --------------------------------------------------------------------------- + +def test_insert_at_specific_index(setup_pars): + """Insert an item at a specific index (not just via append).""" + name = setup_pars.pop('name') + coll = CollectionBase(name, **setup_pars) + + n_before = len(coll) + new_item = Parameter('inserted', 42.0) + + coll.insert(2, new_item) + + assert len(coll) == n_before + 1 + assert coll[2].name == 'inserted' + assert coll[2].value == 42.0 diff --git a/tests/unit/legacy/test_obj_base.py b/tests/unit/legacy/test_obj_base.py new file mode 100644 index 00000000..650f0579 --- /dev/null +++ b/tests/unit/legacy/test_obj_base.py @@ -0,0 +1,119 @@ +# SPDX-FileCopyrightText: 2026 EasyScience contributors +# SPDX-License-Identifier: BSD-3-Clause + +"""Tests for easyscience.legacy.obj_base.ObjBase. + +These tests cover methods not exercised by the existing +tests/unit/base_classes/test_obj_base.py suite. +""" + +import warnings +from typing import ClassVar + +import pytest + +from easyscience import Parameter +from easyscience import global_object +from easyscience.legacy.obj_base import ObjBase + + +@pytest.fixture(autouse=True) +def _clear_map(): + global_object.map._clear() + yield + global_object.map._clear() + + +# --------------------------------------------------------------------------- +# Deprecation warning on instantiation +# --------------------------------------------------------------------------- + +def test_instantiation_emits_deprecation_warning(): + """Instantiating ObjBase should emit a DeprecationWarning.""" + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter('always') + ObjBase('test') + deprecation_warnings = [x for x in w if issubclass(x.category, DeprecationWarning)] + assert len(deprecation_warnings) >= 1 + assert 'deprecated' in str(deprecation_warnings[0].message).lower() + + +# --------------------------------------------------------------------------- +# __repr__ +# --------------------------------------------------------------------------- + +def test_repr(): + """Verify __repr__ returns the expected format.""" + obj = ObjBase('my_name') + r = repr(obj) + assert "ObjBase" in r + assert "my_name" in r + + +# --------------------------------------------------------------------------- +# __setattr__ with interface bindings regeneration +# --------------------------------------------------------------------------- + +def test_setattr_with_interface_calls_generate_bindings(): + """When replacing an attribute that has an interface, generate_bindings is called.""" + from unittest.mock import Mock + + class A(ObjBase): + a: ClassVar[Parameter] + + def __init__(self, a: Parameter): + super().__init__('A', a=a) + + p1 = Parameter('a', 1.0) + a = A(p1) + + # Attach a mock interface + mock_iface = Mock() + a.interface = mock_iface + + # Replace the parameter — should trigger generate_bindings via __setattr__ + p2 = Parameter('a', 2.0) + a.a = p2 + + mock_iface.generate_bindings.assert_called() + + +# --------------------------------------------------------------------------- +# __setattr__ without annotation (the else branch) +# --------------------------------------------------------------------------- + +def test_setattr_without_annotation(): + """Setting a non-annotated BasedBase/DescriptorBase attribute updates graph.""" + + class A(ObjBase): + def __init__(self, p: Parameter): + super().__init__('A', p=p) + + p1 = Parameter('p', 1.0) + a = A(p1) + + graph = global_object.map + edges_before = set(graph.get_edges(a)) + + # Replace the parameter with a new one + p2 = Parameter('p', 2.0) + a.p = p2 + + edges_after = set(graph.get_edges(a)) + assert edges_before != edges_after + assert p2.unique_name in edges_after + assert p1.unique_name not in edges_after + + +# --------------------------------------------------------------------------- +# __setter with DescriptorBase path +# --------------------------------------------------------------------------- + +def test_setter_sets_descriptor_value(): + """When setting a Descriptor via the logged property, the descriptor's value is updated.""" + from easyscience import DescriptorNumber + + d = DescriptorNumber('d1', 0.5) + obj = ObjBase('test', d1=d) + obj.d1 = 3.14 + assert obj.d1.value == 3.14