Vision Transformer (ViT) — это архитектура, которая буквально произвела революцию в том, как машины «видят» мир.

В этой статье я не просто объясню, что такое ViT — я покажу вам, как создать эту магию своими руками, шаг за шагом, даже если вы никогда раньше не работали с трансформерами для задач с изображениями.

Для начала давайте взглянем на архитектуру Vision Transformer:

Vision Transformer architecture
Vision Transformer architecture

Мы напишем код полностью с нуля, а затем обучим модель на датасете CIFAR-10.

Давайте начнём с реализации Patch Embedding:

class PatchEmbedding(nn.Module):
    def __init__(self, img_size = 32, patch_size = 4, in_channels = 3, embed_dim=256):
        super().__init__()

        assert img_size % patch_size == 0, "img_size must be divisible by patch_size"
        self.patch_size = patch_size
        self.num_patches = (img_size//patch_size)**2
        self.conv = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size)

    def forward(self, x):
        x = self.conv(x) #(B, embed_dim, H/patch_size, W/patch_size)
        x = x.flatten(2).transpose(1, 2) #(B, num_patches, embed_dim)
        return x

Изображение будет разделено на патчи, и размер каждого патча можно задать с помощью параметра patch_size. При этом изображение не просто разбивается на патчи, но и пропускается через свёрточные ядра (CNN). В итоге мы получаем не просто патчи изображения — а встраивания (эмбеддинги) этих патчей.

Следующий шаг — реализовать самую интересную часть этой модели — механизм внимания (attention).

Self-Attention Mechanism
Self-Attention Mechanism

Q (Query) формально задаёт вопрос от каждого патча к другим патчам, K (Key) показывает, есть ли у каждого патча ответ на этот вопрос, а V (Value) содержит «значения» — фактические данные каждого патча, которые используются для формирования итогового представления.

Предположим, у нас есть X и Y, и мы хотим, чтобы X обращал внимание на Y. В этом случае матрица Query умножается на X, а матрицы Key и Value — на Y. Вместо прямого умножения на матрицы мы используем линейные слои.

attn_probs — это матрицы внимания, которые показывают, насколько токен i должен «обращать внимание» на токен j. Далее мы умножаем их на V, чтобы получить эмбеддинги изображения с учётом весов внимания attn_probs. V фактически хранит значения каждого патча изображения, а attn_probs показывает, сколько информации каждый патч должен получить от остальных патчей.

Вот как работает одна голова внимания; затем значения с всех голов объединяются. Такая конструкция основана на идее, что каждая голова фокусируется на разных аспектах.

class MultiHeadAttention(nn.Module):
    def __init__(self, dim, num_heads, dropout):
        super().__init__()
        assert dim % num_heads == 0, "Embedding dimension must be divisible by number of heads"
        self.num_heads = num_heads

        self.head_dim = dim // num_heads

        self.qkv = nn.Linear(dim, dim * 3, bias = False)
        self.out = nn.Linear(dim, dim, bias = False)

        self.scale = 1.0 / (self.head_dim ** 0.5)

        self.attn_dropout = nn.Dropout(dropout)

    def forward(self,  x, mask = None,  return_attn=False):
        B, num_patches, embed_dim = x.shape

        qkv = self.qkv(x) # (B, num_patches, 3*embed_dim)
        qkv = qkv.reshape(B, num_patches, 3, self.num_heads, self.head_dim)
        qkv = qkv.permute(2, 0, 3, 1, 4) #(3, B, num_heads, num_patches, head_dim)

        q, k, v = qkv[0], qkv[1], qkv[2]  # each (B, num_heads, num_patches, head_dim)

                                        #How important it is for token i to pay attention to token j.
        attn_scores = (q @ k.transpose(-2, -1)) * self.scale #[B, num_heads, N, N]

        if mask is not None:
            # mask: (B, 1, N, N) or (1, 1, N, N)
            attn_scores = attn_scores.masked_fill(mask == 0, float("-inf"))

        attn_probs = attn_scores.softmax(dim=-1) #[B, num_heads, N, N]
        attn_probs = self.attn_dropout(attn_probs)
        attn_output = attn_probs @ v  # (B, num_heads, num_patches, head_dim)
        attn_output = attn_output.transpose(1, 2).reshape(B, num_patches, embed_dim)

        if return_attn:
          return self.out(attn_output), attn_probs
        else:
          return self.out(attn_output) #(B, num_patches, embed_dim)

