Курс Python → Тестирование модели в PyTorch

Для того чтобы эффективно оценивать работу нашей модели машинного обучения, необходимо определить метод тестирования. Этот метод позволит нам проверить качество работы модели на тестовом наборе данных и вывести точность предсказаний. Основное отличие метода тестирования от обучения заключается в том, что в процессе тестирования мы используем функцию model.eval(), чтобы перевести модель в режим тестирования. Также важно использовать torch.no_grad(), чтобы отключить вычисление градиента, поскольку во время тестирования обратное распространение не требуется.

Для начала необходимо перевести модель в режим тестирования с помощью функции model.eval(). Это гарантирует, что все слои модели будут работать в режиме тестирования, что может влиять на поведение некоторых слоев, таких как Dropout или BatchNorm. Затем мы используем torch.no_grad(), чтобы временно отключить автоматическое дифференцирование и вычисление градиента. Это позволяет ускорить процесс тестирования, поскольку не нужно хранить градиенты для обновления весов модели.


model.eval()

with torch.no_grad():
    for inputs, labels in test_loader:
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        test_loss += loss.item()
        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

test_accuracy = correct / total

Наконец, после прохождения всех тестовых данных, мы вычисляем средние потери для всего тестового набора и общую точность предсказаний. Это позволяет оценить, насколько хорошо модель обучилась и способна предсказывать значения на новых данных. Результаты тестирования помогут нам понять, какие улучшения можно внести в модель для повышения ее эффективности и точности предсказаний.

Твои коллеги будут рады, поделись в

Автор урока

Дмитрий Комаровский
Дмитрий Комаровский

Автоматизация процессов
в КраснодарБанки.ру

Другие уроки курса "Python"

  1. Работа со строками
  2. Создание и инициализация объектов
  3. Вычисление времени выполнения
  4. Метод lt для сортировки объектов
  5. Хеширование паролей с солью
  6. Генераторы в Python
  7. Разработка игры Pong с turtle
  8. Печать списка с помощью метода join
  9. Метод remove() для удаления элемента из списка
  10. Объединение списков в Python
  11. Мониторинг памяти с Pympler
  12. Получение текущего времени в Python
  13. Получение списка кортежей из словаря
  14. Списковое включение в Python
  15. Удаление файлов с shutil.os.remove()
  16. Стать Python-разработчиком
  17. Контекстный менеджер в Python
  18. Оператор морж в Python 3.8
  19. Переопределение метода divmod
  20. Переворот списка в Python
  21. Разделение строки с помощью re.split()
  22. Оператор is в Python
  23. Применение функции к элементам списка
  24. Удаление элементов во время итерации
  25. Тернарный оператор в Python
  26. Изменения в обработке логических значений
  27. Изменение списка срезами
  28. Оператор in для Python
  29. Математические функции в Python
  30. Библиотека wikipedia для Python
  31. Оператор «or» в Python
  32. Однострочники Python
  33. Тайное преобразование типа ключа
  34. Расчет времени выполнения
  35. Проблемы с dict в Python
  36. Нахождение самого длинного слова в списке с помощью max
  37. Работа со словарями в Python
  38. Преобразование списка в словарь через генератор
  39. Асинхронное выполнение задач в процессах
  40. Декораторы в Python
  41. Скрытие вывода данных
  42. Оператор * в Python
  43. Обновление данных через PUT запрос
  44. Функции классификации комплексных чисел

Marketello читают маркетологи из крутых компаний