Курс Python → Тестирование модели в PyTorch

Для того чтобы эффективно оценивать работу нашей модели машинного обучения, необходимо определить метод тестирования. Этот метод позволит нам проверить качество работы модели на тестовом наборе данных и вывести точность предсказаний. Основное отличие метода тестирования от обучения заключается в том, что в процессе тестирования мы используем функцию model.eval(), чтобы перевести модель в режим тестирования. Также важно использовать torch.no_grad(), чтобы отключить вычисление градиента, поскольку во время тестирования обратное распространение не требуется.

Для начала необходимо перевести модель в режим тестирования с помощью функции model.eval(). Это гарантирует, что все слои модели будут работать в режиме тестирования, что может влиять на поведение некоторых слоев, таких как Dropout или BatchNorm. Затем мы используем torch.no_grad(), чтобы временно отключить автоматическое дифференцирование и вычисление градиента. Это позволяет ускорить процесс тестирования, поскольку не нужно хранить градиенты для обновления весов модели.


model.eval()

with torch.no_grad():
    for inputs, labels in test_loader:
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        test_loss += loss.item()
        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

test_accuracy = correct / total

Наконец, после прохождения всех тестовых данных, мы вычисляем средние потери для всего тестового набора и общую точность предсказаний. Это позволяет оценить, насколько хорошо модель обучилась и способна предсказывать значения на новых данных. Результаты тестирования помогут нам понять, какие улучшения можно внести в модель для повышения ее эффективности и точности предсказаний.

Твои коллеги будут рады, поделись в

Автор урока

Дмитрий Комаровский
Дмитрий Комаровский

Автоматизация процессов
в КраснодарБанки.ру

Другие уроки курса "Python"

  1. Работа с JSON данными в Python
  2. Оператор «моржа» (Walrus Operator)
  3. Обратный список чисел
  4. Декораторы в Python
  5. Блок else в циклах.
  6. Метод append() для списка
  7. Удаление элемента по индексу в Python
  8. Оператор Walrus в Python
  9. Работа с YAML в Python
  10. Удаление элемента по индексу
  11. Создание инструмента обнаружения плагиата
  12. Импорт модулей в Python 3.12
  13. Функция product() в Python
  14. Операции с датами в Python
  15. Переворот строки с использованием цикла
  16. Экспорт функций в Python
  17. Генерация случайных чисел в Python
  18. Модуль xkcd: добавление юмора в Python
  19. Срезы в Python
  20. Форматирование строк в Python
  21. Получение имени функции с помощью inspect
  22. Методы сравнения множеств
  23. Кортежи в Python: особенности и преимущества
  24. Генераторы в Python
  25. Конкатенация строк в Python
  26. Фильтрация данных в Python.
  27. Метод is_absolute() для PurePath
  28. Экспорт данных в файл.
  29. Работа со словарями с defaultdict из collections
  30. Создание словарей с defaultdict
  31. Получение размера объекта с sys.getsizeof()
  32. Оператор обр. импликации
  33. Имена объектов в Python
  34. Создание панели меню Tkinter
  35. Оператор объединения словарей
  36. Оператор += в Python
  37. Обновление данных через PUT запрос
  38. Объединение итераторов
  39. Обработка исключений в Python
  40. Проверка типа данных
  41. Замена подстроки
  42. Списковое включение в Python
  43. Работа с Path в Python
  44. Избегайте двойного подчеркивания
  45. Обход дочерних элементов BeautifulSoup

Marketello читают маркетологи из крутых компаний