Давайте перейдём к сборке блока Transformer Encoder:

class TransformerEncoderBlock(nn.Module):
    def __init__(self, dim, num_heads, mlp_dim, dropout):
        super().__init__()
        self.norm1 = nn.LayerNorm(dim)
        self.attn = MultiHeadAttention(dim, num_heads, dropout)
        self.norm2 = nn.LayerNorm(dim)

        self.mlp = nn.Sequential(
            nn.Linear(dim, mlp_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(mlp_dim, dim),
            nn.Dropout(dropout)
        )
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, return_attn=False):
      if return_attn:
        attn_out, attn_weights = self.attn(self.norm1(x), return_attn=True)
        x = x + self.dropout(attn_out)
        x = x + self.dropout(self.mlp(self.norm2(x)))
        return x, attn_weights
      else:
        x = x + self.dropout(self.attn(self.norm1(x)))
        x = x + self.dropout(self.mlp(self.norm2(x)))
        return x

Здесь мы просто следуем архитектуре нашей сети — все необходимые блоки мы уже реализовали.

Мы уже почти на финишной прямой — теперь соберём сам Vision Transformer:

class VisualTransformer(nn.Module):
    def __init__(self,num_classes, img_size=32, patch_size=4, in_channels=3, embed_dim=256,
                 num_layers=6, num_heads=7, mlp_dim=512, dropout=0.1):
        super().__init__()

        self.patch_embed = PatchEmbedding(img_size, patch_size, in_channels, embed_dim)
        self.cls_token = nn.Parameter(torch.randn(1, 1, embed_dim))
        self.pos_embed = nn.Parameter(torch.randn(1, 1 + self.patch_embed.num_patches, embed_dim))
        self.dropout = nn.Dropout(dropout)

        self.encoder_blocks = nn.ModuleList([
                TransformerEncoderBlock(embed_dim, num_heads, mlp_dim, dropout)
                for _ in range(num_layers)
            ])

        self.norm = nn.LayerNorm(embed_dim)

        self.mlp_head = nn.Sequential(
            nn.LayerNorm(embed_dim),
            nn.Linear(in_features=embed_dim, out_features=num_classes)
        )

    def forward(self, x, return_attn = False):
        B = x.size(0)
        x = self.patch_embed(x)  # (B, N, D)

        cls_tokens = self.cls_token.expand(B, -1, -1)  # (B, 1, D)
        x = torch.cat((cls_tokens, x), dim=1)  # (B, 1+N, D)
        x = x + self.pos_embed
        x = self.dropout(x)

        attn_maps = []
        for block in self.encoder_blocks:
          if return_attn:
            x, attn = block(x, return_attn=True)
            attn_maps.append(attn)  # (B, heads, N, N)
          else:
            x = block(x) # (B, 1+N, D)

        x = self.norm(x)

        out = self.mlp_head(x[:, 0, :])

        if return_attn:
          return out, attn_maps
        else:
          return out

Здесь нужно уточнить несколько моментов. Что такое cls_token? Это специальный токен, который мы добавляем вручную, и он имеет тот же размер, что и патчи изображения. Его задача — использоваться позже для классификации изображения. Идея в том, что, проходя через блоки внимания, этот токен собирает информацию обо всём изображении.

