メモ化再帰ちょっと理解した
概要
メモ化再帰をほんの少し理解したので書く。
ただしこれはプログラマ文脈で言うところの「チョットデキル」的な意味合いではなく、むしろ「完全に理解した」的なニュアンスであることに注意されたし。
なお、これは再帰関数が難しいと感じられる人にも分かりやすいように工夫して書いたので、再帰関数を理解したい人にも読んでいただきたい。
フィボナッチ数列でヤる
フィボナッチ数列は漸化式が F(N) = F(N-1) + F (N-2)
となる数列です。
私の雑な理解としては、漸化式を書ける問題は大体再帰で解けます。そして再帰で書くと無駄な計算をしてしまうことが多いので、メモ化して高速化することが必要となるのです。
さて、というわけで早速フィボナッチ数列を求める再帰関数を書いてみましょう。
愚直コード
import java.util.*; /** * fibonacci on Java */ public class fibonacci { public static void main(String[] args) { //time count start long start = System.currentTimeMillis(); Scanner scan = new Scanner(System.in); int n = scan.nextInt(); scan.close(); long ans = calc(n); System.out.println(ans); //time count finish long fin = System.currentTimeMillis(); System.out.printf("%dms\n", fin - start); } public static long calc(int n) { if (n == 0) return 0; if (n == 1) return 1; long ret = 0; ret += calc(n-1) + calc(n-2); return ret; } }
実行時間計測のため幾つかコードを挿入していますが、基本的にはナイーブな解法です。
これでフィボナッチ数列を求めることができます。
しかし遅い
遅いです、実に遅いです。
手元環境だとN = 40 を求めるのに4358msもかかっています。(ただこれはJavaを使っているのと、私のPCのスペックがアなのもある。誰かお金ちょうだい)
これではTLE待ったなしです。
ここで、このアルゴリズムを見直してみます。 一般項を使えばO(1)とかいう数強はちょっと静かにしててね。
例えばフィボナッチ数列の5項目、即ちF(5)を求めるには、
F(5) = F(4) + F(3) F(4) = F(3) + F(2) F(3) = F(2) + F(1) F(2) = F(1) + F(0) F(1) = 1 F(0) = 0
以上に掲げた数式をなぞる必要があります。ここで関数の呼び出し回数を数えてみると、以下の画像のようになります。
赤文字が呼び出した順番を示しています。(指で流れを追ってもらえると、再帰関数がどう挙動するのかが分かりやすいかもしれない)
F(0)とF(1)はともかくとして、F(2)やF(3)を2回3回と呼び直しているのは非常に非効率です。
例えばF(3)を求めるにはF(2)とF(1)とF(0)を呼び出す必要があります。
F(2)を求めるにはF(1)とF(0)を呼び出さなければならないので、F(3)のために合計4回の呼び出しをする必要があります。
一度目は仕方ありませんが、それ以降にF(3)を呼び出すと、その度にまた計4回の呼び直しをしなければならなくなります。
F(3)を4回求めれば、合計16回の呼び出しが必要となるのです。
これでF(40)などの大きな数を求めてしまうと、莫大な数の呼び直しが発生してしまい、結果的にその分総呼び出し回数が増えて計算時間が長くなってしまいます。
そこでメモ化である
では、呼び直さなければ良いのです。
それを実現するのがメモ化という再帰を高速化する一般的なテク(実際にとても一般的)です。
どういうテクなのか
memo[N]という配列を用意して、一度F(N)を求めたらそいつにぶち込んでおき、再度F(N)を呼び出す必要が出た際にはmemo[N]を返すというテクです。
F(3)の例でいうと、最初は4回の呼び出しが必要になりますが、それ以降の呼び直しはmemo[N]を返すという1動作で完遂することができます。
つまりメモ化してF(3)を2回求めると 4 + 1 = 5
より呼び出し回数は5回となり、計算量をぐっと抑えることができます。
そしてこれはNが大きくなるほど違いが顕著になります。
例えばF(20)を求める際の総呼び出し回数は21891回ですが、メモ化することによって39回にまで減らすことができます。すごい。
実験
というわけでちょっと表を作ってみました。(50から実行時間がとても厳しくなってきたので45で妥協しました)
N | 愚直(メモ化なし)呼び出し回数 | メモ化あり呼び出し回数 |
---|---|---|
1 | 1 | 1 |
10 | 177 | 19 |
20 | 21891 | 39 |
30 | 2692537 | 59 |
40 | 331160281 | 79 |
45 | 3672623805 | 89 |
50 | TLE | 99 |
60 | TLE | 119 |
70 | TLE | 139 |
80 | TLE | 159 |
90 | TLE | 179 |
100 | TLE | 199 |
200 | TLE | 399 |
300 | TLE | 599 |
こんな感じになりました。アルゴリズムの勝利ですね。
最後に実際の実装をお見せして〆たいと思います。
実装
import java.util.*; /** * fibonacci on java */ public class fibonacci { //グローバル変数にメモ用の配列を宣言 static long[] memo; public static void main(String[] args) { long start = System.currentTimeMillis(); Scanner scan = new Scanner(System.in); int n = scan.nextInt(); scan.close(); memo = new long[n+10]; Arrays.fill(memo, 0); long ans = calc(n); System.out.println(ans); long fin = System.currentTimeMillis(); System.out.printf("%dms\n", fin - start); } public static long calc(int n) { //memo[n]が0じゃなければreturn if (memo[n] != 0) return memo[n]; //memo[n]が0だった場合は以下の処理 if (n == 0) return 0; if (n == 1) return 1; long ret = 0; ret += calc(n-1) + calc(n-2); //呼び直しを起こさないようmemo[n]に値を格納 memo[n] = ret; return ret; } }
まとめ
メモ化は再帰を覚えれば比較的簡単に書けます。
戦略の幅がぐっと広がるので、是非マスターしましょう!
皆精進しような!!