Nontrivial gmm derivative checked against finite differences
This commit is contained in:
Родитель
b6a86b4a0c
Коммит
3ea0a3123d
|
@ -52,12 +52,12 @@
|
||||||
(let (K (size alphas))
|
(let (K (size alphas))
|
||||||
(+ (- (sum (build n (lam (i : Integer)
|
(+ (- (sum (build n (lam (i : Integer)
|
||||||
(logsumexp (build K (lam (k : Integer)
|
(logsumexp (build K (lam (k : Integer)
|
||||||
(let (mahal_vec
|
(let ((Q (gmm_knossos_makeQ (index k qs) (index k ls)))
|
||||||
(mul$Mat$Vec (gmm_knossos_makeQ (index k qs) (index k ls))
|
(mahal_vec (mul$Mat$Vec Q
|
||||||
(sub$VecR$VecR (index i x) (index k means))))
|
(sub$VecR$VecR (index i x) (index k means)))))
|
||||||
(- (+ (index k alphas) (sum (index k qs)))
|
(- (+ (index k alphas) (sum (index k qs)))
|
||||||
(* 0.500000 (sqnorm mahal_vec))))))))))
|
(* 0.500000 (sqnorm mahal_vec))))))))))
|
||||||
(* (to_float n) (logsumexp alphas)))
|
(* (* wishart_gamma (to_float n) {- wishart_gamma just here to test-}) (logsumexp alphas)))
|
||||||
(* 0.5 (sum (build K (lam (k : Integer)
|
(* 0.5 (sum (build K (lam (k : Integer)
|
||||||
(+ (sqnorm (exp$VecR (index k qs)))
|
(+ (sqnorm (exp$VecR (index k qs)))
|
||||||
(sqnorm (index k ls))))))))))))
|
(sqnorm (index k ls))))))))))))
|
||||||
|
@ -65,16 +65,33 @@
|
||||||
(def mkvec ((n : Integer))
|
(def mkvec ((n : Integer))
|
||||||
(build n (lam (j : Integer) (* 2.0 (+ 1.0 (to_float j))))))
|
(build n (lam (j : Integer) (* 2.0 (+ 1.0 (to_float j))))))
|
||||||
|
|
||||||
(def main ()
|
(def f ((x : Vec Vec Float)
|
||||||
(let (x (build 10 (lam (i : Integer) (mkvec 3))))
|
(gamma : Float)
|
||||||
|
(m : Float))
|
||||||
(let ((alphas (build 10 (lam (i : Integer) 7.0)))
|
(let ((alphas (build 10 (lam (i : Integer) 7.0)))
|
||||||
(mus (build 10 (lam (i : Integer) (mkvec 3))))
|
(mus (build 10 (lam (i : Integer) (mkvec 3))))
|
||||||
(qs (build 10 (lam (i : Integer) (mkvec 3))))
|
(qs (build 10 (lam (i : Integer) (mkvec 3))))
|
||||||
(ls (build 10 (lam (i : Integer) (mkvec 3)))))
|
(ls (build 10 (lam (i : Integer) (mkvec 3)))))
|
||||||
|
(gmm_knossos_gmm_objective x alphas mus qs ls gamma m)))
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
(def main ()
|
||||||
|
(let (x (build 18 (lam (i : Integer) (mkvec 3))))
|
||||||
|
(let ((alphas (build 10 (lam (i : Integer) 7.0)))
|
||||||
|
(mus (build 10 (lam (i : Integer) (mkvec 3))))
|
||||||
|
(qs (build 10 (lam (i : Integer) (mkvec 3))))
|
||||||
|
(ls (build 10 (lam (i : Integer) (mkvec 3))))
|
||||||
|
(z10x3 (* mus 0))
|
||||||
|
(zeros_x (* x 0))
|
||||||
|
(delta 0.001))
|
||||||
(pr x
|
(pr x
|
||||||
(mul$Mat$Vec (gmm_knossos_makeQ (index 0 qs) (index 0 ls)) (index 0 x))
|
(mul$Mat$Vec (gmm_knossos_makeQ (index 0 qs) (index 0 ls)) (index 0 x))
|
||||||
(gmm_knossos_gmm_objective x alphas mus qs ls 1.3 1.2)
|
(gmm_knossos_gmm_objective x alphas mus qs ls 1.3 1.2)
|
||||||
(D$gmm_knossos_gmm_objective x alphas mus qs ls 1.3 1.2)
|
(D$gmm_knossos_gmm_objective x alphas mus qs ls 1.3 1.2)
|
||||||
(fwd$gmm_knossos_gmm_objective x alphas mus qs ls 1.3 1.2
|
(fwd$f x 1.3 1.2
|
||||||
x alphas mus qs ls 1.3 1.2)
|
zeros_x delta 0.0)
|
||||||
|
(- (f x (+ 1.3 delta) 1.2)
|
||||||
|
(f x 1.3 1.2))
|
||||||
|
|
||||||
))))
|
))))
|
||||||
|
|
Загрузка…
Ссылка в новой задаче