Курс Python → Тестирование модели в PyTorch

Для того чтобы эффективно оценивать работу нашей модели машинного обучения, необходимо определить метод тестирования. Этот метод позволит нам проверить качество работы модели на тестовом наборе данных и вывести точность предсказаний. Основное отличие метода тестирования от обучения заключается в том, что в процессе тестирования мы используем функцию model.eval(), чтобы перевести модель в режим тестирования. Также важно использовать torch.no_grad(), чтобы отключить вычисление градиента, поскольку во время тестирования обратное распространение не требуется.

Для начала необходимо перевести модель в режим тестирования с помощью функции model.eval(). Это гарантирует, что все слои модели будут работать в режиме тестирования, что может влиять на поведение некоторых слоев, таких как Dropout или BatchNorm. Затем мы используем torch.no_grad(), чтобы временно отключить автоматическое дифференцирование и вычисление градиента. Это позволяет ускорить процесс тестирования, поскольку не нужно хранить градиенты для обновления весов модели.


model.eval()

with torch.no_grad():
    for inputs, labels in test_loader:
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        test_loss += loss.item()
        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

test_accuracy = correct / total

Наконец, после прохождения всех тестовых данных, мы вычисляем средние потери для всего тестового набора и общую точность предсказаний. Это позволяет оценить, насколько хорошо модель обучилась и способна предсказывать значения на новых данных. Результаты тестирования помогут нам понять, какие улучшения можно внести в модель для повышения ее эффективности и точности предсказаний.

Твои коллеги будут рады, поделись в

Автор урока

Дмитрий Комаровский
Дмитрий Комаровский

Автоматизация процессов
в КраснодарБанки.ру

Другие уроки курса "Python"

  1. Аннотации типов в Python
  2. Оператор in в Python
  3. UserString в Python
  4. Логические операторы в Python
  5. f-строки в формате строк
  6. Оператор «not» в Python
  7. Создание новых списков
  8. Установка random seed в Python
  9. Поиск подстроки в строке
  10. Добавление элементов в список
  11. Работа со слайсами
  12. Печать календаря
  13. Генерация случайных чисел Python
  14. Использование функции enumerate()
  15. Функции в Python: создание и вызов
  16. Модуль math: константы π и e
  17. Метод join для наборов
  18. Форматирование заголовков в Python
  19. Генерация случайных чисел в Python
  20. Функция zip() в Python
  21. Оператор is в Python
  22. Создание OrderedDict
  23. Гибкие функции Python
  24. Python Calendar Usage
  25. Создание копии итератора
  26. Метод enumerate() в Python
  27. Метод count в Python: почему count(», ») возвращает 4?
  28. Правила именования переменных
  29. Встроенные функции Python
  30. Абстракции словарей и множеств в Python
  31. Метод __imod__ для Python
  32. Работа с кортежами
  33. Howdoi — получение ответов из терминала
  34. Каналы Senior: Python, Java, Frontend, SQL, C++
  35. Python Метод sleep() времени
  36. Распаковка элементов массива
  37. Модуль pprint
  38. Обмен переменными в Jupyter
  39. Оператор деления для класса Rational
  40. Экспорт данных с помощью writefile
  41. Добавление элемента в список.
  42. Оператор match в Python
  43. Метод add для класса Vector
  44. Частичное применение функций в Python
  45. Основы Python
  46. Функция zip() — объединение последовательностей
  47. Создание итератора

Marketello читают маркетологи из крутых компаний