Курс Python → Сохранение и загрузка модели в PyTorch

Для сохранения и загрузки модели в PyTorch необходимо использовать методы torch.save() и torch.load(). Для сохранения модели передайте model.state_dict() в качестве первого аргумента, это просто словарь, который содержит информацию о слоях модели и их параметрах (веса и смещения). Вторым аргументом укажите имя файла, в котором будет сохранена модель. Хорошей практикой является использование расширений .pth или .pt для сохранения моделей PyTorch. Также можно указать полный путь к файлу, если вы хотите сохранить модель в определенном каталоге.

Пример сохранения модели:


torch.save(model.state_dict(), "cifar_fc.pth")

Чтобы загрузить сохраненную модель для дальнейшего использования или логического вывода, используйте метод torch.load(). Затем можно загрузить параметры модели с помощью метода load_state_dict(). Это позволит восстановить состояние модели с сохраненными параметрами и продолжить обучение или использование модели для вывода.

Пример загрузки модели:


model = YourModelClass()
model.load_state_dict(torch.load("cifar_fc.pth"))
model.eval()

При загрузке модели убедитесь, что класс модели, для которой загружаются параметры, совпадает с классом модели, которая была сохранена. В противном случае возможны ошибки при загрузке параметров. Также рекомендуется использовать метод model.eval() после загрузки модели, чтобы переключить ее в режим оценки и отключить дополнительные режимы, такие как режим обучения.

Твои коллеги будут рады, поделись в

Автор урока

Дмитрий Комаровский
Дмитрий Комаровский

Автоматизация процессов
в КраснодарБанки.ру

Другие уроки курса "Python"

  1. Преобразование регистра символов
  2. Операции с комплексными числами
  3. Python: Фильтрация списков с помощью filter()
  4. Antigravity модуль
  5. Замена текста с помощью sub
  6. Управление импортом в Python
  7. Howdoi — получение ответов из терминала
  8. Основные функции и модули Python
  9. Создание GUI с Tkinter: Entry
  10. Удаление символа из строки
  11. Обновление множества в Python
  12. Частичное совпадение ввода
  13. Хранение переменных в словаре.
  14. Функции в Python
  15. Асинхронное программирование с asyncio
  16. Лямбда-функции в Python
  17. Циклы в Python
  18. Управление мышью и клавиатурой с Pyautogui
  19. Создание словарей и множеств в Python
  20. Проверка строки на палиндром
  21. Отладчик pdb: начало работы
  22. Создание класса очереди
  23. Работа с географическими данными в Python
  24. Big O оптимизация
  25. Работа с файлами в Python
  26. Переопределение метода __lshift__
  27. Создание словарей с defaultdict()
  28. Работа с getopt
  29. Генератор списка с условием if
  30. Создание пользовательской коллекции в Python
  31. Хеши в Python
  32. Объединение множеств в Python
  33. Оператор in для Python
  34. Оператор @ для умножения матриц
  35. Считывание бинарного файла в Python
  36. Декораторы в Python
  37. Метод ior для битовых операций
  38. Библиотека Emoji: использование смайлов в Python
  39. Работа с enumerate()
  40. Работа с модулем random
  41. Иерархия классов в Python
  42. Анонимные функции Lambda
  43. Выборка чисел
  44. Открытие и редактирование скриптов Python
  45. Декоратор защиты анонимных пользователей
  46. Python Метод sleep() из time
  47. SciPy: широкий функционал для математических операций

Marketello читают маркетологи из крутых компаний