Haskellでメモ化再帰

はじめに

Haskellでメモ化再帰が楽にできるようにしてみたかったので、やってみた。競技プログラミングに使えればいい、というぐらい。

IxData.Ix.Ix で、UnboxableData.Vector.Unboxing.Unboxable です。

memoize ::
  (Ix i, Unboxable e) =>
  (i, i) ->
  (forall m. Monad m => (i -> m e) -> i -> m e) ->
  i -> e

使用例

型を見ればなんとなくわかると思います。Data.Function.Fix.fix とほぼ同じです。Mintは modint です。

fib = memoize (0, 10^6) $ \fib n ->
  if n <= 1 then return $ Mint n
  else liftM2 (+) (fib $ n-1) (fib $ n-2) 

実装

本当はSTUArrayとかを使いたかったんですが、値が unboxed であるという制約を課すのが面倒で、Ixだけ拝借しました。ついでに、メモには modint をのせたいことが多いので、newtype フレンドリーな unboxed vector を使いました。

そこまで凝ったことはしていないので、特に解説はありません。

本質

高速化のための書き方は、実装を眺める上ではノイズになりがちです。省いたやつを書いてみます。どうだろう、修飾なしは逆にわかりにくいところもあるかもしれない。遅いはずですが、計測していません。

import Prelude hiding (read, replicate)
import Data.Vector ((!), freeze)
import Data.Vector.Mutable
import Data.Ix

memoize ::
  (Ix i) =>
  (i, i) ->
  (forall m. Monad m => (i -> m e) -> i -> m e) ->
  i -> e
memoize ran fun i = memo ! index ran i where
  size = rangeSize ran
  memo = runST $ do
    memo <- new size
    memd <- replicate size False
    forM_ (range ran) $ calc memo memd
    freeze memo
  calc memo memd i = do
    written <- read memd ix
    unless written $ do
      write memo ix =<< fun (calc memo memd) i
      write memd ix True
    read memo ix
    where ix = index ran i

実際の実装

高速化と修飾をしたやつです。unsafe まみれですが、memd がチェックしているので、あまり危険ではないはずです。

import qualified Data.Ix as I
import qualified Data.Vector.Unboxing as U
import qualified Data.Vector.Unboxing.Mutable as U

memoize ::
  (I.Ix i, U.Unboxable e) =>
  (i, i) ->
  (forall m. Monad m => (i -> m e) -> i -> m e) ->
  i -> e
memoize ran fun i = res where
  res = U.unsafeIndex memo I.index ran i
  size = I.rangeSize ran
  memo = runST $ do
    memo <- UM.unsafeNew size
    memd <- UM.replicate size False
    forM_ (I.range ran) $ calc memo memd
    U.unsafeFreeze memo
  calc memo memd i = do
    written <- UM.unsafeRead memd ix
    unless written $ do
      UM.unsafeWrite memo ix =<< fun (calc memo memd) i
      UM.unsafeWrite memd ix True
    UM.unsafeRead memo ix
    where ix = I.index ran i

おわりに

まだ実用していないのに記事に仕立て上げてしまった。

言語拡張を書くのを忘れていたことに気づきました。RankNTypes が必要です。