618ZXW

[Triton チュートリアル] triton.autotune

Tritonは並列プログラミングのための言語とコンパイラです。カスタムDNN計算カーネルを効率的に記述し、最新のGPUハードウェア上で最大スループットで実行できるようにするためのPythonベースのプログラミング環境を提供するように設計されています。

Triton の中国語ドキュメントの詳細については、→ https://triton.hyper.ai/ をご覧ください。

 triton.autotune(configs, key, prune_configs_by=None, reset_to_zero=None, restore_value=None, pre_hook=None, post_hook=None, warmup=25, rep=100, use_cuda_graph=False)

triton.jit 関数を自動的に調整するために使用されるデコレータ。

 @triton.autotune(configs=[ triton.Config(kwargs={'BLOCK_SIZE': 128}, num_warps=4), triton.Config(kwargs={'BLOCK_SIZE': 1024}, num_warps=8), ], key=['x_size'] # the two above configs will be evaluated anytime 上面两个配置会随时解析# the value of x_size changes 变量x_size 的值发生了变化) @triton.jit def kernel(x_ptr, x_size, **META): BLOCK_SIZE = META['BLOCK_SIZE']

知らせ:

  • すべての構成が解決されると、カーネルは複数回実行されます。つまり、カーネルによって更新される値は複数回更新されることになります。この望ましくない動作を回避するには、`reset_to_zero`パラメータを使用します。このパラメータは、構成が実行される前に、指定されたテンソル値をゼロにリセットします。

環境変数 TRITON_PRINT_AUTOTUNING が「1」に設定されている場合、Triton は自動カーネル チューニングのたびに、自動チューニングに費やされた時間と最適な構成を含むメッセージを標準出力 (stdout) に出力します。

パラメータ:

  • configs(list[triton.Config]) - triton.Config オブジェクトのリスト。
  • key (list[str]) - 値が変更されたときにすべての構成の解析をトリガーするパラメータ名のリスト。
  • `prune_configs_by` - プルーニング設定のための関数の辞書。以下のフィールドが含まれます。
  • 'perf_model': さまざまな構成の実行時間を予測し、実行時間を返すために使用されるパフォーマンス モデル。
  • 'top_k': ベンチマークに使用する構成の数
  • 'early_config_prune' (オプション): 事前に設定をプルーニングする関数 (例: num_stages)。以下の設定を受け取ります。
    List[Config]を入力として受け取り、トリミングされた構成を返します。
  • reset_to_zero (list[str]) - 構成が解決される前にゼロにリセットされるパラメータ名のリスト。
  • restore_value (list[str]) - 構成が解決された後に値が復元されるパラメータ名のリスト。
  • `pre_hook(lambda args , reset_only) ` - カーネルが呼び出される前に呼び出される関数。このパラメータは、`reset_to_zero` および `restore_value` のデフォルトの `pre_hook` をオーバーライドします。
  • 'args': カーネルに渡される引数のリスト
  • 'reset_only': pre_hook が値のリセットにのみ使用され、対応する post_hook がないかどうかを示すブール値。
  • `post_hook(lambda args, exception)`: カーネル呼び出し後に呼び出される関数。このパラメータは、`restore_value` のデフォルトの `post_hook` をオーバーライドします。
  • 'args': カーネルに渡される引数のリスト
  • 「例外」: コンパイル時または実行時エラーが発生した場合にカーネルによって発生する例外。
  • warmup (int ) - ベンチマークに渡されるウォームアップ時間 (ミリ秒単位)。デフォルト値は 25 です。
  • rep (int) - ベンチマークに渡すテスト繰り返しの継続時間(ミリ秒単位)。デフォルト値は100です。