.. note:: :class: sphx-glr-download-link-note Click :ref:`here ` to download the full example code .. rst-class:: sphx-glr-example-title .. _sphx_glr_beginner_basics_saveloadrun_tutorial.py: `Learn the Basics `_ || `Quickstart `_ || `Tensors `_ || `Datasets & DataLoaders `_ || `Transforms `_ || `Build Model `_ || `Autograd `_ || `Optimization `_ || **Save & Load Model** Save and Load the Model ============================ In this section we will look at how to persist model state with saving, loading and running model predictions. .. code-block:: default import torch import torch.onnx as onnx import torchvision.models as models Saving and Loading Model Weights -------------------------------- PyTorch models store the learned parameters in an internal state dictionary, called ``state_dict``. These can be persisted via the ``torch.save`` method: .. code-block:: default model = models.vgg16(pretrained=True) torch.save(model.state_dict(), 'model_weights.pth') To load model weights, you need to create an instance of the same model first, and then load the parameters using ``load_state_dict()`` method. .. code-block:: default model = models.vgg16() # we do not specify pretrained=True, i.e. do not load default weights model.load_state_dict(torch.load('model_weights.pth')) model.eval() .. note:: be sure to call ``model.eval()`` method before inferencing to set the dropout and batch normalization layers to evaluation mode. Failing to do this will yield inconsistent inference results. Saving and Loading Models with Shapes ------------------------------------- When loading model weights, we needed to instantiate the model class first, because the class defines the structure of a network. We might want to save the structure of this class together with the model, in which case we can pass ``model`` (and not ``model.state_dict()``) to the saving function: .. code-block:: default torch.save(model, 'model.pth') We can then load the model like this: .. code-block:: default model = torch.load('model.pth') .. note:: This approach uses Python `pickle `_ module when serializing the model, thus it relies on the actual class definition to be available when loading the model. Exporting Model to ONNX ----------------------- PyTorch also has native ONNX export support. Given the dynamic nature of the PyTorch execution graph, however, the export process must traverse the execution graph to produce a persisted ONNX model. For this reason, a test variable of the appropriate size should be passed in to the export routine (in our case, we will create a dummy zero tensor of the correct size): .. code-block:: default input_image = torch.zeros((1,3,224,224)) onnx.export(model, input_image, 'model.onnx') There are a lot of things you can do with ONNX model, including running inference on different platforms and in different programming languages. For more details, we recommend visiting `ONNX tutorial `_. Congratulations! You have completed the PyTorch beginner tutorial! Try `revisting the first page `_ to see the tutorial in its entirety again. We hope this tutorial has helped you get started with deep learning on PyTorch. Good luck! .. rst-class:: sphx-glr-timing **Total running time of the script:** ( 0 minutes 0.000 seconds) .. _sphx_glr_download_beginner_basics_saveloadrun_tutorial.py: .. only :: html .. container:: sphx-glr-footer :class: sphx-glr-footer-example .. container:: sphx-glr-download :download:`Download Python source code: saveloadrun_tutorial.py ` .. container:: sphx-glr-download :download:`Download Jupyter notebook: saveloadrun_tutorial.ipynb ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_