GPT トランスフォーマーの仕組み:因果自己注意機構(Causal Self-Attention)の解説

因果マスキングと自己回帰的なトークン生成を示す、GPTデコーダー専用トランスフォーマーの詳細な技術アーキテクチャ図

GPT トランスフォーマーの仕組み:因果自己注意機構(Causal Self-Attention)の解説

近年、Generative Pre-trained Transformers(GPT)は人工知能に革命をもたらしました。コーディングアシスタントから対話型エージェントに至るまで、GPTベースのモデルは今日の最も先進的な生成アプリケーションを支えています。しかし、このテクノロジーは実際にどのように機能しているのでしょうか?

双方向に理解するのに対し、GPTは自己回帰的(autoregressive)な次のトークン予測のために設計された**デコーダー専用(Decoder-only)**のアーキテクチャです。このブログでは、GPTトランスフォーマーの仕組みを解き明かし、因果自己注意機構(causal self-attention)を深く掘り下げ、コードで実装します。


1. 自己回帰的な生成ループ

本質的に、GPTは自己回帰モデルです。これは、テキストのシーケンスを生成するために、すでに生成されたトークンを次の予測のコンテキストとして使用しながら、次のトークンを1つずつ予測することを意味します。

ワークフローは以下の手順に従います。

  1. 入力 (Input): モデルはプロンプトを受け取ります。例:"Deep learning is"
  2. 予測 (Prediction): モデルはこのプロンプトを処理し、語彙全体に対する確率分布を出力します。次のトークンをサンプリングします。例:"awesome"
  3. ループ (Loop): 新しいトークンが入力に追加され、"Deep learning is awesome"になります。このシーケンスが次のステップの入力になります。
  4. 終了 (Termination): モデルが特殊なシーケンス終了([EOS])トークンを出力するか、定義された長さ制限に達するまで、プロセスが繰り返されます。

2. 因果マスキング:デコーダーの核心

BERTのようなエンコーダー専用モデルでは、すべてのトークンが過去と未来の両方を見渡して、他のすべてのトークンに注意を向けることができます。しかし、次のトークンを予測する生成モデルにとって、トレーニング中に未来を見ることは「カンニング」になります。

モデルが未来のトークンを見るのを防ぐために、GPTは因果自己注意(Causal Self-Attention)(またはマスクされた自己注意)を使用します。

因果マスク行列

自己注意の計算中、クエリ(Queries, $Q$)とキー(Keys, $K$)の点積を取ることで、トークン間の類似度スコアを計算します。

$$\text{Scores} = QK^T$$

因果性を強制するために、対角線より上のすべての値が $-\infty$(負の無限大)に設定され、対角線上およびそれ以下の値が 0 であるマスク行列 $M$ を適用します。softmax関数を適用する前に、このマスクをスコアに加算します。

$$\text{Masked Scores} = \frac{QK^T}{\sqrt{d_k}} + M$$

$$M = \begin{pmatrix} 0 & -\infty & -\infty & \dots & -\infty \\ 0 & 0 & -\infty & \dots & -\infty \\ 0 & 0 & 0 & \dots & -\infty \\ \vdots & \vdots & \vdots & \ddots & \vdots \\ 0 & 0 & 0 & \dots & 0 \end{pmatrix}$$

softmax関数を適用すると、$e^{-\infty}$ は $0$ になります。その結果、未来のトークンに対する注意の重み(Attention Weights)は正確に0になり、未来 of トークンは現在のトークンから見えなくなります。


3. GPTレイヤーの主要なアーキテクチャブロック

GPTモデルは、スタックされたトランスフォーマーデコーダーレイヤーで構成されています。各レイヤーには、いくつかの重要なコンポーネントが含まれています。

A. 入力埋め込みと位置エンコーディング (Embeddings & Positional Encoding)

  • トークン化: 生テキストは、バイトペアエンコーディング(BPE)を使用してサブワードトークンに分割されます。
  • トークン埋め込み (Token Embeddings): 各トークンが高次元ベクトルにマッピングされます。
  • 学習された位置埋め込み: 自己注意には固有の順序感覚がないため、GPTはトークン埋め込みに学習された位置埋め込みベクトルを追加し、モデルがシーケンス内の各トークンの位置を把握できるようにします。

B. 前層正規化 (Pre-Layer Normalization / Pre-LN)

残差接続の加算の後にレイヤー正規化を適用していた元のトランスフォーマーアーキテクチャ(Post-LN)とは異なり、現代のGPTアーキテクチャは、注意層およびフィードフォワード層のにレイヤー正規化を適用します。

