Дистилляция моделей

Материал из MachineLearning.

Версия от 12:53, 16 июня 2026; Mihail Mishin (Обсуждение | вклад)
(разн.) ← Предыдущая | Текущая версия (разн.) | Следующая → (разн.)
Перейти к: навигация, поиск
Статья написана с использованием LLM Gemini 3.1 Pro и проверена участником М. Мишин 16:53, 16 июня 2026 (MSD)

Промпт приводится полностью в Обсуждение:Дистилляция моделей


Содержание

Дистилля́ция моде́лей (дистилляция знаний, англ. knowledge distillation) — метод сжатия моделей машинного обучения, при котором компактная модель (студент) обучается воспроизводить поведение более сложной и тяжелой модели или ансамбля моделей (учителя). Основная цель дистилляции — перенести обобщающую способность и внутренние репрезентации большой модели в меньшую по размеру, чтобы существенно ускорить инференс и снизить потребление памяти без значительной потери качества предсказаний.

В современной практике глубокого обучения, особенно в сфере NLP и больших языковых моделей (LLM), дистилляция является ключевым инструментом для создания эффективных локальных моделей (размером 1–8 млрд параметров), способных решать сложные аналитические и логические задачи на уровне флагманских архитектур.

Мотивация и основные идеи

Традиционно для достижения высокой точности на сложных задачах применяются огромные глубокие нейронные сети или композиции (ансамбли) множества моделей. Однако их развертывание в продуктивной среде (например, на мобильных устройствах, edge-устройствах или высоконагруженных серверах) часто невозможно. Главными барьерами выступают строгие ограничения на пропускную способность памяти (memory bandwidth), объем доступной видеопамяти (VRAM) и максимально допустимую задержку ответа (latency).

Центральная идея дистилляции заключается в следующем: вместо того чтобы обучать маленькую модель исключительно на жестких метках классов (hard labels) из оригинального набора данных, мы заставляем её предсказывать непрерывные распределения вероятностей (soft labels), выдаваемые предварительно обученной моделью-учителем.

Эти «мягкие» метки содержат огромное количество скрытой информации (dark knowledge). Например, в задаче классификации изображений учитель может предсказать, что объект на картинке с вероятностью 80% — собака, с вероятностью 19% — кошка, и с вероятностью 1% — автомобиль. Относительные вероятности ошибочных классов (то, что «кошка» в 19 раз вероятнее «автомобиля») описывают внутреннюю структуру данных и скрытые сходства объектов. Модель-студент, обучаясь на таких распределениях, получает богатый градиентный сигнал и сходится быстрее, достигая метрик, недостижимых при обычном обучении «с нуля».

Историческая справка

Идейные предпосылки метода были заложены в работе Кристиана Бусилы и соавторов (Bucila et al., 2006)[1]. В своем исследовании по сжатию моделей они успешно обучили одну быструю нейронную сеть имитировать предсказания громоздкого ансамбля деревьев решений, сохранив при этом высокое качество классификации.

Сам термин «дистилляция знаний» (knowledge distillation) и его современная строгая математическая формулировка с использованием механизма температурного скейлирования (temperature scaling) были введены в прорывной статье Джеффри Хинтона, Ориола Виньялса и Джеффа Дина (Hinton et al., 2015)[1]. Хинтон метафорично описал этот процесс как «дистилляцию» чистых знаний из сложной, перепараметризованной функции в компактную форму.

Математическая формулировка базовой дистилляции

В классической задаче классификации нейронная сеть на последнем слое предсказывает логиты (logits) z_i, которые затем преобразуются в итоговые вероятности q_i с помощью стандартной функции софтмакс:

q_i = \frac{\exp(z_i)}{\sum_{j} \exp(z_j)}

В методе Хинтона в эту формулу искусственно вводится гиперпараметр температуры T. При T=1 мы получаем стандартный софтмакс. Однако при увеличении температуры (T > 1) итоговое распределение вероятностей становится более «мягким» и сглаженным. Это делает вероятности маловероятных (ошибочных) классов более выраженными и отличными от нуля:

