Курс Python → Сохранение и загрузка модели в PyTorch

Для сохранения и загрузки модели в PyTorch необходимо использовать методы torch.save() и torch.load(). Для сохранения модели передайте model.state_dict() в качестве первого аргумента, это просто словарь, который содержит информацию о слоях модели и их параметрах (веса и смещения). Вторым аргументом укажите имя файла, в котором будет сохранена модель. Хорошей практикой является использование расширений .pth или .pt для сохранения моделей PyTorch. Также можно указать полный путь к файлу, если вы хотите сохранить модель в определенном каталоге.

Пример сохранения модели:


torch.save(model.state_dict(), "cifar_fc.pth")

Чтобы загрузить сохраненную модель для дальнейшего использования или логического вывода, используйте метод torch.load(). Затем можно загрузить параметры модели с помощью метода load_state_dict(). Это позволит восстановить состояние модели с сохраненными параметрами и продолжить обучение или использование модели для вывода.

Пример загрузки модели:


model = YourModelClass()
model.load_state_dict(torch.load("cifar_fc.pth"))
model.eval()

При загрузке модели убедитесь, что класс модели, для которой загружаются параметры, совпадает с классом модели, которая была сохранена. В противном случае возможны ошибки при загрузке параметров. Также рекомендуется использовать метод model.eval() после загрузки модели, чтобы переключить ее в режим оценки и отключить дополнительные режимы, такие как режим обучения.

Твои коллеги будут рады, поделись в

Автор урока

Дмитрий Комаровский
Дмитрий Комаровский

Автоматизация процессов
в КраснодарБанки.ру

Другие уроки курса "Python"

  1. Обработка исключений в Python
  2. Работа с необработанными строками
  3. Метод ifloordiv для пользовательских классов
  4. Работа с срезами в Numpy
  5. Бинарный поиск
  6. Блок try…finally в Python
  7. Передача неизвестных аргументов в Python.
  8. Работа с исключениями в Python
  9. Обработка исключений в Python
  10. Создание директории в Python
  11. Объединение словарей в Python
  12. Хеширование паролей с использованием salt
  13. Множественное наследование в Python
  14. F-строки в Python 3.8
  15. Генераторы данных
  16. Возвращение нескольких значений
  17. Использование метода lower()
  18. Monkey Patching в Python
  19. Функция all() в Python
  20. Многострочные строки в Python
  21. Установка и обучение ChatterBot
  22. Поиск частых элементов в списке
  23. Работа с deque в Python
  24. Генераторы в Python
  25. Функциональное программирование в Python
  26. Подсчет элементов в списке с Counter
  27. Транспонирование матрицы
  28. Конвертация коллекций в Python
  29. Генерация тестовых данных с factory_boy
  30. Вычисление логарифмов в Python
  31. Работа с буфером обмена на Python
  32. Подробная информация о %pinfo
  33. Генерация случайных чисел Python
  34. Декодирование строк в Python
  35. Модуль itertools: эффективная работа с итераторами
  36. Генерация чисел с range()
  37. Константы в модуле cmath
  38. Генераторы и сеты в Python
  39. Преобразование текста в речь с Python
  40. Атрибуты класса и экземпляра в Python
  41. Получение значений из словарей
  42. Обработка ошибок в Python
  43. Создание задания в Cron
  44. Тип данных TypeVarTuple
  45. Solidity для DeFi Ethereum
  46. Работа с timedelta в Python

Marketello читают маркетологи из крутых компаний