NumbaによるPythonコードの高速化: C++で書き直した場合と比較してみる

はじめに

Pythonは簡潔にコードが書け、ライブラリが充実しており、アイデアをすばやく形にするのに向いた言語だと思う。
しかし、実行速度については、CやC++などのように静的に型付けされ、機械語に直接変換されて実行される言語に比べると、どうしても遅くなってしまう。

そのため、時間がかかる処理はCやC++で記述し、Pythonから呼び出して手軽に使えるようにするという手段もとられる。 その他、PythonコードをPythonコードのまま高速化する手段の1つとしてNumbaがある。

本記事では、Numbaを用いて、適当なコードの高速化を試してみる。 特に、Pythonコードを人力でC++コードに書き直した場合との速度比較も行う。もし、Numbaで高速化した場合とあまり変わらなければ、一度Pythonで書いたものを単に高速化のために別の言語で書き直す必要がなくなる。

なお、どの程度高速化されるかは、対象のコードや環境によって大きく変化するため、あくまでも1つの例として考えてほしい。

環境

  • Windows 10 Pro 21H2
  • CPU: AMD Ryzen 9 3900X
  • Python
    • Python 3.10.5 (AMD64)
    • numpy 1.22.4
    • numba 0.55.2
  • C++
    • Microsoft Visual Studio 2022 (Microsoft C/C++ Optimizing Compiler Version 19.32.31332 for x64)
    • clang version 13.0.1 (Target: x86_64-pc-windows-msvc)

Numbaについて

Numbaは、Pythonの関数を実行時に最適化された機械語に変換することで、高速な実行を実現する。 実行時コンパイルにはLLVMを使用している。

使用方法はとても簡単で、基本的には@njitなどのデコレータを高速化したい関数に適用するだけでよい。

from numba import njit

@njit
def test(n):
    res = 1
    for i in range(1, n+1):
        res *= i
    return res

注意としては、NumbaはPythonの機能をすべてサポートしているわけではないので、対象関数のコードの書き換えが必要な場合がある。

また、Numbaにはnopythonモードとobjectモードがあり、objectモードの場合はあまり高速化されない(または、遅くなる)場合があるため、nopythonモードの使用が推奨される(@njitは、@jit(nopython=True)と同じ意味で、nopythonモードを強制する)。 ただし、nopythonモードを使用する場合は、サポートされるPythonの機能がより限定される。

Numbaの特徴として、NumPyを使用したコードの高速化を意識して設計されているため、実際にはNumPyをメインに使用したコードに対してNumbaを使用することが多いと思う。

また、念のため記載しておくが、Pythonの標準機能やNumPyによる数値計算的な処理が実行時間の多くを占め、かつ、ある程度実行時間が長い場合でないと、この記事で扱っているような高速化は実質的には意味がない場合がある。例えば、(Numbaで高速化できない)ネットワークやディスクアクセスが処理時間のほとんどを占めるなら、数値計算的な処理の部分だけ高速化してもユーザーにとってはほとんど意味がない。

Cython

Numbaの他に、有名なPythonの高速化手段としてCythonがある。Cythonは、Pythonにオプションで静的型付けを導入したような言語(あるいはそのコンパイラのこと)で、Cythonコンパイラは、CythonコードをCまたはC++コードに変換する。その後変換されたC/C++コードをC/C++コンパイラでコンパイルすることで、Pythonの拡張モジュールができ、簡単に普通のPythonコードから利用できるようになる。

既存のPythonコードの高速化のためには、(記事執筆時点の筆者の感覚としては)Numbaの方が手軽に感じているため、この記事ではNumbaを主体に置いているが、今後Cythonの記事も書くかもしれない。

Numbaによる高速化

今回は、ナンバープレイスという数字パズルを解くコードを高速化する。なぜこのコードかと言うと、筆者は昔このパズルをやっていたから、そして、あまりにも単純なコードでは少し恣意的な気がするからである。

ナンバープレイスはナンプレ、数独などとも呼ばれるパズルで、ルールを知らない場合はWikipediaの説明などを参照してほしい。

