Курс Python → Сохранение и загрузка модели в PyTorch

Для сохранения и загрузки модели в PyTorch необходимо использовать методы torch.save() и torch.load(). Для сохранения модели передайте model.state_dict() в качестве первого аргумента, это просто словарь, который содержит информацию о слоях модели и их параметрах (веса и смещения). Вторым аргументом укажите имя файла, в котором будет сохранена модель. Хорошей практикой является использование расширений .pth или .pt для сохранения моделей PyTorch. Также можно указать полный путь к файлу, если вы хотите сохранить модель в определенном каталоге.

Пример сохранения модели:


torch.save(model.state_dict(), "cifar_fc.pth")

Чтобы загрузить сохраненную модель для дальнейшего использования или логического вывода, используйте метод torch.load(). Затем можно загрузить параметры модели с помощью метода load_state_dict(). Это позволит восстановить состояние модели с сохраненными параметрами и продолжить обучение или использование модели для вывода.

Пример загрузки модели:


model = YourModelClass()
model.load_state_dict(torch.load("cifar_fc.pth"))
model.eval()

При загрузке модели убедитесь, что класс модели, для которой загружаются параметры, совпадает с классом модели, которая была сохранена. В противном случае возможны ошибки при загрузке параметров. Также рекомендуется использовать метод model.eval() после загрузки модели, чтобы переключить ее в режим оценки и отключить дополнительные режимы, такие как режим обучения.

Твои коллеги будут рады, поделись в

Автор урока

Дмитрий Комаровский
Дмитрий Комаровский

Автоматизация процессов
в КраснодарБанки.ру

Другие уроки курса "Python"

  1. Генератор чисел Фибоначчи
  2. Отладка в Python
  3. Транспонирование 2D-массива с помощью zip
  4. Проверка списка: any() и all()
  5. Логические операторы в Python
  6. Функция zip() в Python
  7. Функции map, filter и reduce
  8. Метод rpow в Python
  9. Хеширование паролей с солью
  10. Возврат нескольких значений из функции
  11. Генераторы в Python
  12. Преобразование числа в список цифр
  13. Блок try-except-else
  14. Разность множеств
  15. Создание уникального проекта
  16. Абстракции словарей и множеств в Python
  17. Создание циклической ссылки
  18. Отправка HTTP-запросов с User-Agent
  19. Создание словаря и множества
  20. Лямбда-функции для min/max
  21. Улучшенные подсказки для импорта в Python 3.12
  22. Метод get для словарей
  23. Конкатенация строк в Python
  24. Декоратор проверки активности
  25. Имена объектов в Python
  26. Асинхронный код в Python
  27. Парсинг веб-страниц с Beautiful Soup
  28. Проверка на истинность объектов в Python
  29. Добавление элемента в список.
  30. Оптимизация сравнения в Python
  31. Установка и использование howdoi
  32. Создание GUI на Tkinter
  33. Импорт модулей в Python 3.12
  34. Импорт объектов из модулей
  35. Обработка ошибок в JSON данных
  36. Поиск с библиотекой Google
  37. Раздувающийся словарь в Python
  38. Использование html-скриптов в Jupyter Notebook
  39. Проверка однородности элементов списка
  40. Подсчет элементов в списке с Counter
  41. Именованные срезы в Python
  42. Преобразование данных в Python
  43. Порядок операций в Python
  44. Работа с геоданными с помощью geopy
  45. Многопоточность и асинхронное программирование в Python
  46. Проверка окончания строки с помощью str.endswith()

Marketello читают маркетологи из крутых компаний