Курс Python → Сохранение и загрузка модели в PyTorch

Для сохранения и загрузки модели в PyTorch необходимо использовать методы torch.save() и torch.load(). Для сохранения модели передайте model.state_dict() в качестве первого аргумента, это просто словарь, который содержит информацию о слоях модели и их параметрах (веса и смещения). Вторым аргументом укажите имя файла, в котором будет сохранена модель. Хорошей практикой является использование расширений .pth или .pt для сохранения моделей PyTorch. Также можно указать полный путь к файлу, если вы хотите сохранить модель в определенном каталоге.

Пример сохранения модели:


torch.save(model.state_dict(), "cifar_fc.pth")

Чтобы загрузить сохраненную модель для дальнейшего использования или логического вывода, используйте метод torch.load(). Затем можно загрузить параметры модели с помощью метода load_state_dict(). Это позволит восстановить состояние модели с сохраненными параметрами и продолжить обучение или использование модели для вывода.

Пример загрузки модели:


model = YourModelClass()
model.load_state_dict(torch.load("cifar_fc.pth"))
model.eval()

При загрузке модели убедитесь, что класс модели, для которой загружаются параметры, совпадает с классом модели, которая была сохранена. В противном случае возможны ошибки при загрузке параметров. Также рекомендуется использовать метод model.eval() после загрузки модели, чтобы переключить ее в режим оценки и отключить дополнительные режимы, такие как режим обучения.

Твои коллеги будут рады, поделись в

Автор урока

Дмитрий Комаровский
Дмитрий Комаровский

Автоматизация процессов
в КраснодарБанки.ру

Другие уроки курса "Python"

  1. Выборка чисел
  2. Создание комплексных чисел
  3. Работа с кортежами в Python
  4. Измерение потребления памяти при сортировке
  5. Измерение времени выполнения кода с помощью time
  6. Функции map() и reduce() в Python
  7. Сортировка в Python
  8. Проверка надежности пароля на Python
  9. Управление памятью в Python
  10. Создание таблиц в Python с PrettyTable
  11. Тестирование с unittest
  12. Библиотека sh: удобные команды терминала
  13. Удаление элемента из списка
  14. Python Translator: создание локальных переводчиков
  15. Печать календаря
  16. Взаимодействие с sys
  17. Просмотр атрибутов и методов класса
  18. Создание пустых функций и классов в Python
  19. Проверка дублей в списке.
  20. Применение функции map() с лямбда-функциями
  21. Оператор Walrus: правильное использование
  22. Работа с библиотекой requests
  23. Beautiful Soup — извлечение данных из HTML
  24. Тестирование с responses
  25. Вычисление логарифмов в Python
  26. Значения по умолчанию в Python
  27. Порядок операций в Python
  28. Оператор in для проверки наличия элемента
  29. Дизассемблирование Python кода
  30. Enum в Python
  31. Оператор морж в Python 3.8
  32. Принципы LSP и ISP в Python
  33. Форматирование вывода списков
  34. Работа с буфером обмена на Python
  35. Проверка индексов коллекции
  36. Генераторы в Python
  37. Создание новых функций с помощью functools.partial
  38. Построение графиков в Matplotlib
  39. Конкатенация строк с помощью join()
  40. Python: Splat-оператор и splatty-splat
  41. Возврат нескольких значений
  42. Сложение матриц в NumPy
  43. Распаковка с оператором *
  44. Работа с NumPy
  45. Работа с файлами в Python
  46. Утечки переменных цикла в Python 3.x

Marketello читают маркетологи из крутых компаний