Source code for audobject.core.resolver

import ast
import datetime
import inspect
import os
import textwrap
import types
import typing
import warnings

import audeer
import audobject.core.define as define


DefaultValueType = typing.Union[
    bool,
    datetime.datetime,
    dict,
    float,
    int,
    list,
    None,
    str,
]


[docs]class Base: r"""Abstract resolver class. Implement for arguments that are not one of: * ``bool`` * ``datetime.datetime`` * ``dict`` * ``float`` * ``int`` * ``list`` * ``None`` * ``Object`` * ``str`` """ def __init__(self): self.__dict__[define.ROOT_ATTRIBUTE] = None @property def root(self) -> typing.Optional[str]: r"""Root folder. Returns root folder when object is serialized to or from a file, otherwise ``None`` is returned. Returns: root directory """ return self.__dict__[define.ROOT_ATTRIBUTE]
[docs] def decode(self, value: DefaultValueType) -> typing.Any: r"""Decode value. Takes the encoded value and converts it back to its original type. Args: value: value to decode Returns: decoded value """ raise NotImplementedError # pragma: no cover
[docs] def encode(self, value: typing.Any) -> DefaultValueType: r"""Encode value. The type of the returned value must be one of: * ``bool`` * ``datetime.datetime`` * ``dict`` * ``float`` * ``int`` * ``list`` * ``None`` * ``Object`` * ``str`` Args: value: value to encode Returns: encoded value """ raise NotImplementedError # pragma: no cover
[docs] def encode_type(self) -> type: r"""Return encoded type. Returns: encoded type """ raise NotImplementedError # pragma: no cover
[docs]class FilePath(Base): r"""File path resolver. Turns file path to a relative path when object is serialized to a file and expands it again during reading. Examples: >>> resolver = FilePath() >>> resolver._object_root_ = '/some/root' # usually set by object >>> value = '/some/where/else' >>> encoded_value = resolver.encode(value) >>> encoded_value # doctest: +SKIP '../where/else' >>> resolver.decode(encoded_value) # doctest: +SKIP '/some/where/else' """
[docs] def decode(self, value: str) -> str: r"""Decode file path. If object is read from a file, this will convert a relative file path to an absolute path by expanding it with the source directory. Args: value: relative file path Returns: expanded file path """ if self.root is not None: root = self.root value = os.path.join(root, value) value = audeer.safe_path(value) return value
[docs] def encode(self, value: str) -> str: r"""Encode file path. If object is written to a file, this will convert a file path to a path that is relative to the target directory. Args: value: original file path Returns: relative file path """ if self.root is not None: root = self.root value = os.path.relpath(value, root) return value
[docs] def encode_type(self) -> type: r"""Return encoded type. Returns: encoded type """ return str
[docs]class Function(Base): r"""Function resolver. Encodes source code of function and dynamically evaluates it when the value is decoded again. Note that a decoded function might raise a :class:`NameError`, if it relies on objects or functions that are not defined or imported inside the function. For instance, the following example will raise an error since ``plus_1()`` relies on ``_plus_1()``, which is defined outside the function: .. code-block:: python def _plus_1(x): return x + 1 def plus_1(x): return _plus_1(x) # calls local function -> not serializable resolver = Function() encoded_value = resolver.encode(plus_1) del _plus_1 decoded_value = resolver.decode(encoded_value) decoded_value(1) Examples: >>> resolver = Function() >>> def func(x): ... return x * 2 >>> func(5) 10 >>> encoded_value = resolver.encode(func) >>> encoded_value 'def func(x):\n return x * 2\n' >>> decoded_value = resolver.decode(encoded_value) >>> decoded_value(5) 10 """
[docs] def decode(self, value: str) -> typing.Callable: r"""Decode (lambda) function. Args: value: source code Returns: function object """ func = None # We must dynamically create the function # from the original source code we stored in YAML. # For a regular function we can do this # by calling ``exec()`` with a local namespace directory. # This will create the function in the namespace # from where we can return it. # This preserve defaults and keyword-only arguments. # For lambda expression this is not possible, # as we would end up with an empty namespace # (a lambda has no name!). # Therefore we first compile the code # and then use ``types.FunctionType()`` # to create the function object. # This does not preserve defaults and keyword-only arguments, # but fortunately this is not relevant for lambda expressions. if value.startswith('lambda'): code = compile(value, '<string>', 'exec') for var in code.co_consts: if isinstance(var, types.CodeType): func = types.FunctionType(var, globals()) else: namespace = {} exec(value, globals(), namespace) func_name = next(iter(namespace)) func = namespace[func_name] # we cannot inspect the source code of # dynamically defined functions so we attach it func.__source__ = value return func
[docs] def encode( self, value: typing.Callable, ) -> typing.Union[str, object]: r"""Encode (lambda) function. Args: value: function object Returns: source code """ from audobject.core.object import Object if isinstance(value, types.FunctionType): return self.get_source(value) elif isinstance(value, Object): return value else: raise ValueError( "Cannot decode object " "if it does not derive from " "'audobject.Object'." )
[docs] def encode_type(self) -> type: r"""Returns encoded type. Returns: encoded type """ return str
[docs] def get_source(self, func: typing.Callable) -> str: r"""Obtain source code of (lambda) function. Retrieving the source of a lambda function can become tricky, see the following link for detailed discussion: http://xion.io/post/code/python-get-lambda-code.html Args: func: function object Returns: source code """ # check if source code is attached # otherwise use inspect to get it if hasattr(func, '__source__'): return func.__source__ else: if func.__name__ == "<lambda>": source = self._get_short_lambda_source(func) else: source = inspect.getsource(func) return textwrap.dedent(source)
@staticmethod def _get_short_lambda_source( lambda_func: typing.Callable, ): # pragma: no cover """Return the source of a (short) lambda function. If it's impossible to obtain, returns None. Original code: https://gist.github.com/Xion/617c1496ff45f3673a5692c3b0e3f75a """ try: source_lines, _ = inspect.getsourcelines(lambda_func) except (IOError, TypeError): return None # skip `def`-ed functions and long lambdas if len(source_lines) != 1: return None source_text = os.linesep.join(source_lines).strip() # find the AST node of a lambda definition # so we can locate it in the source code source_ast = ast.parse(source_text) lambda_node = next((node for node in ast.walk(source_ast) if isinstance(node, ast.Lambda)), None) if lambda_node is None: # could be a single line `def fn(x): ...` return None # HACK: Since we can (and most likely will) get source lines # where lambdas are just a part of bigger expressions, they will have # some trailing junk after their definition. # # Unfortunately, AST nodes only keep their _starting_ offsets # from the original source, so we have to determine the end ourselves. # We do that by gradually shaving extra junk from after the definition. lambda_text = source_text[lambda_node.col_offset:] lambda_body_text = source_text[lambda_node.body.col_offset:] min_length = len('lambda:_') # shortest possible lambda expression while len(lambda_text) > min_length: try: # What's annoying is that sometimes the junk even parses, # but results in a *different* lambda. You'd probably have to # be deliberately malicious to exploit it but here's one way: # # bloop = lambda x: False, lambda x: True # get_short_lamnda_source(bloop[0]) # # Ideally, we'd just keep shaving until we get the same code, # but that most likely won't happen because we can't replicate # the exact closure environment. code = compile(lambda_body_text, '<unused filename>', 'eval') # Thus the next best thing is to assume some divergence due # to e.g. LOAD_GLOBAL in original code being LOAD_FAST in # the one compiled above, or vice versa. # But the resulting code should at least be the same *length* # if otherwise the same operations are performed in it. if len(code.co_code) == len(lambda_func.__code__.co_code): return lambda_text except SyntaxError: pass lambda_text = lambda_text[:-1] lambda_body_text = lambda_body_text[:-1] return None
[docs]class Tuple(Base): r"""Tuple resolver. Encodes tuple as a list. Examples: >>> resolver = Tuple() >>> value = (1, 'a') >>> value (1, 'a') >>> encoded_value = resolver.encode(value) >>> encoded_value [1, 'a'] >>> decoded_value = resolver.decode(encoded_value) >>> decoded_value (1, 'a') """
[docs] def decode(self, value: list) -> tuple: r"""Decodes ``list`` as ``tuple``. Args: value: list Returns: tuple """ return tuple(value)
[docs] def encode(self, value: tuple) -> list: r"""Encodes ``tuple`` as ``list``. Args: value: tuple Returns: list """ return list(value)
[docs] def encode_type(self) -> type: r"""Return encoded type. Returns: encoded type """ return list
[docs]class Type(Base): r"""Type resolver. Encodes type as a string. Examples: >>> resolver = Type() >>> value = str >>> value <class 'str'> >>> encoded_value = resolver.encode(value) >>> encoded_value 'str' >>> decoded_value = resolver.decode(encoded_value) >>> decoded_value <class 'str'> """
[docs] def decode(self, value: str) -> type: r"""Decodes ``str`` as ``type``. Args: value: type string Returns: type """ return eval(value)
[docs] def encode(self, value: type) -> str: r"""Encodes ``type`` as ``str``. Args: value: type class Returns: string """ return str(value)[len("<class '"):-len("'>")]
[docs] def encode_type(self) -> type: r"""Return encoded type. Returns: encoded type """ return str
# deprecated classes # @audeer.deprecated( # removal_version='1.0.0', # alternative='resolver.Base', # ) # -> # TypeError: function() argument 1 must be code, not str # -> # as a workaround we raise the deprecation warning in __init__ class ValueResolver: # pragma: no cover # noqa: D101 def __init__(self): message = ( 'ValueResolver is deprecated and will be removed ' 'with version 1.0.0. Use resolver.Base instead.' ) warnings.warn(message, category=UserWarning, stacklevel=2) self.__dict__[define.ROOT_ATTRIBUTE] = None @property def root(self) -> typing.Optional[str]: # noqa: D102 return self.__dict__[define.ROOT_ATTRIBUTE] def decode(self, value: DefaultValueType) -> typing.Any: # noqa: D102 raise NotImplementedError def encode(self, value: typing.Any) -> DefaultValueType: # noqa: D102 raise NotImplementedError def encode_type(self) -> type: # noqa: D102 raise NotImplementedError @audeer.deprecated( removal_version='1.0.0', alternative='resolver.FilePath', ) class FilePathResolver(FilePath): # pragma: no cover # noqa: D101 pass @audeer.deprecated( removal_version='1.0.0', alternative='resolver.Function', ) class FunctionResolver(Function): # pragma: no cover # noqa: D101 pass @audeer.deprecated( removal_version='1.0.0', alternative='resolver.Tuple', ) class TupleResolver(Tuple): # pragma: no cover # noqa: D101 pass @audeer.deprecated( removal_version='1.0.0', alternative='resolver.Type', ) class TypeResolver(Type): # pragma: no cover # noqa: D101 pass