Далее посмотрим на pos_embed. Поскольку мы делим изображение на патчи и выстраиваем их в последовательность — как будто работаем с текстом — модель изначально не понимает пространственные взаимосвязи между патчами. Чтобы это исправить, мы добавляем позиционную информацию к патчам. В нашем случае pos_embed — это обучаемый параметр.

Что касается mlp_head, здесь всё просто: он берёт cls_token, пропускает его через линейный слой и классифицирует изображение.

После сборки нашей модели давайте перейдём к обучению.

Для обучения мы будем использовать следующие гиперпараметры:

BATCH_SIZE = 128
EPOCHS = 80
LEARNING_RATE = 3e-4
PATCH_SIZE = 4
NUM_CLASSES = 10
IMAGE_SIZE = 32
CHANNELS = 3
EMBED_DIM = 256
NUM_HEADS = 8
DEPTH = 6
MLP_DIM = 512
DROP_RATE = 0.1

Давайте посмотрим на количество параметров:

Total parameters: 3,189,514
Trained parameters: 3,189,514

А также следующие аугментации:

train_transforms = transforms.Compose([
    transforms.Resize((70, 70)),
    transforms.RandomResizedCrop(32, scale=(0.8, 1.0)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomRotation(15),
    transforms.ColorJitter(0.2, 0.2, 0.2, 0.1),
    transforms.RandomAffine(degrees=0, translate=(0.1, 0.1), scale=(0.9, 1.1)),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    transforms.RandomErasing(p=0.25, scale=(0.02, 0.2), ratio=(0.3, 3.3), value='random'),
])

Мы обучаем модель и получаем следующие результаты:

 Графики процесса обучения
Графики процесса обучения

Давайте посмотрим на предсказания модели:

Предсказания модели
Предсказания модели

Вот метрики получившейся модели:

Метрики модели
Метрики модели
 Матрица ошибок (Confusion Matrix)
Матрица ошибок (Confusion Matrix)

А теперь к самой интересной части — вниманию. Давайте посмотрим, на что наша модель обращает внимание во время классификации:

 Карта внимания (Attention Map)
Карта внимания (Attention Map)
 Карта внимания (Attention Map)
Карта внимания (Attention Map)

В этой статье мы подробно рассмотрели реализацию Vision Transformer и его механизма внимания. Мы изучили, на что способна эта модель и как она «смотрит» на изображение с помощью механизма внимания. Vision Transformer открыл новые направления в исследовании компьютерного зрения, объединив идеи из NLP и обработки изображений. В будущем мы обязательно применим эту модель для задачи генерации подписей к изображениям (Image Captioning).

Полный код и процесс обучения вы можете найти на моём Kaggle:

https://www.kaggle.com/code/nickr0ot/visual-transformer-from-scratch

Комментарии (5)


  1. lgorSL
    04.07.2025 14:31

    А что если сделать эмбеддинг для относительной позиции, а не для абсолютной? Для этого можно взять rotary embeddings (отдельно по осям x и y).

    Я так пробовал, но кажется у меня были где-то ошибки и результат не получился.


    1. Flokis_guy
      04.07.2025 14:31

      Они не обладают тогда универсальной аппроксимацией.


  1. ioleynikov
    04.07.2025 14:31

    на редкость бестолкова статья. Ничего толком не объяснено. Вот архитектура, вот код и вперед. Что, зачем, для чего абсолютно не ясно. Самое плохое, что тема определения различий кошек от собак отстала от жизни лет на 10! Было бы куда полезней объяснить на пальцах как работают современные сети типа GPT-4 Vision для анализ и описание визуального контента.


    1. S_A
      04.07.2025 14:31

      по мне так код вполне понятный. как вводная история очень даже неплохо.


  1. ioleynikov
    04.07.2025 14:31

    Да я не оспариваю качество кода, который крайне избыточен и написан скорее всего совсем другим человеком. Зачем столько мути для решения простейшей задачи различий кошек и собак? Я так же назвал свои пожелания.