618ZXW

[Triton チュートリアル] 低メモリドロップアウト

Tritonは並列プログラミングのための言語とコンパイラです。カスタムDNN計算カーネルを効率的に記述し、最新のGPUハードウェア上で最大スループットで実行できるようにするためのPythonベースのプログラミング環境を提供するように設計されています。

Triton の中国語ドキュメントの詳細については、→ https://triton.hyper.ai/ をご覧ください。

このチュートリアルでは、状態が単一のint32シードで構成される、メモリ効率の高いDropout実装を作成します。これは、入力と同じ形状のビットマスクテンソルで構成される従来のDropout実装とは異なります。

このプロセスでは、次のことを学習します。

  • PyTorch での Dropout のネイティブ実装の制限。
  • Triton での並列疑似乱数生成。

導入

[SRIVASTAVA 2014]で導入されたドロップアウトは、データ量が少ない状況におけるディープニューラルネットワークの性能向上に用いられる手法であり、正則化によく用いられます。ドロップアウトはベクトルを入力として受け取り、同じ形状の出力ベクトルを生成します。出力の各スカラーの確率pは0に設定され、それ以外の場合は入力から直接コピーされます。これにより、入力が1−p個のスカラーのみであっても、ネットワークは良好な性能を発揮します。

評価フェーズでは、ネットワークの能力を最大限に活用するためにpは0に設定されます。しかし、単にpを0に設定すると出力のノルムが増加し、出力のソフトマックス温度が人為的に低下する可能性があります。これを防ぐため、出力は1/(1-p)にスケーリングされ、ドロップアウト確率に関わらず一貫したノルムが確保されます。

ベースライン

まず、ベースラインの実装を見てみましょう。

 import tabulate import torch import triton import triton.language as tl @triton.jit def _dropout( x_ptr, # 输入指针x_keep_ptr, # pointer to a mask of 0s and 1s 由0 和1 组成的掩码的指针output_ptr, # pointer to the output 输出指针n_elements, # number of elements in the `x` tensor `x` 张量的元素数量p, # probability that an element of `x` is changed to zero 元素`x` 被设置为0 的概率BLOCK_SIZE: tl.constexpr, ): pid = tl.program_id(axis=0) block_start = pid * BLOCK_SIZE offsets = block_start + tl.arange(0, BLOCK_SIZE) mask = offsets < n_elements # Load data # 加载数据x = tl.load(x_ptr + offsets, mask=mask) x_keep = tl.load(x_keep_ptr + offsets, mask=mask) # The line below is the crucial part, described in the paragraph above! # 下一行是上段描述的关键部分output = tl.where(x_keep, x / (1 - p), 0.0) # Write-back output # 写回输出tl.store(output_ptr + offsets, output, mask=mask) def dropout(x, x_keep, p): output = torch.empty_like(x) assert x.is_contiguous() n_elements = x.numel() grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']), ) _dropout[grid](x, x_keep, output, n_elements, p, BLOCK_SIZE=1024) return output # Input tensor # 输入张量x = torch.randn(size=(10, )).cuda() # Dropout mask # Dropout 掩码p = 0.5 x_keep = (torch.rand(size=(10, )) > p).to(torch.int32).cuda() # output = dropout(x, x_keep=x_keep, p=p) print(tabulate.tabulate([ ["input"] + x.tolist(), ["keep mask"] + x_keep.tolist(), ["output"] + output.tolist(), ]))

外:

シードドロップアウト

上で説明したDropoutの実装はうまく機能しますが、Dropoutの状態管理は複雑になる可能性があります。特にバックプロパゲーションや再計算/チェックポイントのシナリオを考慮すると複雑になります。ここでは、以下の利点を持つ代替実装について説明します。

  1. メモリフットプリントが小さくなります。
  2. データの移動が少なくなります。
  3. カーネル関数が複数回呼び出された場合の永続的なランダム性の管理を簡素化します。

Tritonで擬似乱数を生成するのは簡単です!このチュートリアルでは、` triton.language.rand関数を使用します。この関数は、指定されたシードとint32オフセットに基づいて、範囲(0, 1)内に均一に分布するfloat32値のブロックを生成します。ただし、Tritonは必要に応じて他の乱数生成戦略も提供しています。

Triton の PRNG 実装は Philox アルゴリズムに基づいていることに注意してください (詳細については [SALMON2011] を参照してください)。

では、すべてをまとめてみましょう。

 @triton.jit def _seeded_dropout( x_ptr, output_ptr, n_elements, p, seed, BLOCK_SIZE: tl.constexpr, ): # compute memory offsets of elements handled by this instance # 计算由此实例处理的元素的内存偏移量pid = tl.program_id(axis=0) block_start = pid * BLOCK_SIZE offsets = block_start + tl.arange(0, BLOCK_SIZE) # load data from x # 从x 读取数据mask = offsets < n_elements x = tl.load(x_ptr + offsets, mask=mask) # randomly prune it # 随机修剪random = tl.rand(seed, offsets) x_keep = random > p # write-back # 写回output = tl.where(x_keep, x / (1 - p), 0.0) tl.store(output_ptr + offsets, output, mask=mask) def seeded_dropout(x, p, seed): output = torch.empty_like(x) assert x.is_contiguous() n_elements = x.numel() grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']), ) _seeded_dropout[grid](x, output, n_elements, p, seed, BLOCK_SIZE=1024) return output x = torch.randn(size=(10, )).cuda() # Compare this to the baseline - dropout mask is never instantiated! # 与基线相比- dropout 掩码从未被实例化! output = seeded_dropout(x, p=0.5, seed=123) output2 = seeded_dropout(x, p=0.5, seed=123) output3 = seeded_dropout(x, p=0.5, seed=512) print(tabulate.tabulate([ ["input"] + x.tolist(), ["output (seed = 123)"] + output.tolist(), ["output (seed = 123)"] + output2.tolist(), ["output (seed = 512)"] + output3.tolist(), ]))

外:

ミッション完了!同じシード値に対して一貫したドロップアウトマスクを適用できるTritonカーネルが完成しました。従来のドロップアウト実装と比較して、このアプローチはメモリオーバーヘッドを削減し、状態管理を簡素化します。

練習する

  1. カーネルは行列を処理できるように拡張され、シード ベクトル (行ごとに 1 つのシード) が使用されます。
  2. ストライドのサポートを追加します。
  3. (チャレンジ) シードを使用して毎回投影行列を動的に生成する、スパース ジョンソン-リンデンシュトラウス変換のカーネルを実装します。

参考文献

  • [SALMON2011] ジョン・K・サルモン、マーク・A・モラエス、ロン・O・ドロール、デビッド・E・ショー、「並列乱数:1、2、3のように簡単」、2011年
  • [SRIVASTAVA2014] Nitish Srivastava他「ドロップアウト:ニューラルネットワークの過剰適合を防ぐ簡単な方法」JMLR 2014

Jupyterノートブックをダウンロード: 04-low-memory-dropout.ipynb

Pythonソースコードをダウンロード: 04-low-memory-dropout.py

圧縮ファイルをダウンロード: 04-low-memory-dropout.zip