618ZXW

[Triton チュートリアル] 融合ソフトマックス(Fused Softmax)

Tritonは並列プログラミングのための言語とコンパイラです。カスタムDNN計算カーネルを効率的に記述し、最新のGPUハードウェア上で最大スループットで実行できるようにするためのPythonベースのプログラミング環境を提供するように設計されています。

Triton の中国語ドキュメントの詳細については、→ https://triton.hyper.ai/ をご覧ください。

このチュートリアルでは、GPU の静的ランダム アクセス メモリ (SRAM) に収まる特定の種類の行列に対して、PyTorch のネイティブ操作よりも大幅に高速な融合ソフトマックス操作を記述します。

このプロセスを通じて、次のことを学びます。

  • カーネル フュージョンは、帯域幅が制限された操作に利点をもたらします。
  • トリトンでの業務を削減します。

モチベーション

要素ごとの加算用のカスタム GPU カーネルには教育的価値はありますが、実際には大きな進歩をもたらしません。
代わりに、単純な(数値的に安定した)ソフトマックス演算を考えてみましょう。

 import torch import triton import triton.language as tl from triton.runtime import driver def naive_softmax(x): """Compute row-wise softmax of X using native pytorch使用原生PyTorch 计算X 的逐行softmax We subtract the maximum element in order to avoid overflows. Softmax is invariant to this shift.我们减去最大元素以避免溢出。Softmax 对于这种偏移是不变的。 """ # read MN elements ; write M elements # 读取MN 个元素;写入M 个元素x_max = x.max(dim=1)[0] # read MN + M elements ; write MN elements # 读取MN + M 个元素;写入MN 个元素z = x - x_max[:, None] # read MN elements ; write MN elements # 读取MN 个元素;写入MN 个元素numerator = torch.exp(z) # read MN elements ; write M elements # 读取MN 个元素;写入M 个元素denominator = numerator.sum(dim=1) # read MN + M elements ; write MN elements # 读取MN + M 个元素;写入MN 个元素ret = numerator / denominator[:, None] # in total: read 5MN + 2M elements ; wrote 3MN + 2M elements # 总计:读取5MN + 2M 个元素;写入3MN + 2M 个元素return ret

PyTorch で直接実装する場合、 y = naive_softmax(x)を計算するには、DRAM から 5MN+2M 要素を読み取り、3MN+2M 要素を書き戻す必要があります。

これは明らかに無駄です。X を一度だけ読み取り、チップ上で必要な計算をすべて実行するカスタム「融合」カーネルを使用する方がはるかに望ましいでしょう。

これには MN バイトの読み取りと書き込みのみが必要なので、理論的には約 4 倍の高速化が期待できます。

torch.jit.scriptフラグは、この「カーネル融合」を自動化することを目的としていますが、後で説明するように、まだ理想的ではありません。

計算カーネル

ソフトマックス カーネルは次のように動作します。各プログラムは入力行列 X から行セットを読み込み、プログラムの数に応じて段階的に処理し、正規化して、結果を出力 Y に書き戻します。

Triton の主な制限は、各ブロックの要素が 2 の累乗でなければならないことです。したがって、任意の入力形状を処理する場合は、各行を内部的に「パディング」し、メモリ操作を適切に保護する必要があります。

 @triton.jit def softmax_kernel(output_ptr, input_ptr, input_row_stride, output_row_stride, n_rows, n_cols, BLOCK_SIZE: tl.constexpr, num_stages: tl.constexpr): # starting row of the program # 程序起始行row_start = tl.program_id(0) row_step = tl.num_programs(0) for row_idx in tl.range(row_start, n_rows, row_step, num_stages=num_stages): # The stride represents how much we need to increase the pointer to advance 1 row # 步长表示我们需要对指针增加多少以推进1 行row_start_ptr = input_ptr + row_idx * input_row_stride # The block size is the next power of two greater than n_cols, so we can fit each # 块大小是大于n_cols 的下一个二的幂,因此我们可以适配# row in a single block # 单个块中的行col_offsets = tl.arange(0, BLOCK_SIZE) input_ptrs = row_start_ptr + col_offsets # Load the row into SRAM, using a mask since BLOCK_SIZE may be > than n_cols # 将行加载到SRAM 中,使用掩码,因为BLOCK_SIZE 可能大于n_cols mask = col_offsets < n_cols row = tl.load(input_ptrs, mask=mask, other=-float('inf')) # Subtract maximum for numerical stability # 为了数值稳定性而减去最大值row_minus_max = row - tl.max(row, axis=0) # Note that exponentiation in Triton is fast but approximate (ie, think __expf in CUDA) # 请注意,Triton 中的指数运算速度很快,但是是近似的(例如,类似于CUDA 中的__expf)。 numerator = tl.exp(row_minus_max) denominator = tl.sum(numerator, axis=0) softmax_output = numerator / denominator # Write back output to DRAM # 将输出写回DRAM output_row_start_ptr = output_ptr + row_idx * output_row_stride output_ptrs = output_row_start_ptr + col_offsets tl.store(output_ptrs, softmax_output, mask=mask)

