Курс Python → Тестирование модели в PyTorch

Для того чтобы эффективно оценивать работу нашей модели машинного обучения, необходимо определить метод тестирования. Этот метод позволит нам проверить качество работы модели на тестовом наборе данных и вывести точность предсказаний. Основное отличие метода тестирования от обучения заключается в том, что в процессе тестирования мы используем функцию model.eval(), чтобы перевести модель в режим тестирования. Также важно использовать torch.no_grad(), чтобы отключить вычисление градиента, поскольку во время тестирования обратное распространение не требуется.

Для начала необходимо перевести модель в режим тестирования с помощью функции model.eval(). Это гарантирует, что все слои модели будут работать в режиме тестирования, что может влиять на поведение некоторых слоев, таких как Dropout или BatchNorm. Затем мы используем torch.no_grad(), чтобы временно отключить автоматическое дифференцирование и вычисление градиента. Это позволяет ускорить процесс тестирования, поскольку не нужно хранить градиенты для обновления весов модели.


model.eval()

with torch.no_grad():
    for inputs, labels in test_loader:
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        test_loss += loss.item()
        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

test_accuracy = correct / total

Наконец, после прохождения всех тестовых данных, мы вычисляем средние потери для всего тестового набора и общую точность предсказаний. Это позволяет оценить, насколько хорошо модель обучилась и способна предсказывать значения на новых данных. Результаты тестирования помогут нам понять, какие улучшения можно внести в модель для повышения ее эффективности и точности предсказаний.

Твои коллеги будут рады, поделись в

Автор урока

Дмитрий Комаровский
Дмитрий Комаровский

Автоматизация процессов
в КраснодарБанки.ру

Другие уроки курса "Python"

  1. ChainMap избыточные ключи
  2. Конвертация коллекций в Python.
  3. Методы list в Python
  4. Повторение элементов в Python
  5. Частичное применение функций в Python
  6. Работа с индексами списков
  7. Частичное совпадение пользовательского ввода в Python 3.10
  8. Combobox в Tkinter
  9. Определение наиболее частого элемента с помощью collections.Counter
  10. Именование столбцов в Python с pandas
  11. Изменение логики работы с временем
  12. Чтение и запись TOML-конфигов
  13. Печать в одной строке
  14. Модуль itertools: комбинации и перестановки
  15. Работа с модулем glob в Python
  16. Python: изменяемые и неизменяемые коллекции
  17. Наследование в программировании
  18. GitHub в Telegram: подписка на уведомления
  19. Нан-рефлексивность в Python
  20. Срезы в Python
  21. Операции с датами в Python
  22. Функция enumerate в Python
  23. Функция zip() — объединение последовательностей
  24. Улучшение читаемости кода в Python
  25. Генераторы в Python
  26. Инвертирование словаря
  27. Установка и использование Python-dateutil
  28. Копирование и вставка текста в Python
  29. Автоматизация действий с Pyautogui
  30. Работа с аргументами командной строки
  31. Раздувающийся словарь в Python
  32. Подсказки типов в Python
  33. Работа с CSV файлами
  34. Управление сессиями в Python
  35. Работа с CSV файлами в Python
  36. Форматирование чисел в Python
  37. Функция enumerate() в Python
  38. Повторение элементов в Python
  39. Определение функций с необязательными аргументами
  40. Принципы Zen Python
  41. Генераторы в Python
  42. Объединение словарей в Python
  43. Функция format() в Python
  44. Модуль xkcd: добавление юмора в Python
  45. Оптимизация гиперпараметров в Python
  46. Парсинг статей с Newspaper3k

Marketello читают маркетологи из крутых компаний