Nontrivial gmm derivative checked against finite differences

This commit is contained in:
Andrew Fitzgibbon 2018-11-16 10:12:08 +00:00
Родитель b6a86b4a0c
Коммит 3ea0a3123d
1 изменённых файлов: 25 добавлений и 8 удалений

Просмотреть файл

@ -52,12 +52,12 @@
(let (K (size alphas))
(+ (- (sum (build n (lam (i : Integer)
(logsumexp (build K (lam (k : Integer)
(let (mahal_vec
(mul$Mat$Vec (gmm_knossos_makeQ (index k qs) (index k ls))
(sub$VecR$VecR (index i x) (index k means))))
(let ((Q (gmm_knossos_makeQ (index k qs) (index k ls)))
(mahal_vec (mul$Mat$Vec Q
(sub$VecR$VecR (index i x) (index k means)))))
(- (+ (index k alphas) (sum (index k qs)))
(* 0.500000 (sqnorm mahal_vec))))))))))
(* (to_float n) (logsumexp alphas)))
(* (* wishart_gamma (to_float n) {- wishart_gamma just here to test-}) (logsumexp alphas)))
(* 0.5 (sum (build K (lam (k : Integer)
(+ (sqnorm (exp$VecR (index k qs)))
(sqnorm (index k ls))))))))))))
@ -65,16 +65,33 @@
(def mkvec ((n : Integer))
(build n (lam (j : Integer) (* 2.0 (+ 1.0 (to_float j))))))
(def main ()
(let (x (build 10 (lam (i : Integer) (mkvec 3))))
(def f ((x : Vec Vec Float)
(gamma : Float)
(m : Float))
(let ((alphas (build 10 (lam (i : Integer) 7.0)))
(mus (build 10 (lam (i : Integer) (mkvec 3))))
(qs (build 10 (lam (i : Integer) (mkvec 3))))
(ls (build 10 (lam (i : Integer) (mkvec 3)))))
(gmm_knossos_gmm_objective x alphas mus qs ls gamma m)))
(def main ()
(let (x (build 18 (lam (i : Integer) (mkvec 3))))
(let ((alphas (build 10 (lam (i : Integer) 7.0)))
(mus (build 10 (lam (i : Integer) (mkvec 3))))
(qs (build 10 (lam (i : Integer) (mkvec 3))))
(ls (build 10 (lam (i : Integer) (mkvec 3))))
(z10x3 (* mus 0))
(zeros_x (* x 0))
(delta 0.001))
(pr x
(mul$Mat$Vec (gmm_knossos_makeQ (index 0 qs) (index 0 ls)) (index 0 x))
(gmm_knossos_gmm_objective x alphas mus qs ls 1.3 1.2)
(D$gmm_knossos_gmm_objective x alphas mus qs ls 1.3 1.2)
(fwd$gmm_knossos_gmm_objective x alphas mus qs ls 1.3 1.2
x alphas mus qs ls 1.3 1.2)
(fwd$f x 1.3 1.2
zeros_x delta 0.0)
(- (f x (+ 1.3 delta) 1.2)
(f x 1.3 1.2))
))))