Это лето обрадовало нас прорывом в обработке изображений с помощью нейросетей. Одна за другой выходят такие модели как Flux.1 Kontext, Qwen-Image-Edit, Gemini 2.4 Flash Image Preview (Nano Banana) демонстрируя недостижимый до сих пор уровень манипуляции цифровым контентом. Это не замена Фотошопу, а технология, открывающая врата в бесконечные визуальные миры и всё благодаря мощи Diffusion Transformer (DiT) архитектуры. Впечатлившись, я решил поближе познакомиться с диффузными трансформерами - собственноручно натренировать свою собственную DiT-модель. Об этом и будет эта статья.

Но начать стоит с малого.

Базовая модель

Как вообще работают эти диффузные модели? Есть нейросеть, которая принимает на вход зашумлённое изображение, а на выходе выдаёт шум (noise). И спрашивается зачем нам этот шум? А затем что теперь имея шумное изображение и предсказанный шум мы можем вычесть предсказанный шум из изображения и получить изображение с меньшим количеством шума. Я только что сказал "с меньшим количеством шума", но на самом деле это не так. На самом деле всё сложнее.

Небольшое отступление. Для создания модели я буду использовать библиотеку PyTorch. Все термины такие как тензор (tensor), батч (batch), измерение (dimension), шейп (shape) оттуда. Ожидаю от читателей хотя бы поверхностного понимания.

А что моделируем-то?

Вы не задумывались что вообще моделируют диффузные модели? А моделируют они трансформацию нормального распределения в целевое распределение. Возникает логичный вопрос: что ещё за "целевое распределение"? Легче всего это проиллюстрировать на примере двух измерений. Вот несколько сэмплов (точек на координатной плоскости) из нормального распределения. Отсюда и далее, когда я пишу "сэмпл", то имею ввиду двумерный тензор, который можно представить как точку на координатной плоскости:

https://i.imgur.com/gkUodK0.png
https://i.imgur.com/gkUodK0.png

Сэмплы из нормального распределения получить очень просто - вызываем torch.randn(2) сколько надо раз и всё.

А вот так на координатной плоскости могли бы выглядеть сэмплы из целевого распределения:

Это распределение можно определить как "все точки лежащие на кривых, образующих рисунок котёнка". В отличие от нормального распределения у нас нет готовой формулы, которая бы могла вытаскивать точки (сэмплы) именно из этого распределения.

И тут всем любителям котов на помощь приходят нейросети. Всего то и надо, что обучить функцию (модель), которая принимает на вход сэмпл из нормального распределения, а возвращает сэмпл уже из целевого распределения. Вот только такая постановка задачи - возвращать сэмплы целевого распределения - слишком сложная. Поэтому немного перефразируем:

Принимая на вход сэмпл из нормального распределения возвращать вектор направления, двигаясь по которому мы достигнем целевого распределения.

Получается что, например, для сэмпла (0.1, -0.5) наша модель предскажет вектор (1.05, 0.46). Теперь сложим этот вектор с исходным сэмплом. Получаем точку на целевом распределении. И так будет работать для каждого сэмпла из нормального распределения.

То что я только что описал (предсказание вектора) - это вариант диффузии называемый Rectified Flow. Он отличается от известного всем DDPM и отличается в лучшую сторону. Но я от этом рано заговорил, продолжаем.

Итак, можно сказать, что наша модель будет моделировать трансформацию простого распределения (нормальное, оно же Гауссово) в сложное целевое распределение.

Звучит хорошо, да вот только моделировать трансформацию в один шаг - это тяжеловата задача получается. Вон GAN-модели с таким подходом далеко не продвинулись. Ладно, с GAN-моделями я немного утрирую, но гораздо практичнее опираться именно на траекторию - как шаг за шагом простое распределение трансформируется в сложное.

Что я имею в виду когда говорю, что можно опираться на траекторию? Смотрите, изначально данные для тренировки нашей нейросети состояли только из сэмплов целевого распределения (точки, лежащие на графике кота), а также такого же количества сэмплов из нормального распределение (шума короче). Но раз мы догадались моделировать трансформацию распределения "по шагам", то мы можем обогатить наш датасет всеми промежуточными состояниями, то есть целевыми сэмплами, которые зашумлены на 10%, 20%, 76% или вообще на любой процент. Другими словами, точками, которые находятся где-то на полпути между нормальным распределением и целевым.

