Source code for deluca.envs.lung._balloon_lung

# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import jax
import jax.numpy as jnp

from deluca.envs.lung import BreathWaveform
from deluca.envs.lung.core import Lung


def PropValve(x):
    y = 3.0 * x
    flow_new = 1.0 * (jnp.tanh(0.03 * (y - 130)) + 1.0)
    flow_new = jnp.clip(flow_new, 0.0, 1.72)
    return flow_new


def Solenoid(x):
    return x > 0


[docs]class BalloonLung(Lung): """ Lung simulator based on the two-balloon experiment Source: https://en.wikipedia.org/wiki/Two-balloon_experiment TODO: - time / dt - dynamics - waveform - phase """
[docs] def __init__( self, leak=False, peep_valve=5.0, PC=40.0, P0=0.0, C=10.0, R=15.0, dt=0.03, waveform=None, reward_fn=None, ): # dynamics hyperparameters self.min_volume = 1.5 self.C = C self.R = R self.PC = PC self.P0 = 0.0 self.leak = leak self.peep_valve = peep_valve self.time = 0.0 self.dt = dt self.waveform = waveform or BreathWaveform() self.r0 = (3.0 * self.min_volume / (4.0 * jnp.pi)) ** (1.0 / 3.0) # reset states self.reset()
[docs] def reset(self): self.time = 0.0 self.target = self.waveform.at(self.time) self.state = self.dynamics({'volume': self.min_volume, 'pressure': -1.0}, (0.0, 0.0)) return self.observation
@property def observation(self): return { "measured": self.state['pressure'], "target": self.target, "dt": self.dt, "phase": self.waveform.phase(self.time), }
[docs] def dynamics(self, state, action): """ state: (volume, pressure) action: (u_in, u_out) """ volume, pressure = state['volume'], state['pressure'] u_in, u_out = action flow = jnp.clip(PropValve(u_in) * self.R, 0.0, 2.0) flow -= jax.lax.cond( pressure > self.peep_valve, lambda x: jnp.clip(Solenoid(u_out), 0.0, 2.0) * 0.05 * pressure, lambda x: 0.0, flow, ) volume += flow * self.dt volume += jax.lax.cond( self.leak, lambda x: (self.dt / (5.0 + self.dt) * (self.min_volume - volume)), lambda x: 0.0, 0.0, ) r = (3.0 * volume / (4.0 * jnp.pi)) ** (1.0 / 3.0) pressure = self.P0 + self.PC * (1.0 - (self.r0 / r) ** 6.0) / (self.r0 ** 2.0 * r) # pressure = flow * self.R + volume / self.C + self.peep_valve return {'volume': volume, 'pressure': pressure}
[docs] def step(self, action): self.target = self.waveform.at(self.time) reward = -jnp.abs(self.target - self.state['pressure']) self.state = self.dynamics(self.state, action) self.time += self.dt return self.observation, reward, False, {}