Курс Python → Сохранение и загрузка модели в PyTorch

Для сохранения и загрузки модели в PyTorch необходимо использовать методы torch.save() и torch.load(). Для сохранения модели передайте model.state_dict() в качестве первого аргумента, это просто словарь, который содержит информацию о слоях модели и их параметрах (веса и смещения). Вторым аргументом укажите имя файла, в котором будет сохранена модель. Хорошей практикой является использование расширений .pth или .pt для сохранения моделей PyTorch. Также можно указать полный путь к файлу, если вы хотите сохранить модель в определенном каталоге.

Пример сохранения модели:


torch.save(model.state_dict(), "cifar_fc.pth")

Чтобы загрузить сохраненную модель для дальнейшего использования или логического вывода, используйте метод torch.load(). Затем можно загрузить параметры модели с помощью метода load_state_dict(). Это позволит восстановить состояние модели с сохраненными параметрами и продолжить обучение или использование модели для вывода.

Пример загрузки модели:


model = YourModelClass()
model.load_state_dict(torch.load("cifar_fc.pth"))
model.eval()

При загрузке модели убедитесь, что класс модели, для которой загружаются параметры, совпадает с классом модели, которая была сохранена. В противном случае возможны ошибки при загрузке параметров. Также рекомендуется использовать метод model.eval() после загрузки модели, чтобы переключить ее в режим оценки и отключить дополнительные режимы, такие как режим обучения.

Твои коллеги будут рады, поделись в

Автор урока

Дмитрий Комаровский
Дмитрий Комаровский

Автоматизация процессов
в КраснодарБанки.ру

Другие уроки курса "Python"

  1. Игра Виселица на Python
  2. Импорт с альтернативным именем
  3. Генераторы в Python
  4. Имена объектов в Python
  5. Классы данных в Python
  6. Закрытие файла в Python
  7. Переворот списка в Python
  8. Установка и загрузка Instaloader
  9. Функции map, filter и reduce
  10. Удаление элемента по индексу в Python
  11. Инициализация переменных
  12. Работа с CSV в Python
  13. Фильтрация входных данных в Python
  14. enumerate() в Python для работы с индексами
  15. Создание новых списков в Python
  16. Работа с модулем cmath
  17. Установка и использование Logzero
  18. Список методов и атрибутов
  19. Перевод текста с Python Translator
  20. Сокращение ссылок с pyshorteners
  21. Выражения-генераторы в Python
  22. Структура данных словарь в Python
  23. Оператор in в Python
  24. Разделение строк в Python
  25. Многоточие в Python
  26. Генерация тестовых данных с factory_boy
  27. Проверка условий в Python
  28. Взаимодействие с внешними процессами в Python
  29. Работа с дробями в Python
  30. Лямбда-функции в Python
  31. Обработка исключений в Python
  32. Сумма элементов списка
  33. Операторы увеличения и уменьшения в Python
  34. Получение ID процесса
  35. Удаление дубликатов с сохранением порядка с помощью dict.fromkeys
  36. Проверка условий: all и any
  37. Проверка типов с использованием isinstance
  38. Преобразование чисел в слова
  39. Цикл for с enumerate() в Python
  40. Определение локальных переменных в Python
  41. Руководство по библиотеке pydantic
  42. Регистрация на курсы SF Education
  43. Работа с NumPy массивами
  44. Оболочка Python
  45. Модуль math: константы π и e

Marketello читают маркетологи из крутых компаний