Principle of virtual work

論文のリンクを貼っていくだけです。仮想仕事をしたいです。

確率モナドを使ったMCMCサンプラーをScalaで実装してみる

Qiitaから移行です。

一度、会社のブログに書いた内容なのですが、直したい部分があったため書き直しました。 ここで書いている内容はとりあえず動けばいいというスタンスで実験的にコードを書いていました。最近、以下のrepositoryで実装し直しています。stanzという名前はもともとstanっていう統計モデリングDSLで書けるツールがあるので、最後にzをつけた感じです。純粋関数型で実装されたライブラリは最後にzをつけることがあるみたいな? https://github.com/sdual/stanz

はじめに

 関数型言語において参照透過性が保たれるためには、ある値を代入された関数はいかなる場合でも同じ値を返さなければなりません。しかし、確率的事象を扱いたい場合には確率に従って毎回異なる値を返してしまうため、その性質が壊されてしまいます。離散確率については取りうる値の数が限られているため、それぞれの事象とその確率をリストとして保持し同時確率などの計算を出来るようにする確率モナドが知られています(すごいH本などに書いてあります)が、連続確率を使った計算を行いたい場合には必要なときに確率分布からのサンプリングを行わなければならず、結果的に参照透過を壊してしまいます。その解決策として連続確率をモナドとして扱い、参照透過な実装とサンプリングの操作を完全に切り離す方法が考えられます。この場合は確率分布からのサンプリングは実際に必要になるまではされず(遅延評価)、オブジェクトがネストした計算の構造のみが先に作られます。そして、それとは別にインタプリタと呼ばれるそれぞれのオブジェクトに対してどのような処理を実行すべきかの定義を実装し、最後にそのインタプリタを実行することで全てのサンプリングの処理を走らせます。  ここでは統計モデリングで使われる Metropolis-Hastings 法の実装をして、実際に簡単な問題に対してその事後確率の分布からのサンプリングをしてみます。この記事では全体を通して scala を使って実装していきます。  今回扱う内容は確率モナドの有用性を理解することを目的としており、実用性は重視していません。もし、業務などで統計モデリングが必要な場合は Stan なんかを使う方がよいかと思います。

ベイズの定理を表現する

 ベイズ統計ではデータだけでなく、その背後にあるパラメータも確率的に生成されていると考えます。例えば、コインを投げたときに表が出る確率を知りたいとします。 正常なコインであればその確率は $ \frac{1}{2} $ であるはずですが、それを実際にコインを繰り返し投げることで調べることができます。この場合、ベイズの定理を用いて事後分布を計算することで、表が出る確率自体の分布を得ることができます。  ベイズの定理は以下のように書かれ、ここでは $D$ を実際に観測したデータ、$ \theta $ をコインの表が出る確率を表します。

p(\theta | D) \propto p(D | \theta) \ p(\theta)

このベイズの定理により、$ p(\theta) $ という事前分布を与え、データの分布からコインの表が出る確率自体の分布である事後分布 $ p(\theta | D) $ を得ることができます。ここで尤度である $ p(D | \theta) $ はパラメータが与えられたときのデータの分布を表します。  もし、この尤度をScalaの関数で表現することができたら f: Coin => Distribution[Data] のように書かかれるでしょう。つまり、Distribution が以下のような flatMap を持っているとすれば、ベイズ統計の枠組みを表現できるかもしれません。

def flatMap(f: Coin => Distribution[Data]): Distribution[Data]

また、ご存知のように、flatMap を定義できるならば、Distributionモナドとして振る舞うことが期待できます。もちろん、モナド則など満たすべき条件は他にもありますが、ここでは詳細を調べることはしません。

