quantize_model
Description
Quantizes a graph based on the quantization configuration file config_file, inserts the quantization operators, generates a quantization factor record file record_file, and returns the list of newly added operators.
Prototype
quant_add_ops = quantize_model(graph, config_file, record_file)
Parameters
Parameter |
Input/Return |
Description |
Restrictions |
---|---|---|---|
graph |
Input |
A tf.Graph of the model to be quantized. |
A tf.Graph. Must be an inference graph containing no training-mode operators. For example, is_training of the FusedBatchNormV3 operator must be False. |
config_file |
Input |
Quantization configuration file generated by the user, which is used to specify the configuration of the quantization layer in the tf.Graph graph of the model |
A string. |
record_file |
Input |
Directory of the quantization factor record file, including the file name. |
A string. |
quant_add_ops |
Return |
List of operators to be inserted for quantization. NOTE:
The variable values in the list cannot be found in the model training parameter file, so if the model training parameters are directly restored, an error indicating that the variables cannot be found occurs. Therefore, before restoring the model training parameters, you need to: Remove the variable values in the quantity_add_ops list from the recovery list. For details about how to remove the variable values, see How Do I Restore the Model Training Parameters After Quantization Operators Are Inserted?. |
A list of tf.Operations. |
Returns
Returns a list of quantized layers on the network.
quantize_model performs BN convergence on the graph. If the outputs of the network model contain the BN layer and the BN layer is also converged, the output node of the network changes. For example, Conv+BN (or Conv+BiasAdd+BN) is combined into Conv+BiasAdd, and an output node equivalent to BN is a BiasAdd node.
Example
import amct_tensorflow as amct # Create a network to be quantized. network = build_network() # Quantize the model. amct.quantize_model( graph=tf.get_default_graph(), config_file="./configs/config.json", record_file="./record_scale_offset.txt")