618ZXW

[Triton チュートリアル] 行列の乗算

Tritonは並列プログラミングのための言語とコンパイラです。カスタムDNN計算カーネルを効率的に記述し、最新のGPUハードウェア上で最大スループットで実行できるようにするためのPythonベースのプログラミング環境を提供するように設計されています。

Triton の中国語ドキュメントの詳細については、→ https://triton.hyper.ai/ をご覧ください。

このチュートリアルでは、cuBLAS や rocBLAS のパフォーマンスに匹敵する、非常に短く高性能な FP16 行列乗算カーネルを作成します。

具体的には、次の内容を学習します。

  • ブロックレベルの行列乗算。
  • 多次元ポインタ演算。
  • L2 キャッシュ ヒット率を向上させるためにプログラムの並べ替えが実行されます。
  • 自動パフォーマンス最適化。

モチベーション

行列乗算は、ほとんどの現代の高性能コンピューティング システムの重要な構成要素です。

行列乗算は最適化が難しいため、通常、その実装はハードウェアベンダーによって、いわゆる「カーネルライブラリ」(cuBLAS など)の一部として行われます。

これらのライブラリは独自のものであることが多く、現代のディープラーニング ワークロード (融合活性化関数など) のニーズに合わせて簡単にカスタマイズすることはできません。

このチュートリアルでは、Triton を使用して、よりカスタマイズ可能かつ拡張可能な方法で効率的な行列乗算を実装する方法を学習します。

全体として、これから作成するカーネルは、(M, K) に (K, N) 行列を乗算したものを計算する次のブロック アルゴリズムを実装します。

 # Do in parallel # 并行进行for m in range(0, M, BLOCK_SIZE_M): # Do in parallel # 并行进行for n in range(0, N, BLOCK_SIZE_N): acc = zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=float32) for k in range(0, K, BLOCK_SIZE_K): a = A[m : m+BLOCK_SIZE_M, k : k+BLOCK_SIZE_K] b = B[k : k+BLOCK_SIZE_K, n : n+BLOCK_SIZE_N] acc += dot(a, b) C[m : m+BLOCK_SIZE_M, n : n+BLOCK_SIZE_N] = acc

二重ネストループの各反復は、専用の Triton プログラム インスタンスによって実行されます。

計算カーネル

実際、上記のアルゴリズムは Triton で実装するのが非常に簡単です。

主な難しさは、内側のループ内で読み取る必要があるブロックAとBのメモリ位置を計算することです。そのためには、多次元ポインタ演算が必要です。

ポインタ演算

したがって、行優先の2次元テンソルXの場合、 X[i, j]のメモリ位置は&X[i, j] = X + i*stride_xi + j*stride_xjで与えられます。

したがって、 A[m : m+BLOCK_SIZE_M, k:k+BLOCK_SIZE_K]B[k : k+BLOCK_SIZE_K, n:n+BLOCK_SIZE_N]のポインタブロックは、次のように擬似コードで定義できます。

 &A[m : m+BLOCK_SIZE_M, k:k+BLOCK_SIZE_K] = a_ptr + (m : m+BLOCK_SIZE_M)[:, None]*A.stride(0) + (k : k+BLOCK_SIZE_K)[None, :]*A.stride(1); &B[k : k+BLOCK_SIZE_K, n:n+BLOCK_SIZE_N] = b_ptr + (k : k+BLOCK_SIZE_K)[:, None]*B.stride(0) + (n : n+BLOCK_SIZE_N)[None, :]*B.stride(1);

つまり、Tritonでは、ブロックAとBへのポインタは以下のように初期化(つまりk=0)できます。また、MがBLOCK_SIZE_Mの倍数でない場合、またはN 不是BLOCK_SIZE_Nは、これを処理するために追加の剰余演算が必要になることに注意してください。この場合、結果に影響を与えない不要な値でデータを埋め込むことができます。K次元については、後ほどマスクローディングセマンティクスを用いて処理します。

 offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N offs_k = tl.arange(0, BLOCK_SIZE_K) a_ptrs = a_ptr + (offs_am[:, None]*stride_am + offs_k [None, :]*stride_ak) b_ptrs = b_ptr + (offs_k [:, None]*stride_bk + offs_bn[None, :]*stride_bn)

