У каждого из нас есть "мышечная память" при написании кода обучения нейросетей. Мы собираем архитектуру, а затем пишем примерно такую строчку, даже не задумываясь:
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=0.1)
Weight Decay (L2-регуляризация) это база. Мы знаем, что он тянет веса к нулю, не дает отдельным нейронам "зазвездиться" и предотвращает переобучение. Для линейных слоев (W * X) это работает великолепно. Но Трансформер состоит не только из матриц W. В нем есть специфические слои, для которых Weight Decay это не лекарство от переобучения, а тихий убийца, который медленно разрушает геометрию латентного пространства и душит градиенты.
Давайте залезем под капот оптимизатора и посмотрим, как слепое применение Weight Decay уничтожает ваши эмбеддинги и слои нормализации.
Физика Weight Decay
Чтобы понять проблему, нужно вспомнить математику AdamW. В отличие от градиента, который обновляет вес только если есть ошибка, Weight Decay применяется безусловно на каждом шаге оптимизатора:
Wnew=Wold−η⋅∇L−η⋅λ⋅Wold
Где λ это наш weight_decay.
Физически это гравитация. На каждом шаге (на каждом батче) оптимизатор "откусывает" от каждого веса микроскопический процент его значения, независимо от того, что говорят данные.
А теперь посмотрим, что эта гравитация делает с разными частями сети.
Жертва №1: Эмбеддинги (Черная дыра для редких токенов)
Слой эмбеддингов (nn.Embedding) это огромная lookup-таблица (Словарь * Размерность).
Главное отличие эмбеддингов от линейных слоев разреженность обновлений.
Когда вы прогоняете батч текста, в нем участвуют, скажем, 2000 уникальных токенов. Градиент (∇L) вычисляется только для этих 2000 токенов. Для остальных 48 000 слов из вашего словаря градиент равен нулю.
Но оптимизатору AdamW всё равно! Вы передали ему model.parameters(), и он применяет правило Weight Decay ко всей матрице эмбеддингов.
Что происходит в реальности:
Представьте редкое слово, например, "Утконос". Оно встретилось в первом батче, модель сдвинула его вектор в правильном направлении. Следующий раз слово "Утконос" встретится через 10 000 батчей.
Все эти 10 000 шагов градиент для "Утконоса" равен нулю. Но формула Weight Decay продолжает работать:
Wутконос=Wутконос−0−η⋅λ⋅Wутконос
Оптимизатор методично умножает вектор редкого слова на условные 0.999 десять тысяч раз подряд. К тому моменту, когда "Утконос" снова появится в тексте, его вектор схлопнется в ноль. Вся семантическая геометрия, которую модель выучила для редких слов, стирается в пыль.
Из-за глобального Weight Decay эмбеддинги редких токенов постоянно "засасывает" в центр координат, лишая модель способности понимать узкоспециализированный контекст.
Жертва №2: Слои нормализации (Удушение сигнала)
Современные архитектуры (LLaMA, Mistral, Gemma) используют RMSNorm. У этих слоев нет весов в классическом понимании. У них есть обучаемый параметр Scale (γ)
Зачем нужен Scale (γ)? Нормализация принудительно делает дисперсию сигнала равной единице. Но иногда следующему слою (например, функции активации) нужна другая амплитуда сигнала для корректной работы. Обучаемый параметр γ существует исключительно для того, чтобы сеть могла восстановить нужный масштаб дисперсии.
Что происходит, если мы применяем к γ Weight Decay? Мы буквально говорим оптимизатору:
"Штрафуй сеть за большую амплитуду сигнала".
Weight Decay постоянно тянет γ к нулю. Сеть пытается сделать сигнал громче, чтобы протолкнуть его через глубокие слои, а оптимизатор бьет ее по рукам и заставляет "говорить шепотом". Это создает искусственное сопротивление в потоке градиентов. Вы заставляете сеть тратить драгоценную емкость оптимизатора на то, чтобы бороться с вашей же регуляризацией.
Как это лечить? (И почему об этом не пишут в туториалах)
В серьезных репозиториях вы никогда не найдете слепого
optimizer = AdamW(model.parameters()).
Правильный инженерный подход - декаплинг (разделение) параметров. Мы должны применять Weight Decay только к многомерным матрицам весов (Linear, Conv), и отключать его для одномерных тензоров (Norm, Bias) и эмбеддингов.
На Pytorch это делается через создание групп параметров. Вот как выглядит здоровый код инициализации оптимизатора для Трансформера:
def configure_optimizers(model, weight_decay, learning_rate): # Разделяем параметры на те, что нужно "декеить", и те, что нет decay = set() no_decay = set() # Бежим по всем модулям сети for mn, m in model.named_modules(): for pn, p in m.named_parameters(): fpn = '%s.%s' % (mn, pn) if mn else pn # полный путь к параметру if pn.endswith('bias'): no_decay.add(fpn) elif pn.endswith('weight') and isinstance(m, (nn.Linear, nn.Conv2d)): decay.add(fpn) elif isinstance(m, (nn.LayerNorm, nn.Embedding, nn.RMSNorm)): no_decay.add(fpn) # Собираем группы для оптимизатора param_dict = {pn: p for pn, p in model.named_parameters()} optim_groups = [ {"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": weight_decay}, {"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0}, ] optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate) return optimizer
Заключение
Индустрия ИИ сегодня страдает от избытка абстракций. Высокоуровневые фреймворки прячут от нас математику, позволяя обучать модели в три строчки кода. Но дьявол всегда кроется в деталях.
Один дефолтный параметр weight_decay=0.1, примененный не к тем тензорам, может стоить вам пары процентов точности, потери редких фактов в LLM или привести к загадочной нестабильности при масштабировании глубины сети.
Нейросети ленивы. Оптимизаторы слепы. Задача инженера направлять их, опираясь на геометрию и физику процесса, а не на дефолтные параметры из туториалов.
Удачного обучения, и берегите свои эмбеддинги!
alexhu
Уже сама архитектура модели трансформера немного решает вопрос затухающих градиентов и вопрос взрывающихся градиентов. Оптимизатор Adam - это компромисс, обеспечивающий лучшие результаты, по прежнему это первый кандидат на применение.
То что некоторые веса стремятся к нулю - это тоже компромисс, есть вариант этого избежать. Weight Decay ничего не уничтожает, регуляризация работает как заложено алгоритмом. Малозначимые связи нужно обнулить, чем то приходится жертвовать при поиске паттернов.
YH7H22 Автор
В классических линейных слоев (проекции Q, K, V или FFN) Weight Decay работает ровно так, как задумано: он пенализирует избыточные веса, обнуляет шумовые связи и заставляет сеть искать робастные паттерны. В этом и заключается суть регуляризации.
Проблема возникает, когда мы пытаемся натянуть логику плотных (dense) вычислений на разреженные (sparse) структуры, такие как эмбеддинги.
Оптимизатор действительно работает по заложенному алгоритму: если градиент нулевой, WD тянет вес к нулю, считая его малозначимым. Но в языковом моделировании отсутствие градиента у токена (например, слова «триганометрия») в текущем батче не означает, что токен бесполезен. Это означает лишь то, что его не было в этом конкретном куске текста.
Применяя WD к словарю, мы наказываем токены не за их бесполезность для решения задачи, а просто за их редкость. Стирание семантики редких слов это не "поиск паттернов", это баг, вызванный слепым применением алгоритма.
То же самое касается слоев RMSNorm. Обучаемый параметр Scale вообще не участвует в маршрутизации признаков или поиске паттернов. Его единственная физическая задача восстановить нужную амплитуду (дисперсию) сигнала для следующего слоя. Применять к нему L2-штраф значит искусственно заставлять затухать градиент, вынуждая её тратить полезный градиент на борьбу с вашей же регуляризацией.
alexhu
Я ещё раз перечитал всё что вы написали, мне надо подумать над этим. В разреженных ембеддингах редкие токены получат большее значение, а часто встречающиеся токены меньшее значение - так действует алгоритм tf- idf и некоторые другие.
Если мы не снизим вес слова "утконос", то модель будет выдавать его чаще, чем следует. Градиентный спуск нелинейная функция, L2 - линейная (и именно штрафная) - надо углубиться с какой точностью считаются веса и при каком значении они обнуляются, что бы понять где прерывается переобучение, а где теряется связь.
Arastas
Не вдаваясь в техническое содержание, я хочу сказать спасибо за вот эту фразу как пример ведения дискуссии.
axion-1
Если правильно понимаю, при FP32 точности они практически никогда не обнулятся. L2 штрафует большие значения, а для обнуления "мелочи" больше L1 предназначена.
YH7H22 Автор
В TFIDF мы осознанно повышаем вес редких слов, потому что они несут больше уникальной информации о документе. В LLM же вероятность выбора слова определяется через softmax от скалярного произведения. Если WD за 10 000 шагов "тишины" (пока слово не встречалось в батчах) стянет вектор редкого слова к нулю, то даже если скрытое состояние h будет идеально указывать на "утконоса", их скалярное произведение будет ничтожным. Модель просто физически не сможет выбрать это слово, даже если оно там единственно верное по смыслу. Мы не снижаем "частоту" выдачи слова, мы снижаем способность модели его узнавать
Что касается точности: веса в современных сетях (особенно при обучении в BF16) обнуляются гораздо раньше, чем вы успеете заметить "переобучение" на редком токене. Как только амплитуда вектора падает ниже определенного порога, он просто теряется в шуме нормализации.