Nontrivial gmm derivative checked against finite differences
This commit is contained in:
Родитель
b6a86b4a0c
Коммит
3ea0a3123d
|
@ -52,12 +52,12 @@
|
|||
(let (K (size alphas))
|
||||
(+ (- (sum (build n (lam (i : Integer)
|
||||
(logsumexp (build K (lam (k : Integer)
|
||||
(let (mahal_vec
|
||||
(mul$Mat$Vec (gmm_knossos_makeQ (index k qs) (index k ls))
|
||||
(sub$VecR$VecR (index i x) (index k means))))
|
||||
(let ((Q (gmm_knossos_makeQ (index k qs) (index k ls)))
|
||||
(mahal_vec (mul$Mat$Vec Q
|
||||
(sub$VecR$VecR (index i x) (index k means)))))
|
||||
(- (+ (index k alphas) (sum (index k qs)))
|
||||
(* 0.500000 (sqnorm mahal_vec))))))))))
|
||||
(* (to_float n) (logsumexp alphas)))
|
||||
(* (* wishart_gamma (to_float n) {- wishart_gamma just here to test-}) (logsumexp alphas)))
|
||||
(* 0.5 (sum (build K (lam (k : Integer)
|
||||
(+ (sqnorm (exp$VecR (index k qs)))
|
||||
(sqnorm (index k ls))))))))))))
|
||||
|
@ -65,16 +65,33 @@
|
|||
(def mkvec ((n : Integer))
|
||||
(build n (lam (j : Integer) (* 2.0 (+ 1.0 (to_float j))))))
|
||||
|
||||
(def main ()
|
||||
(let (x (build 10 (lam (i : Integer) (mkvec 3))))
|
||||
(def f ((x : Vec Vec Float)
|
||||
(gamma : Float)
|
||||
(m : Float))
|
||||
(let ((alphas (build 10 (lam (i : Integer) 7.0)))
|
||||
(mus (build 10 (lam (i : Integer) (mkvec 3))))
|
||||
(qs (build 10 (lam (i : Integer) (mkvec 3))))
|
||||
(ls (build 10 (lam (i : Integer) (mkvec 3)))))
|
||||
(gmm_knossos_gmm_objective x alphas mus qs ls gamma m)))
|
||||
|
||||
|
||||
|
||||
(def main ()
|
||||
(let (x (build 18 (lam (i : Integer) (mkvec 3))))
|
||||
(let ((alphas (build 10 (lam (i : Integer) 7.0)))
|
||||
(mus (build 10 (lam (i : Integer) (mkvec 3))))
|
||||
(qs (build 10 (lam (i : Integer) (mkvec 3))))
|
||||
(ls (build 10 (lam (i : Integer) (mkvec 3))))
|
||||
(z10x3 (* mus 0))
|
||||
(zeros_x (* x 0))
|
||||
(delta 0.001))
|
||||
(pr x
|
||||
(mul$Mat$Vec (gmm_knossos_makeQ (index 0 qs) (index 0 ls)) (index 0 x))
|
||||
(gmm_knossos_gmm_objective x alphas mus qs ls 1.3 1.2)
|
||||
(D$gmm_knossos_gmm_objective x alphas mus qs ls 1.3 1.2)
|
||||
(fwd$gmm_knossos_gmm_objective x alphas mus qs ls 1.3 1.2
|
||||
x alphas mus qs ls 1.3 1.2)
|
||||
(fwd$f x 1.3 1.2
|
||||
zeros_x delta 0.0)
|
||||
(- (f x (+ 1.3 delta) 1.2)
|
||||
(f x 1.3 1.2))
|
||||
|
||||
))))
|
||||
|
|
Загрузка…
Ссылка в новой задаче