# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""deluca.agents._deep"""
from numbers import Real
import jax
import jax.numpy as jnp
import numpy as np
import functools
from deluca.agents.core import Agent
from deluca.envs.core import Env
from deluca.utils import Random
# generic deep controller for 1-dimensional discrete non-negative action space
[docs]class Deep(Agent):
"""
Generic deep controller that uses zero-order methods to train on an
environment.
"""
[docs] def __init__(
self,
env_state_size,
action_space,
learning_rate: Real = 0.001,
gamma: Real = 0.99,
max_episode_length: int = 500,
seed: int = 0,
) -> None:
"""
Description: initializes the Deep agent
Args:
env (Env): a deluca environment
learning_rate (Real):
gamma (Real):
max_episode_length (int):
seed (int):
Returns:
None
"""
# Create gym and seed numpy
self.env_state_size = int(env_state_size)
self.action_space = action_space
self.max_episode_length = max_episode_length
self.lr = learning_rate
self.gamma = gamma
self.random = Random(seed)
self.reset()
[docs] def reset(self) -> None:
"""
Description: reset agent
Args:
None
Returns:
None
"""
# Init weight
self.W = jax.random.uniform(
self.random.generate_key(),
shape=(self.env_state_size, len(self.action_space)),
minval=0,
maxval=1,
)
# Keep stats for final print of graph
self.episode_rewards = []
self.current_episode_length = 0
self.current_episode_reward = 0
self.episode_rewards = jnp.zeros(self.max_episode_length)
self.episode_grads = jnp.zeros((self.max_episode_length, self.W.shape[0], self.W.shape[1]))
# dummy values for attrs, needed to inform scan of traced shapes
self.state = jnp.zeros((self.env_state_size,))
self.action = self.action_space[0]
ones = jnp.ones((len(self.action_space),))
self.probs = ones * 1/jnp.sum(ones)
[docs] def policy(self, state: jnp.ndarray, w: jnp.ndarray) -> jnp.ndarray:
"""
Description: Policy that maps state to action parameterized by w
Args:
state (jnp.ndarray):
w (jnp.ndarray):
"""
z = jnp.dot(state, w)
exp = jnp.exp(z)
return exp / jnp.sum(exp)
[docs] def softmax_grad(self, softmax: jnp.ndarray) -> jnp.ndarray:
"""
Description: Vectorized softmax Jacobian
Args:
softmax (jnp.ndarray)
"""
s = softmax.reshape(-1, 1)
return jnp.diagflat(s) - jnp.dot(s, s.T)
[docs] def __call__(self, state: jnp.ndarray):
"""
Description: provide an action given a state
Args:
state (jnp.ndarray):
Returns:
jnp.ndarray: action to take
"""
self.state = state
self.probs = self.policy(state, self.W)
self.action = jax.random.choice(
self.random.generate_key(),
a=self.action_space,
p=self.probs
)
return self.action
[docs] def feed(self, reward: Real) -> None:
"""
Description: compute gradient and save with reward in memory for weight updates
Args:
reward (Real):
Returns:
None
"""
dsoftmax = self.softmax_grad(self.probs)[self.action, :]
dlog = dsoftmax / self.probs[self.action]
grad = self.state.reshape(-1, 1) @ dlog.reshape(1, -1)
self.episode_rewards = jax.ops.index_update(
self.episode_rewards, self.current_episode_length, reward
)
self.episode_grads = jax.ops.index_update(
self.episode_grads, self.current_episode_length, grad
)
self.current_episode_length += 1
[docs] def update(self) -> None:
"""
Description: update weights
Args:
None
Returns:
None
"""
for i in range(self.current_episode_length):
# Loop through everything that happend in the episode and update
# towards the log policy gradient times **FUTURE** reward
self.W += self.lr * self.episode_grads[i] + jnp.sum(
jnp.array(
[
r * (self.gamma ** r)
for r in self.episode_rewards[i : self.current_episode_length]
]
)
)
# reset episode length
self.current_episode_length = 0