Курс Python → Тестирование модели в PyTorch

Для того чтобы эффективно оценивать работу нашей модели машинного обучения, необходимо определить метод тестирования. Этот метод позволит нам проверить качество работы модели на тестовом наборе данных и вывести точность предсказаний. Основное отличие метода тестирования от обучения заключается в том, что в процессе тестирования мы используем функцию model.eval(), чтобы перевести модель в режим тестирования. Также важно использовать torch.no_grad(), чтобы отключить вычисление градиента, поскольку во время тестирования обратное распространение не требуется.

Для начала необходимо перевести модель в режим тестирования с помощью функции model.eval(). Это гарантирует, что все слои модели будут работать в режиме тестирования, что может влиять на поведение некоторых слоев, таких как Dropout или BatchNorm. Затем мы используем torch.no_grad(), чтобы временно отключить автоматическое дифференцирование и вычисление градиента. Это позволяет ускорить процесс тестирования, поскольку не нужно хранить градиенты для обновления весов модели.


model.eval()

with torch.no_grad():
    for inputs, labels in test_loader:
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        test_loss += loss.item()
        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

test_accuracy = correct / total

Наконец, после прохождения всех тестовых данных, мы вычисляем средние потери для всего тестового набора и общую точность предсказаний. Это позволяет оценить, насколько хорошо модель обучилась и способна предсказывать значения на новых данных. Результаты тестирования помогут нам понять, какие улучшения можно внести в модель для повышения ее эффективности и точности предсказаний.

Твои коллеги будут рады, поделись в

Автор урока

Дмитрий Комаровский
Дмитрий Комаровский

Автоматизация процессов
в КраснодарБанки.ру

Другие уроки курса "Python"

  1. Метод rxor для операции побитового исключающего «или»
  2. Сортировка в Python
  3. Работа с deque из collections
  4. Генераторы словарей и множеств
  5. Работа с множествами в Python
  6. Работа со строками в Python
  7. Использование html-скриптов в Jupyter Notebook
  8. Выход из профиля в Django
  9. Генераторы в Python
  10. Просмотр внешних файлов в %pycat
  11. Оператор zip в Python
  12. Добавление элемента в список.
  13. Объединение, распаковка и деструктуризация
  14. Создание списков в Python
  15. Работа с изменяемыми коллекциями
  16. Перевод эмодзи и эмотиконов.
  17. Переопределение метода __and__
  18. Добавление элементов в список
  19. Порядок и длина множеств в Python
  20. Создание класса очереди
  21. Функции высшего порядка в Python
  22. Установка и использование библиотеки google
  23. Получение текущей даты в Python
  24. Возврат нескольких значений
  25. Передача аргументов в Python
  26. Замена атрибута в именованном кортеже
  27. Декораторы в Python
  28. Объединение строк с помощью метода join
  29. Проверка типа данных
  30. Оператор деления для класса Rational
  31. Метод __ixor__ для побитового исключающего ИЛИ
  32. Создание детектора плагиата
  33. Простой калькулятор Python
  34. Объединение словарей в Python 3.5+
  35. Списковые включения в Python
  36. Итераторы в Python
  37. Работа с getopt
  38. Работа с zip()
  39. Функция pow() — возвести число в степень
  40. Поиск наиболее частого элемента списке
  41. Работа с эмодзи в Python
  42. Python 3.12: Псевдонимы типов
  43. Работа с Event() в threading
  44. Python reversed() vs срез[::-1]
  45. Методы и функции в Python
  46. Склеивание строк через метод join()
  47. Списки в Python: синтаксис представления

Marketello читают маркетологи из крутых компаний