Курс Python → Сохранение и загрузка модели в PyTorch

Для сохранения и загрузки модели в PyTorch необходимо использовать методы torch.save() и torch.load(). Для сохранения модели передайте model.state_dict() в качестве первого аргумента, это просто словарь, который содержит информацию о слоях модели и их параметрах (веса и смещения). Вторым аргументом укажите имя файла, в котором будет сохранена модель. Хорошей практикой является использование расширений .pth или .pt для сохранения моделей PyTorch. Также можно указать полный путь к файлу, если вы хотите сохранить модель в определенном каталоге.

Пример сохранения модели:


torch.save(model.state_dict(), "cifar_fc.pth")

Чтобы загрузить сохраненную модель для дальнейшего использования или логического вывода, используйте метод torch.load(). Затем можно загрузить параметры модели с помощью метода load_state_dict(). Это позволит восстановить состояние модели с сохраненными параметрами и продолжить обучение или использование модели для вывода.

Пример загрузки модели:


model = YourModelClass()
model.load_state_dict(torch.load("cifar_fc.pth"))
model.eval()

При загрузке модели убедитесь, что класс модели, для которой загружаются параметры, совпадает с классом модели, которая была сохранена. В противном случае возможны ошибки при загрузке параметров. Также рекомендуется использовать метод model.eval() после загрузки модели, чтобы переключить ее в режим оценки и отключить дополнительные режимы, такие как режим обучения.

Твои коллеги будут рады, поделись в

Автор урока

Дмитрий Комаровский
Дмитрий Комаровский

Автоматизация процессов
в КраснодарБанки.ру

Другие уроки курса "Python"

  1. CLI-инструмент howdoi
  2. Преобразование в float
  3. Обработка исключений в Python
  4. Проверка ввода с помощью isdigit
  5. Получение комбинаций в Python
  6. Функции map, filter и reduce
  7. Преобразование символов с помощью map
  8. Функция zip() в Python
  9. Поиск частых элементов в списке
  10. List Comprehension Tutorial
  11. Классы данных в Python
  12. Генераторы в Python
  13. Операторы присваивания в Python
  14. Запрос DELETE с библиотекой requests
  15. Получение размера объекта с sys.getsizeof()
  16. Работа с Path в Python
  17. Пропуск начальных строк с помощью dropwhile()
  18. PUT запрос для обновления данных
  19. Создание матрицы в Python
  20. Профилирование с cProfile
  21. Переменная Шредингера
  22. Разделение строки с регулярными выражениями
  23. Создание namedtuple списком полей
  24. Непрерывная проверка в Python
  25. Генераторы в Python
  26. Генераторы списков в Python
  27. Функция reversed() в Python
  28. Функции с необязательными аргументами
  29. Модуль pprint: улучшение вывода данных
  30. Возврат нескольких значений
  31. Копирование списков в Python
  32. Создание функций с произвольным количеством аргументов
  33. Объединение словарей в Python
  34. Настройка нарезки списков
  35. Работа с словарями в Python
  36. Избегайте изменяемых аргументов
  37. Обмен значений переменных в Python
  38. Библиотека Rich: форматирование текста
  39. Создание спинбокса в tkinter
  40. Обязательные аргументы в Python
  41. Основы работы со строками в Python
  42. Magic Commands — улучшение работы с Python
  43. Работа с IP-адресами в Python
  44. Установка пакета в Python

Marketello читают маркетологи из крутых компаний