618ZXW

[TVMチュートリアル] TEDDによる可視化

Apache TVMは、CPU、GPU、そして様々な機械学習アクセラレーションチップに適した、エンドツーエンドのディープラーニング構築フレームワークです。中国語版のTVMドキュメントは、→ https://tvm.hyper.ai/ をご覧ください。

著者: 顧永峰

この記事では、TEDD (Tensor Expression Debug Display) を使用してテンソル式を視覚化する方法を紹介します。

テンソル式はスケジューリングにプリミティブを使用します。個々のプリミティブは理解しやすいですが、組み合わせると複雑になります。テンソル式には、スケジューリングプリミティブの操作モデルが導入されています。

  • 異なるスケジューリングプリミティブ間の相互作用
  • スケジューリング プリミティブが最終的なコード生成に与える影響。

操作モデルは、データフローグラフ、スケジューリングツリー、およびIterVar関係グラフに基づいています。スケジューリングプリミティブは、これらの計算グラフ上で動作します。

TEDDは、指定されたスケジュールからこれら3つの計算グラフをレンダリングします。このチュートリアルでは、TEDDの使い方と、レンダリングされた計算グラフの解釈方法を説明します。

 import tvm from tvm import te from tvm import topi from tvm.contrib import tedd

バイアスとReLUを使用して畳み込みを定義し、スケジュールする

Bias と ReLU を使用して、畳み込みのテンソル式の例を構築し、最初に conv2d、add、および relu TOPI を連結し、次に一般的な TOPI スケジュールを作成します。

 batch = 1 in_channel = 256 in_size = 32 num_filter = 256 kernel = 3 stride = 1 padding = "SAME" dilation = 1 A = te.placeholder((in_size, in_size, in_channel, batch), name="A") W = te.placeholder((kernel, kernel, in_channel, num_filter), name="W") B = te.placeholder((1, num_filter, 1), name="bias") with tvm.target.Target("llvm"): t_conv = topi.nn.conv2d_hwcn(A, W, stride, padding, dilation) t_bias = topi.add(t_conv, B) t_relu = topi.nn.relu(t_bias) s = topi.generic.schedule_conv2d_hwcn([t_relu])

TEDDを使用した計算グラフのレンダリング

計算グラフをレンダリングすることで、計算とそのスケジュールを確認できます。このチュートリアルをJupyter Notebookで実行している場合は、以下のコメント行を使用してSVGグラフをレンダリングし、Notebookに直接表示できます。

 tedd.viz_dataflow_graph(s, dot_file_path="/tmp/dfg.dot") # tedd.viz_dataflow_graph(s, show_svg = True)


1つ目はデータフローグラフです。各ノードはステージを表し、中央にステージ名とメモリ範囲、両側に入出力情報が表示されます。グラフのエッジはノード間の依存関係を示します。

 tedd.viz_schedule_tree(s, dot_file_path="/tmp/scheduletree.dot") # tedd.viz_schedule_tree(s, show_svg = True)

スケジューリングツリー図は上記に示されています。範囲が利用できないという警告に注意してください。これは、範囲情報を推測するために `normalize()` を呼び出す必要があることを示しています。最初のスケジューリングツリーの確認はスキップしてください。`normalize()` の前後の計算グラフを比較することで、影響を理解することをお勧めします。

 s = s.normalize() tedd.viz_schedule_tree(s, dot_file_path="/tmp/scheduletree2.dot") # tedd.viz_schedule_tree(s, show_svg = True)


2番目のスケジューリングツリーをよく見ると、ROOTの下の各ブロックはステージを表しています。ステージ名は上段に、計算内容は下段に表示されます。中央の行にはIterVarsが含まれており、外側に高い値、内側に低い値が割り当てられています。

`IterVar` 行には、そのインデックス、名前、型、およびその他のオプション情報が含まれています。`W.shared` フェーズを例にとると、最初の行は名前「W.shared」とメモリ範囲「Shared」です。計算は `W(ax0, ax1, ax2, ax3)` です。最外側のループ `IterVar` は `ax0.ax1.fused.ax2.fused.ax3.fused.outer` で、`kDataPar` 内でインデックス 0 で、`threadIdx.y` にバインドされ、範囲は (min=0, ext=8) です。

図に示すように、インデックス ボックスの色を使用して IterVar のタイプを決定することもできます。

ステージが他のステージで計算されていない場合は、そのステージにはルート ノードへの直接のエッジがあります。それ以外の場合は、中間計算ステージで rx.outer にアタッチされている W.shared など、アタッチされている IterVar を指すエッジがあります。

 tedd.viz_itervar_relationship_graph(s, dot_file_path="/tmp/itervar.dot") # tedd.viz_itervar_relationship_graph(s, show_svg = True)


最後はIterVar関係図です。各サブグラフはステージを表し、IterVarノードと変換ノードが含まれています。

例えば、W.shared には 3 つの分割ノードと 3 つの結合ノードがあります。残りのノードは、スケジュールツリーの IterVar と同じ行形式を持つ IterVar ノードです。ルート IterVar は、ax0 のようにどの変換ノードからも駆動されない IterVar です。リーフ IterVar はどの変換ノードからも駆動されず、インデックスが負でない値を持ちます。例えば、ax0.ax1.fused.ax2.fused.ax3.fused.outer で、インデックスは 0 です。

要約

このチュートリアルでは、TEDDの使い方を説明します。TOPIで構築された例を用いて、基礎となるスケジュールを示します。このスケジュールは、スケジューリングプリミティブの前後で使用して、その効果を確認することができます。

Pythonソースコードをダウンロード: tedd.py

Jupyter ノートブックをダウンロード: tedd.ipynb