Курс Python → Тестирование модели в PyTorch

Для того чтобы эффективно оценивать работу нашей модели машинного обучения, необходимо определить метод тестирования. Этот метод позволит нам проверить качество работы модели на тестовом наборе данных и вывести точность предсказаний. Основное отличие метода тестирования от обучения заключается в том, что в процессе тестирования мы используем функцию model.eval(), чтобы перевести модель в режим тестирования. Также важно использовать torch.no_grad(), чтобы отключить вычисление градиента, поскольку во время тестирования обратное распространение не требуется.

Для начала необходимо перевести модель в режим тестирования с помощью функции model.eval(). Это гарантирует, что все слои модели будут работать в режиме тестирования, что может влиять на поведение некоторых слоев, таких как Dropout или BatchNorm. Затем мы используем torch.no_grad(), чтобы временно отключить автоматическое дифференцирование и вычисление градиента. Это позволяет ускорить процесс тестирования, поскольку не нужно хранить градиенты для обновления весов модели.


model.eval()

with torch.no_grad():
    for inputs, labels in test_loader:
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        test_loss += loss.item()
        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

test_accuracy = correct / total

Наконец, после прохождения всех тестовых данных, мы вычисляем средние потери для всего тестового набора и общую точность предсказаний. Это позволяет оценить, насколько хорошо модель обучилась и способна предсказывать значения на новых данных. Результаты тестирования помогут нам понять, какие улучшения можно внести в модель для повышения ее эффективности и точности предсказаний.

Твои коллеги будут рады, поделись в

Автор урока

Дмитрий Комаровский
Дмитрий Комаровский

Автоматизация процессов
в КраснодарБанки.ру

Другие уроки курса "Python"

  1. Работа с YAML в Python
  2. Python enumerate() использование
  3. Курс Data Scientist в медицине
  4. CSV строка разделение в Python
  5. Декораторы в Python
  6. Именование переменных в Python
  7. Генератор списка в Python
  8. Измерение времени выполнения кода
  9. Изменения в обработке логических значений
  10. Работа с файлами в Python
  11. Обработка ошибки IndexError
  12. Работа с YAML в Python: PyYAML.
  13. Создание комплексных чисел
  14. Работа с модулем cmath
  15. Преобразование PowerPoint в PDF.
  16. Docstring в Python
  17. Кортеж в Python: создание, доступ, изменение
  18. Метод rmatmul для пользовательских матриц
  19. Переопределение метода
  20. Работа с модулем random
  21. Класс Counter() для подсчета элементов
  22. Проверка элемента в множестве.
  23. Форматирование заголовков в Python
  24. Retrying в Python: повторные вызовы
  25. Печать комбинаций в Python с Itertools
  26. Оценка выражений генератора в Python
  27. Модуль inspect: получение информации о объектах
  28. Деление в Python
  29. Перевод двоичного кода в целое число
  30. Манипуляция формой массива в Numpy
  31. Регулярные выражения: метод match
  32. Работа с функцией next() в Python
  33. Идентификатор объекта в Python
  34. Возврат нескольких значений
  35. Просмотр файла в Jupyter Noteboo
  36. Получение локальных переменных в Python
  37. Модуль pprint
  38. Обработка исключений в Python
  39. Метод clear для коллекций
  40. Работа с кортежами в Python
  41. Функции в одну строку
  42. Создание графиков в терминале
  43. Возврат нескольких значений из функции
  44. Метод rsub для пользовательских чисел
  45. Работа с асинхронными задачами в Python
  46. Оптимизация сравнения в Python
  47. Удаление ключа из словаря

Marketello читают маркетологи из крутых компаний