競技プログラミング Edit

【python/アルゴリズム】メモ化再帰を基礎から解説

メモ化再帰についてのメモです。

競技プログラミングでもメモ化再帰を使って解く問題は定期的に出題されています。

メモ化再帰とは

メモ化再帰は簡単に言うと「計算結果の再利用」のアルゴリズムです。

アルゴリズムは大きく分類すると2つに分類されます。

アルゴリズムとは

  • 計算結果を再利用することで同じ計算を省く
  • 計算方法を工夫することで処理量を減らす

メモ化再帰はこのうち「計算結果を再利用」することで計算量を減らすアルゴリズムになっています。

たびすけ
例えば、\(1+2+3+4+5\)の答えは?と聞かれた時に一度目は計算しますが、結果をメモしておきます。
再度同じ問題が出されたらメモを見て答えるようなイメージです。そのままのネーミングですね!

メモ化再帰が使えるケース

メモ化再帰が使えるのは与えられた値と答えが\(1:1\)で対応する場合です。

メモ化再帰が使える場合

  • 与えられた値と答えに\(1:1\)の対応関係がある時

メモ化再帰が使える問題

「9時から15時までは何時間ですか?」という問題はメモ化再帰が使えます。

何回「9時から15時までは何時間ですか?」と聞かれても答えは「6時間」になります。

これは問題(与えられた値)と答え(計算結果)に\(1:1\)の対応関係があると言えるからです。

メモ化再帰が使えない問題

「9時から【今】までは何時間ですか?」という問題はメモ化再帰が使えません。

15時に「9時から【今】までは何時間ですか?」と聞かれた時は「6時間」が答えになりますが、16時に「9時から【今】までは何時間ですか?」と聞かれた場合の答えは「7時間」になってしまいます。

この問題(与えられた値)は答え(計算結果)と\(1:1\)の対応関係がりません。

たびすけ
メモ化再帰を使う問題としてメジャーなのが数え上げの問題です。
特に、深さ優先探索(DFS)と一緒に使うことが多いです。

pythonでメモ化再帰を実装

pythonでメモ化再帰を実装するメジャーな方法は2通りあります。

メモ化再帰の実装方法

  • (再帰)関数に「lru_cache」を追加する方法
  • dictionary(辞書)型を使う方法

個人的におすすめなのは、「dictionary(辞書)型を使う方法」です。

「(再帰)関数に【lru_cache」を追加する方法」はお手軽ですが、問題によってはTLEしてしまいます。

pythonは基本的には(再帰)関数の処理が遅い言語なので、競技プログラミングで使うことを考えると「dictionary(辞書)型を使う方法」の方をおススメします。

たびすけ
慣れてしまえば「dictionary(辞書)型を使う方法」も簡単に実装できるよ!

メモ化再帰のコーディング

メモ化再帰がpythonでどのように実装されるのか、実際の問題で確認してみましょう。

例題としてこちらの問題を使います。

(再帰)関数に「lru_cache」を追加する方法

import sys
from functools import lru_cache


@lru_cache(maxsize=None)
def alg_memoization_recursion(n):
    if n == 1:
        return 1
    else:
        return alg_memoization_recursion(n // 2) + alg_memoization_recursion(n // 2) + 1


def solve():
    input = sys.stdin.readline
    mod = 10 ** 9 + 7
    h = int(input().rstrip('\n'))
    print(alg_memoization_recursion(h))


if __name__ == '__main__':
    solve()

dictionary(辞書)型を使う方法

import sys
import collections


def solve():
    input = sys.stdin.readline
    mod = 10 ** 9 + 7
    h = int(input().rstrip('\n'))
    ql = [h]
    d = collections.defaultdict(int)
    d[0] = 0
    while len(ql):
        vl = ql[-1]
        if vl // 2 in d:
            d[vl] = d[vl // 2] * 2 + 1
            ql.pop()
        else:
            ql.append(vl // 2)
    print(d[h])


if __name__ == '__main__':
    solve()

-競技プログラミング, Edit
-