次に、内部ループで次の内容を更新します。

 a_ptrs += BLOCK_SIZE_K * stride_ak; b_ptrs += BLOCK_SIZE_K * stride_bk;

L2キャッシュの最適化

前述のように、各プログラム インスタンスは C の [BLOCK_SIZE_M、BLOCK_SIZE_N] ブロックを計算します。

これらのブロックが計算される順序を覚えておくことは重要です。これはプログラムの L2 キャッシュ ヒット率に影響し、単純な行優先のソートは機能しません。

 pid = tl.program_id(axis=0) grid_n = tl.cdiv(N, BLOCK_SIZE_N) pid_m = pid // grid_n pid_n = pid % grid_n

考えられる解決策の 1 つは、データの再利用を促進するシーケンスでブロックを開始することです。

これは、次の列に移動する前に GROUP_M 行のブロックを「スーパーグループ化」することで実現できます。

 # Program ID # 程序ID pid = tl.program_id(axis=0) # Number of program ids along the M axis # M 轴上程序id 的数量num_pid_m = tl # Number of programs ids along the N axis # N 轴上程序id 的数量num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) # Number of programs in group # 组中程序数量num_pid_in_group = GROUP_SIZE_M * num_pid_n # Id of the group this program is in # 本程序所在的组id group_id = pid // num_pid_in_group # Row-id of the first program in the group # 组内第一个程序的行id first_pid_m = group_id * GROUP_SIZE_M # If `num_pid_m` isn't divisible by `GROUP_SIZE_M`, the last group is smaller # 如果`num_pid_m` 不能被`GROUP_SIZE_M` 整除,最后一组会比较小group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) # *Within groups*, programs are ordered in a column-major order # 在组内,程序按列主序排序。 # Row-id of the program in the *launch grid* # 程序的行id pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) # Col-id of the program in the *launch grid* # 启动网格中的程序的行id pid_n = (pid % num_pid_in_group) // group_size_m

例えば、次の行列乗算の例では、各行列は9x9のブロックです。ご覧のとおり、行優先で出力を計算する場合、最初の9つの出力ブロックを計算するために90個のブロックをSRAMにロードする必要がありますが、グループ優先で計算する場合は54個のブロックをロードするだけで済みます。

実際、このアプローチにより、A100 などの特定のハードウェア アーキテクチャ上の行列乗算コアのパフォーマンスが大幅に向上し、パフォーマンスの向上は 10% を超えて 220 ~ 245 TFLOPS の範囲になります。

最終結果

