Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Have a decorator to wrap universal functions ? #34

Open
eserie opened this issue Apr 18, 2021 · 6 comments · May be fixed by #41
Open

Have a decorator to wrap universal functions ? #34

eserie opened this issue Apr 18, 2021 · 6 comments · May be fixed by #41

Comments

@eserie
Copy link
Contributor

eserie commented Apr 18, 2021

In order to simplify the writting of universal functions it could be great to have a decorator function which hide the technical part of the code (convertion of input and output of the wrapped function/method).
For example, the code:

def my_universal_function(a, b, c):
    # Convert all inputs to EagerPy tensors
    a, b, c = ep.astensors(a, b, c)

    # performs some computations
    result = (a + b * c).square()

    # and return a native tensor
    return result.raw

would become:

@eager_function
def my_universal_function(a, b, c):
    return (a + b * c).square()

In addition, we could add the feature that if the input tensors are already eagerpy tensors, then no convertion to raw format should done on the output tensors.

I wrote a prototype of such a decorator function. It should not work on any type of arguments and so its usage would require that the wrapped function has a rather "simple" signature (with args and kwargs constituted of tensors or nested containers with tensors on leaves: dict, list, tuple or namedtuple like containers).

Would you consider to have this feature in eagerpy?

@jonasrauber
Copy link
Owner

Would you consider to have this feature in eagerpy?

Yes, a nice generic decorator that can handle arbitrary number of arguments (and return values) would be great. I am pretty sure I thought about this before, but I cannot recall why I didn't do it.

In addition, we could add the feature that if the input tensors are already eagerpy tensors, then no convertion to raw format should done on the output tensors.

Have you seen the ep.astensor_ and ep.astensors_ functions (with the underscore)? They already do exactly that: https://eagerpy.jonasrauber.de/guide/generic-functions.html (see the examples at the end).

@eserie
Copy link
Contributor Author

eserie commented Apr 19, 2021

Thanks for your response!
No, unfortunately I didn’t have seen the functions astensor_ and astensors_ (only astensor). It should definitively be a good starting point! I think I have a POC for a version of that function which could manage more general formats for inputs/outputs. I can make a try for integration in eagerpy and propose a PR in coming days if you are ok.

@eserie
Copy link
Contributor Author

eserie commented Apr 20, 2021

I show bellow a first POC that I wrote for the wrapper function eager_function (which is working).
At this stage, it seems not totally trivial to me how to integrate it the code base.
I would appreciate to have a first feedback from this code in order to know if we can go further in this direction.

import numbers
from collections import defaultdict
from functools import wraps
from typing import Any

import eagerpy as ep


def _tuple_as(template, data):
    data = list(data)
    try:
        # list, tuple case
        return type(template)(data)
    except TypeError:
        # named tuple case
        return type(template)(*data)


def _dict_as(template, data):
    """Create dictionary like data structure from template object.
    Parameters
    ----------
    template
        objecti used as template
    data
        data used to fill the created object.
    """
    if isinstance(template, defaultdict):
        return type(template)(template.default_factory, data)
    return type(template)(data)


def as_eager_tensors(data: Any) -> (Any, bool):
    return as_eager_tensors_(data)[0]


def as_eager_tensors_(data: Any) -> (Any, bool):
    """Convert to eagerpy tensors.
    Parameters
    ----------
    data : (tuple, list, dict, namedtuple, defaultdict)
        data structure to convert

    Returns
    -------
    unwrap : bool
        if True, it means that the tensors have been converted
        to eagerpy tensors.

    """
    if isinstance(data, dict):
        # dict, defaultdict
        if not data:
            return data, None
        keys, res_values, unwrap_values = zip(
            *[(dim,) + as_eager_tensors_(var) for dim, var in data.items()]
        )
        unwrap = True in unwrap_values
        return _dict_as(data, dict(zip(keys, res_values))), unwrap
    elif isinstance(data, (list, tuple)):
        if not data:
            return data, None

        res_values, unwrap_values = zip(*[as_eager_tensors_(var) for var in data])
        unwrap = True in unwrap_values
        try:
            res = type(data)(res_values)
        except TypeError:
            res = type(data)(*res_values)
        return res, unwrap

    elif isinstance(data, ep.Tensor):
        return data, False
    elif isinstance(data, np.datetime64):
        # datetime not managed by ep.tensors
        return data, False
    elif isinstance(data, numbers.Number):
        return data, False
    return ep.astensor(data), True


def as_raw_tensors(data):
    """Convert from eager tensors to raw tensors.

    Parameters
    ----------
    data
        data to convert

    """
    if isinstance(data, dict):
        return _dict_as(data, {dim: as_raw_tensors(var) for dim, var in data.items()})
    elif isinstance(data, (list, tuple)):
        return _tuple_as(data, (as_raw_tensors(var) for var in data))

    if isinstance(data, ep.Tensor):
        return data.raw
    else:
        return data


def restore_tensor_type(data: Any, unwrap: bool) -> Any:
    if unwrap:
        return as_raw_tensors(data)
    else:
        return data


def eager_function(func):
    @wraps(func)
    def eager_func(*args, **kwargs):
        self = None
        if len(func.__qualname__.split(".")) > 1:
            args = list(args)
            self = args.pop(0)
        args, args_unwrap = as_eager_tensors_(args)
        kwargs, kwargs_unwrap = as_eager_tensors_(kwargs)
        unwrap = args_unwrap or kwargs_unwrap
        if self:
            args = [self] + args
        result = func(*args, **kwargs)
        return restore_tensor_type(result, unwrap)

    return eager_func

@eserie
Copy link
Contributor Author

eserie commented Apr 22, 2021

Another possibility could be to use pytrees implemented in Jax. This should permit to handle more data structures and also to rely on the existing astensors_ implementation using flatten version of the inputs and outputs. However this would create a hard dependency with Jax in eagerpy while currently it's maybe optional.

@eserie eserie linked a pull request Apr 23, 2021 that will close this issue
@eserie
Copy link
Contributor Author

eserie commented Apr 23, 2021

I propose an implementation based on pytrees in #41.
This way to proceed imply few changes like no more register JAXTensor as a pytree datastructure and instead use jax pytree utils for more general datastructures manipulations in eagerpy. The new introduced datastructure convertion functions permit to factorize a bit the method JAXTensor._value_and_grad_fn (for which the initial registration of JAXTensor was tailored)

@eserie
Copy link
Contributor Author

eserie commented Apr 25, 2021

In fact, I think it's not a good idea to not register JAXTensor in pytrees, it should prevent to have compatibility with jax functionalities. I will restore that in an update of the review.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants