Курс Python → Сохранение и загрузка модели в PyTorch

Для сохранения и загрузки модели в PyTorch необходимо использовать методы torch.save() и torch.load(). Для сохранения модели передайте model.state_dict() в качестве первого аргумента, это просто словарь, который содержит информацию о слоях модели и их параметрах (веса и смещения). Вторым аргументом укажите имя файла, в котором будет сохранена модель. Хорошей практикой является использование расширений .pth или .pt для сохранения моделей PyTorch. Также можно указать полный путь к файлу, если вы хотите сохранить модель в определенном каталоге.

Пример сохранения модели:


torch.save(model.state_dict(), "cifar_fc.pth")

Чтобы загрузить сохраненную модель для дальнейшего использования или логического вывода, используйте метод torch.load(). Затем можно загрузить параметры модели с помощью метода load_state_dict(). Это позволит восстановить состояние модели с сохраненными параметрами и продолжить обучение или использование модели для вывода.

Пример загрузки модели:


model = YourModelClass()
model.load_state_dict(torch.load("cifar_fc.pth"))
model.eval()

При загрузке модели убедитесь, что класс модели, для которой загружаются параметры, совпадает с классом модели, которая была сохранена. В противном случае возможны ошибки при загрузке параметров. Также рекомендуется использовать метод model.eval() после загрузки модели, чтобы переключить ее в режим оценки и отключить дополнительные режимы, такие как режим обучения.

Твои коллеги будут рады, поделись в

Автор урока

Дмитрий Комаровский
Дмитрий Комаровский

Автоматизация процессов
в КраснодарБанки.ру

Другие уроки курса "Python"

  1. Деление в Python
  2. Удаление элементов из списка в Python
  3. Работа с модулем glob в Python
  4. Работа с модулем Calendar
  5. Преобразование символов с помощью map
  6. Фильтрация списка от «ложных» значений
  7. Оператор continue в Python
  8. Избегание циклических зависимостей классов в Python
  9. Группировка элементов в словарь
  10. Проверка типа данных
  11. Итераторы в Python
  12. Глубокое копирование объектов
  13. Создание списков в Python
  14. Работа с базами данных SQLite
  15. Транспонирование 2D-массива с помощью zip
  16. Модуль xkcd: добавление юмора в Python
  17. Создание копии списка в Python
  18. Работа с URL-адресами в Python
  19. Фильтрация данных в Python.
  20. Получение идентификатора объекта в памяти
  21. Аргументы *args и **kwargs
  22. Итерация по итерируемым объектам
  23. Замена переменных в Python
  24. Метод ifloordiv для пользовательских классов
  25. Python Enum Weekday Usage
  26. Сортировка списка по индексам
  27. Просмотр атрибутов и методов класса
  28. Путь к интерпретатору Python
  29. Экспорт функций в Python
  30. Метод eq для сравнения объектов
  31. Метод __int__ в Python
  32. Искажение имен в Python
  33. Основы Python за 14 дней
  34. Избегайте двойного подчеркивания
  35. Подписка на @SelectelNews
  36. Декораторы с аргументами
  37. Операции со строками в Python
  38. Переопределение метода len
  39. Concrete Paths — метод .with_suffix()
  40. Работа с каталогами в Python
  41. Просмотр внешних файлов в %pycat
  42. Добавление вложенных списков
  43. Именование переменных в Python
  44. Удаление специальных символов
  45. Функция format() в Python
  46. Функция findall() для поиска вхождений строки
  47. Многострочные строки в Python

Marketello читают маркетологи из крутых компаний