# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from os import path
import gym
import jax
import jax.numpy as jnp
import numpy as np
from deluca.envs.core import Env
from deluca.utils import Random
@jax.jit
def angle_normalize(x):
return ((x + jnp.pi) % (2 * jnp.pi)) - jnp.pi
def default_reward_fn(x, u):
return -(np.sum(angle_normalize(x[0]) ** 2 + 0.1 * x[1] ** 2 + 0.001 * (u ** 2)))
[docs]class Pendulum(Env):
max_speed = 8.0
max_torque = 2.0 # gym 2.
high = np.array([1.0, 1.0, max_speed])
action_space = gym.spaces.Box(low=-max_torque, high=max_torque, shape=(1,), dtype=np.float32)
observation_space = gym.spaces.Box(low=-high, high=high, dtype=np.float32)
[docs] def __init__(self, reward_fn=None, seed=0, horizon=50):
# self.reward_fn = reward_fn or default_reward_fn
self.dt = 0.05
self.viewer = None
self.state_size = 2
self.action_size = 1
self.action_dim = 1 # redundant with action_size but needed by ILQR
self.H = horizon
self.n, self.m = 2, 1
self.angle_normalize = angle_normalize
self.nsamples = 0
self.random = Random(seed)
self.reset()
# @jax.jit
def _dynamics(state, action):
self.nsamples += 1
th, thdot = state
g = 10.0
m = 1.0
ell = 1.0
dt = self.dt
# Do not limit the control signals
action = jnp.clip(action, -self.max_torque, self.max_torque)
newthdot = (
thdot + (-3 * g / (2 * ell) * jnp.sin(th + jnp.pi) + 3.0 / (m * ell ** 2) * action) * dt
)
newth = th + newthdot * dt
newthdot = jnp.clip(newthdot, -self.max_speed, self.max_speed)
return jnp.reshape(jnp.array([newth, newthdot]), (2,))
@jax.jit
def c(x, u):
# return np.sum(angle_normalize(x[0]) ** 2 + 0.1 * x[1] ** 2 + 0.001 * (u ** 2))
return angle_normalize(x[0])**2 + .1*(u[0]**2)
self.reward_fn = reward_fn or c
self.dynamics = _dynamics
self.f, self.f_x, self.f_u = (
_dynamics,
jax.jacfwd(_dynamics, argnums=0),
jax.jacfwd(_dynamics, argnums=1),
)
self.c, self.c_x, self.c_u, self.c_xx, self.c_uu = (
c,
jax.grad(c, argnums=0),
jax.grad(c, argnums=1),
jax.hessian(c, argnums=0),
jax.hessian(c, argnums=1),
)
[docs] def reset(self):
th = jax.random.uniform(self.random.generate_key(), minval=-jnp.pi, maxval=jnp.pi)
thdot = jax.random.uniform(self.random.generate_key(), minval=-1.0, maxval=1.0)
self.state = jnp.array([th, thdot])
return self.state
[docs] def render(self, mode="human"):
if self.viewer is None:
from gym.envs.classic_control import rendering
self.viewer = rendering.Viewer(500, 500)
self.viewer.set_bounds(-2.2, 2.2, -2.2, 2.2)
rod = rendering.make_capsule(1, 0.2)
rod.set_color(0.8, 0.3, 0.3)
self.pole_transform = rendering.Transform()
rod.add_attr(self.pole_transform)
self.viewer.add_geom(rod)
axle = rendering.make_circle(0.05)
axle.set_color(0, 0, 0)
self.viewer.add_geom(axle)
fname = path.join(path.dirname(__file__), "assets/clockwise.png")
self.img = rendering.Image(fname, 1.0, 1.0)
self.imgtrans = rendering.Transform()
self.img.add_attr(self.imgtrans)
self.viewer.add_onetime(self.img)
self.pole_transform.set_rotation(self.state[0] + np.pi / 2)
if self.last_u:
self.imgtrans.scale = (-self.last_u / 2, np.abs(self.last_u) / 2)
return self.viewer.render(return_rgb_array=mode == "rgb_array")