Давайте ещё раз посмотрим как подход "выучить траекторию за раз" и "выучить траекторию по шагам" меняет нашу модель.

В первом случае наша модель функционирует вот так:
vector_to_target = model.predict(normal_noise_sample) - на вход только сэмпл из нормального распределения.

А в случае траектории по шагам модель будет работать так:
vector_to_target = model.predict(point_between_noise_and_target_distribution, time)
Здесь time - это доля пути, который точка прошла от шума, до целевого распределения. Процент зашумлённости, другими словами. Важно понимать: по сути, задача трансформацию распределения разбивается на мелкие подзадачи: "научись предсказывать путь до целевого распределения при 10% шума", "научись предсказывать путь до целевого распределения при 15% шума" и т.д. Такой подход позволяет гораздо лучше смоделировать трансформацию распределений, что повышает точность "предсказания".

За счёт чего увеличивается точность? А за счёт того, что теперь имея модель с дополнительным условием time мы можем вытаскивать сэмплы из целевого распределения ни с единственной попытки, а делая сколько угодно "уточнений" траектории. Легче понять если взглянуть на код инференса:

noise = sample_noise()  # сэмпл из нормального распределения
steps = 200  # разобъем траекторию на 200 маленьких шагов
for step in range(steps):
	time = step / steps  # будет меняться в интервале [0, 1)
	
	# независимо от time этот вектор всегда одной длины
	# в идеале вообще всегда один и тот же
	predicted_vector = model.predict(noise, time)
	
	scaled_vector = predicted_vector * (1 / steps)  # одна 200-я пути
	noise = noise + scaled_vector  # Наш изначальный шум на 0.5% приблизился 
	                               # к целевому распределению
	                               
# После завершения цикла наш noise - это уже сэмпл из целевого распределения

Помните, что в нашем примере все сэмплы это двумерные векторы (x, y) - точки на плоскости. Пользуясь терминологие PyTorch - тензоры с шейпом (2). А так-то модель можно создавать для тензоров любой формы.

Довольно теории

Для демонстрации нам понадобится датасет. Все элементы из него будут считаться сэмплыми из целевого распределения. Например вот такого:

