Курс Python → Сохранение и загрузка модели в PyTorch

Для сохранения и загрузки модели в PyTorch необходимо использовать методы torch.save() и torch.load(). Для сохранения модели передайте model.state_dict() в качестве первого аргумента, это просто словарь, который содержит информацию о слоях модели и их параметрах (веса и смещения). Вторым аргументом укажите имя файла, в котором будет сохранена модель. Хорошей практикой является использование расширений .pth или .pt для сохранения моделей PyTorch. Также можно указать полный путь к файлу, если вы хотите сохранить модель в определенном каталоге.

Пример сохранения модели:


torch.save(model.state_dict(), "cifar_fc.pth")

Чтобы загрузить сохраненную модель для дальнейшего использования или логического вывода, используйте метод torch.load(). Затем можно загрузить параметры модели с помощью метода load_state_dict(). Это позволит восстановить состояние модели с сохраненными параметрами и продолжить обучение или использование модели для вывода.

Пример загрузки модели:


model = YourModelClass()
model.load_state_dict(torch.load("cifar_fc.pth"))
model.eval()

При загрузке модели убедитесь, что класс модели, для которой загружаются параметры, совпадает с классом модели, которая была сохранена. В противном случае возможны ошибки при загрузке параметров. Также рекомендуется использовать метод model.eval() после загрузки модели, чтобы переключить ее в режим оценки и отключить дополнительные режимы, такие как режим обучения.

Твои коллеги будут рады, поделись в

Автор урока

Дмитрий Комаровский
Дмитрий Комаровский

Автоматизация процессов
в КраснодарБанки.ру

Другие уроки курса "Python"

  1. Метод splitlines() для разделения строк
  2. Метод ifloordiv для пользовательских классов
  3. Многострочные строки в Python
  4. Анализ текста на русском языке с помощью Pymystem3
  5. F-строки в Python 3.8
  6. Оператор in в Python
  7. Dict Comprehension в Python
  8. Создание словарей и множеств в Python.
  9. Цикл while в Python
  10. Colorama: окрашивание текста в Python
  11. Хеширование паролей с использованием salt
  12. Управление доступом к модулю
  13. Декоратор Ajax required
  14. Реверс строки и списка в Python.
  15. Хэш-функции в Python
  16. Python и Юникод: работа с цифрами
  17. Кортежи в Python: особенности и преимущества
  18. Профилирование данных с Pandas
  19. Работа с itertools
  20. Antigravity модуль
  21. Метод pop() списка
  22. Форматирование строк с помощью f-строк
  23. Python: возвращение нескольких значений
  24. Методы classmethod и staticmethod
  25. Обезопасьте ввод данных
  26. Работа с контекстными менеджерами
  27. Игра «Камень, ножницы, бумага» — Python
  28. Декораторы для регистрации функций
  29. Работа с JSON в Python
  30. Поиск HTML-элементов с BeautifulSoup
  31. EMOT преобразование эмодзи в текст
  32. Работа с часовыми поясами в Python
  33. Область видимости переменных
  34. Генераторы списков в Python
  35. Работа с NumPy.linalg
  36. Преобразование генераторов в циклы
  37. Вложенные функции в Python
  38. Работа с Event() в threading
  39. Работа с словарями в Python
  40. Удаление дубликатов с помощью множеств
  41. Открытие и запись файлов
  42. Генераторы в Python
  43. Преобразование многоуровневого словаря
  44. Подсказки типов в Python
  45. Извлечение статей с newspaper3k
  46. Расчет времени выполнения
  47. Python Enum Weekday Usage
  48. Проверка списка: any() и all()

Marketello читают маркетологи из крутых компаний