Курс Python → Сохранение и загрузка модели в PyTorch

Для сохранения и загрузки модели в PyTorch необходимо использовать методы torch.save() и torch.load(). Для сохранения модели передайте model.state_dict() в качестве первого аргумента, это просто словарь, который содержит информацию о слоях модели и их параметрах (веса и смещения). Вторым аргументом укажите имя файла, в котором будет сохранена модель. Хорошей практикой является использование расширений .pth или .pt для сохранения моделей PyTorch. Также можно указать полный путь к файлу, если вы хотите сохранить модель в определенном каталоге.

Пример сохранения модели:


torch.save(model.state_dict(), "cifar_fc.pth")

Чтобы загрузить сохраненную модель для дальнейшего использования или логического вывода, используйте метод torch.load(). Затем можно загрузить параметры модели с помощью метода load_state_dict(). Это позволит восстановить состояние модели с сохраненными параметрами и продолжить обучение или использование модели для вывода.

Пример загрузки модели:


model = YourModelClass()
model.load_state_dict(torch.load("cifar_fc.pth"))
model.eval()

При загрузке модели убедитесь, что класс модели, для которой загружаются параметры, совпадает с классом модели, которая была сохранена. В противном случае возможны ошибки при загрузке параметров. Также рекомендуется использовать метод model.eval() после загрузки модели, чтобы переключить ее в режим оценки и отключить дополнительные режимы, такие как режим обучения.

Твои коллеги будут рады, поделись в

Автор урока

Дмитрий Комаровский
Дмитрий Комаровский

Автоматизация процессов
в КраснодарБанки.ру

Другие уроки курса "Python"

  1. Создание вкладок с TKinter
  2. Получение текущей даты и времени с помощью datetime
  3. Непрерывная проверка в Python
  4. Хранение переменных в Python.
  5. Работа с NumPy массивами
  6. Сортировка данных с лямбда-функциями
  7. Регистрация на курсы SF Education
  8. Псевдонимы в Python
  9. Создание пар из последовательностей
  10. Управление IP-адресами через прокси
  11. Манипуляция формой массива в Numpy
  12. Документирование функций в Python
  13. Разность множеств
  14. globals и locals
  15. Отрицательные индексы списков в Python
  16. Получение ID текущего процесса
  17. Официальный канал Python в Telegram
  18. Сортировка HTML-элементов
  19. Работа с аргументами командной строки в Python
  20. Получение имени функции с помощью inspect
  21. Работа со словарями с defaultdict из collections
  22. Оператор морж в Python 3.8
  23. Приоритет операций в Python
  24. Извлечение аудио из видео
  25. Метод title() в Python
  26. Тайное преобразование типа ключа
  27. Объединение списков в Python
  28. Метод __ixor__ для побитового исключающего ИЛИ
  29. Утечки переменных цикла в Python 3.x
  30. Тестирование функции сложения
  31. Оператор == в Python
  32. split() без разделителя
  33. Многострочные строки в Python
  34. Парсинг статей с Newspaper3k
  35. Пустой оператор pass в Python
  36. Повторение и перенос строки
  37. Модуль xkcd: добавление юмора в Python
  38. Декораторы в Python
  39. Измерение времени выполнения с помощью time
  40. Динамическая типизация в Python
  41. Подсказки типов в Python
  42. Поиск наиболее частого элемента списке
  43. Получение атрибутов и методов класса
  44. Ветвление выражения в Python
  45. Подсчет элементов в Python
  46. Очистка данных в Python
  47. Подсказки типов в Python
  48. Декоратор total_ordering для класса Point

Marketello читают маркетологи из крутых компаний