Курс Python → Тестирование модели в PyTorch

Для того чтобы эффективно оценивать работу нашей модели машинного обучения, необходимо определить метод тестирования. Этот метод позволит нам проверить качество работы модели на тестовом наборе данных и вывести точность предсказаний. Основное отличие метода тестирования от обучения заключается в том, что в процессе тестирования мы используем функцию model.eval(), чтобы перевести модель в режим тестирования. Также важно использовать torch.no_grad(), чтобы отключить вычисление градиента, поскольку во время тестирования обратное распространение не требуется.

Для начала необходимо перевести модель в режим тестирования с помощью функции model.eval(). Это гарантирует, что все слои модели будут работать в режиме тестирования, что может влиять на поведение некоторых слоев, таких как Dropout или BatchNorm. Затем мы используем torch.no_grad(), чтобы временно отключить автоматическое дифференцирование и вычисление градиента. Это позволяет ускорить процесс тестирования, поскольку не нужно хранить градиенты для обновления весов модели.


model.eval()

with torch.no_grad():
    for inputs, labels in test_loader:
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        test_loss += loss.item()
        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

test_accuracy = correct / total

Наконец, после прохождения всех тестовых данных, мы вычисляем средние потери для всего тестового набора и общую точность предсказаний. Это позволяет оценить, насколько хорошо модель обучилась и способна предсказывать значения на новых данных. Результаты тестирования помогут нам понять, какие улучшения можно внести в модель для повышения ее эффективности и точности предсказаний.

Твои коллеги будут рады, поделись в

Автор урока

Дмитрий Комаровский
Дмитрий Комаровский

Автоматизация процессов
в КраснодарБанки.ру

Другие уроки курса "Python"

  1. Списковое включение в Python
  2. Работа с getopt
  3. Принципы LSP и ISP в Python
  4. Retrying в Python: повторные вызовы
  5. Создание детектора плагиата
  6. Проверка вхождения подстроки
  7. Протокол управления контекстом
  8. Ускорение выполнения кода в Python
  9. Объединение списков в Python
  10. Работа с индексами списков
  11. Enum в Python
  12. Роль ключевого слова self
  13. Метод сравнения объектов в Python
  14. Работа со строками в Python
  15. Сортировка с помощью key
  16. UserList в Python: Описание и примеры использования
  17. Логирование в Python
  18. Создание панели меню Tkinter
  19. Python: возвращение нескольких значений
  20. Перевод двоичного кода в целое число
  21. TON Smart Challenge #2: участие и подготовка
  22. Использование двоеточия в Python
  23. Регистрация на хакатоне
  24. Непрерывная проверка в Python
  25. Импорт модулей в Python 3.12
  26. Обязательные аргументы в Python
  27. Запуск внешних программ с subprocess
  28. Поиск файлов по шаблону
  29. Форматирование данных с pprint
  30. Применение функции к каждому элементу списка
  31. Анализ кода — Python
  32. Измерение времени выполнения кода с помощью time
  33. Метод lt для сортировки объектов
  34. Lambda-функция в Python: использование с map() и sum()
  35. Определение объема памяти объекта
  36. Транспонирование 2D-массива с помощью zip
  37. Метод gt в Python
  38. Шаблоны и наследование в Flask
  39. Сравнение def и lambda функций в Python
  40. Преобразование списков в словарь
  41. Синхронизация доступа к ресурсам
  42. Методы сравнения множеств
  43. Область видимости переменных
  44. Фильтрация данных в Python.

Marketello читают маркетологи из крутых компаний