import torch import triton import triton.language as tl def is_cuda(): return triton.runtime.driver.active.get_current_target().backend == "cuda" def is_hip_mi200(): target = triton.runtime.driver.active.get_current_target() return target.backend == 'hip' and target.arch == 'gfx90a' def get_cuda_autotune_config(): return [ triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8), triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2), triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2), # Good config for fp8 inputs. triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8), triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8), triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4) ] def get_hip_autotune_config(): return [ triton.Config( {'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 16, 'GROUP_SIZE_M': 1, 'waves_per_eu': 2}, num_warps=4, num_stages=0), triton.Config( {'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 16, 'GROUP_SIZE_M': 4, 'waves_per_eu': 2}, num_warps=8, num_stages=0), triton.Config( {'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 1, 'waves_per_eu': 2}, num_warps=8, num_stages=0), triton.Config( {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8, 'waves_per_eu': 3}, num_warps=4, num_stages=0), triton.Config( {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 1, 'waves_per_eu': 8}, num_warps=4, num_stages=0), ] def get_autotune_config(): if is_cuda(): return get_cuda_autotune_config() else: return get_hip_autotune_config() # `triton.jit`'ed functions can be auto-tuned by using the `triton.autotune` decorator, which consumes: # `triton.jit` 函数可以通过使用`triton.autotune` 装饰器进行自动调优,该装饰器接受以下内容: # - A list of `triton.Config` objects that define different configurations of # meta-parameters (eg, `BLOCK_SIZE_M`) and compilation options (eg, `num_warps`) to try # - 一组`triton.Config` 对象的列表,这些对象定义了不同的元参数配置(例如`BLOCK_SIZE_M`)和编译选项(例如`num_warps`)以供尝试。 # - An auto-tuning *key* whose change in values will trigger evaluation of all the # provided configs # - 一个自动调优的key,其值的变化将触发对所有提供的配置进行评估。 @triton.autotune( configs=get_autotune_config(), key=['M', 'N', 'K'], ) @triton.jit def matmul_kernel( # Pointers to matrices # 矩阵指针a_ptr, b_ptr, c_ptr, # Matrix dimensions # 矩阵维度M, N, K, # The stride variables represent how much to increase the ptr by when moving by 1 # element in a particular dimension. Eg `stride_am` is how much to increase `a_ptr` # by to get the element one row down (A has M rows). # 这些步幅变量表示在特定维度移动1 个元素时,`ptr` 应该增加多少。例如,`stride_am` 指示了为了访问下一行的元素(假设`A` 有`M` 行),需要增加多少`a_ptr`。 stride_am, stride_ak, # stride_bk, stride_bn, # stride_cm, stride_cn, # Meta-parameters # 元参数BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, # GROUP_SIZE_M: tl.constexpr, # ACTIVATION: tl.constexpr # ): """Kernel for computing the matmul C = A x B. A has shape (M, K), B has shape (K, N) and C has shape (M, N) """ """计算矩阵乘法C = A x B 的核心算法。其中,A 的形状为(M, K),B 的形状为(K, N),C 的形状为(M, N)。 """ # ----------------------------------------------------------- # Map program ids `pid` to the block of C it should compute. # 将程序ID `pid` 映射到它应计算的C 块。 # This is done in a grouped ordering to promote L2 data reuse. # 这是按组顺序进行的,以促进L2 数据重用。 # See above `L2 Cache Optimizations` section for details. # 详细信息请参见上述的`L2 缓存优化` 部分。 pid = tl.program_id(axis=0) num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) num_pid_in_group = GROUP_SIZE_M * num_pid_n group_id = pid // num_pid_in_group first_pid_m = group_id * GROUP_SIZE_M group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) pid_n = (pid % num_pid_in_group) // group_size_m # ---------------------------------------------------------- # Create pointers for the first blocks of A and B. # 创建A 和B 第一个块的指针# We will advance this pointer as we move in the K direction # and accumulate # 在沿着K 方向移动时,我们将推进这个指针并累加# `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers # `a_ptrs` 是一个[BLOCK_SIZE_M, BLOCK_SIZE_K] 大小的指针块# `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers # `b_ptrs` 是一个[BLOCK_SIZE_K, BLOCK_SIZE_N] 大小的指针块# See above `Pointer Arithmetic` section for details # 详细信息请参见上述的`指针算术` 部分。 offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N offs_k = tl.arange(0, BLOCK_SIZE_K) a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) # ----------------------------------------------------------- # Iterate to compute a block of the C matrix. # 迭代计算C 矩阵的一个块。 # We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block # of fp32 values for higher accuracy. # 我们累加到一个`[BLOCK_SIZE_M, BLOCK_SIZE_N]` 大小的fp32 值块,以提高精度。 # `accumulator` will be converted back to fp16 after the loop. # `accumulator` 在循环结束后将转换回fp16。 accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): # Load the next block of A and B, generate a mask by checking the K dimension. # 加载A 和B 的下一个块,通过检查K 维度生成一个掩码。 # If it is out of bounds, set it to 0. # 如果超出边界设为0 a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) # We accumulate along the K dimension. # 通过着K 维度进行累加。 accumulator = tl.dot(a, b, accumulator) # Advance the ptrs to the next K block. # 指针前进到下一个K 块。 a_ptrs += BLOCK_SIZE_K * stride_ak b_ptrs += BLOCK_SIZE_K * stride_bk # You can fuse arbitrary activation functions here # while the accumulator is still in FP32! # 在累加器仍然是FP32 的情况下,您可以在这里融合任意激活函数! if ACTIVATION == "leaky_relu": accumulator = leaky_relu(accumulator) c = accumulator.to(tl.float16) # ----------------------------------------------------------- # Write back the block of the output matrix C with masks. # 写回带有掩码的输出矩阵C 的块。 offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) tl.store(c_ptrs, c, mask=c_mask) # We can fuse `leaky_relu` by providing it as an `ACTIVATION` meta-parameter in `matmul_kernel`. # 我们可以通过在`matmul_kernel` 中将`leaky_relu` 作为`ACTIVATION` 元参数来融合`leaky_relu`。 @triton.jit def leaky_relu(x): return tl.where(x >= 0, x, 0.01 * x)これで、2つの入力テンソルのみを受け入れ、(1)形状制約をチェックし、(2)出力を割り当て、(3)上記のカーネルを開始する便利なラッパー関数を作成できます。 import torch import triton import triton.language as tl def is_cuda(): return triton.runtime.driver.active.get_current_target().backend == "cuda" def is_hip_mi200(): target = triton.runtime.driver.active.get_current_target() return target.backend == 'hip' and target.arch == 'gfx90a' def get_cuda_autotune_config(): return [ triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8), triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2), triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2), # Good config for fp8 inputs. triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8), triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8), triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4) ] def get_hip_autotune_config(): return [ triton.Config( {'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 16, 'GROUP_SIZE_M': 1, 'waves_per_eu': 2}, num_warps=4, num_stages=0), triton.Config( {'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 16, 'GROUP_SIZE_M': 4, 'waves_per_eu': 2}, num_warps=8, num_stages=0), triton.Config( {'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 1, 'waves_per_eu': 2}, num_warps=8, num_stages=0), triton.Config( {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8, 'waves_per_eu': 3}, num_warps=4, num_stages=0), triton.Config( {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 1, 'waves_per_eu': 8}, num_warps=4, num_stages=0), ] def get_autotune_config(): if is_cuda(): return get_cuda_autotune_config() else: return get_hip_autotune_config() # `triton.jit`'ed functions can be auto-tuned by using the `triton.autotune` decorator, which consumes: # `triton.jit` 函数可以通过使用`triton.autotune` 装饰器进行自动调优,该装饰器接受以下内容: # - A list of `triton.Config` objects that define different configurations of # meta-parameters (eg, `BLOCK_SIZE_M`) and compilation options (eg, `num_warps`) to try # - 一组`triton.Config` 对象的列表,这些对象定义了不同的元参数配置(例如`BLOCK_SIZE_M`)和编译选项(例如`num_warps`)以供尝试。 # - An auto-tuning *key* whose change in values will trigger evaluation of all the # provided configs # - 一个自动调优的key,其值的变化将触发对所有提供的配置进行评估。 @triton.autotune( configs=get_autotune_config(), key=['M', 'N', 'K'], ) @triton.jit def matmul_kernel( # Pointers to matrices # 矩阵指针a_ptr, b_ptr, c_ptr, # Matrix dimensions # 矩阵维度M, N, K, # The stride variables represent how much to increase the ptr by when moving by 1 # element in a particular dimension. Eg `stride_am` is how much to increase `a_ptr` # by to get the element one row down (A has M rows). # 这些步幅变量表示在特定维度移动1 个元素时,`ptr` 应该增加多少。例如,`stride_am` 指示了为了访问下一行的元素(假设`A` 有`M` 行),需要增加多少`a_ptr`。 stride_am, stride_ak, # stride_bk, stride_bn, # stride_cm, stride_cn, # Meta-parameters # 元参数BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, # GROUP_SIZE_M: tl.constexpr, # ACTIVATION: tl.constexpr # ): """Kernel for computing the matmul C = A x B. A has shape (M, K), B has shape (K, N) and C has shape (M, N) """ """计算矩阵乘法C = A x B 的核心算法。其中,A 的形状为(M, K),B 的形状为(K, N),C 的形状为(M, N)。 """ # ----------------------------------------------------------- # Map program ids `pid` to the block of C it should compute. # 将程序ID `pid` 映射到它应计算的C 块。 # This is done in a grouped ordering to promote L2 data reuse. # 这是按组顺序进行的,以促进L2 数据重用。 # See above `L2 Cache Optimizations` section for details. # 详细信息请参见上述的`L2 缓存优化` 部分。 pid = tl.program_id(axis=0) num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) num_pid_in_group = GROUP_SIZE_M * num_pid_n group_id = pid // num_pid_in_group first_pid_m = group_id * GROUP_SIZE_M group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) pid_n = (pid % num_pid_in_group) // group_size_m # ---------------------------------------------------------- # Create pointers for the first blocks of A and B. # 创建A 和B 第一个块的指针# We will advance this pointer as we move in the K direction # and accumulate # 在沿着K 方向移动时,我们将推进这个指针并累加# `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers # `a_ptrs` 是一个[BLOCK_SIZE_M, BLOCK_SIZE_K] 大小的指针块# `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers # `b_ptrs` 是一个[BLOCK_SIZE_K, BLOCK_SIZE_N] 大小的指针块# See above `Pointer Arithmetic` section for details # 详细信息请参见上述的`指针算术` 部分。 offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N offs_k = tl.arange(0, BLOCK_SIZE_K) a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) # ----------------------------------------------------------- # Iterate to compute a block of the C matrix. # 迭代计算C 矩阵的一个块。 # We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block # of fp32 values for higher accuracy. # 我们累加到一个`[BLOCK_SIZE_M, BLOCK_SIZE_N]` 大小的fp32 值块,以提高精度。 # `accumulator` will be converted back to fp16 after the loop. # `accumulator` 在循环结束后将转换回fp16。 accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): # Load the next block of A and B, generate a mask by checking the K dimension. # 加载A 和B 的下一个块,通过检查K 维度生成一个掩码。 # If it is out of bounds, set it to 0. # 如果超出边界设为0 a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) # We accumulate along the K dimension. # 通过着K 维度进行累加。 accumulator = tl.dot(a, b, accumulator) # Advance the ptrs to the next K block. # 指针前进到下一个K 块。 a_ptrs += BLOCK_SIZE_K * stride_ak b_ptrs += BLOCK_SIZE_K * stride_bk # You can fuse arbitrary activation functions here # while the accumulator is still in FP32! # 在累加器仍然是FP32 的情况下,您可以在这里融合任意激活函数! if ACTIVATION == "leaky_relu": accumulator = leaky_relu(accumulator) c = accumulator.to(tl.float16) # ----------------------------------------------------------- # Write back the block of the output matrix C with masks. # 写回带有掩码的输出矩阵C 的块。 offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) tl.store(c_ptrs, c, mask=c_mask) # We can fuse `leaky_relu` by providing it as an `ACTIVATION` meta-parameter in `matmul_kernel`. # 我们可以通过在`matmul_kernel` 中将`leaky_relu` 作为`ACTIVATION` 元参数来融合`leaky_relu`。 @triton.jit def leaky_relu(x): return tl.where(x >= 0, x, 0.01 * x)

 def matmul(a, b, activation=""): # Check constraints. # 检查约束assert a.shape[1] == b.shape[0], "Incompatible dimensions" assert a.is_contiguous(), "Matrix A must be contiguous" M, K = a.shape K, N = b.shape # Allocates output. # 分配输出c = torch.empty((M, N), device=a.device, dtype=torch.float16) # 1D launch kernel where each block gets its own program. # 1 维启动核心,其中每个块都有自己的程序。 grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), ) matmul_kernel[grid]( a, b, c, # M, N, K, # a.stride(0), a.stride(1), # b.stride(0), b.stride(1), # c.stride(0), c.stride(1), # ACTIVATION=activation # ) return cユニットテストdef matmul(a, b, activation=""): # Check constraints. # 检查约束assert a.shape[1] == b.shape[0], "Incompatible dimensions" assert a.is_contiguous(), "Matrix A must be contiguous" M, K = a.shape K, N = b.shape # Allocates output. # 分配输出c = torch.empty((M, N), device=a.device, dtype=torch.float16) # 1D launch kernel where each block gets its own program. # 1 维启动核心,其中每个块都有自己的程序。 grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), ) matmul_kernel[grid]( a, b, c, # M, N, K, # a.stride(0), a.stride(1), # b.stride(0), b.stride(1), # c.stride(0), c.stride(1), # ACTIVATION=activation # ) return c

