Codementor Events

How to Build a Simple Image Classification API Using TensorFlow and Flask

Published Dec 22, 2024
How to Build a Simple Image Classification API Using TensorFlow and Flask

Introduction

Image classification is one of the most common applications of machine learning. With tools like TensorFlow, you can train or use pre-trained models to classify images, and with Flask, you can serve your model as an API for real-world applications.

In this post, I’ll show you how to build a simple image classification API using TensorFlow and Flask. This API will take an image as input and return the predicted class. We’ll use a pre-trained model from TensorFlow Hub to keep things simple and focus on deployment.

Step 1: Setting Up the Environment

First, let’s set up the environment and install the required libraries.

# Create a virtual environment
python -m venv image_classification_env
source image_classification_env/bin/activate  # On Windows, use `image_classification_env\Scripts\activate`

# Install required libraries
pip install tensorflow flask numpy pillow

Step 2: Load a Pre-Trained Model

We’ll use a pre-trained image classification model from TensorFlow Hub. For this example, we’ll use the MobileNetV2 model, which is lightweight and efficient.

Here’s how to load the model:

import tensorflow as tf
import tensorflow_hub as hub

# Load the pre-trained MobileNetV2 model from TensorFlow Hub
model = tf.keras.Sequential([
    hub.KerasLayer("https://tfhub.dev/google/tf2-preview/mobilenet_v2/classification/4", input_shape=(224, 224, 3))
])

# Load the labels for the model
import requests
labels_url = "https://storage.googleapis.com/download.tensorflow.org/data/ImageNetLabels.txt"
labels = requests.get(labels_url).text.split("\n")

The model takes images of size 224x224 as input and outputs predictions for 1,000 ImageNet classes.

Step 3: Preprocess the Input Image

To classify an image, we need to preprocess it to match the model’s input requirements. This includes resizing the image to 224x224, normalizing pixel values, and converting it to a tensor.

Here’s the preprocessing function:

from PIL import Image
import numpy as np

def preprocess_image(image_path):
    # Load the image
    image = Image.open(image_path).convert("RGB")

    # Resize the image to 224x224
    image = image.resize((224, 224))

    # Convert the image to a numpy array and normalize pixel values
    image_array = np.array(image) / 255.0

    # Add a batch dimension
    image_tensor = tf.expand_dims(image_array, axis=0)

    return image_tensor

Step 4: Make Predictions

Now that we have the model and the preprocessing function, we can write a function to make predictions.

def predict(image_path):
    # Preprocess the image
    image_tensor = preprocess_image(image_path)

    # Get predictions from the model
    predictions = model(image_tensor).numpy()

    # Get the class with the highest probability
    predicted_class = np.argmax(predictions[0])

    # Get the label for the predicted class
    predicted_label = labels[predicted_class]

    return predicted_label

You can test this function by passing an image file path:

print(predict("example_image.jpg"))

Step 5: Create a Flask API

Next, we’ll create a Flask API to serve the model. The API will accept an image file via a POST request, process it, and return the predicted class.

Here’s the Flask app:

from flask import Flask, request, jsonify

app = Flask(__name__)

@app.route('/predict', methods=['POST'])
def classify_image():
    # Check if an image file is included in the request
    if 'file' not in request.files:
        return jsonify({'error': 'No file provided'}), 400

    # Get the image file from the request
    file = request.files['file']

    # Save the file temporarily
    file_path = "temp_image.jpg"
    file.save(file_path)

    # Make a prediction
    predicted_label = predict(file_path)

    # Return the prediction as a JSON response
    return jsonify({'prediction': predicted_label})

if __name__ == '__main__':
    app.run(debug=True)

Step 6: Test the API

To test the API, you can use a tool like Postman or write a simple Python client. Here’s an example of how to test it using Python:

import requests

# Send an image to the API
url = "http://127.0.0.1:5000/predict"
files = {'file': open('example_image.jpg', 'rb')}
response = requests.post(url, files=files)

# Print the response
print(response.json())

Step 7: Deploy the API

To make the API accessible, you can deploy it on a cloud platform like Heroku, AWS, or Google Cloud. Use Docker to containerize the application for easy deployment.

Here’s a simple Dockerfile:

# Use the official Python image
FROM python:3.9-slim

# Set the working directory
WORKDIR /app

# Copy the application files
COPY . /app

# Install dependencies
RUN pip install tensorflow flask numpy pillow

# Expose the port
EXPOSE 5000

# Run the Flask app
CMD ["python", "app.py"]

Build and run the Docker container:

docker build -t image-classification-api .
docker run -p 5000:5000 image-classification-api

Challenges I Faced

  • Model Size: Some pre-trained models are too large for lightweight applications. MobileNetV2 was a great choice for balancing accuracy and efficiency.
  • Input Preprocessing: Ensuring the input image matches the model’s requirements was tricky at first but manageable with TensorFlow utilities.
    Key Learnings
  • Pre-Trained Models Save Time: Using TensorFlow Hub models can save you weeks of training time.
  • APIs Make AI Accessible: Flask is a simple yet powerful tool for serving machine learning models.
  • Optimization Matters: For real-world applications, always choose models that balance accuracy and performance.

Final Thoughts

Building an image classification API is a great way to learn about deploying machine learning models. This project can be extended to include custom-trained models, additional endpoints, or even a front-end interface.

Next steps:

  • Train a custom model for specific use cases.
  • Add support for batch image classification.
  • Deploy the API on a production server for real-world use.

If you have any questions or want to collaborate on similar projects, feel free to reach out!

Discover and read more posts from Chidozie Managwu
get started