以下は、記事執筆時点でWikipediaのページに掲載されているナンプレの問題である(画像はパブリックドメインのものを転載)。

ナンプレの問題

以下にナンプレを解く&実行時間の計測を行うコードを掲載する。 ナンプレの問題は上の画像と同じものを使用している。
ナンプレを解く部分に関しては、コード中のコメントを参照してほしい。 実行時間の計測については、具体的にはtest()の実行時間を10回測定する。 test()は、ナンプレを解くsolve()TEST_REPEAT(=今回は1万)回繰り返す。
(本当は多数の異なる問題を解くようにしたかったのだが、問題を用意するのが大変そうなのでこのようにした。不自然に見えるかもしれないがご容赦いただきたい。)

なお、以下のコードは「高速化後」のコードで、「高速化前」のコードは単にfrom numba import njitの行と@njitの行をコメントアウトしたものである(逆に言えば、これらをつけるだけで高速化できる)。

高速化後のコード
import time
import numpy as np
from numba import njit

TEST_REPEAT = 10000   # 速度テストでナンプレを解く回数


# 各マスについて、入る可能性がある数字を減らしていく
# 
# answerは、ナンプレの確定した数字を格納するnumpy.ndarray。shapeは(9, 9)。
#   answer[i, j]が1-9のいずれかなら、i+1行j+1列のマスにはその数字が入ることを表し、
#   0ならそのマスは未確定であることを表す。
# candは、ナンプレの各マスについて入る可能性がある数字を表すnumpy.ndarray。shapeは(9, 9, 9)。
#   cand[i, j, a]が1なら、i+1行j+1列のマスには、数字a+1は入らないことを表し、
#   0なら、入る可能性がある(その時点で入る可能性を否定できない)ことを表す。
@njit
def set_candidate(answer, cand):
    for i in range(9):
        for j in range(9):
            if answer[i, j] == 0:
                continue

            a = answer[i, j] - 1
            # 確定した数字がある場所と同じ行・列に、同じ数字が入らないようにする
            cand[i, :, a] = 1
            cand[:, j, a] = 1

            # 3x3のブロックに同じ数字が入らないようにする
            i0 = i // 3 * 3
            j0 = j // 3 * 3
            cand[i0:i0+3, j0:j0+3, a] = 1


# 各マスについて、入る可能性がある数字が1つなら、そのマスをその数字で確定させる
# -> 未確定のマスの個数を返す。ただし、どの数字も入らないマスがある場合(矛盾がある場合)は、-1を返す。
@njit
def set_answer(answer, cand):
    blank = 0

    for i in range(9):
        for j in range(9):
            if answer[i, j] > 0:
                continue

            s = 0
            c = -1
            for k in range(9):
                if cand[i, j, k] == 0:
                    s += 1
                    c = k

            if s == 0:  # 矛盾がある場合
                return -1

            if s == 1:
                answer[i, j] = c + 1    # 入る数字が確定
            else:
                blank += 1

    return blank


# set_candidateとset_answerを交互に呼び出し、未確定のマスの個数が減らなくなったら、
# 入る可能性がある数字を仮に入れてみる
# -> 問題が解けた場合(未確定のマスが0個)、Trueを返し、矛盾がある場合はFalseを返す。
@njit
def _solve(answer, cand):
    _blank = -1
    while True:
        set_candidate(answer, cand)
        blank = set_answer(answer, cand)
        if blank == 0:
            return True
        if blank < 0:
            return False
        if blank == _blank:
            break
        _blank = blank

    for i in range(9):
        for j in range(9):
            for k in range(9):
                if cand[i, j, k] == 0:
                    _answer = answer.copy()
                    _answer[i, j] = k + 1   # 仮に数字を入れてみる
                    if _solve(_answer, cand.copy()):
                        answer[:] = _answer
                        return True

    return False


# 問題を解く
# -> 問題が解けた場合(未確定のマスが0個)、Trueを返し、矛盾がある場合はFalseを返す。
@njit
def solve(answer):
    cand = np.zeros((9, 9, 9), dtype=np.uint8)
    return _solve(answer, cand)


