Курс Python → Тестирование модели в PyTorch

Для того чтобы эффективно оценивать работу нашей модели машинного обучения, необходимо определить метод тестирования. Этот метод позволит нам проверить качество работы модели на тестовом наборе данных и вывести точность предсказаний. Основное отличие метода тестирования от обучения заключается в том, что в процессе тестирования мы используем функцию model.eval(), чтобы перевести модель в режим тестирования. Также важно использовать torch.no_grad(), чтобы отключить вычисление градиента, поскольку во время тестирования обратное распространение не требуется.

Для начала необходимо перевести модель в режим тестирования с помощью функции model.eval(). Это гарантирует, что все слои модели будут работать в режиме тестирования, что может влиять на поведение некоторых слоев, таких как Dropout или BatchNorm. Затем мы используем torch.no_grad(), чтобы временно отключить автоматическое дифференцирование и вычисление градиента. Это позволяет ускорить процесс тестирования, поскольку не нужно хранить градиенты для обновления весов модели.


model.eval()

with torch.no_grad():
    for inputs, labels in test_loader:
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        test_loss += loss.item()
        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

test_accuracy = correct / total

Наконец, после прохождения всех тестовых данных, мы вычисляем средние потери для всего тестового набора и общую точность предсказаний. Это позволяет оценить, насколько хорошо модель обучилась и способна предсказывать значения на новых данных. Результаты тестирования помогут нам понять, какие улучшения можно внести в модель для повышения ее эффективности и точности предсказаний.

Твои коллеги будут рады, поделись в

Автор урока

Дмитрий Комаровский
Дмитрий Комаровский

Автоматизация процессов
в КраснодарБанки.ру

Другие уроки курса "Python"

  1. Работа с IP-адресами в Python
  2. Извлечение статей с newspaper3k
  3. Обработка ошибки IndexError
  4. Управление памятью в Python
  5. Получение идентификатора объекта в памяти
  6. Создание циклической ссылки
  7. Переименование файлов в Python
  8. Список переменных в Python
  9. Взаимодействие с внешними процессами в Python
  10. Функции высшего порядка в Python
  11. Реверс строки в Python
  12. Область видимости переменных
  13. Удаление ключей из словаря
  14. Метод __iand__ для пользовательских классов
  15. Создание файла с проверкой ошибки
  16. Удаление URL-адресов в Python
  17. Сравнение def и lambda функций в Python
  18. Путь к интерпретатору Python
  19. Работа с областями видимости переменных
  20. Функции any() и all() в Python
  21. Python Менеджер контекста
  22. Проверка файла .py на синтаксис.
  23. Удаление элементов во время итерации
  24. Транспонирование 2D-массива с помощью zip
  25. Функции all() и any() в Python
  26. Срезы в Python
  27. Удаление дубликатов из списка с помощью dict.fromkeys
  28. Импорт модулей в Python 3.12
  29. Работа с асинхронными задачами в Python
  30. Concrete Paths в Python
  31. Блок else в обработке исключений
  32. Модуль inspect
  33. Разница между датами
  34. Управление памятью в numpy.
  35. Удаление элементов из списка
  36. Работа с парами ключ-значение
  37. Умножение строк и списков
  38. Парсинг веб-страниц с Beautiful Soup
  39. Метод join() для объединения элементов строки
  40. Создание веб-приложения с Flask
  41. Генератор бросков кубиков
  42. Объединение списков в Python.
  43. Проверка типов с помощью isinstance
  44. Работа с типами данных в Python с помощью pydantic.
  45. Enum в Python: создание и использование перечислений
  46. Расширение информации об ошибке в Python

Marketello читают маркетологи из крутых компаний