Source code for deluca.agents._gpc

# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""deluca.agents._gpc"""
from numbers import Real
from typing import Callable

import jax
import jax.numpy as jnp
import numpy as np
from jax import grad
from jax import jit

from deluca.agents._lqr import LQR
from deluca.agents.core import Agent


def quad_loss(x: jnp.ndarray, u: jnp.ndarray) -> Real:
    """
    Quadratic loss.

    Args:
        x (jnp.ndarray):
        u (jnp.ndarray):

    Returns:
        Real
    """
    return jnp.sum(x.T @ x + u.T @ u)


[docs]class GPC(Agent):
[docs] def __init__( self, A: jnp.ndarray, B: jnp.ndarray, Q: jnp.ndarray = None, R: jnp.ndarray = None, K: jnp.ndarray = None, start_time: int = 0, cost_fn: Callable[[jnp.ndarray, jnp.ndarray], Real] = None, H: int = 3, HH: int = 2, lr_scale: Real = 0.005, decay: bool = True, ) -> None: """ Description: Initialize the dynamics of the model. Args: A (jnp.ndarray): system dynamics B (jnp.ndarray): system dynamics Q (jnp.ndarray): cost matrices (i.e. cost = x^TQx + u^TRu) R (jnp.ndarray): cost matrices (i.e. cost = x^TQx + u^TRu) K (jnp.ndarray): Starting policy (optional). Defaults to LQR gain. start_time (int): cost_fn (Callable[[jnp.ndarray, jnp.ndarray], Real]): H (postive int): history of the controller HH (positive int): history of the system lr_scale (Real): lr_scale_decay (Real): decay (Real): """ cost_fn = cost_fn or quad_loss d_state, d_action = B.shape # State & Action Dimensions self.A, self.B = A, B # System Dynamics self.t = 0 # Time Counter (for decaying learning rate) self.H, self.HH = H, HH self.lr_scale, self.decay = lr_scale, decay # Model Parameters # initial linear policy / perturbation contributions / bias # TODO: need to address problem of LQR with jax.lax.scan self.K = K if K is not None else LQR(self.A, self.B, Q, R).K self.M = jnp.zeros((H, d_action, d_state)) # Past H + HH noises ordered increasing in time self.noise_history = jnp.zeros((H + HH, d_state, 1)) # past state and past action self.state, self.action = jnp.zeros((d_state, 1)), jnp.zeros((d_action, 1)) def last_h_noises(): """Get noise history""" return jax.lax.dynamic_slice_in_dim(self.noise_history, -H, H) self.last_h_noises = last_h_noises def policy_loss(M, w): """Surrogate cost function""" def action(state, h): """Action function""" return -self.K @ state + jnp.tensordot( M, jax.lax.dynamic_slice_in_dim(w, h, H), axes=([0, 2], [0, 1]) ) def evolve(state, h): """Evolve function""" return self.A @ state + self.B @ action(state, h) + w[h + H], None final_state, _ = jax.lax.scan(evolve, np.zeros((d_state, 1)), np.arange(H - 1)) return cost_fn(final_state, action(final_state, HH - 1)) self.policy_loss = policy_loss self.grad = jit(grad(policy_loss, (0, 1)))
[docs] def __call__(self, state: jnp.ndarray) -> jnp.ndarray: """ Description: Return the action based on current state and internal parameters. Args: state (jnp.ndarray): current state Returns: jnp.ndarray: action to take """ action = self.get_action(state) self.update(state, action) return action
[docs] def update(self, state: jnp.ndarray, u:jnp.ndarray) -> None: """ Description: update agent internal state. Args: state (jnp.ndarray): Returns: None """ noise = state - self.A @ self.state - self.B @ u self.noise_history = jax.ops.index_update(self.noise_history, 0, noise) self.noise_history = jnp.roll(self.noise_history, -1, axis=0) delta_M, delta_bias = self.grad(self.M, self.noise_history) lr = self.lr_scale lr *= (1/ (self.t+1)) if self.decay else 1 self.M -= lr * delta_M self.M -= lr * delta_M # update state self.state = state self.t += 1
[docs] def get_action(self, state: jnp.ndarray) -> jnp.ndarray: """ Description: get action from state. Args: state (jnp.ndarray): Returns: jnp.ndarray """ return -self.K @ state + jnp.tensordot(self.M, self.last_h_noises(), axes=([0, 2], [0, 1]))