# 速度テスト用
@njit
def test(answer, q):
    for _ in range(TEST_REPEAT):
        answer[:] = q
        solve(answer)


# 表示用
def show(answer):
    for i in range(9):
        for j in range(9):
            c = answer[i, j]
            print("{}  ".format(c if c else "."), end="")
        print("")


# 速度テストを行う
def main():
    q = np.array([
        [5, 3, 0, 0, 7, 0, 0, 0, 0],
        [6, 0, 0, 1, 9, 5, 0, 0, 0],
        [0, 9, 8, 0, 0, 0, 0, 6, 0],
        [8, 0, 0, 0, 6, 0, 0, 0, 3],
        [4, 0, 0, 8, 0, 3, 0, 0, 1],
        [7, 0, 0, 0, 2, 0, 0, 0, 6],
        [0, 6, 0, 0, 0, 0, 2, 8, 0],
        [0, 0, 0, 4, 1, 9, 0, 0, 5],
        [0, 0, 0, 0, 8, 0, 0, 7, 9]
    ], dtype=np.uint8)

    answer = np.zeros((9, 9), dtype=np.uint8)

    print("TEST_REPEAT = {}".format(TEST_REPEAT))

    for _ in range(10):
        t0 = time.perf_counter()

        test(answer, q)

        t1 = time.perf_counter()
        elapsed = (t1 - t0) * 1_000_000
        print("{:.0f} us ({:.2f} us per loop)".format(elapsed, elapsed / TEST_REPEAT))

    print("Complete.\n")
    show(answer)


if __name__ == "__main__":
    main()

コードは、CPU使用率が低い状態(1%以下)で、十分空きメモリがあるときに実行を開始した。 なお、各コードは複数回実行し、大きく結果が変わらないことを確認しており、以下の結果はその代表例である。

高速化前のコードの実行結果
TEST_REPEAT = 10000
85607257 us (8560.73 us per loop)
86896520 us (8689.65 us per loop)
86999468 us (8699.95 us per loop)
86952021 us (8695.20 us per loop)
86896046 us (8689.60 us per loop)
86904262 us (8690.43 us per loop)
87908536 us (8790.85 us per loop)
88002306 us (8800.23 us per loop)
86212913 us (8621.29 us per loop)
86112582 us (8611.26 us per loop)
Complete.

5  3  4  6  7  8  9  1  2
6  7  2  1  9  5  3  4  8
1  9  8  3  4  2  5  6  7
8  5  9  7  6  1  4  2  3
4  2  6  8  5  3  7  9  1
7  1  3  9  2  4  8  5  6
9  6  1  5  3  7  2  8  4
2  8  7  4  1  9  6  3  5
3  4  5  2  8  6  1  7  9

高速化前では、1万回問題を解くのにかかる時間(test()の実行時間)は86.85秒(10回のtest()の平均)で、 これを1回問題を解く時間に換算すると8.7 msだった。

高速化後のコードの実行結果
TEST_REPEAT = 10000
1721713 us (172.17 us per loop)
112380 us (11.24 us per loop)
112737 us (11.27 us per loop)
111686 us (11.17 us per loop)
110989 us (11.10 us per loop)
112200 us (11.22 us per loop)
112369 us (11.24 us per loop)
111881 us (11.19 us per loop)
112128 us (11.21 us per loop)
112395 us (11.24 us per loop)
Complete.

5  3  4  6  7  8  9  1  2
6  7  2  1  9  5  3  4  8
1  9  8  3  4  2  5  6  7
8  5  9  7  6  1  4  2  3
4  2  6  8  5  3  7  9  1
7  1  3  9  2  4  8  5  6
9  6  1  5  3  7  2  8  4
2  8  7  4  1  9  6  3  5
3  4  5  2  8  6  1  7  9

高速化後のコードでは、Numbaによる実行時コンパイルが行われるので、 最初のtest()の実行では時間が多くかかる。 2回目以降のtest()の実行に着目すると、1万回問題を解くのにかかる時間(test()の実行時間)は0.1121秒(9回のtest()の平均)で、これを1回問題を解く時間に換算すると11 μsだった。

