Курс Python → Сохранение и загрузка модели в PyTorch

Для сохранения и загрузки модели в PyTorch необходимо использовать методы torch.save() и torch.load(). Для сохранения модели передайте model.state_dict() в качестве первого аргумента, это просто словарь, который содержит информацию о слоях модели и их параметрах (веса и смещения). Вторым аргументом укажите имя файла, в котором будет сохранена модель. Хорошей практикой является использование расширений .pth или .pt для сохранения моделей PyTorch. Также можно указать полный путь к файлу, если вы хотите сохранить модель в определенном каталоге.

Пример сохранения модели:


torch.save(model.state_dict(), "cifar_fc.pth")

Чтобы загрузить сохраненную модель для дальнейшего использования или логического вывода, используйте метод torch.load(). Затем можно загрузить параметры модели с помощью метода load_state_dict(). Это позволит восстановить состояние модели с сохраненными параметрами и продолжить обучение или использование модели для вывода.

Пример загрузки модели:


model = YourModelClass()
model.load_state_dict(torch.load("cifar_fc.pth"))
model.eval()

При загрузке модели убедитесь, что класс модели, для которой загружаются параметры, совпадает с классом модели, которая была сохранена. В противном случае возможны ошибки при загрузке параметров. Также рекомендуется использовать метод model.eval() после загрузки модели, чтобы переключить ее в режим оценки и отключить дополнительные режимы, такие как режим обучения.

Твои коллеги будут рады, поделись в

Автор урока

Дмитрий Комаровский
Дмитрий Комаровский

Автоматизация процессов
в КраснодарБанки.ру

Другие уроки курса "Python"

  1. Многоточие в Python
  2. Отладка производительности Python
  3. Функция enumerate() в Python
  4. Базовые объекты Python
  5. Создание функций высшего порядка
  6. Декоратор Ajax required
  7. Капитализация строк
  8. Запрос пароля с помощью getpass
  9. Вызов функций по строке в Python.
  10. Проверка подстроки в строке с помощью in
  11. Объединение словарей в Python
  12. Объединение словарей в Python
  13. Работа со стеком в Python
  14. Работа с SQLite в Python
  15. Нан-рефлексивность в Python
  16. Сортировка данных в Python
  17. Поиск индекса элемента в списке
  18. Открытие и редактирование скриптов Python
  19. Combobox в Tkinter
  20. Форматирование данных с помощью pprint
  21. Перегрузка операторов в Python
  22. Метод setdefault() в Python
  23. Избегайте изменяемых аргументов
  24. Проверка элементов списка условием
  25. Поиск email
  26. Модуль xkcd: загрузка комиксов
  27. Функция enumerate() в Python
  28. Функция с *args.
  29. Удаление элементов по срезу
  30. Операции с комплексными числами
  31. Работа с контекстными переменными
  32. Вычисление натурального логарифма в NumPy
  33. Списковое включение в Python
  34. Выход из профиля в Django
  35. Сортировка в Python
  36. Оформление кода на Python
  37. Запуск Python из интерпретатора
  38. Представление бесконечности в Python
  39. Скрытие вывода данных
  40. Транспонирование 2D-массива с помощью zip
  41. Простой калькулятор Python
  42. Псевдонимы в Python
  43. Избегайте двойного подчеркивания
  44. Извлечение аудио из видео
  45. Курс по дообучению ChatGPT
  46. Функции map() и reduce() в Python
  47. Создание класса очереди
  48. Работа с изображениями Pillow

Marketello читают маркетологи из крутых компаний