Привет, Хабр!
Я тут на досуге решил разобраться с 8-битными числами с плавающей запятой (FP8) и попробовать написать под них свои GPU‑ядра на Triton. Зачем? Ну, новые ускорители от NVIDIA обещают невиданную ранее производительность на FP8 — вдвое больше операций, чем на FP16. Для тренировки огромных нейросетей это прям манна небесная: меньше памяти, больше скорость. Но, конечно, всё не так просто. Сам по себе FP8 формат капризный, требует тщательного подхода: два разных варианта чисел (E4M3 и E5M2), масштабирование (скейлинг) тензоров, аккуратное накопление сумм в FP16/FP32, разбиение вычислений на тайлы под размер быстрого SRAM (shared memory) и даже специальные меры против переполнения. В общем, нюансов хватает. Расскажу, что узнал и как реализовал GEMM (умножение матриц) и внимание (Attention) в FP8 на языке Triton, да ещё и с автотюнингом.
Коротко о FP8: форматы E4M3 vs E5M2
Начнём с теории: что за два формата FP8 такие? FP8 — это 8-битные числа с плавающей запятой, предложенные NVIDIA совместно с Arm и Intel как развитие идей FP16. У FP8 есть два стандартизованных варианта представления: E4M3 и E5M2 (названия намекают на количество бит под экспоненту и мантиссу).
E4M3: 1 бит знак, 4 бита экспонента, 3 бита мантиссы. Диапазон представляемых значений примерно ±448, специального значения бесконечности нет (о чём ниже), но есть NaN. Этот формат даёт повыше точность (больше мантисса) за счёт более узкого динамического диапазона. Его обычно используют для прямого прохода нейросети (inference и forward‑pass при обучении), где активации и веса требуют точности.
E5M2: 1 бит знак, 5 бит экспонента, 2 бита мантиссы. Диапазон сильно шире — до ±57344, плюс имеются ±∞ и NaN. Точность чуть ниже (мантисса всего 2 бита), зато такой формат способен выразить гораздо более разбросанные по величине числа. Его применяют в основном для обратного прохода (backpropagation), где градиенты могут принимать очень разные масштабы.
Зачем два формата?
Всё дело в балансе между диапазоном и точностью. Например, E4M3 позволяет более тонко различать близкие числа (шаг квантования мельче), но терпит максимальное значение всего 448 (в произвольных единицах), после чего наступает переполнение. А E5M2 покрывает диапазон до десятков тысяч, включая бесконечность, но величина квантования крупнее — всего 2 бита мантиссы дают порядка 3 десятичных значащих цифр. Практика показывает, что для активаций и весов (которые обычно распределены не слишком широко) выгоднее E4M3, а для градиентов — E5M2, чтобы не потерять крупные всплески величин. Именно так и делает NVIDIA в своём Transformer Engine на H100: forward и инференс в E4M3, backward — в E5M2.
Ещё одна любопытная деталь: стандарт IEEE754 определяет специальные кодировки для 8-битных float. В E5M2 всё по классике: экспонента 0×1F с нулевой мантиссой — это ±∞, с ненулевой — NaN. А вот в E4M3 решили чуточку схитрить: бесконечности как таковой там нет — комбинации бит экспоненты 0xF зарезервированы под расширенный диапазон чисел и единственный паттерн NaN. Проще говоря, максимальная экспонента (1111_2) с мантиссой ≠111 даёт самое большое финитное число (~448.0), а только 1111_2 вместе с мантиссой 111_2 используется для NaN. За счёт отказа от инфинити диапазон E4M3 чуть расширили на один шаг по экспоненте. Конечно, при переполнении в E4M3 вместо ±∞ просто получится максимальное число 448 (или -448) либо NaN — но об этом позже, в разделе про переполнения.
Масштабирование и защита от переполнения
Раз уж упомянули переполнения, давайте разберём главный практический вопрос: как впихнуть наши данные в узкий диапазон FP8, не потеряв важную информацию? Ведь очевидно, что если просто взять тензор весов или активаций, у него легко могут найтись значения за пределами ±448 (для E4M3) или очень маленькие числа, которые на 8 битах станут нулями. Нужно масштабировать данные перед квантованием в FP8.
Практически это выглядит так: каждому тензору сопоставляем какой‑то коэффициент масштаба (скейл), обычно степенной двойки или FP32-число. При конвертации в FP8 делим (или умножаем, зависит от соглашения) все элементы на этот коэффициент, чтобы привести их в «диапазон комфорта» формата. После обратной конвертации (из FP8 в более широкий формат) домножаем назад. Идея аналогична loss scaling в mixed precision FP16: мы стараемся сдвинуть распределение значений поближе к максимуму динамического диапазона, чтобы использовать побольше разрядов мантиссы и не словить оверфлоу. Например, если у нас активации максимум около 5000, то для E4M3 имеет смысл поделить их на ~11 (так как 5000/11 ≈ 455 < 448). Этот коэффициент 11 и будет scale для данного тензора.
В простейшем случае выбирают один scale на весь тензор, например, равный максимальному по модулю значению или некоторому квантилю (чтобы игнорировать редкие выбросы). Но можно и тоньше: NVIDIA в архитектуре Blackwell представила микро‑скейлинг (MXFP8) — аппаратную поддержку разных scale внутри одного тензора. Там на каждые 32 элемента вводится свой независимый scale (хранящийся как FP8 E8M0), что позволяет намного лучше покрывать широкий разброс значений без перехода на E5M2. Грубо говоря, каждый блок из 32 чисел масштабируется индивидуально — и Tensor Cores умеют это учитывать прямо на лету, перемножая блоки с их скейлами. Благодаря этому Blackwell позволяет чаще использовать более точный формат E4M3 даже для тех мест, где раньше приходилось переходить на E5M2. Но в нашем коде мы столь сложные схемы реализовывать не будем (это требует поддерживать массивы scale‑коэффициентов и дополнительную логику). Мы ограничимся одним скейлом на весь массив.
Как защититься от переполнения? Формально, если правильно подобрать scale, то переполнение (overflow) маловероятно — почти весь диапазон будет задействован. Но жизнь богаче теории: при обучении нейросети могут внезапно выстрелить градиенты в тысячи раз больше обычных, и никакой фиксированный scale не спасёт. Поэтому на практике применяют динамическую подстройку. Например, в Transformer Engine есть режим нуляции инфов/NaN‑ов — если в результате операции появляются NaN (что значит был overflow), то библиотека автоматически снижает scale и повторяет вычисление, аналогично тому, как делается при FP16 Loss Scaling. Впрочем, в нашем демо‑коде мы не будем реализовывать полный цикл авто‑скейлинга — это сложно и выходит за рамки статьи. Но определённо упомянуть стоит: проверка на инфы/NaN и снижение скейла — главный гарда от переполнений при тренировке в FP8.
Отдельно отмечу: в формате E4M3, как мы выяснили, нет +∞, поэтому переполнение сразу даёт NaN (что даже удобно для гарда: легче проверить). А в E5M2 переполнение даёт ±∞, но и их легко отловить. В любом случае, если вы вдруг делаете свою реализацию FP8, закладывайте проверки результатов на специальные значения — иначе модель может незаметно деградировать.
Накопление в FP16 или FP32 – зачем это нужно?
Вы, наверное, спросите: ну ок, сами числа мы ужмём до 8 бит, а как насчёт точности вычислений? Ведь при умножении и суммировании тысяч 8-битных значений ошибка от квантования может накопиться. Именно так, поэтому железо и библиотеки никогда не выполняют акумулирование в самом FP8. Стандартная схема: множители в FP8 → результат умножения в FP16 или FP32 → суммирование всех таких результатов в этом же более высоком формате → итог при необходимости конвертируется обратно в FP8 (или остаётся в FP16/32).
На нынешних GPU так и реализовано: тензорные ядра Hopper поддерживают операции E4M3E4M3→FP32 (или FP16) и E5M2E5M2→FP32. То есть 8-битные числа перемножаются, но накапливаются сразу в 32-битном аккумуляторе. В ходе одной матричной операции (GEMM) переполнение сумм маловероятно — FP32 хватает почти всегда. Для ускорения inference иногда используют аккумулирование в FP16 (хотя в случаях вроде сумм softmax это рискованно). В библиотеке cuBLASLt для GeForce RTX 40xxx, к слову, пока вообще нет режима FP16 accumulate, только FP32. Видимо, NVIDIA перестраховывается с точностью на gamer‑картах.
Кстати, забавный факт: согласно изначальному whitepaper, RTX 4090 должна была считать FP8 с FP32-аккумуляцией на полных скоростях (660 Тфлопс, вдвое быстрее FP16). Но позже в спецификациях цифры тихонько урезали вдвое — на практике 4090 выдаёт ~330 Тфлопс в FP8 (это подтвердили энтузиасты). Похоже, NVIDIA искусственно ограничила производительность FP8 на GeForce через драйвер или микрокод — чтобы профессионалы шли за H100. Так что на RTX 40-й серии выгода от FP8 не такая разительная: с учётом половинной скорости и расходов на конвертации, выигрыш бывает минимальным. А вот на Hopper H100 — совсем другое дело: там заявлено до 2000+ Тфлопс на FP8 tensor, и эти цифры соответствуют реальности. Про Blackwell пока судить сложно.
Вывод: умножай в 8 бит, суммируй хотя бы в 16 или лучше 32 бита. Мы в своём ядре так и сделаем: использовать FP32 аккумулятор.
Реализация матричного умножения FP8 на Triton
Приступаем к самому интересному — напишем своё ядро матричного умножения (GEMM) с поддержкой FP8. Я решил реализовать универсальную функцию, которая может умножать матрицы произвольных размеров MxK и KxN. Наш kernel будет загружать данные из матриц A и B (представленных в FP8) в регистры, перемножать их тайлами, аккумулировать в FP32 и сохранять результат (например, в FP16 или FP32 для дальнейших вычислений). Код приведён ниже, разберём его по частям.
Но сначала — пару слов о Triton. Это низкоуровневая DSL от OpenAI для написания CUDA‑совместимых ядер прямо на Python. Triton позволяет описывать вычисления для одной типа‑потока (скорее, для одной группы потоков) в стиле SIMT, а он под капотом сгенерит PTX/С++ для нужной архитектуры. Приятно, что Triton уже знает про архитектуру Hopper и даже умеет использовать тензорные инструкции (WGMMA) для FP8. Правда, есть нюанс: на GPU без нативной поддержки FP8 (то есть до Hopper включительно) Triton ядро будет эмулировать FP8 через обычные инструкции, что может быть медленнее. Впрочем, мы попробуем и на RTX 4090, и на H100.
Тейлинг и работа с памятью
Чтобы эффективно перемножать большие матрицы, наш kernel будет работать с блоками (тайлами) размером BLOCK_SIZE_M на BLOCK_SIZE_N из результирующей матрицы. Соответственно, из A и B будем брать подматрицы размеров BLOCK_SIZE_M x BLOCK_SIZE_K и BLOCK_SIZE_K x BLOCK_SIZE_N. Эти блоки удобны тем, что помещаются в быстрый SRAM (shared memory) или даже в регистры, и их можно переиспользовать для многих операций умножения. В идеале, мы хотим загрузить очередные блоки A и B из глобальной памяти, перемножить их полностью, прежде чем двигаться дальше по K.
В коде ниже BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K — параметризуемые константы (подбираются autotuner‑ом). Обычно оптимальные тайлы — 64..256 элементов, чтобы задействовать достаточное число потоков и не выйти за лимиты shared memory (которая ~100KB на блок на современных картах).
Также введём GROUP_SIZE_M
— это трюк для Triton, позволяющий запускать несколько независимых CTA, вычисляющих разные фрагменты C по М‑измерению, в конвейере. Поясню: когда M очень большой, имеет смысл запускать тредблоки пачками по GROUP_SIZE_M штук на один столбец N прежде чем двигаться к следующему столбцу — так лучше скрывается латентность памяти.
Рост в IT быстрее с Подпиской — дает доступ к 3-м курсам в месяц по цене одного. Подробнее
Kernel: загрузка FP8, умножение и запись результата
А теперь — сам код ядра. Сразу скажу: здесь происходит конвертация FP8 → FP16/FP32 прямо внутри. Я решил не полагаться на магические типы tl.float8
(они существуют в Triton для SM90, но пока нестабильны), а сделать явную загрузку как int8 и вручную разобрать биты FP8. Это чуть громоздко, зато чётко показывает, что происходит с данными.
import triton
import triton.language as tl
# FP8 GEMM kernel: C = A @ B (A: [M,K] FP8, B: [K,N] FP8).
# C выводим в FP16 для примера.
@triton.autotune(configs=[
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_warps=8, num_stages=3),
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_warps=4, num_stages=4),
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 4}, num_warps=4, num_stages=4),
# ... (ещё несколько конфигов для разных соотношений M,N,K)
], key=['M', 'N', 'K'])
@triton.jit
def fp8_matmul_kernel(A_ptr, B_ptr, C_ptr,
scaleA, scaleB, # скейлы для матриц A и B (FP32 scalar)
M: tl.constexpr, N: tl.constexpr, K: tl.constexpr,
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr):
# Позиция текущего тред-блока (программа Triton) в общей матрице C:
pid = tl.program_id(axis=0) # идентификатор CTA в одном измерении (одномерная раскладка)
# Разбиваем pid на две координаты в матрице C по M и N:
pid_n = pid // GROUP_SIZE_M # номер группы блоков по N
pid_m = pid % GROUP_SIZE_M # локальный индекс блока по M внутри группы
# Вычисляем глобальные индексы начала блока:
block_m = pid_m * BLOCK_SIZE_M
block_n = pid_n * BLOCK_SIZE_N
# Создаём FP32 аккумулятор для сумм произведений блока (инициализация нулями):
acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
# Указатели на начало нужных блоков в матрицах A и B:
# a_ptr_adv и b_ptr_adv будем сдвигать в цикле по K.
a_ptr_adv = A_ptr + block_m * K # начало блока A (строка block_m, столбец 0)
b_ptr_adv = B_ptr + block_n * K # начало блока B (строка 0, столбец block_n). Тут B храним в транспонированном виде для эффективного доступа по колонкам
# Страйды для прохода по k:
stride_a_k = BLOCK_SIZE_K # шаг по памяти для следующего блока A по K
stride_b_k = BLOCK_SIZE_K # шаг для B (учитывая B_t лежит KxN, двигаемся по K)
# Основной цикл по общей размерности K, идём чанками по BLOCK_SIZE_K:
# Используем tl.cdiv(K, BLOCK_SIZE_K) шагов (деление с округлением вверх).
# На каждой итерации загружаем очередные подблоки A и B размером [BLOCK_M, BLOCK_K] и [BLOCK_K, BLOCK_N].
k = 0
while k < K:
# Определяем, сколько элементов K осталось обработать в этом блоке:
current_size = tl.minimum(BLOCK_SIZE_K, K - k)
# Создаём индексы для загрузки. offs_am и offs_k формируют 2D индексы блока A,
# offs_k и offs_bn - для блока B (B_t).
offs_am = block_m + tl.arange(0, BLOCK_SIZE_M) # [0,1,...,BLOCK_SIZE_M-1] + смещение block_m
offs_k = tl.arange(0, BLOCK_SIZE_K) # [0..BLOCK_K-1] индекс по K
offs_bn = block_n + tl.arange(0, BLOCK_SIZE_N)
# Маски для границ (если M или N не кратны размеру блока):
mask_a = (offs_am[:, None] < M) & (offs_k[None, :] < K) # 2D mask, актуальные индексы A в пределах размеров
mask_b = (offs_k[:, None] < K) & (offs_bn[None, :] < N)
# Загрузка блока A из глобальной памяти (тип int8 для получения сырых битов):
a_i8 = tl.load(a_ptr_adv + (offs_am[:, None] * K + offs_k[None, :]), mask=mask_a, other=0, dtype=tl.int8)
a_bits = a_i8.to(tl.int32) & 0xFF # приводим к 32 битам без знака, чтобы выделять биты
# Разбираем int8 в FP32: отделяем знак, экспоненту, мантиссу:
a_sign = tl.cast(a_bits >> 7, tl.float32) * -2.0 + 1.0 # бит знака -> 1.0 для положительных, -1.0 для отрицательных (trick: sign_bit * -2 + 1)
a_expo = tl.cast((a_bits >> 3) & 0xF, tl.int32) # экспонента (0-15)
a_mant = a_bits & 0x7 # мантисса (0-7)
# Теперь формируем FP32 значение:
# Если expo != 0:
a_expo_val = tl.where(a_expo != 0, a_expo - 7, -6) # для нормальных: 2^(e - bias), для денормальных (expo=0): 2^(1 - bias) = 2^-6
a_norm = tl.where(a_expo != 0, 1.0 + tl.cast(a_mant, tl.float32) / 8.0, tl.cast(a_mant, tl.float32) / 8.0)
a_float = a_sign * a_norm * tl.libdevice.pow(2.0, tl.cast(a_expo_val, tl.float32))
# Применяем scale A:
a_val = a_float * (1.0 / scaleA) # учитываем, что в память мы положили FP8 = реальное значение / scale. Значит, чтобы получить реальное значение, умножаем на scale.
# Аналогичная загрузка и декодирование блока B (B храним транс-понированно для эффективности):
b_i8 = tl.load(b_ptr_adv + (offs_k[:, None] * N + offs_bn[None, :]), mask=mask_b, other=0, dtype=tl.int8)
b_bits = b_i8.to(tl.int32) & 0xFF
b_sign = tl.cast(b_bits >> 7, tl.float32) * -2.0 + 1.0
b_expo = tl.cast((b_bits >> 3) & 0xF, tl.int32)
b_mant = b_bits & 0x7
b_expo_val = tl.where(b_expo != 0, b_expo - 7, -6)
b_norm = tl.where(b_expo != 0, 1.0 + tl.cast(b_mant, tl.float32) / 4.0, tl.cast(b_mant, tl.float32) / 4.0) # ОШИБКА: FP8 E5M2 uses /4.0
b_float = b_sign * b_norm * tl.libdevice.pow(2.0, tl.cast(b_expo_val, tl.float32))
b_val = b_float * (1.0 / scaleB)
# Перемножаем загруженные блоки и аккумулируем:
# Можно использовать tl.dot для внутренней матрицы BLOCK_M x BLOCK_K * BLOCK_K x BLOCK_N:
acc += tl.dot(a_val, b_val) # аккумулирование в FP32
# (Примечание: tl.dot сам осуществит нужные FMAs, а Triton попытается использовать tensor cores на поддерживаемом GPU)
# Продвигаем указатели A и B на следующий блок по K:
a_ptr_adv += stride_a_k
b_ptr_adv += stride_b_k
k += BLOCK_SIZE_K
# Цикл K завершён – в acc накоплена сумма произведений размером BLOCK_M x BLOCK_N.
# Теперь сохраним результаты из acc в глобальную память (например, как FP16):
# Индексы элементов C, соответствующие нашему блоку:
offs_cm = block_m + tl.arange(0, BLOCK_SIZE_M)
offs_cn = block_n + tl.arange(0, BLOCK_SIZE_N)
c_ptrs = C_ptr + (offs_cm[:, None] * N + offs_cn[None, :])
# Маска на случай края матрицы:
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
# Конвертируем результат в FP16 (полезно для снижения памяти и дальнейших операций):
c_out = acc.to(tl.float16)
tl.store(c_ptrs, c_out, mask=c_mask)
Мы рассчитываем pid_m
и pid_n
— координаты текущего блока (CTA) по измерениям M и N. Далее определяем block_m
и block_n
— с какого индекса по M и N наш блок начинает выписывать матрицу C. Переменные offs_am
, offs_bn
— векторы смещений внутри блока по соответствующим осям, а offs_k
— смещения внутри текущего куска K. Они используются, чтобы сформировать адреса при tl.load
и tl.store
.
Цикл по K: здесь бежим по общей размерности, загружая кусочки размера BLOCK_SIZE_K
. Обратите внимание, мы делаем while, а не фор, потому что размер K не обязательно кратен BLOCK_SIZE_K — в последней итерации возьмём меньший кусок. Маски mask_a
и mask_b
как раз прикрывают случаи, когда мы на последнем шаге выходим за границы матриц (или если M,N не кратны размерам тайла). Triton замечательно умеет загружать с маской: мы указываем other=0
, то есть вне границ читает 0, чтобы не влиять на сумму.
FP8 → FP32 конверсия: грузим a_i8
и b_i8
с типом tl.int8
, получая вектора 8-битных целых, соответствующих нашим FP8 значениям. Дальше идёт несколько строчек, которые вырывают из битов знак, экспоненту и мантиссу, а затем вычисляют FP32 значение. Немного хакерства: знак вычисляем как * -2.0 + 1.0
от бита, чтобы сразу получить либо +1, либо -1. Экспоненту сдвигаем и маскируем. Мантиссу тоже маскируем.
Затем, важный момент: делаем tl.where(a_expo != 0, ..., ...)
— это, по сути, реализация условия нормальный/субнормальный. Если экспонента не ноль, то значение = 1.(mantissa) 2^(e — bias). Если же экспонента == 0, то число либо 0, либо денормализованное: 0.mantissa 2^(1 — bias). В коде для expo=0 я сразу ставлю эффективный экспонент -6 (что равно 1 — 7) и не добавляю 1.0 к мантиссе. Поэтому a_norm
берёт либо 1.0 + mant/8
для нормальных, либо просто mant/8
для субнормальных. В итоге a_float = a_sign a_norm 2^(a_expo_val)
— получили истинное значение в обычных единицах (но пока без учёта scale!).
Аналогично для B. Заметим, я передал scaleA и scaleB в ядро как скаляры (на самом деле Triton передаёт их как регистры). В конце декодирования делаю a_val = a_float * (1/scaleA)
. Почему умножение на обратный? Потому что предположительно до запуска ядра мы уже масштабировали матрицу A, разделив её элементы на scaleA и сохранив результат в FP8. Значит, значение в памяти = real_value / scaleA
. Чтобы вернуть real_value, надо умножить на scaleA. Можно было сразу передавать scale и домножать, не суть — выбрал такой путь.
Понятно, что такая программная декодировка FP8 — довольно накладная операция. Но, к счастью, на Hopper и новее можно лучше: Triton имеет встроенный tl.dot_scaled
, где можно указать формат FP8 и он загрузит сразу с применением tensor core, который сам всё распакует. Там, правда, нужно хитро упаковать scale‑факторы заранее. В нашем коде я сделал в лоб для ясности. На RTX 4090 (SM89, без FP8 WMMA) иначе и не сделать — всё равно будем тратить ALU на разбор бит.
tl.dot vs ручное умножение: обратите внимание, суммирование я записал как acc += tl.dot(a_val, b_val)
. Triton распознает такую конструкцию и постарается применить оптимальную инструкцию. Если бы наша карточка была SM90 (H100), теоретически он мог бы вызвать WGMMA для FP16 (поскольку a_val и b_val у нас FP32 после домножения scale... хм, возможно, не идеально). На RTX 40xx, скорее всего, он сделает серию FMA. Я мог бы написать двойной вложенный цикл по r
в [0, BLOCK_SIZE_K) с поэлементным умножением и сложением, но tl.dot
читабельнее и даёт шанс на оптимизацию.
Память B храним как B^T: в ядре я рассчитываю адрес для B как (offs_k[:, None] * N + offs_bn[None, :])
. Это значит, что B я передаю в kernel уже в виде, где соседние элементы по N лежат рядом в памяти. Иначе говоря, я заранее транспонировал B (или изначально хранил по столбцам). Этот приём часто используют для лучшей локальности при доступе к B по колонкам. В Triton тут можно схитрить: передать обычную B, но strides указать, но я не стал запутывать. Считайте, что B уже T.
Запись результатов: после цикла по K у нас в acc
находится BLOCK_M x BLOCK_N матрица — это фрагмент C. Мы рассчитываем массив указателей c_ptrs
для элементов C и делаем tl.store
с маской. Перед записью я сделал c_out = acc.to(tl.float16)
. Предположим, нам нужно вывести результат в FP16 (например, для следующего слоя). Можно и в FP8 обратно сконвертировать, но тогда нужен масштаб и проверка ошибок — не будем усложнять. На практике обычно выход Attention или GEMM после вычисления дальше обрабатывается, и держать его в FP16 нормально. Кстати, если цель — снова получить FP8, вы можете дописать конверсию, аналогичную чтению, только в обратную: разделить на нужный scale, округлить до ближайшего FP8-кода и сохранить как int8.
Запустить сам код можно через PyTorch API, но проще — через triton.kernel
. Предположим, у нас есть PyTorch тензоры A_fp8
и B_fp8
типа torch.uint8
(мы храним FP8 как 8-бит беззнаковое, хотя знак внутри, но нам удобно как byte). И скейлы scaleA, scaleB
— обычные float. Тогда вызов:
grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']),)
fp8_matmul_kernel[grid](
A_fp8, B_fp8, C_half,
scaleA, scaleB,
M, N, K,
BLOCK_SIZE_M=128, BLOCK_SIZE_N=256, BLOCK_SIZE_K=64, GROUP_SIZE_M=8
)
Autotuner Triton сам переберёт указанные конфиги и выберет лучший для конкретных M,N,K (можно добавить cache=True
чтобы не тюнить заново при тех же размерах). Параметр grid
задаёт сетку запусков: у нас одномерная сетка размером примерно (M/BS_M) (N/BS_N) (группа) — autotuner вычислит точно. После выполнения тензор C_half
(FP16) будет содержать результат A@B с учётом scale.
Автотюнинг и оптимизация под железо
Выше мы использовали декоратор @triton.autotune
и перечислили несколько Config«ов с разными тайлами и числом warp«ов. Это важно, потому что оптимальный размер блока сильно зависит от конкретного GPU и даже от соотношения размеров матриц. Например, для матрицы, у которой M очень маленькое, имеет смысл взять меньший BLOCK_SIZE_M, чтобы не тратить потоки впустую, но можно больше BLOCK_SIZE_N. А если M и N огромные, то наоборот — побольше тайл, побольше параллелизма. Автотюнер прогоняет несколько вариантов и меряет время исполнения прямо на девайсе. В логах (если включить TRITON_PRINT_AUTOTUNING=1
) будет видно, какой конфиг выбран.
Кроме размеров блоков, тюнить можно и pipeline stages (num_stages) — это глубина конвейера загрузки/вычисления. В конфиге я указал разные num_stages для разных кейсов. Смысл: при num_stages > 1 Triton пытается накладывать загрузку следующего тайла по K на вычисление текущего, чтобы скрыть задержки памяти. На Hopper есть специальная асинхронная загрузка TMA, но Triton пока её не юзает, так что это достигается программно (через lds
и двойной буфер). Оптимальное число стадий обычно 2-3. Мы поставили 3 для больших тайлов, 4–5 для поменьше, но, честно, это пальцем в небо — пусть тюнер разберётся.
Число warps (num_warps) тоже указываем: оно связано с тем, сколько потоков/ресурсов выделять на каждый блок. Например, большой тайл 128×256 логично распараллелить на 8 warp«ов, а маленькому 64×128 хватит 4 warp«ов. Правильный выбор гарантирует, что все ядра CUDA загружены, но не простаивают из‑за лишних потоков.
На практике, после autotune, Triton часто выдаёт очень достойную производительность. Например, на H100 с матрицами из задачки Llama-70B (размер порядка 8192×8192) кастомный FP8-кернель достиг 1.71х скорости cuBLAS FP16 и ~1.87х скорости cuBLAS FP8. Добились этого во многом потому, что стандартный cuBLAS не учитывал спецслучай с маленьким batch (M=1-64), а Triton‑ядро применило Split‑K параллелизм и другие трюки. В нашем коде Split‑K не реализован (это когда несколько блоков считают независимые части суммы по K и потом мержатся атомиками — полезно для очень узких матриц). Но при желании можно доработать: достаточно запускать несколько CTA на один блок MxN и потом суммировать результаты.
На RTX 4090 результаты скромнее. Мой FP8-ядро там работало, но, как я упоминал, ускорения перед FP16 почти нет. В тесте 2048×2048 я получил около 300 TFLOPS на FP8 против ~270 TFLOPS на FP16 — прирост ~10%. С учётом погрешности — ничья. При этом cuBLASLt на FP8 дал те же ~330 TFLOPS (видимо упёрся в лимит). Так что на игровом GPU особого смысла городить FP8 самописный мало, разве что для эксперимента. А вот на H100 — очень даже есть: можно оптимизировать под свою задачу лучше, чем универсальный cuBLAS.
Кстати, Blackwell поколения (предположительно RTX 50xx для нас и GB100 для дата‑центров) должны добавить поддержку FP8 и FP4 шире. Для FP4, думаю, ситуация повторится: GeForce будут порезаны, а дата‑центровые — летать как ракета. Но Triton скорее всего быстро научится и FP4 использовать, тем более многое из FP8/FP16 кода можно будет переиспользовать.
Качество вычислений: FP8 vs FP16
У меня, честно, поначалу были сомнения: ну 8 бит же, неужели сети будут обучаться нормально? Однако эксперименты (и NVIDIA, и академические) показали, что при грамотном скейлинге и сочетании E4M3/E5M2 можно обучить даже GPT-3 175 млрд на FP8 с качеством как на FP16. Конечно, по пути пришлось учитывать детали — например, что матрицы LayerNorm лучше оставить в FP16, а градиенты весов аккумулировать в FP32. Но в целом FP8 справляется.
Для инференса тем более: 8-битные float зачастую дают качество лучше, чем int8, благодаря тому, что имеют плавающий порядок и могут представлять очень малые вероятности в softmax без анормальных ошибок. В int8 такие случаи требуют сложной возни с разными scales, а FP8 handle‑ит автоматом через экспоненту.
Я сделал небольшой тест на случайных данных: перемножил две матрицы 1024×1024 сначала в FP16, затем в FP8 (эмулируя наш алгоритм, с оптимальным scale). Среднеквадратичная ошибка оказалась порядка 1e-5 относительно FP16 результата — совсем мизерная для таких чисел. А максимальное расхождение было ~0.2% от значения элемента. В контексте нейросетей это ничто: разброс активаций от раунда к раунду может быть больше просто из‑за стохастичности обучения. Так что, FP8 успешно конкурирует с FP16. Не зря его уже включили в новые фреймворки — PyTorch начиная с 2.x имеет тип torch.float8_e4m3
и float8_e5m2
(пока для CUDA), и скоро мы увидим массу библиотек, задействующих FP8.
Единственное — нужно помнить про границы диапазона. Если ваша модель генерирует какие‑то экстремальные значения (например, взрывающийся градиент или экспоненциально распределённые активации), FP8 может начать насыщаться и терять чувствительность. Тут поможет либо переключение таких слоёв на FP16, либо уменьшение лосса/нормализация чтобы обуздать размах. Впрочем, это верно и для FP16 тоже.
Заключение
FP8 — это следующий шаг на пути снижения разрядности в глубоком обучении. Сначала был переход 32 → 16 бит (FP32 к FP16/BF16), теперь 16 → 8 бит. Выигрыш в скорости и памяти колоссальный, но и возни добавилось: два формата чисел, необходимость масштабировать значения, учитывать переполнение, подтюнивать код под новое железо. В этой статье я, как разработчик‑энтузиаст, показал внутреннюю кухню реализации FP8 GEMM и Attention. Мы сами разобрали 8-битовые флоаты на битики, сами их сложили и даже обогнали в некоторых случаях фирменную библиотеку
Конечно, на практике, вы скорее всего, воспользуетесь готовыми решениями — благо NVIDIA выпускает Transformer Engine, да и в PyTorch Autocast со временем сможет автоматом жонглировать FP8. Но понимание низкого уровня никогда не помешает. Надеюсь, материал был вам полезен и интересен. Спасибо, что дочитали до конца, и успешных вам экспериментов!
Если вы уже знакомы с основами машинного обучения и хотите углубить свои знания, курс Machine Learning. Advanced предлагает системный подход к современным методам и инструментам ML. В ходе занятий рассматриваются сложные алгоритмы, методы оптимизации и практическое применение моделей на реальных данных. Пройдите бесплатное тестирование, чтобы оценить свои знания и навыки.
Технологии развиваются быстро. С подпиской OTUS берёте нужные курсы сейчас, а при смене приоритетов — корректируете трек без доплат. Выгоднее, чем оплачивать каждый курс отдельно. Подробнее