Numbaにより、簡単に800倍程度(コンパイル後の場合)も実行速度を高速化できた。 ただし、実行時コンパイルにある程度時間がかかることから、それが問題になる場合は、 Numbaのキャッシュ機能事前(Ahead-of-Time)コンパイル機能を使用する。

PythonコードをC++で書き直した場合との比較

上記のナンプレを解く&実行時間の計測を行うコードをC++に書き直したものを以下に掲載する(Pythonのコードとかなり対応したものになっていると思う)。

C++のコードと、比較用のNumbaで高速化したPythonコードについては、test()でナンプレを解く回数(TEST_REPEAT)を100万回とした(上記で1万回としたのは、高速化前のコードの実行時間が長くなりすぎるためである)。

なお、実行時間の測定対象のコードは、ほとんどCの機能しか使用していないため、このコードをさらにCに書き直したとしてもあまり結果は変わらないと思う。

コンパイラは、Microsoft Visual Studio 2022付属のC/C++コンパイラ(以後、MSVC)と、Clang(LLVMのフロントエンドのC/C++/Objective-C等のコンパイラ)を使用した。 Clangは、MSVC互換のドライバclang-clが存在し、Visual Studioから簡単にインストール・利用できる。

MSVC, Clang (clang-cl)ともに速度を優先した最適化オプション(/O2)を指定してコンパイルした。

最適化について

test()内のループは時間測定のためで、実質的には無駄なことをやっている(本当はループの中身を1回実行するだけで、実行時間以外については同じ結果が得られる)。このようなコードは、賢いコンパイラが最適化によりループを失くすなどの「ずるい」ことをしてしまわないか少し心配になる。(コンパイラの最適化は、どのくらいコードの意味を理解したものになるのだろうか。現状筆者は全くわかっていないのでいつか調べてみたい。)

念のためコンパイラが出力したアセンブリ言語のコードを確認したが、特にそのような最適化は行われておらず、指定した回数(TEST_REPEAT)だけ、ナンプレを解いているようだった。 ちなみに、MSVC, Clangともに、solve()がインライン展開されていた。また、ループのカウンタはインクリメントでなくデクリメントされていた(TEST_REPEATから始まり0になったときにループ脱出)。

実行時間の測定について

Pythonのtime.perf_counter()は、筆者の環境(Windows)では、Windows APIのQueryPerformanceCounterが使用される。 また、C++のstd::chrono::high_resolution_clockも、MSVC2022付属ライブラリの実装ではQueryPerformanceCounterが使用されている。 また、1回のtest()の実行時間は数秒以上(後述)あるため、Python, C++の時間測定方法自体の違いによる結果への影響は小さいと思う。

C++バージョン
#include <chrono>
#include <iostream>
#include <iomanip>
#include <cstring>

static const int TEST_REPEAT = 1000000;


using Cell = unsigned char;
using NumberPlace = Cell[9][9];
using Candidate = Cell[9][9][9];

void set_candidate(const NumberPlace answer, Candidate cand)
{
    for (int i = 0; i < 9; i++) {
        for (int j = 0; j < 9; j++) {
            if (answer[i][j] == 0) continue;

            int a = answer[i][j] - 1;
            for (int k = 0; k < 9; k++) {
                cand[i][k][a] = 1;
                cand[k][j][a] = 1;
            }

            int i0 = i / 3 * 3;
            int j0 = j / 3 * 3;
            for (int k = 0; k < 3; k++) {
                for (int m = 0; m < 3; m++) {
                    cand[i0 + k][j0 + m][a] = 1;
                }
            }
        }
    }
}

int set_answer(NumberPlace answer, const Candidate cand)
{
    int blank = 0;

    for (int i = 0; i < 9; i++) {
        for (int j = 0; j < 9; j++) {
            if (answer[i][j] > 0) continue;

            int s = 0;
            int c = -1;
            for (int k = 0; k < 9; k++) {
                if (cand[i][j][k] == 0) {
                    s++;
                    c = k;
                }
            }

            if (s == 0) return -1;

            if (s == 1) {
                answer[i][j] = c + 1;
            } else {
                blank++;
            }
        }
    }

    return blank;
}

