Курс Python → Тестирование модели в PyTorch

Для того чтобы эффективно оценивать работу нашей модели машинного обучения, необходимо определить метод тестирования. Этот метод позволит нам проверить качество работы модели на тестовом наборе данных и вывести точность предсказаний. Основное отличие метода тестирования от обучения заключается в том, что в процессе тестирования мы используем функцию model.eval(), чтобы перевести модель в режим тестирования. Также важно использовать torch.no_grad(), чтобы отключить вычисление градиента, поскольку во время тестирования обратное распространение не требуется.

Для начала необходимо перевести модель в режим тестирования с помощью функции model.eval(). Это гарантирует, что все слои модели будут работать в режиме тестирования, что может влиять на поведение некоторых слоев, таких как Dropout или BatchNorm. Затем мы используем torch.no_grad(), чтобы временно отключить автоматическое дифференцирование и вычисление градиента. Это позволяет ускорить процесс тестирования, поскольку не нужно хранить градиенты для обновления весов модели.


model.eval()

with torch.no_grad():
    for inputs, labels in test_loader:
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        test_loss += loss.item()
        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

test_accuracy = correct / total

Наконец, после прохождения всех тестовых данных, мы вычисляем средние потери для всего тестового набора и общую точность предсказаний. Это позволяет оценить, насколько хорошо модель обучилась и способна предсказывать значения на новых данных. Результаты тестирования помогут нам понять, какие улучшения можно внести в модель для повышения ее эффективности и точности предсказаний.

Твои коллеги будут рады, поделись в

Автор урока

Дмитрий Комаровский
Дмитрий Комаровский

Автоматизация процессов
в КраснодарБанки.ру

Другие уроки курса "Python"

  1. Функция с **kwargs в Python
  2. Преобразование текста в речь с Python
  3. Вывод переменной и строки в Python
  4. Умножение строк и списков
  5. Декораторы в Python
  6. Условное добавление элементов в список
  7. Работа с итераторами через срезы
  8. JSON-esque в Python
  9. Порядок операций в Python
  10. Взаимодействие с внешними процессами в Python
  11. Создание новых списков
  12. Названия столбцов в Python таблицах
  13. Получение текущего времени в Python
  14. Работа со списками
  15. Работа с срезами в Python
  16. Оператор «not» в Python
  17. Структуры данных в Python
  18. Манипуляция формой массива в Numpy
  19. Функции range() в Python
  20. Работа с CSV файлами в Python
  21. Подсчет элементов с помощью Counter
  22. PUT запрос для обновления данных
  23. Декораторы с аргументами
  24. Перебор элементов списка в Python
  25. Метод radd для пользовательских чисел
  26. Построение графиков в Matplotlib
  27. 9 уловок для чистого кода
  28. Создание инструмента обнаружения плагиата
  29. Скрытие вывода данных
  30. Работа с массивами в Python
  31. Создание и использование ChainMap
  32. Методы split() и join() — Python строк.
  33. Виртуальное окружение Python
  34. Создание словарей и множеств в Python
  35. Экспорт данных в файл.
  36. Удаление ресурса в Python
  37. Методы обработки строк в Python
  38. Преобразование символов в нижний регистр
  39. Генераторы списков в Python
  40. Принципы LSP и ISP в Python
  41. Проверка однородности элементов списка
  42. Аннотации типов в Python
  43. Делегирование в Python
  44. Проверка наличия элемента в списке
  45. Работа с модулем random
  46. Создание веб-приложения с Flask
  47. Комментарии в Python

Marketello читают маркетологи из крутых компаний