618ZXW

[TVM チュートリアル] Tensorize を使用してハードウェア インライン関数を活用する

Apache TVMは、CPU、GPU、そして様々な機械学習アクセラレーションチップに適した、エンドツーエンドのディープラーニング構築フレームワークです。中国語版のTVMドキュメントは、→ https://tvm.hyper.ai/ をご覧ください。

著者: 劉亦志

この記事では、TVM でテンソル量子化を実行する方法について説明します。

スケジューリング プリミティブ テンソライズを使用すると、計算ユニットを対応するインライン関数に置き換えることができるため、手書きのマイクロカーネルの使用が可能になり、TVM を拡張して新しいハードウェア アーキテクチャをサポートできるようになります。

このチュートリアルの目的は、効果的なソリューションを提供することではなく、テンソライズの機能と使用方法を示すことです。

 from __future__ import absolute_import, print_function import tvm from tvm import te import tvm.testing import numpy as np

行列の乗算を定義する

行列の乗算を例に挙げると、Matmulはまず2つの行列の対応する要素を乗算し、次に特定の軸に沿って合計します。次のコードは、TVMにおけるA * B^Tの計算を記述しています。

 N, M, L = 1024, 512, 64 A = te.placeholder((N, L), name="A") B = te.placeholder((M, L), name="B") k = te.reduce_axis((0, L), name="k") C = te.compute((N, M), lambda i, j: te.sum(A[i, k] * B[j, k], axis=k), name="C") s = te.create_schedule(C.op) print(tvm.lower(s, [A, B, C], simple_mode=True))

