The most surprising thing about converting TensorFlow models to ONNX is that the conversion process itself often reveals hidden assumptions or bugs in the original TensorFlow graph that were previously masked by TensorFlow’s dynamic nature.

Let’s see this in action. Suppose you have a trained TensorFlow model, perhaps for image classification, saved in the SavedModel format. You want to convert it to ONNX to leverage ONNX Runtime for faster inference on various hardware.

Here’s a simplified TensorFlow SavedModel structure:

my_model/
  saved_model.pb
  variables/
    ...

You’d typically use the tf2onnx library for this. A basic conversion might look like this:

python -m tf2onnx.convert \
  --saved-model my_model \
  --output my_model.onnx \
  --opset 13 \
  --inputs input_tensor:0 \
  --outputs output_tensor:0

This command tells tf2onnx to find your SavedModel, convert it to an ONNX file named my_model.onnx, targeting ONNX opset version 13, and explicitly specifying the input and output tensor names.

The beauty of ONNX is its standardized representation. Once converted, you can load and run my_model.onnx using ONNX Runtime, which is often significantly faster and more memory-efficient than TensorFlow’s own inference engine, especially on specialized hardware like NVIDIA GPUs or Intel NPUs.

import onnxruntime as ort
import numpy as np

# Load the ONNX model
session = ort.InferenceSession("my_model.onnx")

# Prepare dummy input data (e.g., a batch of images)
# The shape and dtype must match what the model expects
input_shape = session.get_inputs()[0].shape
input_dtype = session.get_inputs()[0].type
# Example: input_shape might be [1, 224, 224, 3] for a batch of 1 image
# Example: input_dtype might be 'tensor(float)' which maps to np.float32

# Create dummy data
dummy_input = np.random.rand(*input_shape).astype(np.float32) # Ensure correct dtype

# Run inference
outputs = session.run(None, {session.get_inputs()[0].name: dummy_input})

# 'outputs' will be a list of numpy arrays, corresponding to the model's outputs
print("Inference successful!")

The mental model here is that TensorFlow, especially with eager execution, can be very dynamic. Operations can be defined and executed on the fly. When you convert to ONNX, you’re essentially freezing that graph into a static, directed acyclic graph (DAG). This static nature is what allows for aggressive optimizations. The tf2onnx tool acts as a translator, mapping TensorFlow operations to their ONNX equivalents. The opset version is crucial; it defines the set of ONNX operators available. Using a higher opset version allows for more modern and efficient operators but requires a compatible ONNX Runtime.

The conversion process involves tf2onnx analyzing the TensorFlow graph, identifying TensorFlow operations, and finding the corresponding ONNX operators in the specified opset. It then reconstructs the graph using these ONNX operators. This static representation is key. TensorFlow might have some implicit handling of variable-sized dimensions or dynamic control flow that tf2onnx needs to resolve explicitly during conversion, often by inferring shapes or requiring explicit input/output specifications.

Consider this: when you define a tf.keras.layers.Conv2D in TensorFlow, it has parameters like filters, kernel_size, strides, and padding. tf2onnx will look for the Conv operator in the ONNX opset. If the padding was 'same', TensorFlow implicitly calculates the padding amount. tf2onnx needs to translate this into explicit pads attribute for the ONNX Conv operator. If the original TensorFlow graph used a custom operation not directly supported by the target ONNX opset, tf2onnx might fail, or you might need to implement a custom ONNX operator or rewrite the TensorFlow part.

One aspect often overlooked is how TensorFlow’s tf.function decorator can affect conversion. When you use @tf.function, TensorFlow builds a static graph behind the scenes. tf2onnx works best with these traced graphs. If your model was trained or saved without tf.function and relies heavily on eager execution, the conversion might be less straightforward, and tf2onnx might try to trace it, or you might need to explicitly trace it before conversion. The tracing process essentially captures the graph structure that tf2onnx then consumes.

The next challenge you’ll likely encounter is optimizing the ONNX model for specific hardware targets using ONNX Runtime’s execution providers.

Want structured learning?

Take the full Tensorflow course →