Курс Python → Сохранение и загрузка модели в PyTorch

Для сохранения и загрузки модели в PyTorch необходимо использовать методы torch.save() и torch.load(). Для сохранения модели передайте model.state_dict() в качестве первого аргумента, это просто словарь, который содержит информацию о слоях модели и их параметрах (веса и смещения). Вторым аргументом укажите имя файла, в котором будет сохранена модель. Хорошей практикой является использование расширений .pth или .pt для сохранения моделей PyTorch. Также можно указать полный путь к файлу, если вы хотите сохранить модель в определенном каталоге.

Пример сохранения модели:


torch.save(model.state_dict(), "cifar_fc.pth")

Чтобы загрузить сохраненную модель для дальнейшего использования или логического вывода, используйте метод torch.load(). Затем можно загрузить параметры модели с помощью метода load_state_dict(). Это позволит восстановить состояние модели с сохраненными параметрами и продолжить обучение или использование модели для вывода.

Пример загрузки модели:


model = YourModelClass()
model.load_state_dict(torch.load("cifar_fc.pth"))
model.eval()

При загрузке модели убедитесь, что класс модели, для которой загружаются параметры, совпадает с классом модели, которая была сохранена. В противном случае возможны ошибки при загрузке параметров. Также рекомендуется использовать метод model.eval() после загрузки модели, чтобы переключить ее в режим оценки и отключить дополнительные режимы, такие как режим обучения.

Твои коллеги будут рады, поделись в

Автор урока

Дмитрий Комаровский
Дмитрий Комаровский

Автоматизация процессов
в КраснодарБанки.ру

Другие уроки курса "Python"

  1. Явный импорт переменных
  2. Переопределение метода __and__
  3. Логирование с Logzero
  4. Логирование с Logzero
  5. Обновление ключей в Python
  6. Тип данных TypeVarTuple
  7. Запуск Python из интерпретатора
  8. Оптимизация памяти с __slots__
  9. Преобразование в float
  10. Python и Юникод: работа с цифрами
  11. Комментарии в Python
  12. Анонимные функции в Python
  13. Хэш-функции в Python
  14. Копирование словарей и списков в Python
  15. Парсинг веб-страниц с Beautiful Soup
  16. Ограничение итераций в Python
  17. Аннотации типов в Python
  18. Вложенные циклы в Python
  19. Экранирование символов в Python
  20. Регулярные выражения в Python
  21. Управление браузером с Selenium
  22. Копирование списков в Python
  23. Создание детектора плагиата
  24. Глобальные переменные в Python
  25. Метод hash в Python
  26. Создание и операции с дробями
  27. Работа с часовыми поясами в Python.
  28. Работа с collections в Python
  29. Настройка шрифта и цвета в Tkinter
  30. Defaultdict в Python
  31. Доступ к локальным переменным
  32. Функции высшего порядка в Python
  33. Метод repr() в Python
  34. Разделение строк методом split()
  35. Обновление данных через PUT запрос
  36. Пересечение списков с использованием множеств
  37. Метаклассы в Python
  38. Логические значения в Python
  39. Создание даты из строки ISO
  40. Работа с Colorama
  41. Работа с deque из collections
  42. Работа с исключениями в Python
  43. Библиотека schedule: планировщик задач
  44. Удаление символа из строки
  45. Объединение коллекций в Python
  46. Повторение элементов списков
  47. Генератор данных в Keras
  48. Применение функции map() в Python

Marketello читают маркетологи из крутых компаний