def make_simple_dataset(amount):  
    cluster_1 = torch.rand((amount // 2), 2) * 1.5 + 0.6  
    cluster_2 = torch.rand((amount // 4), 2) * 0.8 + torch.tensor([-.8, .6])  
    cluster_3 = torch.rand((amount // 4), 2) * torch.tensor([1.2, 0.4]) + torch.tensor([-2.4, 0.6])  
  
    return torch.cat([cluster_1, cluster_2, cluster_3], dim=0)

На координатной плоскости 800 сэмплов этого распределения будут выглядеть вот так:

Теперь пора приниматься за PyTorch модель.

Будет эта модель состоять из трёх частей:

  1. Энкодер, который проецирует входной вектор из двух измерений во внутренний вектор из большего числа измерений (16). Зачем нам большее количество измерений? Чтобы модели было где "развернуться".

  2. Основная модель, состоящая из нескольких MLP-блоков. Будет непосредственной заниматься денойзингом входящих "шумных" сэмплов. MLP - это multilayer perceptron, та самая "классическая" нейронная сеть.

  3. Декодер, который превращает 16-размерный вектор обратно в две координаты.

Схематичное изображение модели
Схематичное изображение модели

Вот так выглядит упрощённая схема. Linear A - это энкодер. Состоит из одного единственного слоя nn.Linear (здесь и далее всё типы из PyTorch). По сути просто матрица, для трансформации 2-размерного тензора в 16-размерный. Прямоугольник посередине - это сам денойзер, состоящий из нескольких последовательных блоков/уровней. Несколько небольших нейросетей, выстроенных в ряд, короче. Linear B - это декодер. Опять же, матрица для трансформации 16-размерного тензора обратно в 2 координаты на плоскости. На схеме вместе 2 и 16 написано (B, 2) и (B, 16), потому что модель принимает на вход сэмплы не по одному а сразу группой (батчем). B - это размер батча (количество элементов).

Давайте сразу разберёмся как устроены внутренние блоки. Заодно вспомним как сделать в PyTorch Multilayered perceptron.

Итак, как написать вот такую нейросеть (модель)?

Стереотипная 'нейронная сеть'
Стереотипная 'нейронная сеть'

На PyTorch проще простого:

model = nn.Sequential(
	nn.Linear(4, 3),  # полносвязный слой
	nn.SiLU(),  # функция активации
	nn.Linear(3, 5),  # полносвязный слой
)

Давайте только завернём это в отдельный класс:

class MyBlock(nn.Module):
	def __init__(self):  
	    super().__init__()
	    self.mlp = nn.Sequential(
			nn.Linear(4, 3),
			nn.SiLU(),
			nn.Linear(3, 5),
		)
		
	def forward(x):  # эта функция пропускает входные данные сквозь нашу модель
		return self.mlp(x)

nn.Module - это полезный класс, который облегчает последующую работу с моделью.

Думаю, код будет гибче, если вынести скалярные константы в конструктор:

class MyBlock(nn.Module):
	def __init__(self, hidden_dim, mlp_ratio):  
	    super().__init__()
	    self.mlp = nn.Sequential(
			nn.Linear(hidden_dim, hidden_dim * mlp_ratio),
			nn.SiLU(),
			nn.Linear(hidden_dim * mlp_ratio, hidden_dim),
		)
		
	def forward(x):
		return self.mlp(x)

Теперь "ширину" и размерность входного вектора для этой нейросети можно указывать при создании.

Полный код
import torch.nn as nn

class DenoiserBlock(nn.Module):
	def __init__(self, hidden_dim, mlp_ratio):  
	    super().__init__()
	    self.ln = nn.LayerNorm(hidden_dim)
	    self.mlp = nn.Sequential(
			nn.Linear(hidden_dim, hidden_dim * mlp_ratio),
			nn.SiLU(),
			nn.Linear(hidden_dim * mlp_ratio, hidden_dim),
		)
		
	def forward(self, x):
		z = self.ln(x)  # сначала прогоняем входной тензор через нормализацию
		return self.mlp(z)  # а потом через MLP

В итоге получилась модель внутреннего блока, из которых будет состоять наша основная модель-денойзер. Так, а откуда взялся nn.LayerNorm ? Сейчас вдаваться в подробности не буду, просто скажу, что nn.LayerNorm позволяет удерживать значения тензоров (точек) где-то в пределах [-2, 2].

Вот для примера данные до и после нормализации

Скомкивались в центре, но сохранили форму
Скомкивались в центре, но сохранили форму

Другими словами, нормализовать - это значит преобразовать данные так, чтобы средняя была 0, а стандартное отклонение 1. Обыкновенная формула из статистики. Таким образом nn.LayerNorm облегчает тренировку модели и имеет ещё одно полезное свойство, о котором я, возможно, расскажу позже.

Кстати, вопрос к залу: тензор какой формы сможет принимать на вход вот эта конкретная модель:

model = DenoiserBlock(25, 3)

Ответ: любой тензор, у которого последнее измерение равно 25. Например (25), (72, 25), (1, 25), (8, 3, 25, 25), и т.п.

Ладно, с единственным блоком разобрались, переходим к основной модели.

class Denoiser(nn.Module):  
    def __init__(self, hidden_dims):  
        super().__init__()  
        self.input_encoder = nn.Linear(2, hidden_dims)  
        self.blocks = []
        self.output_decoder = nn.Linear(hidden_dims, 2)  
  

Как и на схеме выше, слой для превращения входного 2-размерного тензор в 16-размерный, слой для превращения внутреннего 16-размерного тензора обратно в 2-размерный. Давайте ещё раз напомню, что здесь 2-размерный тензора это тензор с шейпом (резмером) (B, 2), а 16-размерный с шейпом (B, 16). Здесь B - это размер батча (группы). Сколько сэмплов (точек) мы обрабатываем за раз, другими словами. Если, например, входной тензор будет размером (64, 2), то пройдя через input_encoder он превратится в тензор (64, 16).

Добавляем внутренние блоки:

class Denoiser(nn.Module):  
    def __init__(self, hidden_dims, num_blocks):  
        super().__init__()  
        self.input_encoder = nn.Linear(2, hidden_dims)
        block_list = [DenoiserBlock(hidden_dims, 4) for _ in range(num_blocks)]
        self.blocks = nn.ModuleList(block_list)
        self.output_decoder = nn.Linear(hidden_dims, 2)  
  

Зачем оборачивать в nn.ModuleList? Всё для того, чтобы параметры всех внутренних моделей в списке были доступны внешней модели. Другими словами, чтобы Denoiser со списком DenoiserBlock внутри управлялся как единая модель.

Осталось только дописать метод forward

def forward(self, x):  
    hidden = self.input_encoder(x)  # (B, 2) -> (B, 16)
    for block in self.blocks:
        hidden = block(hidden)  # выход одного блока передаётся на вход второму
    return self.output_decoder(hidden)  # (B, 16) -> (B, 2)

А код-то неверный! Блоки внутри денойзера должны быть соединены не последовательно, а через skip-connection. Не забывайте, что конечная цель - это модель-трансформер, а у трансформеров слои (блоки) соединены через skip-connection, поэтому переделываем то, как соединены блоки и превращаем Denoiser в остаточную сеть:

def forward(self, x):  
    hidden = self.input_encoder(x)  # (B, 2) -> (B, 16)
    for block in self.blocks:
	    # Выход каждого блока прибавляется к начальному представлению и передаётся дальше
        hidden = hidden + block(hidden)  # 
    return self.output_decoder(hidden)  # (B, 16) -> (B, 2)
Всё вместе это теперь выглядит вот так
import torch.nn as nn

class DenoiserBlock(nn.Module):
	def __init__(self, hidden_dim, mlp_ratio):  
	    super().__init__()
	    self.ln = nn.LayerNorm(hidden_dims)
	    self.mlp = nn.Sequential(
			nn.Linear(hidden_dim, hidden_dim * mlp_ratio),
			nn.SiLU(),
			nn.Linear(hidden_dim * mlp_ratio, hidden_dim),
		)
		
	def forward(x):
		z = self.ln(x)  # сначала прогоняем входной тензор через нормализацию
		return self.mlp(z)  # а потом через MLP
		

class Denoiser(nn.Module):  
    def __init__(self, hidden_dims, num_blocks):  
        super().__init__()  
        self.input_encoder = nn.Linear(2, hidden_dims)
        block_list = [Denoiser2DBlock(hidden_dims, 4) for _ in range(num_blocks)]
        self.blocks = nn.ModuleList(block_list)
        self.output_decoder = nn.Linear(hidden_dims, 2)
        
	def forward(self, x):  
	    hidden = self.input_encoder(x)  # (B, 2) -> (B, 16)
	    for block in self.blocks:
		    # ВНИМАНИЕ: это остаточная (residual) сеть, то есть  
		    # после каждого слоя мы обновляем наше скрытое представление прибавляя к нему результат работы блока  
		    hidden = hidden + block(hidden + time_embedding)
	        hidden = block(hidden)  # выход одного блока передаётся на вход второму
	    return self.output_decoder(hidden)  # (B, 16) -> (B, 2)

Осталось только написать код для тренировки

В самом начале надо сформировать датасет:

BATCH_SIZE = 128
simple_dataset = TensorDataset(make_simple_dataset(4096))  
data_loader = DataLoader(dataset=simple_dataset,   
                         batch_size=BATCH_SIZE,   
                         shuffle=True)

Короче, data_loader - это итератор по датасету из 4096 элементов, который будет за раз возвращать по 128 элементов из этого датасета, в случайном порядке (но без повторений). Если все элементы кончатся, то просто начнёт сначала.

Сама тренировка, это вот такой цикл:

for x, in data_loader:
	# одна итерация тренировки

А почему x, а не просто x без запятой? Да потому, что data_loader возвращает список, ведь датасеты обычно состоят из пар типа (вопрос -> ответ) или (изображение -> описание).

Один раз пройти по датасету часто бывает недостаточно. Модель должна увидеть один и тот же сэмпл несколько раз. Один проход по датасету называется эпоха (epoch). Поэтому оборачиваем код в ещё один цикл:

EPOCH = 2000
for epoch in range(EPOCH):
	epoch_loss = 0  # чтобы отслеживать какая суммарная ошибка за эпоху
	for x, in data_loader:
		# одна итерация тренировки
		# накапливаем epoch_loss
	if epoch % 100 == 0:  # пишем в консоль каждые 100 эпох
		print(f"Epoch {epoch + 1} completed. Loss: {epoch_loss:.2f}")

Давайте теперь добавим модель для тренировки:

LR = 3e-4  # Learning rate. Насколько сильно за раз будем обновлять веса модели
DEVICE = "cuda"  # Ну не на CPU же.

# инициализируем модель с 8 блоками внутри
model = Denoiser(hidden_dims=16, num_blocks=8)
model.to(DEVICE)  # и отправляем веса модели на GPU

# оптимайзер, который будет по-умному обновлять веса у модели (с инерцией и т.п.)
optimizer = torch.optim.AdamW(model.parameters(), lr=LR)                           


EPOCH = 2000
for epoch in range(EPOCH):
	epoch_loss = 0  # чтобы отслеживать какая суммарная ошибка за эпоху
	for x, in data_loader:
		# здесь подготовка данных для модели
		# создание xt и true_vector на основе сэмпла x иными словами
		
		# прогоняем входные данные сквозь модель и получаем предсказание
		predicted_vector = model(xt)
		# сравниваем предсказанный моделью вектор с эталонным и вычисляем ошибку
		loss = torch.mean((true_vector - predicted_vector) ** 2)
		  
		optimizer.zero_grad()  # очищаем градиент оставшийся с предыдущего цикла
		loss.backward()  # вычисляем градиент методом backpropagation
		optimizer.step()  # обновляем веса 
	if epoch % 100 == 0:  # пишем в консоль каждые 100 эпох
		print(f"Epoch {epoch + 1} completed. Loss: {epoch_loss:.2f}")

Много существует статей, где подробно расписан механизм обратного распространения ошибки, поэтому опишу всё очень просто, пропуская важные подробности.

В общем, вот этот код создаёт граф вычислений, в котором сохранены все вычисления, которые произошли, когда данные через модель:

predicted_vector = model(xt)
loss = torch.mean((true_vector - predicted_vector) ** 2)

Получившийся тензор loss тоже является частью этого графа, и поэтому, когда мы вызываем

loss.backward()

PyTorch использует имеющийся граф вычислений, чтобы определить (вычислить) как должны поменяться параметры всех участвующих в графе тензоров, чтобы значение loss стало меньше. Производную считает, короче. После этого у каждого параметра внутри модели появилось ещё дополнительное число, которое и является этой производной. Это и называется градиент. Теперь за дело принимается optimizer. Он имеет доступ ко всем параметрам модели (посмотрите как он инициализировался), а значит и к градиенту. Команда optimizer.step() заставляет оптимайзер обновить все параметры в модели руководствуясь градиентом, learning rate и своим внутренним состоянием (хитро вычисляемая инерция). Кстати, заметьте, что параметры обновились, но градиент никуда не делся и так и остался привязан к параметрам. Поэтому, в следующем цикле и вызываем optimizer.zero_grad(), чтобы очистить его, иначе loss.backward() наложит старый градиент на новый. Это иногда бывает полезно, но в такие подробности вдаваться не будем.

Дописываем оставшуюся часть цикла:

BATCH_SIZE = 128  
LR = 3e-4
DEVICE = "cuda"
EPOCH = 1000

for epoch in range(EPOCH):
	epoch_loss = 0
	for x, in data_loader:
	    # копируем сэмпл на GRU
		x0 = x.to(DEVICE)  # (128, 2) - форма тензора
		
		# создаём сэмпл из случайного распределения. Тоже (128, 2)
		noise = torch.randn_like(x, device=DEVICE)  # сразу же окажется на GPU
		
		# а это уже сэмплы из равномерного распределения в интервале от 0 до 1
		time = torch.rand((BATCH_SIZE, 1), device=DEVICE)  # (128, 1)
		
		
		# Какой вектор надо прибавить к шуму, чтобы получить целевое распределение
		true_vector = x0 - noise  # сразу 128 векторов за раз
		# xt - это точки лежащие "на полпути" от целевого распределения до нормального
		xt = noise + true_vector * (1 - time)
		
		predicted_vector = model(xt)
		loss = torch.mean((true_vector - predicted_vector) ** 2)
		epoch_loss += loss.item()  # накапливаем ошибку для логирования
		  
		optimizer.zero_grad()
		loss.backward()
		optimizer.step()
	if epoch % 100 == 0:
		print(f"Epoch {epoch + 1} completed. Loss: {epoch_loss:.2f}")

Какая у нас задача во время тренировки? Есть сэмплы из целевого распределения x0, есть сэмплы из нормального распределения noise. Чтобы научить модель трансформировать шумное распределение в целевое мы создаём набор данных xt - это точки лежащие где-то на траектории от шумного (нормального) распределения до целевого. Вот вы спросите, а какая у каждой точки траектория вообще? Тут просто, на самом деле: для каждого сэмпла из целевого распределения мы берём соответствующий ему (просто совпавший по индексу) сэмпл из нормального распределения. Так-как это распределение случайное, то можно просто сказать, что каждому сэмплу (точке) из целевого распределения берётся случайная точка из нормального распределения. Не самая ближайшая, а просто случайная, да. В таком случае true_vector - это просто вектор, соединяющий эту пару точек. И таких пар 128 - по размеру батча. Теперь раз у нас есть точки (сэмплы) из случайного распределения и вектора, которые указывают направление до (какого-то) сэмпла из целевого распределения, нам ничего не стоит создать набор точек, лежащий на траектории от точек случайного распределения до точек целевого - просто прибавить к точкам случайного распределения соответствующий вектор, предварительно его масштабировав (10% или там 87%). Переменная time со случайными числами от 0 до 1 как раз для этого. Так и получаем набор данных xt который и скармливаем модели:

xt = noise + true_vector * (1 - time)
predicted_vector = model(xt)

Если интересно почему (1 - time), а не просто time, то это мы привязываемся на то, что чем больше time тем больше шума.

Теперь код обучения готов, но если мы попытаемся запустить этот модуль, то в результате лишь увидим с десяток записей в консоль. Мы даже модель не сохраняем для дальнейшего использования. И самое главное - не видим подтверждения того, что моделируемая нами трансформация простого распределения в сложное вообще работает. Что делать? Написать код сэмплирования (инференса), где мы будем использовать натренированную модель, чтобы извлекать сэмплы из целевого распределения (надеюсь).

Ладно, вот код инференса:

samples = torch.randn((400, 2), device=DEVICE)  
with torch.no_grad():  
    STEPS = 50  
    for step in range(STEPS, 0, -1):  
        predicted = model(samples)  
        samples += predicted * (1 / STEPS)

Что тут происходит? Мы генерируем сэмпл из случайного распределения, потом 50 раз прогоняем его сквозь модель каждый раз уточняя вектор predicted. После каждой итерации цикла samples становятся всё ближе и ближе к целевому распределению, ну
а torch.no_grad() отключает вычисление градиента. Полезно, если мы хотим заниматься инференсом прямо во время тренировки и не хотим, чтобы тестовый прогон модели как-то на эту тренировку влиял.

Осталось лишь нанести точки на график, предварительно импортировав pyplot:

import matplotlib.pyplot as plt

Добавляем точки на график:

samples = torch.permute(samples, (1, 0)).cpu().detach()  # (400, 2) -> (2, 400)
  
plt.figure(figsize=(6, 6))  
plt.xlim(-3, 3)  
plt.ylim(-3, 3)  
plt.grid(True)  
plt.axhline(0, color='black', linewidth=0.5)  
plt.axvline(0, color='black', linewidth=0.5)  
plt.scatter(samples[0], samples[1], s=4, c='blue')  
plt.show()

И в результате получаем вот такую визуализацию (сравнение с целевой):

Итак, сегодня мы научились проекти... Погодите-ка! Вот вам не кажется, что мы что-то упустили, нет? Внимательный читатель уже догадался - наша модель полностью игнорирует переменную time. А это значит модель училась предсказывать целевое распределение не получая дополнительной информации о том, на каком участке траектории находились переданные ей сэмплы. Неудивительно, что вместо целевого распределения на графике клякса какая-то!

Ок, нам надо каким-то образом передать в модель информацию о времени (шаге). На руках у нас только число от 0 до 1, но модель простые числа не переваривает - нужно векторное представление. Для простоты скажем что вектор должен быть 16-размерным - такой же длины как и скрытое представление модели. Получив этот вектор (time_embedding) уникальный для каждого числа в интервале от 0 до 1, мы просто будем прибавлять его к скрытому представлению на каждом уровне, вот так:

hidden = hidden + block(hidden + time_embedding)

Таким образом, "впечатывая" в скрытое представление информацию о том, на каком уровне "зашумления" находятся переданные в модель данные. Напомню ещё раз, это нужно для того, чтобы модели было легче моделировать трансформацию распределений - ведь теперь она сможет выявлять закономерности между уровнем шума (time) и переданными данными, таким образом обучаясь лучше (в теории).

Только теперь проблема: как же нам из числа получить 16-размерный вектор? Я хотел тут написать про sinusoidal и прочие экспоненты, но, если честно, то простой проекции хватит.

Добавляем внутрь конструктора Denoiser:

self.time_linear = nn.Sequential(  
    nn.Linear(1, hidden_dims),  
    nn.LayerNorm(hidden_dims)  
)

LayerNorm просто чтобы выходные значения были где-то в районе [-2, 2].

Меняем forward метод:

def forward(self, x, t):  # теперь на вход принимает и время
    hidden = self.input_encoder(x)  # (B, 2) -> (B, 16)  
    
    # делаем его поменьше, чтобы информации о времене была
    # но при этом не "перезаписать" само скрытое представление
    time_embedding = self.time_linear(t) * 0.02  
    for block in self.blocks:  
        hidden = hidden + block(hidden + time_embedding)  
    return self.output_decoder(hidden)

Осталось лишь поправить код инференса

samples = torch.randn((400, 2), device=DEVICE)  
with torch.no_grad():  
    STEPS = 50  
    for step in range(STEPS, 0, -1):
	    # тензор с шейпом (1)
        time = torch.tensor(step / STEPS, device=DEVICE)
        # расширяем его до шейпа (400, 1)
	    time = time.expand(samples.size(0), 1)
        predicted = model(samples, time)  
        samples += predicted * (1 / STEPS)

Запускаем тренировку с теми же самыми параметрами и получаем:

Вот так!

Тут пара анимацией процесса инференса

Бараний датасет

Бараний датасет
Бараний датасет

Важные детали:

  1. Давайте ещё раз скажу, модель выучила не целевое распределение, а именно трансформацию нормального распределения. Если запустить инференс с другим сидом, то точки уже будут в другом расположении, но всё ещё в пределах целевого распределения.

  2. Это не DDPM модель! При DDPM мы бы использовали формулу forward diffusion, а тут у нас rectified flow, поэтому для "зашумления" сэмплов мы используем простую линейную интерполяцию xt = x0 * (1 - time) + noise * time. И предсказывает модель flow-вектор (velocity), а не шум, как в DDPM.

  3. Весь написанный код можно найти здесь.

Заключение

В этой статье я шаг за шагом рассказал как с нуля обучить мини-диффузную модель использую простой синтетический датасет.

Ключевым моментом стала идея моделировать трансформацию простого распределения (шума) в сложное (наш датасет) не за один шаг, а пошагово, используя Rectified Flow вместо DDPM. Это позволило нам обогатить обучающий набор промежуточными состояниями и создать модель, которая предсказывает вектор направления (flow-vector) для каждого шага.

На собственном примере убедились, какое важное значение играет time-embedding для повышения точности - позволяет модели учитывать степень зашумления данных и лучше обучаться.

Думаю, я рассказал достаточно, чтобы у вас появилось базовое понимание работы диффузных моделей, а значит можно приниматься за что-то поинтереснее. Часть 2 будет про обучение уже на датасете EMNIST. Будем создавать модель, способную генерировать рисунки чисел и букв. Продолжение следует...

Комментарии (0)