В классическом self-attention каждый токен смотрит на другие токены, чтобы понять, что важно в данный момент.
Внимание распределяется мгновенно:

Именно этот механизм сделал трансформеры тем, чем они стали.

Но вот в чём проблема - внимание не имеет памяти.
На каждой итерации оно переобучается заново, не зная, куда оно смотрело в прошлый раз.
Из за этого внимание может скакать, шуметь и терять контекст, особенно в длинных последовательностях.

Проблема: внимание без инерции

Представьте, что вы идёте по неровной дороге.
Если вы будете менять направление мгновенно, без инерции, вас просто будет бросать из стороны в сторону.
Точно так же и внимание в трансформере:
оно то цепляется за один токен, то внезапно переключается на другой,
порождая хаотичные изменения в градиентах и мешая стабильному обучению.

А что, если добавить вниманию немного физики?

Momentum это понятие из механики.
Если у тела есть скорость, оно не останавливается мгновенно, а плавно замедляется.
Почему бы не применить тот же принцип к вниманию?

Идея:

Пусть текущее внимание немного зависит от того, каким оно было раньше.
Не только “куда я смотрю сейчас?”,
но и “куда я смотрел мгновение назад?”.

От классического внимания к Momentum Attention

В классике:

Теперь добавим инерцию к Value-векторам:

Пояснение: Если бы я добавил инерцию к attn_scores, модель была бы вынуждена смотреть на те же самые токены, что и на прошлом шаге. Это очень жесткое ограничение. Добавляя инерцию к V, я позволяю вниманию свободно выбирать, куда смотреть на каждом шаге (Q и K новые), но информация, которую оно извлекает (V), будет смесью новой и старой.

Тогда:

То есть текущее внимание теперь частично помнит, какие значения были важны на предыдущем шаге. α (например, 0.9) задаёт вес настоящего по сравнению с прошлым

Простой пример на pytorch

import torch
import torch.nn as nn
import torch.nn.functional as F

class MomentumAttention(nn.Module):
    def __init__(self, d_model, n_heads=8, alpha=0.9):
        super().__init__()
        if d_model % n_heads != 0:
            raise ValueError("d_model должен делиться на n_heads без остатка")
            
        self.alpha = alpha
        self.n_heads = n_heads
        self.d_k = d_model // n_heads
        
        self.W_q = nn.Linear(d_model, d_model, bias=False)
        self.W_k = nn.Linear(d_model, d_model, bias=False)
        self.W_v = nn.Linear(d_model, d_model, bias=False)
        self.W_o = nn.Linear(d_model, d_model)

    def forward(self, Q, K, V, prev_V=None):

        B, T_q, D = Q.shape
        _, T_k, _ = K.shape

        # Линейные проекции и разделение на головы
        q = self.W_q(Q).view(B, T_q, self.n_heads, self.d_k).transpose(1, 2) # [B, n_heads, T_q, d_k]
        k = self.W_k(K).view(B, T_k, self.n_heads, self.d_k).transpose(1, 2) # [B, n_heads, T_k, d_k]
        v = self.W_v(V).view(B, T_k, self.n_heads, self.d_k).transpose(1, 2) # [B, n_heads, T_k, d_k]

        # Применение Momentum к векторам Value
        if prev_V is None:
            # На самом первом шаге инерции нет, используем текущее значение
            v_momentum = v
        else:
            # Совмещаем текущее значение с прошлым
            v_momentum = self.alpha * v + (1 - self.alpha) * prev_V

        # 3. Сохраняем новое состояние для следующего шага.
        # .detach() используется, чтобы градиенты не текли через всю историю состояний,
        # что превратило бы механизм в полноценный RNN и сильно усложнило бы обучение.
        new_prev_V = v_momentum.detach()

        # 4. Стандартный механизм self-attention
        attn_scores = torch.matmul(q, k.transpose(-2, -1)) / (self.d_k ** 0.5)
        attn_weights = F.softmax(attn_scores, dim=-1)
        
        # Внимание применяется к инерционным значениям v_momentum
        out = torch.matmul(attn_weights, v_momentum)

        # 5. Собираем головы вместе и пропускаем через финальный линейный слой
        out = out.transpose(1, 2).contiguous().view(B, T_q, D)
        
        return self.W_o(out), new_prev_V


