Курс Python → Тестирование модели в PyTorch

Для того чтобы эффективно оценивать работу нашей модели машинного обучения, необходимо определить метод тестирования. Этот метод позволит нам проверить качество работы модели на тестовом наборе данных и вывести точность предсказаний. Основное отличие метода тестирования от обучения заключается в том, что в процессе тестирования мы используем функцию model.eval(), чтобы перевести модель в режим тестирования. Также важно использовать torch.no_grad(), чтобы отключить вычисление градиента, поскольку во время тестирования обратное распространение не требуется.

Для начала необходимо перевести модель в режим тестирования с помощью функции model.eval(). Это гарантирует, что все слои модели будут работать в режиме тестирования, что может влиять на поведение некоторых слоев, таких как Dropout или BatchNorm. Затем мы используем torch.no_grad(), чтобы временно отключить автоматическое дифференцирование и вычисление градиента. Это позволяет ускорить процесс тестирования, поскольку не нужно хранить градиенты для обновления весов модели.


model.eval()

with torch.no_grad():
    for inputs, labels in test_loader:
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        test_loss += loss.item()
        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

test_accuracy = correct / total

Наконец, после прохождения всех тестовых данных, мы вычисляем средние потери для всего тестового набора и общую точность предсказаний. Это позволяет оценить, насколько хорошо модель обучилась и способна предсказывать значения на новых данных. Результаты тестирования помогут нам понять, какие улучшения можно внести в модель для повышения ее эффективности и точности предсказаний.

Твои коллеги будут рады, поделись в

Автор урока

Дмитрий Комаровский
Дмитрий Комаровский

Автоматизация процессов
в КраснодарБанки.ру

Другие уроки курса "Python"

  1. Метод radd для пользовательских чисел
  2. Хеширование паролей с использованием salt
  3. Логический оператор «and» в Python
  4. Многоточие в Python
  5. Вычисление натурального логарифма в NumPy
  6. Запрос пароля с помощью getpass
  7. Функция findall() для поиска вхождений строки
  8. Настройка нарезки списков
  9. Измерение времени выполнения кода
  10. Библиотека Rich: форматирование текста
  11. Установка и использование pyshorteners
  12. Операторы увеличения и уменьшения переменной
  13. Цепные операции в Python
  14. Работа с getopt
  15. Списки в Python: синтаксис представления
  16. Создание словарей с defaultdict()
  17. Работа с JSON в Python
  18. Получение срезов итераторов
  19. Закрытие файла в Python
  20. Переопределение метода __floordiv__
  21. Сравнение объектов в Python
  22. Функции в одну строку
  23. Метод rlshift для битового сдвига
  24. Перезапуск ячейки в Jupyter Notebook с dostoevsky
  25. Многопроцессорное программирование в Python
  26. Цикл for в Python
  27. Функция count() в Python
  28. Установка и использование howdoi
  29. Инверсия списка/строки в Python
  30. Сравнение строк в Python
  31. Работа с комбинациями в Python.
  32. PATCH-запрос с библиотекой requests
  33. Создание и удаление объектов
  34. Кортеж в Python: создание и использование
  35. Удаление URL-адресов в Python
  36. Обработка ошибок в Python
  37. Конкатенация строк в Python
  38. JSON в Python: модуль, dump, dumps, load
  39. Python enumerate() функции
  40. Преобразование числа в список цифр
  41. Создание итератора
  42. Обучение модели с указанием эпох
  43. Конкатенация строк с помощью join()

Marketello читают маркетологи из крутых компаний