Source code for deluca.envs.classic._planar_quadrotor

# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from os import path

import jax
import jax.numpy as jnp
import numpy as ojnp
from scipy.optimize._numdiff import approx_derivative as scipy_approx_derivative

from deluca.envs.core import Env


def approx_derivative(f, x, method):
    if "alg" in method:
        if method["alg"] in ["2-point", "3-point", "cs"]:  # grid-based finite differences
            if "rel_step" in method:
                rel_step = method["rel_step"]
            else:
                rel_step = None
            # print(jnp.linalg.norm(x))
            return scipy_approx_derivative(f, x, method["alg"], rel_step=rel_step)

        elif method["alg"] == "es":  # Monte Carlo Gaussian smoothing derivative
            n = method["n"]
            sigma = method["sigma"]

            d_in = len(x)
            G = ojnp.random.randn(n, d_in)
            J_est = None

            for g in G:
                J_samp = ojnp.outer(f(x + sigma * g) - f(x - sigma * g), g) / 2
                if J_est is None:
                    J_est = J_samp
                else:
                    J_est += J_samp

            return J_est / (n * sigma)

    raise NotImplementedError


[docs]class PlanarQuadrotor(Env): metadata = {"render.modes": ["human", "rgb_array"], "video.frames_per_second": 30}
[docs] def __init__(self, placebo=0, autodiff=True, method="3-point"): self.nsamples = 0 self.m, self.l, self.g, self.dt, self.H = 0.1, 0.2, 9.81, 0.05, 100 self.initial_state, self.goal_state, self.goal_action = ( jnp.array([1.0, 1.0, 0.0, 0.0, 0.0, 0.0]), jnp.array([0.0, 0.0, 0.0, 0.0, 0.0, 0.0]), jnp.array([self.m * self.g / 2.0, self.m * self.g / 2.0]), ) self.goal_action = jnp.hstack((self.goal_action, jnp.zeros(placebo))) self.viewer = None self.action_dim, self.state_dim = 2 + placebo, 6 self.last_u = jnp.zeros((2,)) def f(x, u): self.nsamples += 1 state = x x, y, th, xdot, ydot, thdot = state u1, u2 = u[:2] m, g, l, dt = self.m, self.g, self.l, self.dt xddot = -(u1 + u2) * jnp.sin(th) / m yddot = (u1 + u2) * jnp.cos(th) / m - g thddot = l * (u2 - u1) / (m * l ** 2) state_dot = jnp.array([xdot, ydot, thdot, xddot, yddot, thddot]) new_state = state + state_dot * dt return new_state def c(x, u): return 0.1 * (u - self.goal_action) @ (u - self.goal_action) + (x - self.goal_state) @ ( x - self.goal_state ) def f_x(x, u): return approx_derivative(lambda x: f(x, u), x, method=method) def f_u(x, u): return approx_derivative(lambda u: f(x, u), u, method=method) def c_x(x, u): return 2 * (x - self.goal_state) def c_u(x, u): return 2 * 0.1 * (u - self.goal_action) def c_xx(x, u): return 2 * jnp.eye(self.state_dim) def c_uu(x, u): return 2 * 0.1 * jnp.eye(self.action_dim) if autodiff: self.f, self.f_x, self.f_u = ( f, jax.jacfwd(f, argnums=0), jax.jacfwd(f, argnums=1), ) self.c, self.c_x, self.c_u, self.c_xx, self.c_uu = ( c, jax.grad(c, argnums=0), jax.grad(c, argnums=1), jax.hessian(c, argnums=0), jax.hessian(c, argnums=1), ) else: self.f, self.f_x, self.f_u = f, f_x, f_u self.c, self.c_x, self.c_u, self.c_xx, self.c_uu = c, c_x, c_u, c_xx, c_uu
[docs] def reset(self): self.state, self.last_u, self.h = self.initial_state, None, 0 return self.state
[docs] def step(self, u): self.last_u, self.h = u, self.h + 1 self.state = self.f(self.state, u) return self.state, self.c(self.state, u), False, {}
[docs] def render(self, mode="human"): if self.viewer is None: from gym.envs.classic_control import rendering self.viewer = rendering.Viewer(1000, 1000) self.viewer.set_bounds(-1.2, 1.2, -1.2, 1.2) fname = path.dirname(__file__) + "/assets/drone.png" self.img = rendering.Image(fname, 0.4, 0.17) self.img.set_color(1.0, 1.0, 1.0) self.imgtrans = rendering.Transform() self.img.add_attr(self.imgtrans) fnamewind = path.dirname(__file__) + "/assets/wind.png" self.imgwind = rendering.Image(fnamewind, 2.4, 2.4) self.imgwind.set_color(0.5, 0.5, 0.5) self.imgtranswind = rendering.Transform() self.imgwind.add_attr(self.imgtranswind) self.viewer.add_onetime(self.imgwind) self.viewer.add_onetime(self.img) self.imgtrans.set_translation(self.state[0], self.state[1] + 0.04) self.imgtrans.set_rotation(self.state[2]) return self.viewer.render(return_rgb_array=mode == "rgb_array")
[docs] def close(self): if self.viewer: self.viewer.close() self.viewer = None