Source code for deluca.core

# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import os
import pickle

import jax


def tree_flatten(obj):
    """Flatten module parameters for Jax"""
    leaves, aux = jax.tree_util.tree_flatten(obj.attrs)
    aux = {
        "treedef": aux,
        "arguments_": obj.arguments_,
        "class": obj.__class__,
    }
    return leaves, aux


def tree_unflatten(aux, leaves):
    """Unflatten obj parameters for Jax"""
    obj = aux["class"](*aux["arguments_"].args, **aux["arguments_"].kwargs)
    attrs = jax.tree_util.tree_unflatten(aux["treedef"], leaves)
    for key, val in attrs.items():
        obj.__setattr__(key, val)
    return obj


[docs]class JaxObject:
[docs] def __new__(cls, *args, **kwargs): """For avoiding super().__init__()""" obj = object.__new__(cls) obj.__setattr__("attrs_", {}) obj.__setattr__("arguments_", inspect.signature(obj.__init__).bind(*args, **kwargs)) obj.arguments_.apply_defaults() for key, val in obj.arguments_.arguments.items(): obj.__setattr__(key, val) return obj
[docs] @classmethod def __init_subclass__(cls, *args, **kwargs): """For avoiding a decorator for each subclass""" super().__init_subclass__(*args, **kwargs) jax.tree_util.register_pytree_node(cls, tree_flatten, tree_unflatten)
@property def name(self): return self.__class__.__name__ @property def attrs(self): return self.attrs_
[docs] def __str__(self): return self.name
[docs] def __setattr__(self, key, val): if ( key[-1] != "_" and not callable(val) and key != "observation_space" and key != "action_space" and not isinstance(val, str) and val is not None ): self.attrs[key] = val self.__dict__[key] = val
def save(self, path): dirname = os.path.abspath(os.path.dirname(path)) if not os.path.exists(dirname): os.makedirs(dirname) with open(path, "wb") as file: pickle.dump(self, file) @classmethod def load(cls, path): return pickle.load(open(path, "rb")) def throw(self, err, msg): raise err(f"Class {self.name}: {msg}")