q_i = \frac{\exp(z_i / T)}{\sum_{j} \exp(z_j / T)}

Процесс обучения студента сводится к минимизации комбинированной функции потерь (\mathcal{L}), которая состоит из двух независимых компонентов:

  1. Потеря дистилляции (L_{KD}): Вычисляется как дивергенция Кульбака-Лейблера (KL-divergence) между сглаженными предсказаниями студента и учителя при одинаково высокой температуре T.
  2. Потеря студента на реальных данных (L_{CE}): Вычисляется как стандартная кросс-энтропия между предсказаниями студента (при T=1) и истинными метками из датасета (hard labels).

Итоговая функция потерь взвешивается параметром \alpha \in [0, 1]:

\mathcal{L} = \alpha \cdot T^2 \cdot \text{KL}\left(P_{teacher}^{(T)} \parallel P_{student}^{(T)}\right) + (1 - \alpha) \cdot \text{CE}\left(y_{true}, P_{student}^{(1)}\right)

Умножение дивергенции на квадрат температуры (T^2) является критически важным математическим шагом. Поскольку градиенты KL-дивергенции, вычисленные по логитам, масштабируются пропорционально 1/T^2, это умножение необходимо для сохранения относительного веса двух компонентов функции потерь при варьировании температуры.

Основные архитектуры дистилляции

Помимо классической дистилляции по логитам (Logits-based distillation), описанной Хинтоном, существуют и более продвинутые архитектуры переноса знаний:

Дистилляция скрытых признаков (Feature-based distillation)

Предложена в концепции FitNets (Romero et al., 2014)[1]. В этом подходе модель-студент обучается воспроизводить не только финальные вероятности, но и промежуточные активации (карты признаков) внутренних слоев учителя. Функция потерь в таком случае включает среднеквадратичное отклонение (MSE) между тензорами признаков:

L_{feat} = \text{MSE}\left(\phi(F_{student}), F_{teacher}\right)

где F — активации скрытого слоя, а \phi — обучаемая проекционная матрица (адаптер), которая выравнивает размерность узкого слоя студента с широким слоем учителя.

Дистилляция отношений (Relation-based distillation)

Вместо того чтобы передавать информацию о каждом отдельном объекте изолированно, этот метод передает знания о взаимосвязях между объектами в батче. Например, студент учится сохранять ту же матрицу попарных косинусных расстояний между эмбеддингами изображений, которую формирует учитель.

Дистилляция больших языковых моделей (LLM)

С переходом индустрии к генеративному ИИ, фокус дистилляции сместился с вероятностных распределений классов на генерацию связных текстовых последовательностей. Современные подходы включают:

  • Дистилляция на уровне токенов (Token-level KD). Выравнивание распределений вероятностей для каждого следующего сгенерированного токена между открытой моделью-учителем (например, архитектурой уровня Llama 3 70B) и локальным студентом.
  • Дистилляция цепочек рассуждений (Chain-of-Thought Distillation). Одним из наиболее перспективных направлений является дистилляция математических и логических способностей. Процесс часто строится вокруг генерации массивов синтетических цепочек рассуждений (synthetic Chain-of-Thought, CoT) мощной моделью-учителем. Для повышения качества датасета на этапе генерации применяется сэмплирование Best-of-N (генерация множества ответов и выбор лучшего на основе Reward-модели). Отфильтрованные данные используются для дообучения компактной модели-студента (например, архитектуры с 1 млрд параметров). Для ускорения тонкой настройки (fine-tuning) и снижения требований к видеопамяти на этом этапе применяется Адаптация низкого ранга (LoRA). Практика показывает, что такой многоступенчатый пайплайн способен кардинально улучшить метрики — известны случаи семикратного роста точности (7x accuracy) небольших моделей на математических бенчмарках без изменения базового количества параметров.

Связь с другими методами сжатия

