Курс Python → Сохранение и загрузка модели в PyTorch

Для сохранения и загрузки модели в PyTorch необходимо использовать методы torch.save() и torch.load(). Для сохранения модели передайте model.state_dict() в качестве первого аргумента, это просто словарь, который содержит информацию о слоях модели и их параметрах (веса и смещения). Вторым аргументом укажите имя файла, в котором будет сохранена модель. Хорошей практикой является использование расширений .pth или .pt для сохранения моделей PyTorch. Также можно указать полный путь к файлу, если вы хотите сохранить модель в определенном каталоге.

Пример сохранения модели:


torch.save(model.state_dict(), "cifar_fc.pth")

Чтобы загрузить сохраненную модель для дальнейшего использования или логического вывода, используйте метод torch.load(). Затем можно загрузить параметры модели с помощью метода load_state_dict(). Это позволит восстановить состояние модели с сохраненными параметрами и продолжить обучение или использование модели для вывода.

Пример загрузки модели:


model = YourModelClass()
model.load_state_dict(torch.load("cifar_fc.pth"))
model.eval()

При загрузке модели убедитесь, что класс модели, для которой загружаются параметры, совпадает с классом модели, которая была сохранена. В противном случае возможны ошибки при загрузке параметров. Также рекомендуется использовать метод model.eval() после загрузки модели, чтобы переключить ее в режим оценки и отключить дополнительные режимы, такие как режим обучения.

Твои коллеги будут рады, поделись в

Автор урока

Дмитрий Комаровский
Дмитрий Комаровский

Автоматизация процессов
в КраснодарБанки.ру

Другие уроки курса "Python"

  1. Работа с файлами в Python
  2. Установка и загрузка Instaloader
  3. Модуль sys: основы
  4. Лямбда-функции в Python
  5. Функция enumerate() — Python
  6. Оптимизация поиска в словарях
  7. Оператор walrus в Python
  8. Анонимные функции Lambda
  9. Удаление и повторная вставка ключа в OrderedDict
  10. Создание именованных кортежей в Python
  11. Библиотека Emoji: использование смайлов в Python
  12. Проверка типа объекта в Python
  13. Мониторинг памяти с Pympler
  14. CLI-инструмент howdoi
  15. Отображение HTML кода в Python
  16. Метод clear для коллекций
  17. Виртуальное окружение Python
  18. Работа с контекстными менеджерами
  19. CSV строка разделение в Python
  20. Поиск наиболее частого элемента в списке
  21. Хэш-функции и метод цепочек
  22. Копирование словарей и списков в Python
  23. Асинхронный код в Python
  24. Функция enumerate() в Python
  25. Работа со слайсами
  26. Выход из профиля в Django
  27. Получение размера объекта с sys.getsizeof()
  28. Python и Монти Пайтон
  29. Локальные переменные.
  30. Ошибка NotImplemented в Python
  31. Частичное совпадение ввода
  32. Функция divmod() в Python
  33. Удаление элементов во время итерации
  34. Получение имени функции с помощью inspect
  35. Именованные срезы в Python
  36. Декораторы в Python
  37. Метод classmethod
  38. Создание веб-приложения с Flask
  39. Удаление дубликатов из списка

Marketello читают маркетологи из крутых компаний