Source code for audobject.core.resolver

from __future__ import annotations

import ast
from collections.abc import Callable
import datetime
import inspect
import os
import textwrap
import types
import typing
import warnings

import asttokens

import audeer
import audobject.core.define as define


DefaultValueType = typing.Union[
    bool,
    datetime.datetime,
    dict,
    float,
    int,
    list,
    None,
    str,
]


[docs]class Base: r"""Abstract resolver class. Implement for arguments that are not one of: * ``bool`` * ``datetime.datetime`` * ``dict`` * ``float`` * ``int`` * ``list`` * ``None`` * ``Object`` * ``str`` """ def __init__(self): self.__dict__[define.ROOT_ATTRIBUTE] = None @property def root(self) -> str | None: r"""Root folder. Returns root folder when object is serialized to or from a file, otherwise ``None`` is returned. Returns: root directory """ return self.__dict__[define.ROOT_ATTRIBUTE]
[docs] def decode(self, value: DefaultValueType) -> object: r"""Decode value. Takes the encoded value and converts it back to its original type. Args: value: value to decode Returns: decoded value """ raise NotImplementedError # pragma: no cover
[docs] def encode(self, value: object) -> DefaultValueType: r"""Encode value. The type of the returned value must be one of: * ``bool`` * ``datetime.datetime`` * ``dict`` * ``float`` * ``int`` * ``list`` * ``None`` * ``Object`` * ``str`` Args: value: value to encode Returns: encoded value """ raise NotImplementedError # pragma: no cover
[docs] def encode_type(self) -> type: r"""Return encoded type. Returns: encoded type """ raise NotImplementedError # pragma: no cover
[docs]class FilePath(Base): r"""File path resolver. Turns file path to a relative path when object is serialized to a file and expands it again during reading. Examples: >>> resolver = FilePath() >>> resolver._object_root_ = "/some/root" # usually set by object >>> value = "/some/where/else" >>> encoded_value = resolver.encode(value) >>> encoded_value # doctest: +SKIP '../where/else' >>> resolver.decode(encoded_value) # doctest: +SKIP '/some/where/else' """
[docs] def decode(self, value: str) -> str: r"""Decode file path. If object is read from a file, this will convert a relative file path to an absolute path by expanding it with the source directory. Args: value: relative file path Returns: expanded file path """ if self.root is not None: root = self.root value = os.path.join(root, value) value = audeer.safe_path(value) return value
[docs] def encode(self, value: str) -> str: r"""Encode file path. If object is written to a file, this will convert a file path to a path that is relative to the target directory. Args: value: original file path Returns: relative file path """ if self.root is not None: root = self.root value = os.path.relpath(value, root) return value
[docs] def encode_type(self) -> type: r"""Return encoded type. Returns: encoded type """ return str
[docs]class Function(Base): r"""Function resolver. Encodes source code of function and dynamically evaluates it when the value is decoded again. Note that a decoded function might raise a :class:`NameError`, if it relies on objects or functions that are not defined or imported inside the function. For instance, the following example will raise an error since ``plus_1()`` relies on ``_plus_1()``, which is defined outside the function: .. code-block:: python def _plus_1(x): return x + 1 def plus_1(x): return _plus_1(x) # calls local function -> not serializable resolver = Function() encoded_value = resolver.encode(plus_1) del _plus_1 decoded_value = resolver.decode(encoded_value) decoded_value(1) Examples: >>> resolver = Function() >>> def func(x): ... return x * 2 >>> func(5) 10 >>> encoded_value = resolver.encode(func) >>> encoded_value 'def func(x):\n return x * 2\n' >>> decoded_value = resolver.decode(encoded_value) >>> decoded_value(5) 10 """
[docs] def decode(self, value: str) -> Callable: r"""Decode (lambda) function. Args: value: source code Returns: function object """ func = None # We must dynamically create the function # from the original source code we stored in YAML. # For a regular function we can do this # by calling ``exec()`` with a local namespace directory. # This will create the function in the namespace # from where we can return it. # This preserve defaults and keyword-only arguments. # For lambda expression this is not possible, # as we would end up with an empty namespace # (a lambda has no name!). # Therefore we first compile the code # and then use ``types.FunctionType()`` # to create the function object. # This does not preserve defaults and keyword-only arguments, # but fortunately this is not relevant for lambda expressions. if value.startswith("lambda"): code = compile(value, "<string>", "exec") for var in code.co_consts: if isinstance(var, types.CodeType): func = types.FunctionType(var, globals()) else: namespace = {} exec(value, globals(), namespace) func_name = next(iter(namespace)) func = namespace[func_name] # we cannot inspect the source code of # dynamically defined functions so we attach it func.__source__ = value return func
[docs] def encode( self, value: Callable, ) -> str | object: r"""Encode (lambda) function. Args: value: function object Returns: source code """ from audobject.core.object import Object if isinstance(value, types.FunctionType): return self.get_source(value) elif isinstance(value, Object): return value else: raise ValueError( "Cannot decode object if it does not derive from 'audobject.Object'." )
[docs] def encode_type(self) -> type: r"""Returns encoded type. Returns: encoded type """ return str
[docs] def get_source(self, func: Callable) -> str: r"""Obtain source code of (lambda) function. Retrieving the source of a lambda function can become tricky, see the following link for detailed discussion: http://xion.io/post/code/python-get-lambda-code.html Args: func: function object Returns: source code """ # check if source code is attached # otherwise use inspect to get it if hasattr(func, "__source__"): return func.__source__ else: if func.__name__ == "<lambda>": source = self._get_short_lambda_source(func) else: source = inspect.getsource(func) return textwrap.dedent(source)
@staticmethod def _get_short_lambda_source( lambda_func: Callable, ): # pragma: no cover """Return the source of a (short) lambda function. If it's impossible to obtain, returns None. """ try: source_lines, _ = inspect.getsourcelines(lambda_func) except (IOError, TypeError): return None # skip `def`-ed functions and long lambdas if len(source_lines) != 1: return None source_text = os.linesep.join(source_lines).strip() atok = asttokens.ASTTokens(source_text, parse=True) # Search for the first occurring lambda node in AST tree lambda_node = None for node in ast.walk(atok.tree): if isinstance(node, ast.Lambda): lambda_node = node break return None if lambda_node is None else atok.get_text(lambda_node)
[docs]class Tuple(Base): r"""Tuple resolver. Encodes tuple as a list. Examples: >>> resolver = Tuple() >>> value = (1, "a") >>> value (1, 'a') >>> encoded_value = resolver.encode(value) >>> encoded_value [1, 'a'] >>> decoded_value = resolver.decode(encoded_value) >>> decoded_value (1, 'a') """
[docs] def decode(self, value: list) -> tuple: r"""Decodes ``list`` as ``tuple``. Args: value: list Returns: tuple """ return tuple(value)
[docs] def encode(self, value: tuple) -> list: r"""Encodes ``tuple`` as ``list``. Args: value: tuple Returns: list """ return list(value)
[docs] def encode_type(self) -> type: r"""Return encoded type. Returns: encoded type """ return list
[docs]class Type(Base): r"""Type resolver. Encodes type as a string. Examples: >>> resolver = Type() >>> value = str >>> value <class 'str'> >>> encoded_value = resolver.encode(value) >>> encoded_value 'str' >>> decoded_value = resolver.decode(encoded_value) >>> decoded_value <class 'str'> """
[docs] def decode(self, value: str) -> type: r"""Decodes ``str`` as ``type``. Args: value: type string Returns: type """ return eval(value)
[docs] def encode(self, value: type) -> str: r"""Encodes ``type`` as ``str``. Args: value: type class Returns: string """ return str(value)[len("<class '") : -len("'>")]
[docs] def encode_type(self) -> type: r"""Return encoded type. Returns: encoded type """ return str
# deprecated classes # @audeer.deprecated( # removal_version='1.0.0', # alternative='resolver.Base', # ) # -> # TypeError: function() argument 1 must be code, not str # -> # as a workaround we raise the deprecation warning in __init__ class ValueResolver: # pragma: no cover # noqa: D101 def __init__(self): message = ( "ValueResolver is deprecated and will be removed " "with version 1.0.0. Use resolver.Base instead." ) warnings.warn(message, category=UserWarning, stacklevel=2) self.__dict__[define.ROOT_ATTRIBUTE] = None @property def root(self) -> str | None: # noqa: D102 return self.__dict__[define.ROOT_ATTRIBUTE] def decode(self, value: DefaultValueType) -> object: # noqa: D102 raise NotImplementedError def encode(self, value: object) -> DefaultValueType: # noqa: D102 raise NotImplementedError def encode_type(self) -> type: # noqa: D102 raise NotImplementedError @audeer.deprecated( removal_version="1.0.0", alternative="resolver.FilePath", ) class FilePathResolver(FilePath): # pragma: no cover # noqa: D101 pass @audeer.deprecated( removal_version="1.0.0", alternative="resolver.Function", ) class FunctionResolver(Function): # pragma: no cover # noqa: D101 pass @audeer.deprecated( removal_version="1.0.0", alternative="resolver.Tuple", ) class TupleResolver(Tuple): # pragma: no cover # noqa: D101 pass @audeer.deprecated( removal_version="1.0.0", alternative="resolver.Type", ) class TypeResolver(Type): # pragma: no cover # noqa: D101 pass