Курс Python → Сохранение и загрузка модели в PyTorch

Для сохранения и загрузки модели в PyTorch необходимо использовать методы torch.save() и torch.load(). Для сохранения модели передайте model.state_dict() в качестве первого аргумента, это просто словарь, который содержит информацию о слоях модели и их параметрах (веса и смещения). Вторым аргументом укажите имя файла, в котором будет сохранена модель. Хорошей практикой является использование расширений .pth или .pt для сохранения моделей PyTorch. Также можно указать полный путь к файлу, если вы хотите сохранить модель в определенном каталоге.

Пример сохранения модели:


torch.save(model.state_dict(), "cifar_fc.pth")

Чтобы загрузить сохраненную модель для дальнейшего использования или логического вывода, используйте метод torch.load(). Затем можно загрузить параметры модели с помощью метода load_state_dict(). Это позволит восстановить состояние модели с сохраненными параметрами и продолжить обучение или использование модели для вывода.

Пример загрузки модели:


model = YourModelClass()
model.load_state_dict(torch.load("cifar_fc.pth"))
model.eval()

При загрузке модели убедитесь, что класс модели, для которой загружаются параметры, совпадает с классом модели, которая была сохранена. В противном случае возможны ошибки при загрузке параметров. Также рекомендуется использовать метод model.eval() после загрузки модели, чтобы переключить ее в режим оценки и отключить дополнительные режимы, такие как режим обучения.

Твои коллеги будут рады, поделись в

Автор урока

Дмитрий Комаровский
Дмитрий Комаровский

Автоматизация процессов
в КраснодарБанки.ру

Другие уроки курса "Python"

  1. Декораторы классов
  2. Метод join() для объединения элементов строки
  3. Функции высшего порядка в Python
  4. Поток данных в Python
  5. Работа с библиотекой requests
  6. Просмотр атрибутов и методов класса
  7. Форматирование строк в Python
  8. Показ всплывающих окон Tkinter
  9. Улучшенные подсказки для импорта в Python 3.12
  10. Ограничение ресурсов в Python
  11. Именованные кортежи в Python
  12. Отображение графиков в Jupyter с Matplotlib
  13. Вложенные генераторы в Python
  14. Функция с *args.
  15. Python: Фильтрация списков с помощью filter()
  16. Операторы сравнения в Python
  17. Итерация по копии коллекции
  18. Бинарный поиск
  19. Асинхронное выполнение задач в Python
  20. Создание таблиц в Python с PrettyTable
  21. lru_cache оптимизация функций
  22. Преобразование списка в словарь через генератор
  23. Форматирование строк в Python
  24. Импорт и использование модулей в Python
  25. Удаление символа из строки
  26. Срезы в Python
  27. Лямбда-функции в defaultdict
  28. Преобразование Excel в PDF с Spire.XLS
  29. Метод add для класса Vector
  30. Многострочные комментарии в Python
  31. Python Менеджер контекста
  32. Поиск самого длинного слова в списке с использованием max()
  33. Руководство по Pymorphy2
  34. Управление памятью в numpy.
  35. Цепные операции в Python
  36. Библиотека itertools: объединение списков
  37. Оператор in в Python
  38. Выбор редактора кода.
  39. Обновление множества в Python
  40. Проверка окончания строки с помощью str.endswith()
  41. Работа с часовыми поясами в Python.
  42. Функции all и any в Python
  43. Измерение времени выполнения кода
  44. Antigravity модуль
  45. Автоматизация действий с Pyautogui
  46. Логирование в Python
  47. Преобразование Word в PDF с Spire.Doc
  48. Роль запятой в Python

Marketello читают маркетологи из крутых компаний