カスタム行列乗算演算をテストし、ネイティブ Torch 実装 (cuBLAS など) と比較します。

 torch.manual_seed(0) a = torch.randn((512, 512), device='cuda', dtype=torch.float16) b = torch.randn((512, 512), device='cuda', dtype=torch.float16) triton_output = matmul(a, b) torch_output = torch.matmul(a, b) print(f"triton_output_with_fp16_inputs={triton_output}") print(f"torch_output_with_fp16_inputs={torch_output}") # Bigger tolerance for AMD MI200 devices. # 对于AMD MI200 设备,使用更大的容差。 # MI200 devices use reduced precision fp16 and bf16 and flush input and # output denormal values to zero. Detailed info is at: https://pytorch.org/docs/stable/notes/numerical_accuracy.html#reduced-precision-fp16-and-bf16-gemms-and-convolutions-on-amd-instinct-mi200-devices # MI200 设备使用降低精度的FP16 和BF16,并将输入和输出的非规格化值清零。详细信息在以下链接:https://pytorch.org/docs/stable/notes/numerical_accuracy.html#reduced-precision-fp16-and-bf16-gemms-and-convolutions-on-amd-instinct-mi200-devices rtol = 1e-2 if is_hip_mi200() else 0 if torch.allclose(triton_output, torch_output, atol=1e-2, rtol=rtol): print("✅ Triton and Torch match") else: print("❌ Triton and Torch differ") TORCH_HAS_FP8 = hasattr(torch, "float8_e5m2") if TORCH_HAS_FP8 and is_cuda(): torch.manual_seed(0) a = torch.randn((512, 512), device="cuda", dtype=torch.float16) b = torch.randn((512, 512), device="cuda", dtype=torch.float16) a = a.to(torch.float8_e5m2) # pre-transpose b for efficiency. # 提前转置b 提高效率b = bT b = b.to(torch.float8_e5m2) triton_output = matmul(a, b) torch_output = torch.matmul(a.to(torch.float16), b.to(torch.float16)) print(f"triton_output_with_fp8_inputs={triton_output}") print(f"torch_output_with_fp8_inputs={torch_output}") if torch.allclose(triton_output, torch_output, atol=0.125, rtol=0): print("✅ Triton and Torch match") else: print("❌ Triton and Torch differ")

