The most surprising thing about serving TensorFlow models with Flask and FastAPI is how little of the framework actually touches your model.

Let’s see it in action. Imagine you’ve trained a simple Keras model to classify MNIST digits.

import tensorflow as tf
from tensorflow import keras

# Load a pre-trained model (or train your own)
model = keras.models.load_model('mnist_model.h5')

# Prepare some dummy data for prediction
import numpy as np
dummy_data = np.random.rand(1, 28, 28, 1).astype(np.float32)

Now, let’s serve this with FastAPI.

from fastapi import FastAPI
from pydantic import BaseModel
import numpy as np

class Image(BaseModel):
    data: list[list[float]] # Expecting a 28x28 array of floats

app = FastAPI()

@app.post("/predict/")
async def predict(image: Image):
    # Convert the incoming list of lists to a NumPy array
    image_array = np.array(image.data).astype(np.float32)

    # Reshape to match the model's expected input shape (batch_size, height, width, channels)
    # Assuming the model expects (None, 28, 28, 1)
    image_array = image_array.reshape((1, 28, 28, 1))

    # Make the prediction
    prediction = model.predict(image_array)

    # Return the prediction (e.g., class probabilities)
    return {"prediction": prediction.tolist()}

# To run this:
# 1. Save the code as main.py
# 2. Install: pip install fastapi uvicorn tensorflow pydantic numpy
# 3. Run: uvicorn main:app --reload

With Flask, it looks very similar:

from flask import Flask, request, jsonify
import numpy as np
import tensorflow as tf
from tensorflow import keras

app = Flask(__name__)

# Load the model
model = keras.models.load_model('mnist_model.h5')

@app.route('/predict', methods=['POST'])
def predict():
    data = request.get_json()
    # Assuming the incoming JSON has a 'data' key with a list of lists
    image_array = np.array(data['data']).astype(np.float32)

    # Reshape
    image_array = image_array.reshape((1, 28, 28, 1))

    # Predict
    prediction = model.predict(image_array)

    # Return
    return jsonify({"prediction": prediction.tolist()})

# To run this:
# 1. Save the code as app.py
# 2. Install: pip install Flask tensorflow numpy
# 3. Run: flask run

The core problem these frameworks solve is handling HTTP requests, parsing incoming data, and sending back responses. Your TensorFlow model, however, lives in its own world. You load it, you feed it data, and you get results. The web framework is just the messenger, translating between the web request/response cycle and the model’s input/output.

The model’s input shape is critical. If your model expects (None, 28, 28, 1) (batch size, height, width, channels), you must reshape the incoming data to match this. For MNIST, the input is grayscale, hence the 1 channel. If it were color images, it would be 3 for RGB. The None or None, in the shape often means the batch size can be anything, but when you’re serving a single prediction, you’ll typically set it to 1.

You control the interface with the model. In FastAPI, BaseModel from Pydantic defines the expected structure of the incoming JSON. In Flask, you’re manually parsing request.get_json(). Both are essentially validating and deserializing the incoming HTTP payload into a format Python can understand, which is then converted into a NumPy array for TensorFlow.

The model.predict() call is where the heavy lifting happens. TensorFlow takes the NumPy array, runs it through the model’s layers, and produces an output. This output is then converted back into a Python list (.tolist()) so it can be serialized into JSON for the HTTP response.

A common pitfall is not handling the data type correctly. TensorFlow models often expect float32. If your incoming data is interpreted as strings or integers by default, you’ll get errors. Explicitly casting with .astype(np.float32) is key.

The real power of these frameworks isn’t in running the model itself, but in their ability to scale. You can run multiple instances of your FastAPI/Flask app behind a load balancer, allowing you to handle many prediction requests concurrently. They also provide mechanisms for health checks, logging, and request throttling, which are essential for production deployments.

When you send a POST request with JSON data to /predict/ in FastAPI, Uvicorn receives it, FastAPI parses the JSON into the Image Pydantic model, you convert that data to a NumPy array, TensorFlow makes a prediction, and FastAPI serializes the result back into JSON to send as the HTTP response.

The next step after getting basic predictions working is often handling batch predictions, where you send multiple samples in a single request to maximize throughput.

Want structured learning?

Take the full Tensorflow course →