FrameworkPTAdapter 2.0.1 PyTorch Network Model Porting and Training Guide 01

Exporting an ONNX Model

Exporting an ONNX Model

Introduction

The deployment policy of the Ascend AI Processor for PyTorch models is implemented based on the ONNX module that is supported by PyTorch. ONNX is a mainstream model format in the industry and is widely used for model sharing and deployment. This section describes how to export a checkpoint file as an ONNX model by using the torch.onnx.export() API.

Using the .pth or .pt File to Export the ONNX Model

The saved .pth or .pt file can be restored by building a model using PyTorch and then loading the weight. Then you can export the ONNX model. The following is an example.
import torch
import torch.onnx
import torchvision.models as models
# Set the CPU to be used to export the model.
device = torch.device("cpu") 
 
def convert():
# The model definition comes from the torchvision. The model file generated in the example is based on the ResNet-50 model.
    model = models.resnet50(pretrained = False)  
    resnet50_model = torch.load('resnet50.pth', map_location='cpu')
    model.load_state_dict(resnet50_model) 
 
    batch_size = 1 # Size of the batch processing
    input_shape = (3, 224, 224) # Input data. Replace it with the actual shape.

    # Set the model to inference mode.
    model.eval()

    dummy_input = torch.randn(batch_size, *input_shape) # Define the input shape.
    torch.onnx.export(model, 
                      dummy_input, 
                      "resnet50_official.onnx", 
                      input_names = ["input"], # Construct the input name.
                      output_names = ["output"], # Construct the output name.
                      opset_version=11, # Currently, the ATC tool supports only opset_version=11.
                      dynamic_axes={"input":{0:"batch_size"}, "output":{0:"batch_size"}}) # Dynamic axes of the output is supported.
                      ) 
     
if __name__ == "__main__":
    convert()
  • Before exporting the ONNX model, the model.eval() must be called to set the dropout and batch normalization layers to inference mode.
  • The model in the sample script comes from the definition in the torchvision module. You need to specify a model when using your own model.
  • The constructed input and output must correspond to the input and output during training. Otherwise, the inference cannot be performed properly.

Using the .pth.tar File to Export the ONNX Model

Before exporting the ONNX model using the .pth.tar file, you need to check the saved information. Sometimes, the saved node name may be different from the node name in the model definition. For example, a prefix and suffix may be added. During the conversion, you can modify the node name. The following is an example of the conversion.
import torch
import torch.onnx
from collections import OrderedDict
import mobilenet

# In this example, when the .pth.tar file is saved, the prefix module is added to the node name. Delete it by traversing.
def proc_nodes_module(checkpoint, AttrName):
    new_state_dict = OrderedDict()
    for key, value in checkpoint[AttrName].items():
        if key == "module.features.0.0.weight":
            print(value)
        if(key[0:7] == "module."):
            name = key[7:]
        else:
            name = key[0:]

        new_state_dict[name] = value
    return new_state_dict

def convert():
    checkpoint = torch.load("./mobilenet_cpu.pth.tar", map_location=torch.device('cpu'))
    checkpoint['state_dict'] = proc_nodes_module(checkpoint,'state_dict')
    model = mobilenet.mobilenet_v2(pretrained = False)
    model.load_state_dict(checkpoint['state_dict'])
    model.eval()
    input_names = ["actual_input_1"]
    output_names = ["output1"]
    dummy_input = torch.randn(1, 3, 224, 224)
    torch.onnx.export(model, dummy_input, "mobilenetV2_npu.onnx", input_names = input_names, output_names = output_names, opset_version=11)

if __name__ == "__main__":
    convert()
Favorite
Download
Update Date:2021-06-10
Document ID:EDOC1100191782
Views:150718
Downloads:106
Average rating:0.0Points

Digital Signature File

digtal sigature tool