出力結果:

 @main = primfn(A_1: handle, B_1: handle, C_1: handle) -> () attr = {"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True} buffers = {A: Buffer(A_2: Pointer(float32), float32, [65536], []), B: Buffer(B_2: Pointer(float32), float32, [32768], []), C: Buffer(C_2: Pointer(float32), float32, [524288], [])} buffer_map = {A_1: A, B_1: B, C_1: C} preflattened_buffer_map = {A_1: A_3: Buffer(A_2, float32, [1024, 64], []), B_1: B_3: Buffer(B_2, float32, [512, 64], []), C_1: C_3: Buffer(C_2, float32, [1024, 512], [])} { for (i: int32, 0, 1024) { for (j: int32, 0, 512) { C[((i*512) + j)] = 0f32 for (k: int32, 0, 64) { let cse_var_1: int32 = ((i*512) + j) C[cse_var_1] = (C[cse_var_1] + (A[((i*64) + k)]*B[((j*64) + k)])) } } } }

スケジュール マトマル

行列ベクトル乗算(GEMV)をハードウェアプリミティブとしてサポートするアクセラレータがあるとします。このアクセラレータは任意のサイズのreduce軸を使用できますが、もう1つの軸は16以下である必要があります。matmulループを分解し、最内ループが(16x64)GEMVになるようにする必要があります。

 factor = 16 x, y = C.op.axis (z,) = C.op.reduce_axis yo, yi = s[C].split(y, factor=factor) s[C].reorder(x, yo, yi, z) print(tvm.lower(s, [A, B, C], simple_mode=True))

出力結果:

 @main = primfn(A_1: handle, B_1: handle, C_1: handle) -> () attr = {"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True} buffers = {A: Buffer(A_2: Pointer(float32), float32, [65536], []), B: Buffer(B_2: Pointer(float32), float32, [32768], []), C: Buffer(C_2: Pointer(float32), float32, [524288], [])} buffer_map = {A_1: A, B_1: B, C_1: C} preflattened_buffer_map = {A_1: A_3: Buffer(A_2, float32, [1024, 64], []), B_1: B_3: Buffer(B_2, float32, [512, 64], []), C_1: C_3: Buffer(C_2, float32, [1024, 512], [])} { for (i: int32, 0, 1024) { for (j.outer: int32, 0, 32) { for (j.inner: int32, 0, 16) { C[(((i*512) + (j.outer*16)) + j.inner)] = 0f32 for (k: int32, 0, 64) { let cse_var_1: int32 = (((i*512) + (j.outer*16)) + j.inner) C[cse_var_1] = (C[cse_var_1] + (A[((i*64) + k)]*B[(((j.outer*1024) + (j.inner*64)) + k)])) } } } } }

上記の印刷された IR に示されているように、内側のループ j.inner は k とともに GEMV 計算を構成します。最も内側の 2 つのループでは、インデックス i は固定されており、行列 A へのアクセスは k のみに依存するため、A のアクセス パターンは「ベクトル」になります。j.inner はテンソル化できるため、ハードウェアで想定されている GEMV 命令を利用できます。

GEMVテンソル化インライン関数を定義する

テンソルをスケジュールする前に、まずGEMVのインライン関数を定義します。この関数は2つの部分から構成されます。最初の部分はGEMVの計算定義で、TVMはこれを使用して元のMatmulスケジュールの計算パターンと一致させます。2番目の部分はデバイス上でGEMVを実行する方法を指定します。これは以下のinrin_funcで実行されます。

 def intrin_gemv(m, l): a = te.placeholder((l,), name="a") b = te.placeholder((m, l), name="b") k = te.reduce_axis((0, l), name="k") c = te.compute((m,), lambda i: te.sum(a[k] * b[i, k], axis=k), name="c") Ab = tvm.tir.decl_buffer(a.shape, a.dtype, name="A", offset_factor=1, strides=[1]) Bb = tvm.tir.decl_buffer(b.shape, b.dtype, name="B", offset_factor=1, strides=[te.var("s1"), 1]) Cb = tvm.tir.decl_buffer(c.shape, c.dtype, name="C", offset_factor=1, strides=[1]) def intrin_func(ins, outs): ib = tvm.tir.ir_builder.create() aa, bb = ins cc = outs[0] ib.emit( tvm.tir.call_extern( "int32", "gemv_update", cc.access_ptr("w"), aa.access_ptr("r"), bb.access_ptr("r"), m, l, bb.strides[0], ) ) return ib.get() return te.decl_tensor_intrin(c.op, intrin_func, binds={a: Ab, b: Bb, c: Cb})

ここで、`te.decl_tensor_intrin` は計算 `c.op` の実行方法を宣言しています。実装では、入力と出力を受け取り、それらをポインタに変換し、外部関数呼び出しを提供するだけです。

テンソル化では、ユーザーが offset_factor を指定する必要があります。TVM はこの情報を使用して、データが元のデータ構造の開始アドレスとテンソル化に渡されたオフセットの間にアライメントされているかどうかを判断します。これにより、ベクトル化されたロードによる最適化が可能になります。簡略化のため、係数は 1 に設定してください。

入出力バッファの宣言は必須ではありませんが、宣言することでバッファが提供する追加情報を活用できるようになります。例えば、外部関数 gemv_update に引数として bb.strides[0] を渡します。これで bb.strides[0] == l となり、後ほどより複雑なスケジュールとの違いを見ていきます。

`te.var("s1")` は B の最初のステップ次元として使用されていることに注意してください。ステップ サイズを推測できる場合 (その場合、TVM はテンソル B がコンパクトであることを認識しているので、ステップ サイズは [L, 1] になります)、このようなプレースホルダーにより、TVM は推測された値を自動的にバインドできます。

 gemv = intrin_gemv(factor, L) s[C].tensorize(yi, gemv) print(tvm.lower(s, [A, B, C], simple_mode=True))

出力結果:

 @main = primfn(A_1: handle, B_1: handle, C_1: handle) -> () attr = {"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True} buffers = {A: Buffer(A_2: Pointer(float32), float32, [65536], []), B: Buffer(B_2: Pointer(float32), float32, [32768], []), C: Buffer(C_2: Pointer(float32), float32, [524288], [])} buffer_map = {A_1: A, B_1: B, C_1: C} preflattened_buffer_map = {A_1: A_3: Buffer(A_2, float32, [1024, 64], []), B_1: B_3: Buffer(B_2, float32, [512, 64], []), C_1: C_3: Buffer(C_2, float32, [1024, 512], [])} { for (i: int32, 0, 1024) { for (j.outer: int32, 0, 32) { @tir.call_extern("gemv_update", @tir.tvm_access_ptr(@tir.type_annotation(, dtype=float32), C_2, ((i*512) + (j.outer*16)), 16, 2, dtype=handle), @tir.tvm_access_ptr(@tir.type_annotation(, dtype=float32), A_2, (i*64), 64, 1, dtype=handle), @tir.tvm_access_ptr(@tir.type_annotation(, dtype=float32), B_2, (j.outer*1024), 1024, 1, dtype=handle), 16, 64, 64, dtype=int32) } } }

yi をテンソル化することで、最も内側の2つのループが、以前に定義されたインライン関数に置き換えられました。モジュールをビルドして実行するために、外側の関数 gemv_update (デモ目的のみの GEMV の単純な実装)が定義されています。

 def gemv_impl(): cc_code = """ extern "C" int gemv_update(float *cc, float *aa, float *bb, int m, int l, int stride) { for (int i = 0; i < m; ++i) { for (int j = 0; j < l; ++j) { cc[i] += aa[j] * bb[i * stride + j]; } } return 0; } """ from tvm.contrib import utils, clang temp = utils.tempdir() ll_path = temp.relpath("temp.ll") # 从C 源代码创建LLVM ir ll_code = clang.create_llvm(cc_code, output=ll_path) return ll_code

テンソル量子化 GEMV を実行する前に、コンパイラ ディレクティブ属性 import_llvm を使用して llvm インライン asm をインポートします。

 s[C].pragma(x, "import_llvm", gemv_impl()) print(tvm.lower(s, [A, B, C], simple_mode=True))

出力結果:

 @main = primfn(A_1: handle, B_1: handle, C_1: handle) -> () attr = {"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True} buffers = {A: Buffer(A_2: Pointer(float32), float32, [65536], []), B: Buffer(B_2: Pointer(float32), float32, [32768], []), C: Buffer(C_2: Pointer(float32), float32, [524288], [])} buffer_map = {A_1: A, B_1: B, C_1: C} preflattened_buffer_map = {A_1: A_3: Buffer(A_2, float32, [1024, 64], []), B_1: B_3: Buffer(B_2, float32, [512, 64], []), C_1: C_3: Buffer(C_2, float32, [1024, 512], [])} { attr [IterVar(i: int32, (nullptr), "DataPar", "")] "pragma_import_llvm" = "; ModuleID = '/tmp/tmpnmkhyqx0/input0.cc'\nsource_filename = \"/tmp/tmpnmkhyqx0/input0.cc\"\ntarget datalayout = \"em:e-i64:64-f80:128-n8:16:32:64-S128\"\ntarget triple = \"x86_64-pc-linux-gnu\"\n\n; Function Attrs: noinline nounwind optnone uwtable\ndefine dso_local i32 @gemv_update(float*, float*, float*, i32, i32, i32) #0 {\n %7 = alloca float*, align 8\n %8 = alloca float*, align 8\n %9 = alloca float*, align 8\n %10 = alloca i32, align 4\n %11 = alloca i32, align 4\n %12 = alloca i32, align 4\n %13 = alloca i32, align 4\n %14 = alloca i32, align 4\n store float* %0, float** %7, align 8\n store float* %1, float** %8, align 8\n store float* %2, float** %9, align 8\n store i32 %3, i32* %10, align 4\n store i32 %4, i32* %11, align 4\n store i32 %5, i32* %12, align 4\n store i32 0, i32* %13, align 4\n br label %15\n\n15: ; preds = %50, %6\n %16 = load i32, i32* %13, align 4\n %17 = load i32, i32* %10, align 4\n %18 = icmp slt i32 %16, %17\n br i1 %18, label %19, label %53\n\n19: ; preds = %15\n store i32 0, i32* %14, align 4\n br label %20\n\n20: ; preds = %46, %19\n %21 = load i32, i32* %14, align 4\n %22 = load i32, i32* %11, align 4\n %23 = icmp slt i32 %21, %22\n br i1 %23, label %24, label %49\n\n24: ; preds = %20\n %25 = load float*, float** %8, align 8\n %26 = load i32, i32* %14, align 4\n %27 = sext i32 %26 to i64\n %28 = getelementptr inbounds float, float* %25, i64 %27\n %29 = load float, float* %28, align 4\n %30 = load float*, float** %9, align 8\n %31 = load i32, i32* %13, align 4\n %32 = load i32, i32* %12, align 4\n %33 = mul nsw i32 %31, %32\n %34 = load i32, i32* %14, align 4\n %35 = add nsw i32 %33, %34\n %36 = sext i32 %35 to i64\n %37 = getelementptr inbounds float, float* %30, i64 %36\n %38 = load float, float* %37, align 4\n %39 = fmul float %29, %38\n %40 = load float*, float** %7, align 8\n %41 = load i32, i32* %13, align 4\n %42 = sext i32 %41 to i64\n %43 = getelementptr inbounds float, float* %40, i64 %42\n %44 = load float, float* %43, align 4\n %45 = fadd float %44, %39\n store float %45, float* %43, align 4\n br label %46\n\n46: ; preds = %24\n %47 = load i32, i32* %14, align 4\n %48 = add nsw i32 %47, 1\n store i32 %48, i32* %14, align 4\n br label %20\n\n49: ; preds = %20\n br label %50\n\n50: ; preds = %49\n %51 = load i32, i32* %13, align 4\n %52 = add nsw i32 %51, 1\n store i32 %52, i32* %13, align 4\n br label %15\n\n53: ; preds = %15\n ret i32 0\n}\n\nattributes #0 = { noinline nounwind optnone uwtable \"correctly-rounded-divide-sqrt-fp-math\"=\"false\" \"disable-tail-calls\"=\"false\" \"less-precise-fpmad\"=\"false\" \"min-legal-vector-width\"=\"0\" \"no-frame-pointer-elim\"=\"true\" \"no-frame-pointer-elim-non-leaf\" \"no-infs-fp-math\"=\"false\" \"no-jump-tables\"=\"false\" \"no-nans-fp-math\"=\"false\" \"no-signed-zeros-fp-math\"=\"false\" \"no-trapping-math\"=\"false\" \"stack-protector-buffer-size\"=\"8\" \"target-cpu\"=\"x86-64\" \"target-features\"=\"+cx8,+fxsr,+mmx,+sse,+sse2,+x87\" \"unsafe-fp-math\"=\"false\" \"use-soft-float\"=\"false\" }\n\n!llvm.module.flags = !{!0}\n!llvm.ident = !{!1}\n\n!0 = !{i32 1, !\"wchar_size\", i32 4}\n!1 = !{!\"clang version 9.0.0-2~ubuntu18.04.2 (tags/RELEASE_900/final)\"}\n"; for (i, 0, 1024) { for (j.outer: int32, 0, 32) { @tir.call_extern("gemv_update", @tir.tvm_access_ptr(@tir.type_annotation(, dtype=float32), C_2, ((i*512) + (j.outer*16)), 16, 2, dtype=handle), @tir.tvm_access_ptr(@tir.type_annotation(, dtype=float32), A_2, (i*64), 64, 1, dtype=handle), @tir.tvm_access_ptr(@tir.type_annotation(, dtype=float32), B_2, (j.outer*1024), 1024, 1, dtype=handle), 16, 64, 64, dtype=int32) } } }

最後に、テンソル量子化バージョンを numpy.dot によって生成されたバージョンと比較し、実装が正しいことを確認します。

 func = tvm.build(s, [A, B, C], target="llvm", name="gemv") from tvm.topi.utils import get_const_tuple dtype = A.dtype dev = tvm.device("cpu", 0) a = np.random.uniform(size=get_const_tuple(A.shape)).astype(dtype) b = np.random.uniform(size=get_const_tuple(B.shape)).astype(dtype) c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=dtype), dev) func(tvm.nd.array(a, dev), tvm.nd.array(b, dev), c) tvm.testing.assert_allclose(c.numpy(), np.dot(a, bT), rtol=1e-3)

テンソライズのreduceを更新する

テンソライズの基本的な概念を学んだので、より複雑なケースを見てみましょう。

アクセラレータはベクトルと行列の乗算のみが可能で、ベクトルのサイズは 16 以下であると仮定します。このハードウェア制約を考慮すると、reduce 軸は次のように分割する必要があります。

 zo, zi = s[C].split(z, factor=factor) s[C].reorder(x, yo, zo, yi, zi)

tensorize インライン関数は現在、「body」関数を使用する代わりに、reduce 軸の一部のみをカバーするため、TVM では、reduce_reset 関数 (reduce for ループの前に呼び出される) と reduce_update 関数 (「更新」計算戦略を定義する) が必要です。

 def gemv_impl(): cc_code = """ extern "C" int gemv_update(float *cc, float *aa, float *bb, int m, int l, int stride) { for (int i = 0; i < m; ++i) { for (int j = 0; j < l; ++j) { cc[i] += aa[j] * bb[i * stride + j]; } } return 0; } extern "C" int gemv_reset(float *cc, int m) { for (int i = 0; i < m; ++i) { cc[i] = 0.0; } return 0; } """ from tvm.contrib import utils, clang temp = utils.tempdir() ll_path = temp.relpath("temp.ll") # 从C 源代码创建LLVM ir ll_code = clang.create_llvm(cc_code, output=ll_path) return ll_code def intrin_gemv(m, l): a = te.placeholder((l,), name="a") b = te.placeholder((m, l), name="b") k = te.reduce_axis((0, l), name="k") c = te.compute((m,), lambda i: te.sum(a[k] * b[i, k], axis=k), name="c") Ab = tvm.tir.decl_buffer(a.shape, a.dtype, name="A", offset_factor=1, strides=[1]) Bb = tvm.tir.decl_buffer(b.shape, b.dtype, name="B", offset_factor=1, strides=[te.var("s1"), 1]) Cb = tvm.tir.decl_buffer(c.shape, c.dtype, name="C", offset_factor=1, strides=[1]) def intrin_func(ins, outs): aa, bb = ins cc = outs[0] def _body(): ib = tvm.tir.ir_builder.create() ib.emit( tvm.tir.call_extern( "int32", "gemv_update", cc.access_ptr("w"), aa.access_ptr("r"), bb.access_ptr("r"), m, l, bb.strides[0], ) ) return ib.get() def _reduce_reset(): ib = tvm.tir.ir_builder.create() ib.emit(tvm.tir.call_extern("int32", "gemv_reset", cc.access_ptr("w"), m)) return ib.get() def _reduce_update(): return _body() return _body(), _reduce_reset(), _reduce_update() return te.decl_tensor_intrin(c.op, intrin_func, binds={a: Ab, b: Bb, c: Cb})

`intrin_func` は (body, reduce_reset, reduce_update) の3つを返すことに注意してください。テンソル化にすべての reduce 軸が含まれている場合は関数 `body()` が呼び出され、含まれていない場合は `reduce_reset()` と `reduce_update()` が一緒に使用されます。

この例では、body() と reduce_update() は同じ実装になっていますが、他のケースでは、これら2つの関数のハードウェア命令が異なる場合があります。さらに、タイリングのため、bb.strides[0] は l と異なります。

2乗された GEMV をテンソル化し、構築して、結果を確認します。

 gemv = intrin_gemv(factor, factor) s[C].tensorize(yi, gemv) s[C].pragma(yo, "import_llvm", gemv_impl()) func = tvm.build(s, [A, B, C], target="llvm", name="gemv") a = np.random.uniform(size=get_const_tuple(A.shape)).astype(dtype) b = np.random.uniform(size=get_const_tuple(B.shape)).astype(dtype) c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=dtype), dev) func(tvm.nd.array(a, dev), tvm.nd.array(b, dev), c) tvm.testing.assert_allclose(c.numpy(), np.dot(a, bT), rtol=1e-3)

要約

このチュートリアルでは、TVMのインライン関数「tensorize」の使い方を説明します。「tensorize」は、マイクロカーネルを通じて完全に最適化されたスケジューリングを実現する手段を提供します。例えば、Intel CPUのINT8量子化では、「tensorize」を使用してAVX命令を直接呼び出します。さらに、TVMをASICにコンパイルすることも可能です。詳細は「VTA: Versatile Tensor Accelerator」をご覧ください。また、このドキュメントでは、インラインアセンブリインポートの使い方も説明しており、これによりユーザーは簡単にスケジューラにアセンブリを挿入できます。

Pythonソースコードをダウンロードする: tensorize.py

Jupyter Notebook をダウンロード: tensorize.ipynb