$$x_{l+1} = x_l + \text{Attention}(\text{LayerNorm}(x_l))$$

Pre-LNはトレーニング中の勾配を安定させ、数千億のパラメータを持つ非常に深いネットワークの安定したトレーニングを可能にします。

C. フィードフォワードネットワーク (Feed-Forward Network / FFN)

注意ブロックに続いて、各トークンの表現は、2つの線形変換と活性化関数(通常はGeLU)で構成される多層パーセプトロン(MLP)を通過します。

$$\text{FFN}(x) = \max(0, x W_1 + b_1) W_2 + b_2$$


4. サンプリングの仕組み(ロジットからトークンへ)

最終的なデコーダーブロックは、各位置の Logits と呼ばれる生スコアのベクトルを出力します。softmax関数を使用して、これらのロジットを確率に変換します。生成されるテキストのランダム性を制御するために、サンプリング中にパラメータを適用します。

  • 温度 ($T$): softmaxの前にロジットをスケーリングします。温度が低い(例: $T = 0.2$)と、モデルは決定論的かつ集中したものになり、温度が高い(例: $T = 0.8$)と、創造性と多様性が高まります。
  • Top-K: 次のトークンの選択肢を、確率の最も高い上位 $K$ 個のトークンに制限します。
  • Top-P(ニュークリアスサンプリング): 確率分布を累積し、累積確率が $P$(例: $P = 0.9$)を超える最小のトークンセットから選択します。

5. 因果自己注意機構の PyTorch 実装

以下は、因果マスキングを使用した因果自己注意機構を示す自己完結型の PyTorch 実装です。

import torch
import torch.nn as nn
import torch.nn.functional as F

class CausalSelfAttention(nn.Module):
    def __init__(self, d_model, n_heads):
        super().__init__()
        assert d_model % n_heads == 0
        self.d_model = d_model
        self.n_heads = n_heads
        self.d_k = d_model // n_heads
        
        # Query、Key、Valueの射影層
        self.q_proj = nn.Linear(d_model, d_model)
        self.k_proj = nn.Linear(d_model, d_model)
        self.v_proj = nn.Linear(d_model, d_model)
        self.out_proj = nn.Linear(d_model, d_model)
        
    def forward(self, x):
        batch_size, seq_len, d_model = x.size()
        
        # 1. 入力を Q, K, V に射影
        Q = self.q_proj(x).view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
        K = self.k_proj(x).view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
        V = self.v_proj(x).view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
        
        # 2. 生の注意スコアを計算
        scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.d_k ** 0.5)
        
        # 3. 因果マスクの作成と適用
        # 負の無限大で満たされた上三角マスク
        mask = torch.triu(torch.ones(seq_len, seq_len, device=x.device), diagonal=1).bool()
        scores = scores.masked_fill(mask, float('-inf'))
        
        # 4. Softmax により -inf は確率 0 に変換される
        attn_weights = F.softmax(scores, dim=-1)
        
        # 5. 値の加重平均を計算して出力
        context = torch.matmul(attn_weights, V)
        context = context.transpose(1, 2).contiguous().view(batch_size, seq_len, d_model)
        
        return self.out_proj(context)

# クイック検証実行
if __name__ == "__main__":
    # バッチサイズ = 1, シーケンス長 = 4, モデル次元 = 8, ヘッド数 = 2
    x = torch.randn(1, 4, 8)
    attention_layer = CausalSelfAttention(d_model=8, n_heads=2)
    output = attention_layer(x)
    print("入力形状:", x.shape)
    print("出力形状:", output.shape)

6. アーキテクチャの比較

特徴 BERT(エンコーダーのみ) GPT(デコーダーのみ) 元のトランスフォーマー
主なタスク 理解 / 抽出 生成 / 統合 翻訳 / シーケンス変換
注意タイプ 双方向自己注意 因果マスク自己注意 双方向および因果クロス注意
マスキング マスクされたトークン ([MASK]) 因果三角マスキング デコーダー内の因果マスキング
処理 シーケンス全体を一度に処理 自己回帰的なトークン生成 エンコーダーが一度処理しデコーダーが生成

結論

エンコーダーを排除し、因果マスク自己注意に完全に焦点を当てることで、GPTは生成モデルのスケーリングの道を開きました。次のトークンを予測するというシンプルなルールと、大規模な並列トレーニングの組み合わせにより、GPTモデルは論理、コード、および言語の豊かな表現を捉えることができ、現代の認知AIの基礎を築いています。


Ghaznix ブログでさらに技術的な洞察を探索する →