Курс Python → Тестирование модели в PyTorch

Для того чтобы эффективно оценивать работу нашей модели машинного обучения, необходимо определить метод тестирования. Этот метод позволит нам проверить качество работы модели на тестовом наборе данных и вывести точность предсказаний. Основное отличие метода тестирования от обучения заключается в том, что в процессе тестирования мы используем функцию model.eval(), чтобы перевести модель в режим тестирования. Также важно использовать torch.no_grad(), чтобы отключить вычисление градиента, поскольку во время тестирования обратное распространение не требуется.

Для начала необходимо перевести модель в режим тестирования с помощью функции model.eval(). Это гарантирует, что все слои модели будут работать в режиме тестирования, что может влиять на поведение некоторых слоев, таких как Dropout или BatchNorm. Затем мы используем torch.no_grad(), чтобы временно отключить автоматическое дифференцирование и вычисление градиента. Это позволяет ускорить процесс тестирования, поскольку не нужно хранить градиенты для обновления весов модели.


model.eval()

with torch.no_grad():
    for inputs, labels in test_loader:
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        test_loss += loss.item()
        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

test_accuracy = correct / total

Наконец, после прохождения всех тестовых данных, мы вычисляем средние потери для всего тестового набора и общую точность предсказаний. Это позволяет оценить, насколько хорошо модель обучилась и способна предсказывать значения на новых данных. Результаты тестирования помогут нам понять, какие улучшения можно внести в модель для повышения ее эффективности и точности предсказаний.

Твои коллеги будут рады, поделись в

Автор урока

Дмитрий Комаровский
Дмитрий Комаровский

Автоматизация процессов
в КраснодарБанки.ру

Другие уроки курса "Python"

  1. Активация Matplotlib в Jupyter
  2. Работа с базами данных SQLite
  3. Многопоточность в Python
  4. Метод join() для объединения элементов строки
  5. Настройка Cron
  6. Объединение множеств в Python
  7. Декоратор Property в Python
  8. Профилирование данных с Pandas
  9. Работа с каталогами в Python
  10. Bootle — простой веб-фреймворк
  11. Метод split() в Python
  12. Область видимости переменных в Python
  13. Утечки переменных цикла в Python 3.x
  14. Подписка на SelectelNews в Twitter
  15. Основные методы NumPy
  16. Поиск кода
  17. Тип CodeType в Python.
  18. Получение имени функции с помощью inspect
  19. Обновление ключей в Python
  20. Модуль Operator в Python
  21. Блок else в циклах.
  22. Pretty-printing JSON в Python
  23. Декораторы с аргументами в Python
  24. Генераторы списков
  25. Структурирование данных с Pydantic
  26. Курсы Яндекс Практикум
  27. Установка пакета в Python
  28. Генераторы и сеты в Python
  29. Нахождение самого длинного слова в списке с помощью max
  30. Инициализация структур данных
  31. Модуль xkcd: добавление юмора в Python
  32. Объединение словарей в Python
  33. Нахождение разницы между списками в Python
  34. Отображение графиков в Jupyter с Matplotlib
  35. Упрощенный вывод данных в Python
  36. Тестирование времени с Freezegun
  37. Тернарный оператор в Python
  38. Отладка в командной строке
  39. Резервирование символов в Python
  40. Метод ne для сравнения объектов
  41. Добавление Progressbar в Python
  42. Автоматизация действий с Pyautogui
  43. Обработка ошибок в Python
  44. Создание циклической ссылки
  45. Обрезка изображения с Pillow
  46. Запуск файлового сервера
  47. Аргумент по умолчанию
  48. Python reversed() функция

Marketello читают маркетологи из крутых компаний