Source code for deluca.agents._deep

# Copyright 2020 Google LLC
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.
from numbers import Real

import jax
import jax.numpy as jnp
import numpy as np
import functools

from deluca.agents.core import Agent
from deluca.envs.core import Env
from deluca.utils import Random

# generic deep controller for 1-dimensional discrete non-negative action space
[docs]class Deep(Agent): """ Generic deep controller that uses zero-order methods to train on an environment. """
[docs] def __init__( self, env_state_size, action_space, learning_rate: Real = 0.001, gamma: Real = 0.99, max_episode_length: int = 500, seed: int = 0, ) -> None: """ Description: initializes the Deep agent Args: env (Env): a deluca environment learning_rate (Real): gamma (Real): max_episode_length (int): seed (int): Returns: None """ # Create gym and seed numpy self.env_state_size = int(env_state_size) self.action_space = action_space self.max_episode_length = max_episode_length = learning_rate self.gamma = gamma self.random = Random(seed) self.reset()
[docs] def reset(self) -> None: """ Description: reset agent Args: None Returns: None """ # Init weight self.W = jax.random.uniform( self.random.generate_key(), shape=(self.env_state_size, len(self.action_space)), minval=0, maxval=1, ) # Keep stats for final print of graph self.episode_rewards = [] self.current_episode_length = 0 self.current_episode_reward = 0 self.episode_rewards = jnp.zeros(self.max_episode_length) self.episode_grads = jnp.zeros((self.max_episode_length, self.W.shape[0], self.W.shape[1])) # dummy values for attrs, needed to inform scan of traced shapes self.state = jnp.zeros((self.env_state_size,)) self.action = self.action_space[0] ones = jnp.ones((len(self.action_space),)) self.probs = ones * 1/jnp.sum(ones)
[docs] def policy(self, state: jnp.ndarray, w: jnp.ndarray) -> jnp.ndarray: """ Description: Policy that maps state to action parameterized by w Args: state (jnp.ndarray): w (jnp.ndarray): """ z =, w) exp = jnp.exp(z) return exp / jnp.sum(exp)
[docs] def softmax_grad(self, softmax: jnp.ndarray) -> jnp.ndarray: """ Description: Vectorized softmax Jacobian Args: softmax (jnp.ndarray) """ s = softmax.reshape(-1, 1) return jnp.diagflat(s) -, s.T)
[docs] def __call__(self, state: jnp.ndarray): """ Description: provide an action given a state Args: state (jnp.ndarray): Returns: jnp.ndarray: action to take """ self.state = state self.probs = self.policy(state, self.W) self.action = jax.random.choice( self.random.generate_key(), a=self.action_space, p=self.probs ) return self.action
[docs] def feed(self, reward: Real) -> None: """ Description: compute gradient and save with reward in memory for weight updates Args: reward (Real): Returns: None """ dsoftmax = self.softmax_grad(self.probs)[self.action, :] dlog = dsoftmax / self.probs[self.action] grad = self.state.reshape(-1, 1) @ dlog.reshape(1, -1) self.episode_rewards = jax.ops.index_update( self.episode_rewards, self.current_episode_length, reward ) self.episode_grads = jax.ops.index_update( self.episode_grads, self.current_episode_length, grad ) self.current_episode_length += 1
[docs] def update(self) -> None: """ Description: update weights Args: None Returns: None """ for i in range(self.current_episode_length): # Loop through everything that happend in the episode and update # towards the log policy gradient times **FUTURE** reward self.W += * self.episode_grads[i] + jnp.sum( jnp.array( [ r * (self.gamma ** r) for r in self.episode_rewards[i : self.current_episode_length] ] ) ) # reset episode length self.current_episode_length = 0