Курс Python → Сохранение и загрузка модели в PyTorch

Для сохранения и загрузки модели в PyTorch необходимо использовать методы torch.save() и torch.load(). Для сохранения модели передайте model.state_dict() в качестве первого аргумента, это просто словарь, который содержит информацию о слоях модели и их параметрах (веса и смещения). Вторым аргументом укажите имя файла, в котором будет сохранена модель. Хорошей практикой является использование расширений .pth или .pt для сохранения моделей PyTorch. Также можно указать полный путь к файлу, если вы хотите сохранить модель в определенном каталоге.

Пример сохранения модели:


torch.save(model.state_dict(), "cifar_fc.pth")

Чтобы загрузить сохраненную модель для дальнейшего использования или логического вывода, используйте метод torch.load(). Затем можно загрузить параметры модели с помощью метода load_state_dict(). Это позволит восстановить состояние модели с сохраненными параметрами и продолжить обучение или использование модели для вывода.

Пример загрузки модели:


model = YourModelClass()
model.load_state_dict(torch.load("cifar_fc.pth"))
model.eval()

При загрузке модели убедитесь, что класс модели, для которой загружаются параметры, совпадает с классом модели, которая была сохранена. В противном случае возможны ошибки при загрузке параметров. Также рекомендуется использовать метод model.eval() после загрузки модели, чтобы переключить ее в режим оценки и отключить дополнительные режимы, такие как режим обучения.

Твои коллеги будут рады, поделись в

Автор урока

Дмитрий Комаровский
Дмитрий Комаровский

Автоматизация процессов
в КраснодарБанки.ру

Другие уроки курса "Python"

  1. Работа со случайными элементами
  2. Удаление дубликатов из списка
  3. Генераторы в Python
  4. Частичное совпадение пользовательского ввода в Python 3.10
  5. Создание словарей в Python
  6. Любовь к Python
  7. Философия Python
  8. Генераторные функции в Python
  9. TypedDict для kwargs в Python 3.12
  10. Генераторы в Python
  11. Разделение строки в Python
  12. Операторы увеличения и уменьшения переменной
  13. Классы данных в Python
  14. Виртуальные среды в Python
  15. Работа с эмодзи в Python
  16. Очистка строки в Python
  17. Очистка данных с Pandas
  18. Запрос пароля с помощью getpass
  19. EMOT преобразование эмодзи в текст
  20. Переменная Шредингера
  21. Условное добавление элементов в список
  22. Вывод переменной и строки в Python
  23. Уникальность ключей в словаре
  24. Beautiful Soup — извлечение данных из HTML
  25. Метод index() в Python
  26. Метод setitem в Python
  27. Роль object и type в Python
  28. Поиск уникальных элементов строкой в Python
  29. Python: цикл for и оператор присваивания
  30. Работа с модулем random
  31. Итерации в Python
  32. Открытие и запись файлов
  33. Оператор break в Python
  34. Ускоренный импорт библиотек
  35. Удаление дубликатов из списка с помощью dict.fromkeys
  36. Декораторы с аргументами в Python
  37. Конкатенация строк в Python
  38. Бесконечные списки в Python
  39. Python: отсутствие точек с запятыми
  40. Поиск шаблона в начале строки
  41. Проверка файла .py на синтаксис.
  42. Инверсия списков и строк в Python
  43. Однострочники Python
  44. Работа со словарями
  45. Работа с Enum в Python3.
  46. Переопределение метода __pow__
  47. Встроенные функции Python

Marketello читают маркетологи из крутых компаний