# Пример модели, которая использует MomentumAttention

class AutoregressiveModel(nn.Module):
    def __init__(self, vocab_size, d_model, n_heads, alpha):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.momentum_attn = MomentumAttention(d_model, n_heads, alpha)
        self.layernorm = nn.LayerNorm(d_model)
        self.ffn = nn.Sequential(
            nn.Linear(d_model, 4 * d_model),
            nn.ReLU(),
            nn.Linear(4 * d_model, d_model)
        )
        self.out_proj = nn.Linear(d_model, vocab_size)

    def forward(self, input_ids):

        B, T = input_ids.shape
        x = self.embedding(input_ids)

        # Инициализируем состояние для всей последовательности
        prev_V_state = None
        all_step_outputs = []

        # Цикл по каждому шагу (токену) в последовательности
        for t in range(T):
            # Берем срез данных для текущего шага
            # В реальном декодере Q - это текущий токен, K и V - все предыдущие.
            # Для простоты демонстрации механизма инерции, мы используем только текущий токен
            # как Q, K, и V. Это показывает, как состояние `prev_V_state` передается.
            current_x_step = x[:, t:t+1, :] # Shape: [B, 1, D]
            
            # Вызываем слой внимания, передавая ему состояние с прошлого шага
            attn_output, prev_V_state = self.momentum_attn(
                Q=current_x_step, 
                K=current_x_step, 
                V=current_x_step, 
                prev_V=prev_V_state
            )
            
            # Стандартные блоки трансформера (residual connection, layernorm, FFN)
            h = self.layernorm(current_x_step + attn_output)
            step_output = self.ffn(h)
            all_step_outputs.append(step_output)

        # Собираем выходы со всех шагов в один тензор
        full_output = torch.cat(all_step_outputs, dim=1) # Shape: [B, T, D]
        
        # Финальная проекция в размер словаря
        logits = self.out_proj(full_output)
        return logits


# Параметры
batch_size = 4
seq_len = 10
vocab_size = 100
d_model = 64
n_heads = 8
alpha = 0.9

# Создаем модель
model = AutoregressiveModel(vocab_size, d_model, n_heads, alpha)

# Создаем случайные входные данные
input_ids = torch.randint(0, vocab_size, (batch_size, seq_len))

print(f"Входные данные (shape): {input_ids.shape}")

# Получаем выход модели
output_logits = model(input_ids)

print(f"Выходные логиты (shape): {output_logits.shape}")

# Проверка корректности размеров
assert output_logits.shape == (batch_size, seq_len, vocab_size)

print("\nМодель успешно отработала")

Что это даёт

  • Сглаживание представлений.
    Вектора V не перескакивают резко между шагами прошлое состояние частично сохраняется, что снижает турбулентность активаций.

  • Более стабильное распределение внимания.
    Модель получает эффект инерции в значениях, и внимание не скачет при малых изменениях входа. Это особенно полезно в авторегрессионных моделях, где выходы сильно зависят от предыдущего шага.

  • Облегчённое обучение.
    Так как prev_V передаётся через detach(), градиенты не текут сквозь всю историю, что предотвращает взрыв или затухание градиентов в отличие от полного RNN-подхода.

  • Простая интеграция.
    Механизм не требует изменения архитектуры он полностью совместим с обычным MultiHeadAttention и может быть вставлен в любой трансформерный блок.

Возможные минусы

  • Накопление смещения (drift).
    Если alpha слишком велико, старые состояния начинают тянуть новые векторные представления, и внимание может начать запоминать шум.

  • Сложность выбора alpha.
    Значение 0.9 подходит не всегда при быстрых изменениях контекста модель может терять реактивность (поздно реагировать на новые токены).

  • Невозможность параллелизации по времени.
    Так как состояние prev_V передаётся последовательно, обучение по всей последовательности становится менее параллельным (особенно при autoregressive setup).

  • Потенциальная инерция ошибок.
    Если модель делает ошибку на шаге t, она может частично переноситься дальше через prev_V, особенно при большом alpha.

Заключение

Momentum Attention это шаг в сторону более живых архитектур.
Мы не просто учим модель смотреть на токены,
мы учим её чувствовать движение своего внимания как будто у неё появилась инерция восприятия.

Комментарии (0)