Курс Python → Сохранение и загрузка модели в PyTorch

Для сохранения и загрузки модели в PyTorch необходимо использовать методы torch.save() и torch.load(). Для сохранения модели передайте model.state_dict() в качестве первого аргумента, это просто словарь, который содержит информацию о слоях модели и их параметрах (веса и смещения). Вторым аргументом укажите имя файла, в котором будет сохранена модель. Хорошей практикой является использование расширений .pth или .pt для сохранения моделей PyTorch. Также можно указать полный путь к файлу, если вы хотите сохранить модель в определенном каталоге.

Пример сохранения модели:


torch.save(model.state_dict(), "cifar_fc.pth")

Чтобы загрузить сохраненную модель для дальнейшего использования или логического вывода, используйте метод torch.load(). Затем можно загрузить параметры модели с помощью метода load_state_dict(). Это позволит восстановить состояние модели с сохраненными параметрами и продолжить обучение или использование модели для вывода.

Пример загрузки модели:


model = YourModelClass()
model.load_state_dict(torch.load("cifar_fc.pth"))
model.eval()

При загрузке модели убедитесь, что класс модели, для которой загружаются параметры, совпадает с классом модели, которая была сохранена. В противном случае возможны ошибки при загрузке параметров. Также рекомендуется использовать метод model.eval() после загрузки модели, чтобы переключить ее в режим оценки и отключить дополнительные режимы, такие как режим обучения.

Твои коллеги будут рады, поделись в

Автор урока

Дмитрий Комаровский
Дмитрий Комаровский

Автоматизация процессов
в КраснодарБанки.ру

Другие уроки курса "Python"

  1. Установка максимального количества цифр
  2. Генераторы в Python
  3. Работа с дробями в Python
  4. Форматирование строк с f-строками
  5. Многоточие в Python
  6. Работа с NumPy.linalg
  7. Создание новых списков в Python
  8. Сортировка данных с лямбда-функциями
  9. Работа с модулем glob в Python
  10. Цикл for в Python
  11. Вывод переменной и строки в Python
  12. Лямбда-функции в Python
  13. Проверка условий в Python
  14. Базовые объекты Python
  15. Проверка кортежей.
  16. Передача словаря через **kwargs
  17. Асинхронное выполнение задач в Python
  18. Повторение элементов в Python
  19. capitalize() — изменение регистра первого символа строки
  20. Лямбда-функции в Python
  21. Декоратор Property в Python
  22. Работа с эмодзи в Python
  23. Извлечение аудио из видео
  24. Объединение строк с помощью метода join
  25. Обратный список чисел
  26. Ускоренный импорт библиотек
  27. Проверка списка: any() и all()
  28. Операции с датами в Python
  29. Функциональное программирование.
  30. Оператор @ для умножения матриц
  31. Подсчет элементов с помощью Counter из collections
  32. Управление фоновыми задачами в Python
  33. Форматирование строк в Python
  34. Принцип одной функции
  35. Операции с кортежами
  36. Нарезка списков в Python
  37. Python: Фильтрация списков с помощью filter()
  38. Метод difference_update() — разность множеств
  39. Нахождение разницы между списками в Python
  40. Управление асинхронными задачами с помощью Semaphore
  41. Вложенные функции в Python
  42. *args и **kwargs в Python
  43. Метод rlshift для битового сдвига
  44. Удаление символа из строки
  45. Синтаксис переменных цикла в Python
  46. Методы работы со списками
  47. Метод invert для побитового отрицания
  48. Переворот последовательности

Marketello читают маркетологи из крутых компаний