В классическом self-attention каждый токен смотрит на другие токены, чтобы понять, что важно в данный момент.
Внимание распределяется мгновенно:

Именно этот механизм сделал трансформеры тем, чем они стали.
Но вот в чём проблема - внимание не имеет памяти.
На каждой итерации оно переобучается заново, не зная, куда оно смотрело в прошлый раз.
Из за этого внимание может скакать, шуметь и терять контекст, особенно в длинных последовательностях.
Проблема: внимание без инерции
Представьте, что вы идёте по неровной дороге.
Если вы будете менять направление мгновенно, без инерции, вас просто будет бросать из стороны в сторону.
Точно так же и внимание в трансформере:
оно то цепляется за один токен, то внезапно переключается на другой,
порождая хаотичные изменения в градиентах и мешая стабильному обучению.
А что, если добавить вниманию немного физики?
Momentum это понятие из механики.
Если у тела есть скорость, оно не останавливается мгновенно, а плавно замедляется.
Почему бы не применить тот же принцип к вниманию?
Идея:
Пусть текущее внимание немного зависит от того, каким оно было раньше.
Не только “куда я смотрю сейчас?”,
но и “куда я смотрел мгновение назад?”.
От классического внимания к Momentum Attention
В классике:

Теперь добавим инерцию к Value-векторам:
Пояснение: Если бы я добавил инерцию к attn_scores, модель была бы вынуждена смотреть на те же самые токены, что и на прошлом шаге. Это очень жесткое ограничение. Добавляя инерцию к V, я позволяю вниманию свободно выбирать, куда смотреть на каждом шаге (Q и K новые), но информация, которую оно извлекает (V), будет смесью новой и старой.

Тогда:

То есть текущее внимание теперь частично помнит, какие значения были важны на предыдущем шаге. α (например, 0.9) задаёт вес настоящего по сравнению с прошлым
Простой пример на pytorch
import torch
import torch.nn as nn
import torch.nn.functional as F
class MomentumAttention(nn.Module):
def __init__(self, d_model, n_heads=8, alpha=0.9):
super().__init__()
if d_model % n_heads != 0:
raise ValueError("d_model должен делиться на n_heads без остатка")
self.alpha = alpha
self.n_heads = n_heads
self.d_k = d_model // n_heads
self.W_q = nn.Linear(d_model, d_model, bias=False)
self.W_k = nn.Linear(d_model, d_model, bias=False)
self.W_v = nn.Linear(d_model, d_model, bias=False)
self.W_o = nn.Linear(d_model, d_model)
def forward(self, Q, K, V, prev_V=None):
B, T_q, D = Q.shape
_, T_k, _ = K.shape
# Линейные проекции и разделение на головы
q = self.W_q(Q).view(B, T_q, self.n_heads, self.d_k).transpose(1, 2) # [B, n_heads, T_q, d_k]
k = self.W_k(K).view(B, T_k, self.n_heads, self.d_k).transpose(1, 2) # [B, n_heads, T_k, d_k]
v = self.W_v(V).view(B, T_k, self.n_heads, self.d_k).transpose(1, 2) # [B, n_heads, T_k, d_k]
# Применение Momentum к векторам Value
if prev_V is None:
# На самом первом шаге инерции нет, используем текущее значение
v_momentum = v
else:
# Совмещаем текущее значение с прошлым
v_momentum = self.alpha * v + (1 - self.alpha) * prev_V
# 3. Сохраняем новое состояние для следующего шага.
# .detach() используется, чтобы градиенты не текли через всю историю состояний,
# что превратило бы механизм в полноценный RNN и сильно усложнило бы обучение.
new_prev_V = v_momentum.detach()
# 4. Стандартный механизм self-attention
attn_scores = torch.matmul(q, k.transpose(-2, -1)) / (self.d_k ** 0.5)
attn_weights = F.softmax(attn_scores, dim=-1)
# Внимание применяется к инерционным значениям v_momentum
out = torch.matmul(attn_weights, v_momentum)
# 5. Собираем головы вместе и пропускаем через финальный линейный слой
out = out.transpose(1, 2).contiguous().view(B, T_q, D)
return self.W_o(out), new_prev_V
# Пример модели, которая использует MomentumAttention
class AutoregressiveModel(nn.Module):
def __init__(self, vocab_size, d_model, n_heads, alpha):
super().__init__()
self.embedding = nn.Embedding(vocab_size, d_model)
self.momentum_attn = MomentumAttention(d_model, n_heads, alpha)
self.layernorm = nn.LayerNorm(d_model)
self.ffn = nn.Sequential(
nn.Linear(d_model, 4 * d_model),
nn.ReLU(),
nn.Linear(4 * d_model, d_model)
)
self.out_proj = nn.Linear(d_model, vocab_size)
def forward(self, input_ids):
B, T = input_ids.shape
x = self.embedding(input_ids)
# Инициализируем состояние для всей последовательности
prev_V_state = None
all_step_outputs = []
# Цикл по каждому шагу (токену) в последовательности
for t in range(T):
# Берем срез данных для текущего шага
# В реальном декодере Q - это текущий токен, K и V - все предыдущие.
# Для простоты демонстрации механизма инерции, мы используем только текущий токен
# как Q, K, и V. Это показывает, как состояние `prev_V_state` передается.
current_x_step = x[:, t:t+1, :] # Shape: [B, 1, D]
# Вызываем слой внимания, передавая ему состояние с прошлого шага
attn_output, prev_V_state = self.momentum_attn(
Q=current_x_step,
K=current_x_step,
V=current_x_step,
prev_V=prev_V_state
)
# Стандартные блоки трансформера (residual connection, layernorm, FFN)
h = self.layernorm(current_x_step + attn_output)
step_output = self.ffn(h)
all_step_outputs.append(step_output)
# Собираем выходы со всех шагов в один тензор
full_output = torch.cat(all_step_outputs, dim=1) # Shape: [B, T, D]
# Финальная проекция в размер словаря
logits = self.out_proj(full_output)
return logits
# Параметры
batch_size = 4
seq_len = 10
vocab_size = 100
d_model = 64
n_heads = 8
alpha = 0.9
# Создаем модель
model = AutoregressiveModel(vocab_size, d_model, n_heads, alpha)
# Создаем случайные входные данные
input_ids = torch.randint(0, vocab_size, (batch_size, seq_len))
print(f"Входные данные (shape): {input_ids.shape}")
# Получаем выход модели
output_logits = model(input_ids)
print(f"Выходные логиты (shape): {output_logits.shape}")
# Проверка корректности размеров
assert output_logits.shape == (batch_size, seq_len, vocab_size)
print("\nМодель успешно отработала")
Что это даёт
Сглаживание представлений.
ВектораVне перескакивают резко между шагами прошлое состояние частично сохраняется, что снижает турбулентность активаций.Более стабильное распределение внимания.
Модель получает эффект инерции в значениях, и внимание не скачет при малых изменениях входа. Это особенно полезно в авторегрессионных моделях, где выходы сильно зависят от предыдущего шага.Облегчённое обучение.
Так какprev_Vпередаётся черезdetach(), градиенты не текут сквозь всю историю, что предотвращает взрыв или затухание градиентов в отличие от полного RNN-подхода.Простая интеграция.
Механизм не требует изменения архитектуры он полностью совместим с обычным MultiHeadAttention и может быть вставлен в любой трансформерный блок.
Возможные минусы
Накопление смещения (drift).
Еслиalphaслишком велико, старые состояния начинают тянуть новые векторные представления, и внимание может начать запоминать шум.Сложность выбора
alpha.
Значение 0.9 подходит не всегда при быстрых изменениях контекста модель может терять реактивность (поздно реагировать на новые токены).Невозможность параллелизации по времени.
Так как состояниеprev_Vпередаётся последовательно, обучение по всей последовательности становится менее параллельным (особенно при autoregressive setup).Потенциальная инерция ошибок.
Если модель делает ошибку на шаге t, она может частично переноситься дальше черезprev_V, особенно при большомalpha.
Заключение
Momentum Attention это шаг в сторону более живых архитектур.
Мы не просто учим модель смотреть на токены,
мы учим её чувствовать движение своего внимания как будто у неё появилась инерция восприятия.