bool _solve(NumberPlace answer, Candidate cand)
{
    int _blank = -1;
    while (true) {
        set_candidate(answer, cand);
        int blank = set_answer(answer, cand);

        if (blank == 0) return true;
        if (blank < 0) return false;

        if (blank == _blank) break;
        _blank = blank;
    }

    for (int i = 0; i < 9; i++) {
        for (int j = 0; j < 9; j++) {
            for (int k = 0; k < 9; k++) {
                if (cand[i][j][k] == 0) {
                    NumberPlace _answer;
                    std::memcpy(_answer, answer, sizeof(_answer));

                    Candidate _cand;
                    std::memcpy(_cand, cand, sizeof(_cand));

                    if (_solve(_answer, _cand)) {
                        std::memcpy(answer, _answer, sizeof(_answer));
                        return true;
                    }
                }
            }
        }
    }

    return false;
}

bool solve(NumberPlace answer)
{
    Candidate cand = {};
    return _solve(answer, cand);
}

void test(NumberPlace answer, const NumberPlace q)
{
    for (int i = 0; i < TEST_REPEAT; i++) {
        std::memcpy(answer, q, sizeof(NumberPlace));
        solve(answer);
    }
}

void show(const NumberPlace answer)
{
    for (int i = 0; i < 9; i++) {
        for (int j = 0; j < 9; j++) {
            int a = answer[i][j];
            if (a == 0) {
                std::cout << ".  ";
            } else {
                std::cout << a << "  ";
            }
        }
        std::cout << std::endl;
    }
}

int main()
{
    namespace chrono = std::chrono;

    NumberPlace q = {
        {5, 3, 0, 0, 7, 0, 0, 0, 0},
        {6, 0, 0, 1, 9, 5, 0, 0, 0},
        {0, 9, 8, 0, 0, 0, 0, 6, 0},
        {8, 0, 0, 0, 6, 0, 0, 0, 3},
        {4, 0, 0, 8, 0, 3, 0, 0, 1},
        {7, 0, 0, 0, 2, 0, 0, 0, 6},
        {0, 6, 0, 0, 0, 0, 2, 8, 0},
        {0, 0, 0, 4, 1, 9, 0, 0, 5},
        {0, 0, 0, 0, 8, 0, 0, 7, 9}
    };

    std::cout << "TEST_REPEAT = " << TEST_REPEAT << std::endl;
    std::cout << std::fixed << std::setprecision(2);

    NumberPlace answer;
    for (int i = 0; i < 10; i++) {
        auto t0 = chrono::high_resolution_clock::now();

        test(answer, q);

        auto t1 = chrono::high_resolution_clock::now();
        auto elapsed = chrono::duration_cast<chrono::microseconds>(t1 - t0).count();

        std::cout << elapsed << " us ("
                  << (elapsed + 0.0) / TEST_REPEAT << " us per loop)"
                  << std::endl;
    }

    std::cout << "Complete." << std::endl << std::endl;
    show(answer);

    return 0;
}

コードは、CPU使用率が低い状態(1%以下)で、十分空きメモリがあるときに実行を開始した。 なお、各コードは複数回実行し、大きく結果が変わらないことを確認しており、以下の結果はその代表例である。

高速化後のPythonコードの実行結果
TEST_REPEAT = 1000000
11842108 us (11.84 us per loop)
10622070 us (10.62 us per loop)
10641747 us (10.64 us per loop)
10673202 us (10.67 us per loop)
10603304 us (10.60 us per loop)
10594038 us (10.59 us per loop)
10630965 us (10.63 us per loop)
10612910 us (10.61 us per loop)
10635690 us (10.64 us per loop)
10671376 us (10.67 us per loop)
Complete.

5  3  4  6  7  8  9  1  2
6  7  2  1  9  5  3  4  8
1  9  8  3  4  2  5  6  7
8  5  9  7  6  1  4  2  3
4  2  6  8  5  3  7  9  1
7  1  3  9  2  4  8  5  6
9  6  1  5  3  7  2  8  4
2  8  7  4  1  9  6  3  5
3  4  5  2  8  6  1  7  9