確率モナド

 はじめに書いたように離散的な確率分布であれば確率変数はいくつかの決まった値しかとらないためリストを用いて表現できますが、連続的な確率を扱う場合は無限要素を持つリストを用意するわけにもいかないので、確率分布を用意し必要なときにサンプリングする方法を取ります。しかし、サンプリングされた値は毎回異なるため、参照透過性が壊れてしまいます。そのためここでは連続確率をモナドにすることで、参照透過でないサンプリング操作を分離します。(実装方法は Practical Probabilistic Programming with Monads http://mlg.eng.cam.ac.uk/pub/pdf/SciGhaGor15.pdf を参考にしています。)  確率分布を表現するために まず、Distribution という型を用意します。

trait Distribution[A]

これがモナドであるためには flatMappoint (PureとかReturnと同じ)が実装されている必要があります。具体的な処理の実装を切り離したいので、以下のような必要な処理のプレースホルダーとして以下のようなcase classを用意します。Conditional は事前分布と尤度関数を扱うために使うもので、 Primitive は組み込み関数としてもっている分布関数をラップして確率モナドとして扱えるようにするものです。

case class Point[A](value: A) extends Distribution[A]
case class FlatMap[A, B](dist: Distribution[A], f: A => Distribution[B]) extends Distribution[B]
case class Primitive[A](fa: Samplable[A]) extends Distribution[A]
case class Conditional[A](dist: Distribution[A], likelihood: A => Probability) extends Distribution[A]

また、実際の確率の値には Probability という型にしておきます。

type Probability = Double

これらのcase classを用いて確率モナドを以下のように定義します。bindflatMap 対応するものです。 モナドはファンクターである必要があるので、map の定義もしておきます。(ここではscalazの力を少し借りることにしました。)

import scalaz.{Functor, Monad}

sealed trait DistributionInstances {

  implicit val distributionInstance = new Functor[Distribution] with Monad[Distribution] {

    def point[A](a: => A): Distribution[A] = Point(a)

    def bind[A, B](fa: Distribution[A])(f: A => Distribution[B]): Distribution[B] = FlatMap(fa, f)

    override def map[A, B](fa: Distribution[A])(f: A => B): Distribution[B] = FlatMap(fa, (a: A) => Point(f(a)))

  }

}

object DistributionInstances extends DistributionInstances

ここまでを用意しておけば、Distributionモナドとして振る舞うことができ、flatMapの糖衣構文としてfor文を使うことがでます。今のところ具体的な処理を定義していないので、Point, FlatMap, Primitiveなどのcase classがネストしたオブジェクトが作られるだけです。これらの型に対応した処理を定義したインタプリタをつくります。  以下のように trait Distributionインタプリタとして sample メソッドを定義します。sample メソッドの中にそれぞれの型についてどのような動作をするか定義しておきます。FlatMap の処理がネストしているのは stackless にするためで、Scalaの場合、末尾再帰にしないと最適化されないためです。

import scala.annotation.tailrec

trait Distribution[A] {

  def sample(random: Random): A = {

    @tailrec
    def loop(dist: Distribution[A], random: Random): A = {
      dist match {
        case pt1: Point[A]      => pt1.value
        case fm1: FlatMap[A, _] => fm1.dist match {
          case pt2: Point[A]      => loop(fm1.f(pt2.value), random)
          case fm2: FlatMap[A, _] => loop(fm2.dist flatMap (a => fm2.f(a) flatMap fm1.f), random)
          case pr2: Primitive[A]  => loop(fm1.f(pr2.fa.sample(random)), random)
        }
        case pr1: Primitive[A] => pr1.fa.sample(random)
        case _ => throw new Exception("can't sample.")
      }
    }

    loop(this, random)
  }
}

 組み込みの確率分布関数などをラップして確率モナドとして扱うために導入した Primitive の使い方についてですが、例えば、scala.util.Random にある nextGaussian() という関数からサンプリングしたい場合、以下のようなclassを作って、sampleというメソッドを持たせます。しかし、このままでは Normal は確率モナドとして扱うことができないため、Primitive(new Normal(mean, stdDev)) のようにします。上記のインタプリタで内で Primitivesample が呼ばれたときには下で定義されている sampleメソッドが呼ばれ確率分布からのサンプリングが行われます。

import scala.util.Random

sealed trait Sampleable[A] {
  def sample(random: Random): A
}

class Normal(mean: Double, stdDev: Double) extends Sampleable[Double] {
  def sample(random: Random): Double = {
    val sampled = random.nextGaussian()
    (stdDev * sampled) + mean
  }
}

Metropolis-Hastings 法

 ここでは上で定義した確率モナドを使って実際にMetropolis-Hastings法の実装をしてみたいと思います。Metropolis-Hastings 法のアルゴリズムについてはここでは説明しないので、統計モデリングの本などを参照して頂ければと思います。Metropolis-Hastings 法のアルゴリズムを実装しているだけで、確率分布は全て確率モナドにしている以外は特別なことはしていません。

class MetropolisHastings {

  def run[A](n: Int, d: Distribution[A]): Distribution[List[A]] = {

    val proposal: Distribution[(A, Probability)] = Prior.prior(d)

    @tailrec
    def iterate(i: Int, prob: Distribution[List[(A, Probability)]]): Distribution[List[(A, Probability)]] = {
      i match {
        case 0 => prob
        case _ =>
          val nextDist = for {
            p <- prob
            (v1, p1) = p.head
            prop <- proposal
            (v2, p2) = prop
            isAccepted <- bernoulli(1.0 min p2 / p1)
            next = if (isAccepted) (v2, p2) else (v1, p1)
          } yield next :: p
          iterate(i - 1, nextDist)
      }
    }

    for {
      result <- iterate(n, proposal.map(x => List(x)))
    } yield result.map(x => x._1)

  }

}

bernoulliというメソッドは以前に説明した Primitive を使って、ベルヌーイ分布からのサンプリングを行います。

def bernoulli(prob: Probability): Distribution[Boolean] = {
  Primitive(new Bernoulli(prob))
}

Prior.prior(d) は提案分布を与えています。prior の処理は引数でテストデータ受け取って尤度関数を作ることです。

prior関数をStacklessにする

prior関数を参考にした論文通りに愚直に実装すると以下のようになりますが、Scalaの場合これでは末尾再帰になっていないため最適化されません。

object Prior {

  def prior[A](dist: Distribution[A]): Distribution[(A, Probability)] = {

    def loop(dist: Distribution[A]): Distribution[(A, Probability)] = {
      dist match {
        case cond: Conditional[A] =>
          for {
            vp <- loop(cond.dist)
            (v, p) = vp
          } yield (v, p * cond.likelihood(v))
        case fm: FlatMap[A, _] =>
          for {
            vp <- loop(fm.dist)
            (v, p) = vp
            y <- fm.f(v)
          } yield (y, p)
        case otherwise =>
          for {
            v <- otherwise
          } yield (v, 1.0)
      }
    }
  }
}

この関数を最適化するにはトランポリンと呼ばれる手法を使うしかありません。トランポリンはscalazの中にすでにFree Monadを使った実装があるのでそれを使うことにしました。トランポリンを使って実装し直すと以下のようになります。トランポリン自体もインタプリタを定義しており、確率モナドの実装と殆ど同じことをやっています。ここでは詳細は説明しません。

import scalaz.Free.Trampoline
import scalaz.Trampoline

object Prior {

  def prior[A](dist: Distribution[A]): Distribution[(A, Probability)] = {

    def loop(dist: Distribution[A]): Trampoline[Distribution[(A, Probability)]] = {
      dist match {
        case cond: Conditional[A] =>
          Trampoline.suspend(loop(cond.dist).map(l => l.map(vp => (vp._1, vp._2 * cond.likelihood(vp._1)))))
        case fm: FlatMap[A, _] =>
          Trampoline.suspend(loop(fm.dist).map(l => l.flatMap(vp => fm.f(vp._1).map(y => (y, vp._2)))))
        case otherwise =>
          Trampoline.done(otherwise.map(v => (v, 1.0)))
      }
    }
    loop(dist).run
  }

}

サンプリングを実行してみる

 データは Conditional を使って尤度関数と共に与えます。尤度関数はモデルパラメータを引数にとりデータが既に適用された状態、例えば、$y = ax + b$ のような線形回帰をする場合、尤度をガウス分布だとすると、データ x, y は既に適用され、パラメータを引数にとる確率密度を返す関数を用意します。

import org.apache.commons.math3.distribution.NormalDistribution

def func(param: (Double, Double)): Probability = {
  val norm = new NormalDistribution(param._1 * data._1 + param._2, 1.0)
  norm.density(data._2)
}

これを尤度関数として、Conditionalに入れて、全データを使って Condtional がネストした構造を作ります(foldLeftとか使えば簡単に作れる。)。

val dist = Conditional(...Conditional(Conditional(priorDist, func1), func2)...)

func1, func2, ...は上に定義したfuncにデータが適用されたもので、パラメータ $a$, $b$ を引数に取る関数です。最も内側にある priorDist は事前分布なのでパラメータ a, b の事前分布を与えます。これもガウス分布だとすると、

val priorDist: Distribution[(Double, Double)] = for {
  a <- Primitive(new Normal(0.0, 1.0))
  b <- Primitive(new Normal(0,0, 1.0))
} yield (a, b)

のように同時確率を作ります。このようにデータと事前分布を与えることで、priorの中で尤度が計算されます。ここまで実装すれば実際に動かすことができます。

import scala.util.Random

val r = new Random
val n = 100000 // 回数

val mh = new MetropolisHastings
val posterior: Distribution[List[(Double, Double)]] = mh.run(n, dist) // 上で作ったdistを代入

val result = posterior.sample(r)

prosterior が事後分布を表すオブジェクトになっていますが、これを作った時点ではまだサンプリングは行われてなく、ただcase classがネストした計算構造だけが作られています。最後の行の sample を実行して初めて必要なサンプリングが全て行われ、事後分布が数値として得られます。この実装はあたかも直接事後分布から直接サンプリングしているかのように見えます。

 今回は実装方法の説明なのであまり重要ではないですが、実際に $y = -0.5 x + 0.3$ にガウスノイズを加えたデータを使って計算してみたので、その結果を載せておきます。$a$ が-0.5、$b$ が0.3の周りに分布していることがわかります。

param_a.png

param_b.png

おわりに

 長々と説明してきましたが、注目すべきことは確率モナドを使うことで毎回結果の違うサンプリングという操作を他の参照透過な実装から完全に切り離すことが出来ることです。さらに構造だけを先に作って最後にサンプリングを行っているため、全ての必要なサンプリングが遅延評価され、あたかも事後分布から直接サンプリングしているかのように見えます。今回の例はそれほど実用性を重視していませんでしたが、確率モナドを理解するという点で非常に勉強になったと思います。  ここで使ったコードをgithubにあげておきます。 https://github.com/sdual/proba-monad

参考文献