vi-hds/models/prpr_constant.py

131 строка
4.8 KiB
Python

# ------------------------------------
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# ------------------------------------
from vihds.ode import OdeModel, OdeFunc
from vihds.precisions import ConstantPrecisions, NeuralPrecisions
import torch
# pylint: disable = no-member, not-callable
class PRPR_Constant_RHS(OdeFunc):
def __init__(self, config, theta, treatments, dev_1hot, precisions=None, version=1):
super(PRPR_Constant_RHS, self).__init__(config, theta, treatments, dev_1hot)
# Pass in a class instance for dynamic (neural) precisions. If None, then it's expected that you have latent
# variables for the precisions, assigned as part of BaseModel.expand_precisions_by_time()
self.precisions = precisions
self.n_batch = theta.get_n_batch()
self.n_iwae = theta.get_n_samples()
self.n_species = 6
# TODO: NN growth model for ethanol?
# tile treatments, one per iwae sample
# treatments_transformed = torch.clamp(torch.exp(treatments) - 1.0, 1e-12, 1e6)
# c6a, c12a = torch.unbind(treatments_transformed, axis=1)
# c6 = torch.transpose(c6a.repeat([self.n_iwae, 1]),0,1)
# c12 = torch.transpose(c12a.repeat([self.n_iwae, 1]),0,1)
# need to clip these to avoid overflow
self.r = torch.clamp(theta.r, 0.0, 4.0)
self.K = torch.clamp(theta.K, 0.0, 4.0)
self.tlag = theta.tlag
self.rc = theta.rc
self.a530 = theta.a530
self.a480 = theta.a480
self.drfp = torch.clamp(theta.drfp, 1e-12, 2.0)
self.dyfp = torch.clamp(theta.dyfp, 1e-12, 2.0)
self.dcfp = torch.clamp(theta.dcfp, 1e-12, 2.0)
self.aCFP = theta.aCFP_PR
self.aYFP = theta.aYFP_PR
def forward(self, t, state):
x, rfp, yfp, cfp, f530, f480 = torch.unbind(state[:, :, : self.n_species], axis=2)
# Cells growing or not (not before lag-time)
gr = self.r * torch.sigmoid(4.0 * (t - self.tlag))
# Specific growth and dilution
g = 1.0 - x / self.K
gamma = gr * g
# Right-hand sides
d_x = gamma * x
d_rfp = self.rc - (gamma + self.drfp) * rfp
d_yfp = self.rc * self.aYFP - (gamma + self.dyfp) * yfp
d_cfp = self.rc * self.aCFP - (gamma + self.dcfp) * cfp
d_f530 = self.rc * self.a530 - gamma * f530
d_f480 = self.rc * self.a480 - gamma * f480
dX = torch.stack([d_x, d_rfp, d_yfp, d_cfp, d_f530, d_f480], axis=2)
if self.precisions is not None:
dV = self.precisions(t, state, None, self.n_batch, self.n_iwae)
return torch.cat([dX, dV], dim=2)
else:
return dX
class PRPR_Constant(OdeModel):
def __init__(self, config):
super(PRPR_Constant, self).__init__(config)
self.precisions = ConstantPrecisions(["prec_x", "prec_rfp", "prec_yfp", "prec_cfp"])
self.species = ["OD", "RFP", "YFP", "CFP", "F530", "F480"]
self.n_species = 6
self.device = config.device
self.version = 1
def initialize_state(self, theta, _treatments):
n_batch = theta.get_n_batch()
n_iwae = theta.get_n_samples()
zero = torch.zeros([n_batch, n_iwae], device=self.device)
x0 = torch.stack([theta.init_x, theta.init_rfp, theta.init_yfp, theta.init_cfp, zero, zero], axis=2,)
return x0
def gen_reaction_equations(self, config, theta, treatments, dev_1hot):
func = PRPR_Constant_RHS(config, theta, treatments, dev_1hot, version=self.version)
return func
def summaries(self, writer, epoch):
pass
class PRPR_Constant_Precisions(OdeModel):
def __init__(self, config):
super(PRPR_Constant_Precisions, self).__init__(config)
self.species = ["OD", "RFP", "YFP", "CFP", "F530", "F480"]
self.n_species = 6
self.precisions = NeuralPrecisions(self.n_species, config.params.n_hidden_decoder_precisions, 4)
self.version = 1
def initialize_state(self, theta, _treatments):
n_batch = theta.get_n_batch()
n_iwae = theta.get_n_samples()
zero = torch.zeros([n_batch, n_iwae])
x0 = torch.stack(
[
theta.init_x,
theta.init_rfp,
theta.init_yfp,
theta.init_cfp,
zero,
zero,
theta.init_prec_x,
theta.init_prec_rfp,
theta.init_prec_yfp,
theta.init_prec_cfp,
],
axis=2,
)
return x0
def gen_reaction_equations(self, config, theta, treatments, dev_1hot):
func = PRPR_Constant_RHS(config, theta, treatments, dev_1hot, precisions=self.precisions, version=self.version,)
return func
def summaries(self, writer, epoch):
self.precisions.summaries(writer, epoch)