from __future__ import annotations
from collections.abc import Mapping
from collections.abc import Sequence
import inspect
import io
import os
import warnings
import oyaml as yaml
import audeer
from audobject.core import define
from audobject.core import resolver
from audobject.core import utils
[docs]class Object:
r"""Base class for objects that can be serialized to YAML.
Examples:
>>> class Foo(Object):
... def __init__(self, bar: str):
... self.bar = bar
>>> foo = Foo("hello object!")
>>> print(foo)
$audobject.core.object.Foo:
bar: hello object!
"""
def __init__(
self,
**kwargs,
):
self.__dict__[define.KEYWORD_ARGUMENTS] = list(kwargs)
@property
def arguments(self) -> Mapping[str, object]:
r"""Returns arguments that are serialized.
Returns:
Dictionary of arguments and their values.
Raises:
RuntimeError: if arguments are found that are not assigned to
attributes of the same name
Examples:
>>> import audobject.testing
>>> o = audobject.testing.TestObject("test", point=(1, 1))
>>> o.arguments
{'name': 'test', 'point': (1, 1)}
""" # noqa: E501
signature = inspect.signature(self.__init__)
# non-keyword arguments from __init__
names = [p.name for p in signature.parameters.values() if p.name != "kwargs"]
# additional keyword arguments
if define.KEYWORD_ARGUMENTS in self.__dict__:
names.extend(self.__dict__[define.KEYWORD_ARGUMENTS])
# remove hidden and borrowed attributes
borrowed = self.borrowed_arguments
hidden = self.hidden_arguments
names = [
name
for name in names
if (
(name != "self")
and (not name.startswith("_"))
and (name not in hidden)
and (name not in borrowed.values())
)
]
# check for missing attributes
if missing := [
name
for name in names
if (name not in self.__dict__) and (name not in borrowed)
]:
raise RuntimeError(
"Arguments "
f"{missing} "
"of "
f"{self.__class__} "
"not assigned to attributes of same name."
)
# check borrowed attributes
for key, value in borrowed.items():
can_borrow = False
if hasattr(self, value):
if hasattr(self.__dict__[value], key):
can_borrow = True
elif isinstance(self.__dict__[value], Mapping):
can_borrow = True
if not can_borrow:
raise RuntimeError(
"Cannot borrow attribute " f"'{key}' " "from " f"'self.{value}'."
)
# pick arguments from self and borrowed attributes
args = {name: self.__dict__[name] for name in names if (name in self.__dict__)}
for key, value in borrowed.items():
if hasattr(self.__dict__[value], key):
args[key] = self.__dict__[value].__dict__[key]
else:
args[key] = self.__dict__[value][key]
return args
@property
def borrowed_arguments(self) -> Mapping[str, str]:
r"""Returns borrowed arguments.
Returns:
Dictionary with borrowed arguments.
"""
if define.BORROWED_ATTRIBUTES in self.__dict__:
borrowed = self.__dict__[define.BORROWED_ATTRIBUTES]
else:
borrowed = {}
return borrowed
@property
def hidden_arguments(self) -> Sequence[str]:
r"""Returns hidden arguments.
Returns:
List with names of hidden arguments.
"""
if define.HIDDEN_ATTRIBUTES in self.__dict__:
names = self.__dict__[define.HIDDEN_ATTRIBUTES]
else:
names = []
return names
@property
def id(self) -> str:
r"""Object identifier.
The ID of an object ID is created from its non-hidden arguments.
Returns:
object identifier
Examples:
>>> class Foo(Object):
... def __init__(self, bar: str):
... self.bar = bar
>>> foo1 = Foo("I am unique!")
>>> foo1.id
'893df240-babe-d796-cdf1-c436171b7a96'
>>> foo2 = Foo("I am different!")
>>> foo2.id
'9303f2a5-bfc9-e5ff-0ffa-a9846e2d2190'
>>> foo3 = Foo("I am unique!")
>>> foo1.id == foo3.id
True
"""
string = self.to_yaml_s(include_version=False)
return audeer.uid(from_string=string)
@property
def is_loaded_from_dict(self) -> bool:
r"""Check if object was loaded from a dictionary.
Returns ``True``
if object was initialized
from a dictionary,
e.g. after loading it from a YAML file.
Returns:
``True`` if object was loaded from a dictionary,
otherwise ``False``
"""
return define.OBJECT_LOADED in self.__dict__
[docs] @staticmethod
@audeer.deprecated(
removal_version="1.0.0",
alternative="audobject.from_dict",
)
def from_dict( # noqa: D102
d: Mapping[str, object],
root: str = None,
**kwargs,
) -> Object: # pragma: no cover
from audobject.core.api import from_dict
return from_dict(d, root, **kwargs)
[docs] @staticmethod
@audeer.deprecated(
removal_version="1.0.0",
alternative="audobject.from_yaml",
)
def from_yaml( # noqa: D102
path_or_stream: str | io.IOBase,
**kwargs,
) -> Object: # pragma: no cover
from audobject.core.api import from_yaml
return from_yaml(path_or_stream, **kwargs)
[docs] @staticmethod
@audeer.deprecated(
removal_version="1.0.0",
alternative="audobject.from_yaml_s",
)
def from_yaml_s( # noqa: D102
yaml_string: str,
**kwargs,
) -> Object: # pragma: no cover
from audobject.core.api import from_yaml_s
return from_yaml_s(yaml_string, **kwargs)
@property
def resolvers(self) -> Mapping[str, resolver.Base]:
r"""Return resolvers.
Returns:
Dictionary with resolvers.
"""
if hasattr(self, define.CUSTOM_VALUE_RESOLVERS):
return self.__dict__[define.CUSTOM_VALUE_RESOLVERS]
return {}
@property
def short_id(self) -> str:
r"""Short object identifier.
The short ID consists of eight characters
and is created from its non-hidden arguments.
Returns:
short object identifier
Examples:
>>> class Foo(Object):
... def __init__(self, bar: str):
... self.bar = bar
>>> foo1 = Foo("I am unique!")
>>> foo1.id
'893df240-babe-d796-cdf1-c436171b7a96'
>>> foo1.short_id
'171b7a96'
>>> foo2 = Foo("I am different!")
>>> foo2.short_id
'6e2d2190'
>>> foo3 = Foo("I am unique!")
>>> foo1.short_id == foo3.short_id
True
"""
string = self.to_yaml_s(include_version=False)
return audeer.uid(from_string=string, short=True)
[docs] def to_dict(
self,
*,
include_version: bool = True,
flatten: bool = False,
root: str = None,
) -> Mapping[str, resolver.DefaultValueType]:
r"""Converts object to a dictionary.
Includes items from :attr:`audobject.Object.arguments`.
If an argument has a resolver, its value is encoded.
Usually, the object can be re-instantiated using
:meth:`audobject.Object.from_dict`.
However, if ``flatten=True``, this is not possible.
Args:
include_version: add version to class name
flatten: flatten the dictionary
root: if file is written to disk, set to target directory
Returns:
dictionary that represent the object
Examples:
>>> import audobject.testing
>>> o = audobject.testing.TestObject("test", point=(1, 1))
>>> o.to_dict(include_version=False)
{'$audobject.core.testing.TestObject': {'name': 'test', 'point': [1, 1]}}
>>> o.to_dict(flatten=True)
{'name': 'test', 'point.0': 1, 'point.1': 1}
""" # noqa: E501
name = utils.create_class_key(self.__class__, include_version)
d = {
key: self._encode_variable(
key,
value,
include_version,
root,
)
for key, value in self.arguments.items()
}
return Object._flatten(d) if flatten else {name: d}
[docs] def to_yaml(
self,
path_or_stream: str | io.IOBase,
*,
include_version: bool = True,
):
r"""Save object to YAML file.
Args:
path_or_stream: file path or stream
include_version: add version to class name
"""
if isinstance(path_or_stream, str):
path_or_stream = audeer.safe_path(path_or_stream)
root = os.path.dirname(path_or_stream)
audeer.mkdir(root)
with open(path_or_stream, "w") as fp:
return self.to_yaml(fp, include_version=include_version)
else:
return yaml.dump(
self.to_dict(
include_version=include_version,
root=os.path.dirname(path_or_stream.name),
),
path_or_stream,
)
[docs] def to_yaml_s(
self,
*,
include_version: bool = True,
) -> str:
r"""Convert object to YAML string.
Args:
include_version: add version to class name
Returns:
YAML string
Examples:
>>> import audobject.testing
>>> o = audobject.testing.TestObject("test", point=(1, 1))
>>> print(o.to_yaml_s(include_version=False))
$audobject.core.testing.TestObject:
name: test
point:
- 1
- 1
""" # noqa: E501
return yaml.dump(self.to_dict(include_version=include_version))
def _encode_variable(
self,
name: str,
value: object,
include_version: bool,
root: str | None,
):
r"""Encode value.
Encode a value
by first looking for a custom resolver,
otherwise switch to default encoder.
"""
value = self._resolve_value(name, value, root)
return Object._encode_value(value, include_version)
@staticmethod
def _encode_value(
value: object,
include_version: bool,
):
r"""Default value encoder."""
if value is None:
return None
elif isinstance(value, Object) or utils.is_class(value):
return value.to_dict(include_version=include_version)
elif isinstance(value, define.DEFAULT_VALUE_TYPES):
return value
elif isinstance(value, list):
return [Object._encode_value(item, include_version) for item in value]
elif isinstance(value, dict):
return {
Object._encode_value(key, include_version): Object._encode_value(
val, include_version
)
for key, val in value.items()
}
elif callable(value):
raise RuntimeError(f"Cannot encode type '{type(value)}'.")
else:
warnings.warn(
f"No default encoding exists for type '{type(value)}'. "
f"Consider to register a custom resolver.",
RuntimeWarning,
)
return value
@staticmethod
def _flatten(
d: Mapping[str, resolver.DefaultValueType],
):
r"""Flattens a dictionary."""
def helper(dict_in: dict, dict_out: dict, prefix: str):
for key, value in dict_out.items():
if utils.is_class(key):
helper(dict_in, value, prefix)
elif isinstance(value, (list, tuple)):
helper(
dict_in,
{idx: item for idx, item in enumerate(value)},
f"{prefix}.{key}" if prefix else key,
)
elif isinstance(value, dict):
helper(
dict_in,
value,
f"{prefix}.{key}" if prefix else key,
)
else:
if prefix:
key = f"{prefix}.{key}"
dict_in[key] = value
ret = {}
helper(ret, d, "")
return ret
def _resolve_value(
self,
name: str,
value: object,
root: str | None,
) -> resolver.DefaultValueType:
if value is not None and name in self.resolvers:
# let resolver know if we write to a stream
self.resolvers[name].__dict__[define.ROOT_ATTRIBUTE] = root
value = self.resolvers[name].encode(value)
return value
[docs] def __hash__(self) -> int:
r"""Hash based on object ID."""
return hash(self.id)
def __eq__(self, other: "Object") -> bool:
r"""Check if two objects are equal."""
if isinstance(other, type(self)):
return self.id == other.id
else:
# it might happen that we need to compare
# to other non-supported types, compare
# https://github.com/audeering/audinterface/issues/68
return False
def __repr__(self) -> str: # noqa: D105
return str(self.to_dict(include_version=False))
def __str__(self) -> str: # noqa: D105
return self.to_yaml_s(include_version=False)