Курс Python → Тестирование модели в PyTorch

Для того чтобы эффективно оценивать работу нашей модели машинного обучения, необходимо определить метод тестирования. Этот метод позволит нам проверить качество работы модели на тестовом наборе данных и вывести точность предсказаний. Основное отличие метода тестирования от обучения заключается в том, что в процессе тестирования мы используем функцию model.eval(), чтобы перевести модель в режим тестирования. Также важно использовать torch.no_grad(), чтобы отключить вычисление градиента, поскольку во время тестирования обратное распространение не требуется.

Для начала необходимо перевести модель в режим тестирования с помощью функции model.eval(). Это гарантирует, что все слои модели будут работать в режиме тестирования, что может влиять на поведение некоторых слоев, таких как Dropout или BatchNorm. Затем мы используем torch.no_grad(), чтобы временно отключить автоматическое дифференцирование и вычисление градиента. Это позволяет ускорить процесс тестирования, поскольку не нужно хранить градиенты для обновления весов модели.


model.eval()

with torch.no_grad():
    for inputs, labels in test_loader:
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        test_loss += loss.item()
        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

test_accuracy = correct / total

Наконец, после прохождения всех тестовых данных, мы вычисляем средние потери для всего тестового набора и общую точность предсказаний. Это позволяет оценить, насколько хорошо модель обучилась и способна предсказывать значения на новых данных. Результаты тестирования помогут нам понять, какие улучшения можно внести в модель для повышения ее эффективности и точности предсказаний.

Твои коллеги будут рады, поделись в

Автор урока

Дмитрий Комаровский
Дмитрий Комаровский

Автоматизация процессов
в КраснодарБанки.ру

Другие уроки курса "Python"

  1. Использование модуля __future__
  2. Генераторы списков
  3. Переопределение метода __floordiv__
  4. Проверка элемента в множестве.
  5. Перевод двоичного кода в целое число
  6. Безопасный доступ к значениям словаря
  7. Конкатенация строк с join() в Python
  8. Методы в Python
  9. Функция zip() — объединение последовательностей
  10. Декоратор Ajax required
  11. HTTP-запросы с библиотекой Requests
  12. Просмотр внешнего файла в Python
  13. Функция product() из itertools
  14. Считывание бинарного файла в Python
  15. Оператор умножения для вектора
  16. Класс Counter() для подсчета элементов
  17. Декораторы с аргументами в Python
  18. Очистка данных с Pandas
  19. Подсчет часто встречающихся элементов
  20. Проверка переменных окружения в Python
  21. Оптимизация сравнения в Python
  22. Замена элементов в списке с помощью генераторов списков
  23. Генерация фальшивых данных с Faker
  24. Python: библиотеки и функции
  25. Фильтрация списка чисел
  26. Руководство по библиотеке pydantic
  27. Работа с SQLite в Python
  28. Проверка условий: all и any
  29. Экспорт функций в Python
  30. Метаклассы в Python
  31. Переопределение метода __and__
  32. Искажение имен в Python
  33. Удаление ссылок в Python
  34. Создание списка через цикл
  35. Методы работы со списками
  36. Метод __ilshift__ для битового сдвига влево
  37. Различия символов в Python
  38. Измерение времени выполнения кода
  39. Работа с файлами в Python
  40. Итерация по копии коллекции
  41. Поиск частого элемента
  42. Функция enumerate в Python
  43. Извлечение статей с newspaper3k
  44. Работа с NumPy массивами
  45. Аннотации типов в Python
  46. Работа с атрибутом dict
  47. Декораторы классов
  48. Импорт модулей в Python 3.12

Marketello читают маркетологи из крутых компаний