Курс Python → Тестирование модели в PyTorch

Для того чтобы эффективно оценивать работу нашей модели машинного обучения, необходимо определить метод тестирования. Этот метод позволит нам проверить качество работы модели на тестовом наборе данных и вывести точность предсказаний. Основное отличие метода тестирования от обучения заключается в том, что в процессе тестирования мы используем функцию model.eval(), чтобы перевести модель в режим тестирования. Также важно использовать torch.no_grad(), чтобы отключить вычисление градиента, поскольку во время тестирования обратное распространение не требуется.

Для начала необходимо перевести модель в режим тестирования с помощью функции model.eval(). Это гарантирует, что все слои модели будут работать в режиме тестирования, что может влиять на поведение некоторых слоев, таких как Dropout или BatchNorm. Затем мы используем torch.no_grad(), чтобы временно отключить автоматическое дифференцирование и вычисление градиента. Это позволяет ускорить процесс тестирования, поскольку не нужно хранить градиенты для обновления весов модели.


model.eval()

with torch.no_grad():
    for inputs, labels in test_loader:
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        test_loss += loss.item()
        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

test_accuracy = correct / total

Наконец, после прохождения всех тестовых данных, мы вычисляем средние потери для всего тестового набора и общую точность предсказаний. Это позволяет оценить, насколько хорошо модель обучилась и способна предсказывать значения на новых данных. Результаты тестирования помогут нам понять, какие улучшения можно внести в модель для повышения ее эффективности и точности предсказаний.

Твои коллеги будут рады, поделись в

Автор урока

Дмитрий Комаровский
Дмитрий Комаровский

Автоматизация процессов
в КраснодарБанки.ру

Другие уроки курса "Python"

  1. Работа с функцией next() в Python
  2. Ветвление выражения в Python
  3. Объединение кортежей в Python
  4. Дизассемблирование Python кода
  5. Распаковка с оператором *
  6. Генераторы в Python
  7. Инициализация структур данных
  8. Декораторы с аргументами
  9. Сортировка элементов с OrderedDict
  10. Оператор морж в Python 3.8
  11. Метод clear для коллекций
  12. Счетчик в Python: most_common()
  13. Функции any() и all() в Python
  14. Метод join() для объединения элементов строки
  15. Профилирование кода на Python
  16. Поиск email
  17. JSON-esque в Python
  18. Правила именования переменных
  19. Показ всплывающих окон Tkinter
  20. Преобразование символов в нижний регистр
  21. Работа с комплексными числами в Python
  22. Пропуск строк в файле с itertools
  23. Анализ кода — Python
  24. Работа с путями в Python
  25. Работа с массивами в Python
  26. Карта бомбоубежищ в Москве и Питере
  27. Форматирование данных с помощью pprint
  28. Удаление URL-адресов в Python
  29. Функции map() и reduce() в Python
  30. Оператор match в Python
  31. Сравнение неупорядоченных списков
  32. Операторы увеличения и уменьшения переменной
  33. Удаление элементов из списка в Python.
  34. Numpy: объединение массивов
  35. Проверка элементов списка условием
  36. Комментарии в Python.
  37. Округление дробей в Python
  38. Проверка надежности пароля на Python
  39. Разделение строки с помощью re.split()
  40. Проверка строки на палиндром
  41. Проверка на палиндром
  42. Отладка производительности Python
  43. Декораторы в Python
  44. Округление банкира в Python
  45. Создание панели меню Tkinter

Marketello читают маркетологи из крутых компаний