任意の入力テンソルに対してカーネルとその (メタ) パラメータ キューを構築するためのヘルパー関数を作成できます。

 device = torch.cuda.current_device() properties = driver.active.utils.get_device_properties(device) NUM_SM = properties["multiprocessor_count"] NUM_REGS = properties["max_num_regs"] SIZE_SMEM = properties["max_shared_mem"] WARP_SIZE = properties["warpSize"] target = triton.runtime.driver.active.get_current_target() kernels = {} def softmax(x): n_rows, n_cols = x.shape # The block size of each loop iteration is the smallest power of two greater than the number of columns in `x` # 每次循环迭代的块大小是大于`x` 列数的最小二的幂BLOCK_SIZE = triton.next_power_of_2(n_cols) # Another trick we can use is to ask the compiler to use more threads per row by # increasing the number of warps (`num_warps`) over which each row is distributed. # 另一个技巧是通过增加每行分配的线程数来要求编译器使用更多的线程块(`num_warps`) # You will see in the next tutorial how to auto-tune this value in a more natural # way so you don't have to come up with manual heuristics yourself. # 将在下一个教程中看到如何以更自然的方式自动调整此值,以免自己进行手动启发式处理。 num_warps = 8 # Number of software piepling stages. # 软件流水线阶段的数量num_stages = 4 if SIZE_SMEM > 200000 else 2 # Allocate output # 分配输出空间y = torch.empty_like(x) # pre-compile kernel to get register usage and compute thread occupancy. # 预编译内核以获取寄存器使用情况并计算线程占用情况。 kernel, num_programs = kernels.get(BLOCK_SIZE, (None, 0)) if kernel is None: kernel = softmax_kernel.warmup(y, x, x.stride(0), y.stride(0), n_rows, n_cols, BLOCK_SIZE=BLOCK_SIZE, num_stages=num_stages, num_warps=num_warps, grid=(1, )) kernel._init_handles() n_regs = kernel.n_regs size_smem = kernel.metadata.shared occupancy = NUM_REGS // (n_regs * WARP_SIZE * num_warps) occupancy = min(occupancy, SIZE_SMEM // size_smem) num_programs = NUM_SM * occupancy kernels[BLOCK_SIZE] = (kernel, num_programs) num_programs = min(num_programs, n_rows) # Create a number of persistent programs. # 创建一些持久化程序。 kernel[(num_programs, 1, 1)]( y, x, x.stride(0), y.stride(0), n_rows, n_cols, ) return y

ユニットテスト

行と列の数が不規則な行列でカーネルをテストします。

これにより、パディング メカニズムが機能しているかどうかが確認されます。

 torch.manual_seed(0) x = torch.randn(1823, 781, device='cuda') y_triton = softmax(x) y_torch = torch.softmax(x, axis=1) assert torch.allclose(y_triton, y_torch), (y_triton, y_torch)

結果は予想通りでした。

ベンチマーク

ここでは、入力行列の列数の関数に基づいてベンチマークを行い、行数が 4096 であると仮定し、 naive_softmaxを定義します。

次にその性能を上記で定義した(1) torch.softmaxと(2) naive_softmaxと比較する。

 @triton.testing.perf_report( triton.testing.Benchmark( x_names=['N'], # argument names to use as an x-axis for the plot 用作图表x 轴的参数名x_vals=[128 * i for i in range(2, 100)], # different possible values for `x_name` `x_name` 的不同可能值line_arg='provider', # argument name whose value corresponds to a different line in the plot 参数名,其值对应于图表中不同线条line_vals=['triton', 'torch'], # possible values for `line_arg`` `line_arg` 的可能值line_names=[ "Triton", "Torch", ], # label name for the lines 线条的标签名称styles=[('blue', '-'), ('green', '-')], # line styles 线条的样式ylabel="GB/s", # label name for the y-axis y 轴的标签名称plot_name="softmax-performance", # name for the plot. Used also as a file name for saving the plot. 图表的名称,也用作保存图表的文件名args={'M': 4096}, # values for function arguments not in `x_names` and `y_name` `x_names` 和`y_name` 中未包含的函数参数的值)) def benchmark(M, N, provider): x = torch.randn(M, N, device='cuda', dtype=torch.float32) stream = torch.cuda.Stream() torch.cuda.set_stream(stream) if provider == 'torch': ms = triton.testing.do_bench(lambda: torch.softmax(x, axis=-1)) if provider == 'triton': ms = triton.testing.do_bench(lambda: softmax(x)) gbps = lambda ms: 2 * x.nelement() * x.element_size() * 1e-9 / (ms * 1e-3) return gbps(ms) benchmark.run(show_plots=True, print_data=True)








上の画像では次のことがわかります。

  • TritonはTorch JITの4倍高速です。これは、Torch JITがここで統合されていないのではないかという私たちの推測を裏付けています。
  • Triton は、読みやすく、理解しやすく、保守しやすいだけでなく、torch.softmax よりも大幅に高速です。

Jupyterノートブックをダウンロード: 02-fused-softmax.ipynb

Pythonソースコードをダウンロード: 02-fused-softmax.py

圧縮ファイルをダウンロード: 02-fused-softmax.zip