Курс Python → Сохранение и загрузка модели в PyTorch

Для сохранения и загрузки модели в PyTorch необходимо использовать методы torch.save() и torch.load(). Для сохранения модели передайте model.state_dict() в качестве первого аргумента, это просто словарь, который содержит информацию о слоях модели и их параметрах (веса и смещения). Вторым аргументом укажите имя файла, в котором будет сохранена модель. Хорошей практикой является использование расширений .pth или .pt для сохранения моделей PyTorch. Также можно указать полный путь к файлу, если вы хотите сохранить модель в определенном каталоге.

Пример сохранения модели:


torch.save(model.state_dict(), "cifar_fc.pth")

Чтобы загрузить сохраненную модель для дальнейшего использования или логического вывода, используйте метод torch.load(). Затем можно загрузить параметры модели с помощью метода load_state_dict(). Это позволит восстановить состояние модели с сохраненными параметрами и продолжить обучение или использование модели для вывода.

Пример загрузки модели:


model = YourModelClass()
model.load_state_dict(torch.load("cifar_fc.pth"))
model.eval()

При загрузке модели убедитесь, что класс модели, для которой загружаются параметры, совпадает с классом модели, которая была сохранена. В противном случае возможны ошибки при загрузке параметров. Также рекомендуется использовать метод model.eval() после загрузки модели, чтобы переключить ее в режим оценки и отключить дополнительные режимы, такие как режим обучения.

Твои коллеги будут рады, поделись в

Автор урока

Дмитрий Комаровский
Дмитрий Комаровский

Автоматизация процессов
в КраснодарБанки.ру

Другие уроки курса "Python"

  1. Лямбда-функции в Python
  2. Обезопасьте ввод данных
  3. Преобразование чисел в восьмеричную строку
  4. Срез в Python
  5. inspect в Python: анализ кода
  6. TypedDict для kwargs в Python 3.12
  7. Комплексные числа в Python
  8. Оценка выражений генератора в Python
  9. Библиотека Rich: форматирование текста
  10. Генерация случайных чисел в Python
  11. Перезапуск ячейки в Jupyter Notebook с dostoevsky
  12. globals и locals
  13. Удаление файлов в Python
  14. Генераторы в Python
  15. Замена элементов в списке с помощью генераторов списков
  16. Установка и использование TensorFlow
  17. kwargs в Python
  18. Метод __int__ в Python
  19. Работа с комплексными числами
  20. Метод __iand__ для пользовательских классов
  21. Нахождение самого длинного слова в списке с помощью max
  22. Декораторы в Python
  23. Игра «Угадывание чисел»
  24. Лямбда-функции в цикле
  25. Методы в Python
  26. Очистка данных с Pandas
  27. Проектирование Singleton с метаклассом
  28. Форматирование объектов с модулем pprint
  29. Проблемы с именами переменных
  30. Метод is_absolute() для PurePath
  31. Непрерывная проверка в Python
  32. Декораторы в Python
  33. Типы возвращаемых значений в Python
  34. Преобразование числа в восьмеричную строку
  35. Закрытие файла в Python
  36. Избегайте использования goto
  37. Именованные срезы в Python
  38. Работа со словарями
  39. Создание GUI с Tkinter: Entry
  40. Изменение списка срезами
  41. Оператор += в Python
  42. Распаковка элементов массива
  43. Декоратор total_ordering для сравнения объектов
  44. Аннотации типов в Python
  45. Встроенные функции Python
  46. Срез списка в Python

Marketello читают маркетологи из крутых компаний