外:

triton_output_with_fp16_inputs=テンソル([[-10.9531, -4.7109, 15.6953, ..., -28.4062, 4.3320, -26.4219],

 [ 26.8438, 10.0469, -5.4297, ..., -11.2969, -8.5312, 30.7500], [-13.2578, 15.8516, 18.0781, ..., -21.7656, -8.6406, 10.2031], ..., [ 40.2812, 18.6094, -25.6094, ..., -2.7598, -3.2441, 41.0000], [ -6.1211, -16.8281, 4.4844, ..., -21.0312, 24.7031, 15.0234], [-17.0938, -19.0000, -0.3831, ..., 21.5469, -30.2344, -13.2188]], device='cuda:0', dtype=torch.float16)

torch_output_with_fp16_inputs=テンソル([[-10.9531, -4.7109, 15.6953, ..., -28.4062, 4.3320, -26.4219],

[ 26.8438, 10.0469, -5.4297, ..., -11.2969, -8.5312, 30.7500],

[-13.2578, 15.8516, 18.0781, ..., -21.7656, -8.6406, 10.2031],

...、

[ 40.2812, 18.6094, -25.6094, ..., -2.7598, -3.2441, 41.0000],

[ -6.1211, -16.8281, 4.4844, ..., -21.0312, 24.7031, 15.0234],

[-17.0938, -19.0000, -0.3831, ..., 21.5469, -30.2344, -13.2188]], device='cuda:0', dtype=torch.float16)✅ TritonとTorchはtriton_output_with_fp8_inputs=tensor([[-21.4375, 13.1719, 6.0352, ..., 28.7031, 8.6719, -40.7500],

[ 10.0000, 37.0000, -5.5664, ..., 20.9844, 46.8125, 30.8281],

[ 19.5625, -3.0078, -20.0469, ..., -2.1309, -8.0625, 12.5625],

...、

[-18.1562, -34.1562, -27.4219, ..., -27.3906, -24.0938, -12.3516],

[ -3.3945, -8.6250, -23.6562, ..., -4.1094, -3.5332, -16.0781],

[-23.9688, -3.2637, -33.6875, ..., 17.3125, -36.6250, 25.8594]], デバイス='cuda:0', dtype=torch.float16)torch_output_with_fp8_inputs=tensor([[-21.4375, 13.1719, 6.0352, ..., 28.7031, 8.6719, -40.7500],

[ 10.0000, 37.0000, -5.5664, ..., 20.9844, 46.8125, 30.8281],

[ 19.5625, -3.0078, -20.0469, ..., -2.1309, -8.0625, 12.5625],

...、

[-18.1562, -34.1562, -27.4219, ..., -27.3906, -24.0938, -12.3516],

[ -3.3945, -8.6250, -23.6562, ..., -4.1094, -3.5332, -16.0781],

[-23.9688, -3.2637, -33.6875, ..., 17.3125, -36.6250, 25.8594]], device='cuda:0', dtype=torch.float16)✅ TritonとTorchは一致しています

