618ZXW

わずか 30 行のコードで、500 万の長文テキストの推論処理が 8 倍高速化!「Tree Attention」により GPU の負荷がさらに高まります。

クロス GPU アテンション並列処理により、最大 8 倍の高速化が実現され最大 512 万のシーケンス長の推論がサポートされます。

Ring Attention の後継であるTree Attentionが登場しました。

最も重要な点は、通信ステップの数がデバイスの数に応じて線形ではなく対数的に増加することです。

つまり、ツリーアテンションの利点は、デバイスの数が増えるにつれてより顕著になります。実験では、カード枚数128枚、シーケンス長512万の設定で最大8倍の高速化が達成されました。

リングアテンションと比較すると、ピーク時のメモリ使用量も大幅に削減できます

関連するコードはGoogle JAXフレームワークをベースにしたオープンソースで、Flash Attentionと統合されています。実装に必要なコードはわずか30行です

論文が発表されると、業界からは「推論ニーズの高い大企業にとって重要」と賞賛された。

これは、ジェンセン・フアン氏の「GPUを多く買えば買うほど節約できる」という理論と一致しており、Nvidiaが再び勝利した。

注意メカニズムに関するエネルギー的視点

まず、比較に使用した、カリフォルニア大学バークレー校の Pieter Abeel 氏のチームによって提案されたリング アテンション モデルについて簡単に説明します。

リング アテンションは、以前の一連の大規模モデルを数百万のコンテキストに拡張できるようにするための鍵であると考えられており、Google Gemini 1.5 からその後の Llama 3.1 シリーズまでのいくつかのバリエーションで使用されてきました。

簡単に言えば、リングアテンションの核となる考え方は、長いシーケンスを複数のブロックに分割し、各GPUがそれぞれ1つのブロックを処理するというものです。位相的には、これはすべてのGPUがリング状に配置され、キーと値の情報を下位に渡すと同時に、前のGPUから情報を受け取ることに相当します。

計算時間がデータ転送時間よりも長い限り、このプロセスによって追加のオーバーヘッドは発生しません。

従来の近似法とは異なり、ループ アテンションは精度を失わず、完全なアテンション計算を維持します。

最新のツリー アテンション アプローチは、ブロック計算、デバイス間の並列性、精度の保持に基づいており、勾配を計算しツリー トポロジを利用することで複数の GPU 間の通信を最適化する自己アテンション エネルギー関数を提案しています

従来、クエリ ベクトルとキー ベクトルの類似性を一致させ、値ベクトルに対して加重合計を実行することに重点が置かれていました。

Tree Attention チームは、ホップフィールド ネットワークなどのエネルギーベースのモデルに関する研究を基に、注意を特定の変数に関するエネルギー関数の勾配として解釈します。

Key、Query、Value、および補助変数 ζ に依存するスカラー エネルギー関数 F が存在し、注意の結果は ζ=0 における ζ に関する F の勾配と正確に等しくなります。

自動微分などの手法を組み合わせ、エネルギーと勾配の観点から自己注意を見ると、 F が効率的に計算できる限り、自己注意も効率的に計算できることが示唆されます。

具体的には、言語モデルでは、KV ベースのデコードは次のようにエネルギー関数で表すことができます。

logsumexp と max の両方の演算は結合法則を満たしているため、最終結果に影響を与えることなく任意の順序で実行できます。

この前提の下、チームは新しい並列化アルゴリズムを設計しました。このアルゴリズムは、まず各 GPU 上で並列にローカル エネルギー関数を計算し、次にツリー状の Allreduce を通じて各場所からの結果を要約し、最後に自動微分を使用して勾配を取得することで注目出力を取得します。

プロセス全体に必要なのは、エネルギー関数の計算と同じ時間のオーバーヘッドのみで、ビデオ メモリの使用に追加の負担はほとんどありません。

Tree Attention は、その設計においてGPU クラスターの 2 レベル トポロジも最大限に活用します。つまり、同じノード内では高速 NVLink が使用され、ノード間では IB または Ethernet が使用されます。

対照的に、リング アテンションは本質的にこのトポロジには適しておらず、通信と計算を効果的にオーバーラップすることが困難であり、最終的には最も遅い相互接続帯域幅によって制約されます。

最後に、理論的には同様の戦略を使用して単一の GPU 内でプロセスを高速化できますが、現在のハードウェアではストリーミング プロセッサ (SM) 間の通信に共有メモリがまだ使用されているため、利点はそれほど大きくないことに言及する価値があります。

ただし、 NVIDIA は H100 上の SM 間のポイントツーポイント コマンドを実験的にサポートしており、これは将来のシングル カード アテンションの最適化に新たな可能性をもたらします。

最も過小評価されているAIラボの1つ

Tree Attention チームの中心メンバーは、新興 AI スタートアップ企業Zyphra出身で、 「現在最も過小評価されている AI ラボの 1 つ」と評されている。

ZyphraはエッジAIとデバイスサイドAIに注力しており、 Mambaアーキテクチャをベースにした基本モデルであるZambaをリリースしている

創設者のKrithik Puthalath氏、共著者のVasudev Shyam氏、Jonathan Pilault氏は、いずれも数学と理論物理学の学問的背景を持っています

論文の宛先:
https://arxiv.org/abs/2408.04093

参考リンク:
[1]https://x.com/ryu0000000001/s... [2]https://www.zyphra.com/post/t...