Numbaによる実行時コンパイルが行われるので、最初のtest()の実行では時間が多くかかる。 2回目以降のtest()の実行に着目すると、100万回問題を解くのにかかる時間(test()の実行時間)は10.63秒(9回のtest()の平均)で、これを1回問題を解く時間に換算すると11 μsだった。

C++コード(MSVCでコンパイル)の実行結果
TEST_REPEAT = 1000000
5315572 us (5.32 us per loop)
5323457 us (5.32 us per loop)
5315202 us (5.32 us per loop)
5307550 us (5.31 us per loop)
5320600 us (5.32 us per loop)
5317027 us (5.32 us per loop)
5310545 us (5.31 us per loop)
5318214 us (5.32 us per loop)
5318236 us (5.32 us per loop)
5322866 us (5.32 us per loop)
Complete.

5  3  4  6  7  8  9  1  2
6  7  2  1  9  5  3  4  8
1  9  8  3  4  2  5  6  7
8  5  9  7  6  1  4  2  3
4  2  6  8  5  3  7  9  1
7  1  3  9  2  4  8  5  6
9  6  1  5  3  7  2  8  4
2  8  7  4  1  9  6  3  5
3  4  5  2  8  6  1  7  9

C++バージョン(MSVC)では、100万回問題を解くのにかかる時間(test()の実行時間)は5.317秒(10回のtest()の平均)で、 これを1回問題を解く時間に換算すると5.3 μsだった。

C++コード(Clangでコンパイル)の実行結果
TEST_REPEAT = 1000000
4543116 us (4.54 us per loop)
4540318 us (4.54 us per loop)
4540863 us (4.54 us per loop)
4547494 us (4.55 us per loop)
4558024 us (4.56 us per loop)
4559926 us (4.56 us per loop)
4567010 us (4.57 us per loop)
4562800 us (4.56 us per loop)
4567270 us (4.57 us per loop)
4565093 us (4.57 us per loop)
Complete.

5  3  4  6  7  8  9  1  2
6  7  2  1  9  5  3  4  8
1  9  8  3  4  2  5  6  7
8  5  9  7  6  1  4  2  3
4  2  6  8  5  3  7  9  1
7  1  3  9  2  4  8  5  6
9  6  1  5  3  7  2  8  4
2  8  7  4  1  9  6  3  5
3  4  5  2  8  6  1  7  9

C++バージョン(Clang)では、100万回問題を解くのにかかる時間(test()の実行時間)は4.555秒(10回のtest()の平均)で、 これを1回問題を解く時間に換算すると4.6 μsだった。

C++で書き直した場合、Numbaで高速化したものより2倍程度実行速度が速くなった。 また、MSVCとClangでは、Clangでコンパイルした方が1~2割程度高速だった。

まとめ

今回の条件では、Numbaを用いた高速化は簡単な割に大きな効果が得られた。今回のナンプレを解くコードでNumbaを使用すると、コンパイル時間を除いた場合、実行速度が使用前のコードの800倍程度になった。 ただし、C++で書き直した場合の実行速度には及ばなかった。具体的には、コンパイル時間を除く実行速度は、C++で書き直した場合の方がNumbaを使う場合より2倍程度速くなった。

この結果をどう見るかは状況によって異なる。Numbaによる高速化で実行速度の問題が解決できることは多々あると思う。一方で、実行速度がとても重要であり、高速化しようとする部分が全体に大きな影響を与えるのであれば、2倍の差は十分CやC++で書き直す価値があるとも思った。

なお、冒頭でも書いたが、Numbaの使用やC/C++等での実装でどの程度高速化されるかは、対象のコードや環境によって大きく変化する。苦労してCやC++で書き直しても、Numbaで簡単に高速化した場合とあまり変わらないという状況も考えられる。Pythonでの(主に数値計算的な)処理が遅いと感じたら、ひとまずNumbaを使ってみてはどうだろうか。