Дистилляция часто применяется не изолированно, а в синергии с другими методами оптимизации нейросетей:

  • Квантование (Quantization): Снижение разрядности весов модели (например, с 16-битных чисел с плавающей точкой до 8-битных или 4-битных целых чисел). Дистилляция часто используется для восстановления точности модели после агрессивного квантования (Quantization-Aware Knowledge Distillation).
  • Прунинг (Pruning): Физическое удаление наименее значимых весов или целых слоев из архитектуры. Отрезанная (прореженная) модель может использовать исходную плотную сеть в качестве учителя для тонкой донастройки.

Практическая реализация на PyTorch

Внедрение базовой дистилляции по логитам не требует изменения архитектуры самой сети, достаточно лишь модифицировать функцию потерь на этапе обучения (training loop). Ниже представлен классический пример реализации «с нуля» без использования тяжелых сторонних фреймворков:

import torch
import torch.nn as nn
import torch.nn.functional as F
 
def distillation_loss(student_logits, teacher_logits, true_labels, T=2.0, alpha=0.5):
    """
    Вычисляет комбинированную функцию потерь для дистилляции знаний.
 
    Параметры:
    student_logits (Tensor): Сырые предсказания модели-студента.
    teacher_logits (Tensor): Сырые предсказания модели-учителя.
    true_labels (Tensor): Истинные метки классов (hard labels).
    T (float): Температура для сглаживания распределений (T > 1).
    alpha (float): Вес для балансировки двух функций потерь (от 0 до 1).
    """
    # 1. Стандартная потеря (hard loss) на истинных метках при T=1
    hard_loss = F.cross_entropy(student_logits, true_labels)
 
    # 2. Дистилляционная потеря (soft loss) с повышенной температурой T
    # Вычисляем логарифм вероятностей для студента (требование KLDivLoss в PyTorch)
    student_soft = F.log_softmax(student_logits / T, dim=-1)
 
    # Вычисляем вероятности для учителя
    teacher_soft = F.softmax(teacher_logits / T, dim=-1)
 
    # Вычисляем KL-дивергенцию и масштабируем градиенты умножением на T^2
    kl_div = F.kl_div(student_soft, teacher_soft, reduction='batchmean')
    soft_loss = kl_div * (T ** 2)
 
    # 3. Итоговая функция: взвешенная сумма двух компонентов
    return alpha * soft_loss + (1 - alpha) * hard_loss

Распространённые ошибки (Антипаттерны)

  • Неправильный подбор температуры (T). Если установить температуру слишком большой (например, T > 10 для простых задач), распределение вероятностей приблизится к полностью равномерному, и студент потеряет полезные сигналы о структуре классов. Обычно оптимальное значение T находится в диапазоне от 2 до 5.
  • Несоответствие мощностей (Capacity Gap). Попытка дистиллировать знания из гигантского ансамбля в нейронную сеть из пары слоев (underparameterized student) приведет к тому, что студент просто не сможет аппроксимировать настолько сложную функцию. Если разница в размерах слишком велика, применяют промежуточных «ассистентов» (Teacher Assistant Knowledge Distillation).
  • Отключение Hard Loss на реальных данных. В некоторых задачах полное отключение кросс-энтропии на истинных метках (\alpha = 1) приводит к нестабильности обучения и снижению финальных метрик на валидационной выборке. Студент всегда должен иметь доступ к «наземной правде» (ground truth).

См. также

Литература

  • Bucila C., Caruana R., Niculescu-Mizil A. Model compression // Proceedings of the 12th ACM SIGKDD. — 2006. — С. 535–541.
  • Hinton G., Vinyals O., Dean J. Distilling the knowledge in a neural network // arXiv preprint arXiv:1503.02531. — 2015.
  • Gou J., Yu B., Maybank S. J., Tao D. Knowledge distillation: A survey // International Journal of Computer Vision. — 2021. — Т. 129. — № 6. — С. 1789-1819.
  • Romero A., Ballas N., Kahou S. E., Chassang A., Gatta C., Bengio Y. Fitnets: Hints for thin deep nets // arXiv preprint arXiv:1412.6550. — 2014.