Parameter Inference of Gene Transcription Model
In this example, we apply GradInf to a chemical reaction network (RN) model of gene transcription. The considered model has four reactions:
∅→αM, M→βM + P, M→γ∅, P→δ∅
corresponding to gene transcription, mRNA translation, mRNA degradation, and protein degradation, respectively. We simulate the reaction network using the Gillespie algorithm, and apply a loss function that measures the relative error between the simulated mean copy numbers of mRNA and protein and ground truth reference values.
Writing the Model
{-# LANGUAGE RebindableSyntax #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE RankNTypes #-}
import Prelude
import Control.Applicative (ZipList(..), getZipList)
import Control.Monad
import Control.Monad.Bayes.Sampler.Lazy
import System.Random (setStdGen, mkStdGen)
import Data.List (transpose)
import Numeric.GradInf.Primitives.DeterministicPrimitives
import Numeric.GradInf.Primitives.Uniform
import Numeric.GradInf.Primitives.CategoricalMaxIndep
import Numeric.GradInf.Primitives.FoldP
import Numeric.GradInf.Primitives.IterateP
import Numeric.GradInf.Inference
import Numeric.GradInf.Inference.Base (Base(..))
import Numeric.GradInf
-- $setup (use a fixed seed)
-- >>> setStdGen (mkStdGen 42)
numObs :: Int
numObs = 5 -- number of independent traces to simulate; reduced in the tutorial for faster execution
n :: Int
n = 1000
capT :: Double
capT = 2.5
geneTranscriptionKernel ::
forall m d i b mat.
(DeterministicPrimitives d i b mat,
CategoricalMaxIndep m d [i], Uniform m d)
=> [d] -> ([i], d, [d]) -> m ([i], d, [d])
geneTranscriptionKernel theta (x, t, acc) = do
if extractDouble t > capT then
return (x, t, acc)
else do
let alpha = exp (theta !! 0)
let beta = exp (theta !! 1)
let gamma = exp (theta !! 2)
let delta = exp (theta !! 3)
let m :: i = x !! 0
let p :: i = x !! 1
let rates :: [d] = [alpha, beta * (toDouble m), gamma * (toDouble m), delta * (toDouble p)]
let totalRate :: d = sum rates
let probs :: [d] = map (\q -> q / totalRate) rates
x' <- categoricalMaxIndep (probs, [[m + 1, p], [m, p + 1], [m - 1, p], [m, p - 1]])
u <- uniform
let t' = t + log (1 / u) / totalRate
let tEnd = if extractDouble t' > capT then fromDouble capT else t'
let dt = tEnd - t
let acc' = [acc !! 0 + toDouble m * dt, acc !! 1 + toDouble p * dt]
return (x', t', acc')
geneTranscriptionModel ::
forall m d i b mat.
(DeterministicPrimitives d i b mat, IterateP m ([i], d, [d]),
FoldrP m ([i], d, [d]) ([d], d),
CategoricalMaxIndep m d [i], Uniform m d)
=> Int -> Int -> [d] -> m d
geneTranscriptionModel numObs n theta = do
let x0 = [5, 40]
let simulateTrace :: m ([i], d, [d]) =
iterateP (geneTranscriptionKernel theta) (x0, 0.0, [0.0, 0.0]) !! n
let foldF (_, t, timeInt) (accX, accT) = do
if extractDouble t < capT then
error ("Did not simulate the full time period (t = " ++ (show t) ++ " < " ++ (show capT) ++ ")")
else do
([accX !! 0 + timeInt !! 0, accX !! 1 + timeInt !! 1], accT + t)
(traceSumX, traceSumT) <- foldrP foldF ([0.0, 0.0], 0.0) (replicate numObs simulateTrace)
let mAvg = traceSumX !! 0 / fromIntegral numObs / fromDouble capT
let pAvg = traceSumX !! 1 / fromIntegral numObs / fromDouble capT
let mExp = 10.4
let pExp = 22.3
let loss = (mAvg - mExp) * (mAvg - mExp) / (mExp * mExp) + (pAvg - pExp) * (pAvg - pExp) / (pExp * pExp)
return loss
We make use of a categoricalMaxIndep primitive, which tells GradInf to
employ a maximal coupling with independent residuals when
differentiating the gene transcription model.
getSamples :: Sampler [Double]
getSamples =
replicateM 5 (geneTranscriptionModel numObs n
(map log [18, 8, 1.5, 4]))
Differentiating the Model
We now apply GradInf to compute gradient estimates with respect to all
four parameters simultaneously using the ZipList functor. This
computes the full gradient in a single call to gradInfAD.
getGradientEstimates :: Sampler [ZipList Double]
getGradientEstimates =
let twistFunc ::
forall d i b mat. (Num d, Fractional d, Floating d, FromDouble d, ToDouble i d, ExtractDouble d) =>
Int ->
([Coupled i], Coupled d, [Coupled d]) ->
d
twistFunc _ (x, t, acc) = do
let Coupled (mA, mB) = x !! 0
let Coupled (pA, pB) = x !! 1
let Coupled (tA, tB) = t
let Coupled (accMA, accMB) = acc !! 0
let Coupled (accPA, accPB) = acc !! 1
let horizon = fromDouble $ min 0.5 (capT - extractDouble tA)
let timeScale = 1 / 20
abs ((accMB - accMA) + (toDouble mB - toDouble mA) * horizon) / fromDouble 10.4
+ abs ((accPB - accPA) + (toDouble pB - toDouble pA) * horizon) / fromDouble 22.3
+ abs (tB - tA) * timeScale
in replicateM 10
( gradInfAD
(geneTranscriptionModel numObs n . getZipList)
( TwistedSMCInference 1 twistFunc ::
forall d i b mat. (DeterministicPrimitives d i b mat) =>
TwistedSMCInference Base CRN d i b
)
(ZipList (map log [18, 8, 1.5, 4]))
)
-- |
-- >>> fmap (map getZipList) (sampler getGradientEstimates)
-- [[0.14409024236638188,0.11626475191080754,-0.3959782466319224,-0.33542844390957105],[-0.19389663177549918,-8.764469447850404e-2,0.14729457653949962,3.047766914693925e-2],[0.2097233270616201,-0.5420223962041146,-0.9989225686730319,-0.14908300886071865],[-8.578939748394014e-2,-0.13719902372690249,5.3826963222109524e-2,-0.11435708776253722],[9.07704924279736e-2,0.11441913800866106,-1.936587936373914e-2,-8.527810690679449e-2],[-1.4101513738978218e-2,-0.10367086023403566,3.5507579198788726e-2,2.4173278724310077e-2],[4.96847450171267e-2,8.898381879440269e-2,-9.77076205342738e-3,-1.5242867343302848e-2],[7.160380585293154e-2,5.5657611194625485e-2,-3.522902896384387e-2,-5.078994328063205e-2],[8.518867942215486e-2,2.2521171867356336e-2,-0.12402426950342152,-0.10932506166065545],[-4.071094367489457e-2,0.16894068362793094,6.62798544875896e-2,0.16195007965891323]]
We can now calculate the component-wise empirical mean and variance of the gradient estimator with respect to each parameter.
meanAndVariance :: [Double] -> (Double, Double)
meanAndVariance xs = do
let n = length xs
let mu = (sum xs) / (fromIntegral n)
let var = (sum (map (\x -> (x - mu) * (x - mu)) xs))
/ (fromIntegral (n - 1))
(mu, var)
printMeanAndVariance :: IO ()
printMeanAndVariance = do
samples <- fmap (map getZipList)
(sampler getGradientEstimates)
let byComponent = transpose samples
mapM_ (print . meanAndVariance) byComponent
-- |
-- >>> printMeanAndVariance
-- (3.1656280547487654e-2,1.3847633481611804e-2)
-- (-3.0374979923977273e-2,4.3165330546016206e-2)
-- (-0.1280381781741399,0.11533864109220437)
-- (-6.429034921940491e-2,1.7362265555445507e-2)