зеркало из https://github.com/microsoft/vi-hds.git
265 строки
9.4 KiB
Python
265 строки
9.4 KiB
Python
# ------------------------------------
|
|
# Copyright (c) Microsoft Corporation.
|
|
# Licensed under the MIT license.
|
|
# ------------------------------------
|
|
from vihds.ode import OdeModel, OdeFunc, power
|
|
from vihds.precisions import ConstantPrecisions, NeuralPrecisions
|
|
from vihds.utils import variable_summaries
|
|
import torch
|
|
|
|
# pylint: disable = no-member, not-callable
|
|
|
|
|
|
class Relay_Constant_RHS(OdeFunc):
|
|
def __init__(
|
|
self, config, theta, treatments, dev_1hot, condition_on_device, precisions=None, version=1,
|
|
):
|
|
super(Relay_Constant_RHS, self).__init__(config, theta, treatments, dev_1hot, condition_on_device)
|
|
|
|
# Pass in a class instance for dynamic (neural) precisions. If None, then it's expected that you have latent
|
|
# variables for the precisions, assigned as part of BaseModel.expand_precisions_by_time()
|
|
self.precisions = precisions
|
|
|
|
self.n_batch = theta.get_n_batch()
|
|
self.n_iwae = theta.get_n_samples()
|
|
self.n_species = 10
|
|
|
|
# tile treatments, one per iwae sample
|
|
treatments_transformed = torch.clamp(torch.exp(treatments) - 1.0, 1e-12, 1e6)
|
|
c6a, c12a = torch.unbind(treatments_transformed, axis=1)
|
|
c6 = torch.transpose(c6a.repeat([self.n_iwae, 1]), 0, 1)
|
|
c12 = torch.transpose(c12a.repeat([self.n_iwae, 1]), 0, 1)
|
|
|
|
# need to clip these to avoid overflow
|
|
self.r = torch.clamp(theta.r, 0.0, 4.0)
|
|
self.K = torch.clamp(theta.K, 0.0, 4.0)
|
|
self.tlag = theta.tlag
|
|
self.rc = theta.rc
|
|
self.a530 = theta.a530
|
|
self.a480 = theta.a480
|
|
|
|
self.drfp = torch.clamp(theta.drfp, 1e-12, 2.0)
|
|
self.dyfp = torch.clamp(theta.dyfp, 1e-12, 2.0)
|
|
self.dcfp = torch.clamp(theta.dcfp, 1e-12, 2.0)
|
|
self.dR = torch.clamp(theta.dR, 1e-12, 5.0)
|
|
self.dS = torch.clamp(theta.dS, 1e-12, 5.0)
|
|
|
|
self.dlasI = torch.clamp(theta.dlasI, 1e-12, 5.0)
|
|
self.dluxI = torch.clamp(theta.dluxI, 1e-12, 5.0)
|
|
|
|
self.e76 = theta.e76
|
|
self.e81 = theta.e81
|
|
self.aCFP = theta.aCFP
|
|
self.aYFP = theta.aYFP
|
|
self.KGR_76 = theta.KGR_76
|
|
self.KGS_76 = theta.KGS_76
|
|
self.KGR_81 = theta.KGR_81
|
|
self.KGS_81 = theta.KGS_81
|
|
|
|
self.KC6 = theta.KC6
|
|
self.KC12 = theta.KC12
|
|
self.Klux = theta.Klux
|
|
self.Klas = theta.Klas
|
|
|
|
# Condition on device information by mapping param_cond = f(param, d; \phi) where d is one-hot rep of device
|
|
# if condition_on_device:
|
|
# ones = torch.tensor([1.0]).repeat([self.n_batch, self.n_iwae])
|
|
# self.aR = self.device_conditioner(ones, 'aR', dev_1hot)
|
|
# self.aS = self.device_conditioner(ones, 'aS', dev_1hot)
|
|
# else:
|
|
# self.aR = theta.aR
|
|
# self.aS = theta.aS
|
|
|
|
self.aR = theta.aR
|
|
self.aS = theta.aS
|
|
|
|
# Activation constants for convenience
|
|
nR = torch.clamp(theta.nR, 0.5, 3.0)
|
|
nS = torch.clamp(theta.nS, 0.5, 3.0)
|
|
lb = 1e-12
|
|
ub = 1e0
|
|
if version == 1:
|
|
KR6 = torch.clamp(theta.KR6, lb, ub)
|
|
KR12 = torch.clamp(theta.KR12, lb, ub)
|
|
KS6 = torch.clamp(theta.KS6, lb, ub)
|
|
KS12 = torch.clamp(theta.KS12, lb, ub)
|
|
self.fracLuxR = (power(KR6 * c6, nR) + power(KR12 * c12, nR)) / power(1.0 + KR6 * c6 + KR12 * c12, nR)
|
|
self.fracLasR = (power(KS6 * c6, nS) + power(KS12 * c12, nS)) / power(1.0 + KS6 * c6 + KS12 * c12, nS)
|
|
else:
|
|
raise Exception("Unknown version of Relay_Constant: %d" % version)
|
|
|
|
def forward(self, t, state):
|
|
x, rfp, yfp, cfp, f530, f480, luxR, lasR, luxI, lasI = torch.unbind(state[:, :, : self.n_species], axis=2)
|
|
|
|
# Cells growing or not (not before lag-time)
|
|
gr = self.r * torch.sigmoid(4.0 * (t - self.tlag))
|
|
|
|
# Specific growth and dilution
|
|
g = 1.0 - x / self.K
|
|
gamma = gr * g
|
|
|
|
# Promoter activity
|
|
boundLuxR = luxR * luxR * self.fracLuxR
|
|
boundLasR = lasR * lasR * self.fracLasR
|
|
P76 = (self.e76 + self.KGR_76 * boundLuxR + self.KGS_76 * boundLasR) / (
|
|
1.0 + self.KGR_76 * boundLuxR + self.KGS_76 * boundLasR
|
|
)
|
|
P81 = (self.e81 + self.KGR_81 * boundLuxR + self.KGS_81 * boundLasR) / (
|
|
1.0 + self.KGR_81 * boundLuxR + self.KGS_81 * boundLasR
|
|
)
|
|
|
|
# Right-hand sides
|
|
d_x = gamma * x
|
|
d_rfp = self.rc - (gamma + self.drfp) * rfp
|
|
d_yfp = self.rc * self.aYFP * P81 - (gamma + self.dyfp) * yfp
|
|
d_cfp = self.rc * self.aCFP * P76 - (gamma + self.dcfp) * cfp
|
|
d_f530 = self.rc * self.a530 - gamma * f530
|
|
d_f480 = self.rc * self.a480 - gamma * f480
|
|
d_luxR = self.rc * self.aR - (gamma + self.dR) * luxR
|
|
d_lasR = self.rc * self.aS - (gamma + self.dS) * lasR
|
|
|
|
d_luxI = self.rc * P81 - (gamma + self.dluxI) * luxI
|
|
d_lasI = self.rc * P76 - (gamma + self.dlasI) * lasI
|
|
|
|
d_c6 = (self.KC6 * self.rc * x * luxI) / (1.0 + luxI / self.Klux)
|
|
d_c12 = (self.KC12 * self.rc * x * lasI) / (1.0 + lasI / self.Klas)
|
|
|
|
dX = torch.stack(
|
|
[d_x, d_rfp, d_yfp, d_cfp, d_f530, d_f480, d_luxR, d_lasR, d_luxI, d_lasI, d_c6, d_c12], axis=2,
|
|
)
|
|
if self.precisions is not None:
|
|
dV = self.precisions(t, state, None, self.n_batch, self.n_iwae)
|
|
return torch.cat([dX, dV], dim=2)
|
|
else:
|
|
return dX
|
|
|
|
|
|
class Relay_Constant(OdeModel):
|
|
def __init__(self, config):
|
|
super(Relay_Constant, self).__init__(config)
|
|
self.precisions = ConstantPrecisions(["prec_x", "prec_rfp", "prec_yfp", "prec_cfp"])
|
|
self.species = [
|
|
"OD",
|
|
"RFP",
|
|
"YFP",
|
|
"CFP",
|
|
"F530",
|
|
"F480",
|
|
"LuxR",
|
|
"LasR",
|
|
"LuxI",
|
|
"LasI",
|
|
"C6",
|
|
"C12",
|
|
]
|
|
self.n_species = 12
|
|
self.device = config.device
|
|
self.version = 1
|
|
|
|
def initialize_state(self, theta, _treatments):
|
|
n_batch = theta.get_n_batch()
|
|
n_iwae = theta.get_n_samples()
|
|
zero = torch.zeros([n_batch, n_iwae], device=self.device)
|
|
|
|
treatments_transformed = torch.clamp(torch.exp(_treatments) - 1.0, 1e-12, 1e6)
|
|
c6a, c12a = torch.unbind(treatments_transformed, axis=1)
|
|
c6 = torch.transpose(c6a.repeat([n_iwae, 1]), 0, 1)
|
|
c12 = torch.transpose(c12a.repeat([n_iwae, 1]), 0, 1)
|
|
|
|
x0 = torch.stack(
|
|
[
|
|
theta.init_x,
|
|
theta.init_rfp,
|
|
theta.init_yfp,
|
|
theta.init_cfp,
|
|
zero,
|
|
zero,
|
|
theta.init_luxR,
|
|
theta.init_lasR,
|
|
theta.init_luxI,
|
|
theta.init_lasI,
|
|
c6,
|
|
c12,
|
|
],
|
|
axis=2,
|
|
)
|
|
return x0
|
|
|
|
def gen_reaction_equations(self, config, theta, treatments, dev_1hot, condition_on_device=True):
|
|
func = Relay_Constant_RHS(config, theta, treatments, dev_1hot, condition_on_device, version=self.version,)
|
|
self.aR = func.aR
|
|
self.aS = func.aS
|
|
return func
|
|
|
|
def summaries(self, writer, epoch):
|
|
variable_summaries(writer, epoch, self.aR, "aR.conditioned")
|
|
variable_summaries(writer, epoch, self.aS, "aS.conditioned")
|
|
|
|
|
|
class Relay_Constant_Precisions(OdeModel):
|
|
def __init__(self, config):
|
|
super(Relay_Constant_Precisions, self).init_with_params(config)
|
|
self.species = [
|
|
"OD",
|
|
"RFP",
|
|
"YFP",
|
|
"CFP",
|
|
"F530",
|
|
"F480",
|
|
"LuxR",
|
|
"LasR",
|
|
"LasI",
|
|
"LuxI",
|
|
"C6",
|
|
"C12",
|
|
]
|
|
self.n_species = 12
|
|
self.precisions = NeuralPrecisions(self.n_species, config.params.n_hidden_decoder_precisions, 4)
|
|
self.version = 1
|
|
|
|
def initialize_state(self, theta, _treatments):
|
|
n_batch = theta.get_n_batch()
|
|
n_iwae = theta.get_n_samples()
|
|
zero = torch.zeros([n_batch, n_iwae])
|
|
|
|
treatments_transformed = torch.clamp(torch.exp(_treatments) - 1.0, 1e-12, 1e6)
|
|
c6a, c12a = torch.unbind(treatments_transformed, axis=1)
|
|
c6 = torch.transpose(c6a.repeat([n_iwae, 1]), 0, 1)
|
|
c12 = torch.transpose(c12a.repeat([n_iwae, 1]), 0, 1)
|
|
|
|
x0 = torch.stack(
|
|
[
|
|
theta.init_x,
|
|
theta.init_rfp,
|
|
theta.init_yfp,
|
|
theta.init_cfp,
|
|
zero,
|
|
zero,
|
|
theta.init_luxR,
|
|
theta.init_lasR,
|
|
theta.init_luxI,
|
|
theta.init_lasI,
|
|
c6,
|
|
c12,
|
|
theta.init_prec_x,
|
|
theta.init_prec_rfp,
|
|
theta.init_prec_yfp,
|
|
theta.init_prec_cfp,
|
|
],
|
|
axis=2,
|
|
)
|
|
return x0
|
|
|
|
def gen_reaction_equations(self, config, theta, treatments, dev_1hot, condition_on_device=True):
|
|
func = Relay_Constant_RHS(
|
|
config, theta, treatments, dev_1hot, condition_on_device, precisions=self.precisions, version=self.version,
|
|
)
|
|
self.aR = func.aR
|
|
self.aS = func.aS
|
|
return func
|
|
|
|
def summaries(self, writer, epoch):
|
|
variable_summaries(writer, epoch, self.aR, "aR.conditioned")
|
|
variable_summaries(writer, epoch, self.aS, "aS.conditioned")
|
|
self.precisions.summaries(writer, epoch)
|