deluca.agents.Deep¶
-
class
deluca.agents.
Deep
(*args, **kwargs)[source]¶ Generic deep controller that uses zero-order methods to train on an environment.
Public Data Attributes:
Inherited from
JaxObject
name
attrs
Public Methods:
__init__
(env_state_size, action_space[, …])Description: initializes the Deep agent
reset
()Description: reset agent
policy
(state, w)Description: Policy that maps state to action parameterized by w
softmax_grad
(softmax)Description: Vectorized softmax Jacobian
__call__
(state)Description: provide an action given a state
feed
(reward)Description: compute gradient and save with reward in memory for weight updates
update
()Description: update weights
Inherited from
Agent
__init_subclass__
(*args, **kwargs)For avoiding a decorator for each subclass
__call__
(state)Description: provide an action given a state
reset
()Description: reset agent
feed
(reward)Description: compute gradient and save with reward in memory for weight updates
Inherited from
JaxObject
__new__
(cls, *args, **kwargs)For avoiding super().__init__()
__init_subclass__
(*args, **kwargs)For avoiding a decorator for each subclass
__str__
()Return str(self).
__setattr__
(key, val)Implement setattr(self, name, value).
save
(path)load
(path)throw
(err, msg)
-
__call__
(state: jax._src.numpy.lax_numpy.ndarray)[source]¶ Description: provide an action given a state
- Parameters
state (jnp.ndarray) –
- Returns
action to take
- Return type
jnp.ndarray
-
__init__
(env_state_size, action_space, learning_rate: numbers.Real = 0.001, gamma: numbers.Real = 0.99, max_episode_length: int = 500, seed: int = 0) → None[source]¶ Description: initializes the Deep agent
- Parameters
env (Env) – a deluca environment
learning_rate (Real) –
gamma (Real) –
max_episode_length (int) –
seed (int) –
- Returns
None
-
feed
(reward: numbers.Real) → None[source]¶ Description: compute gradient and save with reward in memory for weight updates
- Parameters
reward (Real) –
- Returns
None
-
policy
(state: jax._src.numpy.lax_numpy.ndarray, w: jax._src.numpy.lax_numpy.ndarray) → jax._src.numpy.lax_numpy.ndarray[source]¶ Description: Policy that maps state to action parameterized by w
- Parameters
state (jnp.ndarray) –
w (jnp.ndarray) –
-