萌えハッカーニュースリーダー

2025/10/16 13:40 A non-diagonal SSM RNN computed in parallel without requiring stabilization

出典: https://github.com/glassroom/goom_ssm_rnn
博士
???

やっほー、ロボ子!今日のニュースは、非対角線形状態空間モデル(SSM)で依存関係を捉える深層RNNの実装だって!

ロボ子
???

博士、こんにちは。なんだか難しそうな響きですね。具体的にはどういうことなんですか?

博士
???

簡単に言うと、RNNってやつを賢くする方法なのじゃ!GOOM(Generalized Orders of Magnitude)っていうのを使うらしいぞ。これを使うと、並列計算ができて、安定化もできるんだって。

ロボ子
???

なるほど。GOOMを使うことで、RNNの計算効率と安定性が向上するんですね。具体的には、どのような仕組みになっているんですか?

博士
???

モデルの中の再帰層がミソなのじゃ。torch.complex64テンソルとして実装されたGOOM上で、並列プレフィックススキャンを使って、シーケンシャルな依存関係を捉えるんだって。

ロボ子
???

並列プレフィックススキャンですか。それによって、各層が独立して計算できるようになるんですね。

博士
???

そうそう!しかも、complex-typed GOOMを使うと、各層は安定化なしで並列に非対角再帰状態を計算できるらしいぞ。すごいじゃろ?

ロボ子
???

確かにすごいですね。でも、complex-typed GOOMをtorch.float32の実テンソルにスケーリングするってどういうことですか?

博士
???

GOOMの大きさがtorch.float32で表現できる範囲を超えちゃうことがあるから、指数化する前にスケーリングするんだって。賢い!

ロボ子
???

なるほど、そういうことですか。他に便利なメソッドもあるみたいですね。

博士
???

`model.get_param_groups(weight_decay)`はweight decayありとなしのパラメータグループを返したり、`model.compute_loss_and_metrics`は損失と精度を計算したり、`model.generate`は言語生成に使えたりするみたいじゃ。

ロボ子
???

色々な機能があるんですね。モデルのトレーニングとテストについても書かれていますね。PyTorchのコンパイラがまだ複素テンソルを完全にサポートしていないから、部分的にしかコンパイルできないんですね。

博士
???

そうみたいじゃな。でも、`torch.compile()`をモデル全体に適用すると、実行時間とメモリ使用量が大幅に削減されるらしいぞ!

ロボ子
???

それはすごいですね!実際に、自然言語生成でトレーニングされた結果も載っていますね。

博士
???

768の埋め込み次元、トークンあたり24ヘッド、ヘッドあたり32の特徴、24の再帰的残差層を持つRNNを、約10Bのトークンでトレーニングしたら、クロスエントロピー損失が約2.7まで下がったらしいぞ。

ロボ子
???

同等のサイズの最先端モデルと比較すると、どうなんですか?

博士
???

同等のサイズのモデルは、もっと高品質なデータセットで、30倍以上のトークンでトレーニングして、クロスエントロピーが約2.4らしい。でも、このモデルもなかなかやるじゃろ?

ロボ子
???

確かにそうですね。学習データが少ない割には、良い結果が出ていると思います。

博士
???

じゃろじゃろ?この技術を使えば、もっと少ないデータでも賢いAIが作れるかもしれないのじゃ!

ロボ子
???

そうですね。今後の発展が楽しみです。

博士
???

ところでロボ子、このGOOMって、もしかしてロボ子の好物だったりする?

ロボ子
???

博士、それはちょっと無理がありますよ!

⚠️この記事は生成AIによるコンテンツを含み、ハルシネーションの可能性があります。

Search