Курс Python → Тестирование модели в PyTorch

Для того чтобы эффективно оценивать работу нашей модели машинного обучения, необходимо определить метод тестирования. Этот метод позволит нам проверить качество работы модели на тестовом наборе данных и вывести точность предсказаний. Основное отличие метода тестирования от обучения заключается в том, что в процессе тестирования мы используем функцию model.eval(), чтобы перевести модель в режим тестирования. Также важно использовать torch.no_grad(), чтобы отключить вычисление градиента, поскольку во время тестирования обратное распространение не требуется.

Для начала необходимо перевести модель в режим тестирования с помощью функции model.eval(). Это гарантирует, что все слои модели будут работать в режиме тестирования, что может влиять на поведение некоторых слоев, таких как Dropout или BatchNorm. Затем мы используем torch.no_grad(), чтобы временно отключить автоматическое дифференцирование и вычисление градиента. Это позволяет ускорить процесс тестирования, поскольку не нужно хранить градиенты для обновления весов модели.


model.eval()

with torch.no_grad():
    for inputs, labels in test_loader:
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        test_loss += loss.item()
        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

test_accuracy = correct / total

Наконец, после прохождения всех тестовых данных, мы вычисляем средние потери для всего тестового набора и общую точность предсказаний. Это позволяет оценить, насколько хорошо модель обучилась и способна предсказывать значения на новых данных. Результаты тестирования помогут нам понять, какие улучшения можно внести в модель для повышения ее эффективности и точности предсказаний.

Твои коллеги будут рады, поделись в

Автор урока

Дмитрий Комаровский
Дмитрий Комаровский

Автоматизация процессов
в КраснодарБанки.ру

Другие уроки курса "Python"

  1. Кортеж в Python: создание, доступ, изменение
  2. Переопределение метода __rshift__
  3. Установка пакетов с помощью pip
  4. Метод count в Python: почему count(», ») возвращает 4?
  5. Структурирование именованных констант
  6. Работа с аргументами командной строки в Python
  7. Приближение чисел в Python
  8. Поиск самого частого элемента
  9. Именованные кортежи в Python
  10. ChainMap.new_child() — добавление нового словаря
  11. Генерация QR-кодов с библиотекой qrcode
  12. Вложенные генераторы в Python
  13. Работа с необработанными строками
  14. Передача словаря через **kwargs
  15. Enum в Python
  16. Извлечение чисел из текста
  17. Изменения в обработке логических значений
  18. Определение наиболее частого элемента с помощью collections.Counter
  19. Оператор continue в Python
  20. Переопределение оператора % для объектов
  21. Избегайте ошибку FileNotFoundError
  22. Генераторы в Python
  23. Создание словарей с defaultdict
  24. Гибкие функции Python
  25. Нахождение пересечения множеств
  26. Логирование с Logzero
  27. Создание namedtuple из словаря
  28. Генераторные выражения и islice.
  29. Метод get() для словарей
  30. Метод rxor для операции побитового исключающего «или»
  31. Мониторинг работы программы Py-spy
  32. Добавление цвета в консоли
  33. Поиск email
  34. Метод count() для списков
  35. Создание функций с произвольным количеством аргументов
  36. Colorama: окрашивание текста в Python
  37. Переопределение метода __eq__
  38. Списки в Python: основы
  39. Использование метода lower()
  40. Подсчет элементов в Python
  41. Установка и использование библиотеки google
  42. Метод __ixor__ для побитового исключающего ИЛИ
  43. Взаимодействие с sys
  44. Numpy: объединение массивов
  45. Тернарный оператор в Python

Marketello читают маркетологи из крутых компаний