ベンチマーク

このセクションでは、カーネルとcuBLASまたはrocBLASのパフォーマンスの違いを比較します。ここでは正方行列を例として使用しますが、必要に応じてスクリプトを調整することで、他の形状の行列をベンチマークできます。

 ref_lib = 'cuBLAS' if is_cuda() else 'rocBLAS' configs = [] for fp8_inputs in [False, True]: if fp8_inputs and (not TORCH_HAS_FP8 or not is_cuda()): continue configs.append( triton.testing.Benchmark( x_names=["M", "N", "K"], # Argument names to use as an x-axis for the plot 作为绘图x 轴的参数名x_vals=[128 * i for i in range(2, 33)], # Different possible values for `x_name` `x_names` 参数的不同可能值line_arg="provider", # Argument name whose value corresponds to a different line in the plot 对应绘图中不同线的参数名# Possible values for `line_arg` `line_arg` 的可能值# Don't compare to cublas for fp8 cases as torch.matmul doesn't support fp8 at the moment. 在fp8 情况下不与cuBLAS 比较,因为torch.matmul 目前不支持fp8。 line_vals=["triton"] if fp8_inputs else [ref_lib.lower(), "triton"], # Label name for the lines line_names=["Triton"] if fp8_inputs else [ref_lib, "Triton"], # Line styles styles=[("green", "-"), ("blue", "-")], ylabel="TFLOPS", # Label name for the y-axis y 轴的标签名称plot_name="matmul-performance-" + ("fp16" if not fp8_inputs else "fp8"), # Name for the plot, used also as a file 绘图名称,也用作保存绘图的文件名name for saving the plot. args={"fp8_inputs": fp8_inputs}, )) @triton.testing.perf_report(configs) def benchmark(M, N, K, provider, fp8_inputs): a = torch.randn((M, K), device='cuda', dtype=torch.float16) b = torch.randn((K, N), device='cuda', dtype=torch.float16) if TORCH_HAS_FP8 and fp8_inputs: a = a.to(torch.float8_e5m2) b = bT b = b.to(torch.float8_e5m2) quantiles = [0.5, 0.2, 0.8] if provider == ref_lib.lower(): ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.matmul(a, b), quantiles=quantiles) if provider == 'triton': ms, min_ms, max_ms = triton.testing.do_bench(lambda: matmul(a, b), quantiles=quantiles) perf = lambda ms: 2 * M * N * K * 1e-12 / (ms * 1e-3) return perf(ms), perf(max_ms), perf(min_ms) benchmark.run(show_plots=True, print_data=True)


外:

matmul-パフォーマンス-fp16:


matmul-パフォーマンス-fp8:


Jupyterノートブックをダウンロード: 03-matrix-multiplication.ipynb

Pythonソースコードをダウンロード: 03-matrix-multiplication.py

圧縮ファイルをダウンロード: 03-matrix-multiplication.zip