How to Build a Simple Image Classification API Using TensorFlow and Flask
Introduction
Image classification is one of the most common applications of machine learning. With tools like TensorFlow, you can train or use pre-trained models to classify images, and with Flask, you can serve your model as an API for real-world applications.
In this post, I’ll show you how to build a simple image classification API using TensorFlow and Flask. This API will take an image as input and return the predicted class. We’ll use a pre-trained model from TensorFlow Hub to keep things simple and focus on deployment.
Step 1: Setting Up the Environment
First, let’s set up the environment and install the required libraries.
# Create a virtual environment
python -m venv image_classification_env
source image_classification_env/bin/activate # On Windows, use `image_classification_env\Scripts\activate`
# Install required libraries
pip install tensorflow flask numpy pillow
Step 2: Load a Pre-Trained Model
We’ll use a pre-trained image classification model from TensorFlow Hub. For this example, we’ll use the MobileNetV2 model, which is lightweight and efficient.
Here’s how to load the model:
import tensorflow as tf
import tensorflow_hub as hub
# Load the pre-trained MobileNetV2 model from TensorFlow Hub
model = tf.keras.Sequential([
hub.KerasLayer("https://tfhub.dev/google/tf2-preview/mobilenet_v2/classification/4", input_shape=(224, 224, 3))
])
# Load the labels for the model
import requests
labels_url = "https://storage.googleapis.com/download.tensorflow.org/data/ImageNetLabels.txt"
labels = requests.get(labels_url).text.split("\n")
The model takes images of size 224x224 as input and outputs predictions for 1,000 ImageNet classes.
Step 3: Preprocess the Input Image
To classify an image, we need to preprocess it to match the model’s input requirements. This includes resizing the image to 224x224, normalizing pixel values, and converting it to a tensor.
Here’s the preprocessing function:
from PIL import Image
import numpy as np
def preprocess_image(image_path):
# Load the image
image = Image.open(image_path).convert("RGB")
# Resize the image to 224x224
image = image.resize((224, 224))
# Convert the image to a numpy array and normalize pixel values
image_array = np.array(image) / 255.0
# Add a batch dimension
image_tensor = tf.expand_dims(image_array, axis=0)
return image_tensor
Step 4: Make Predictions
Now that we have the model and the preprocessing function, we can write a function to make predictions.
def predict(image_path):
# Preprocess the image
image_tensor = preprocess_image(image_path)
# Get predictions from the model
predictions = model(image_tensor).numpy()
# Get the class with the highest probability
predicted_class = np.argmax(predictions[0])
# Get the label for the predicted class
predicted_label = labels[predicted_class]
return predicted_label
You can test this function by passing an image file path:
print(predict("example_image.jpg"))
Step 5: Create a Flask API
Next, we’ll create a Flask API to serve the model. The API will accept an image file via a POST request, process it, and return the predicted class.
Here’s the Flask app:
from flask import Flask, request, jsonify
app = Flask(__name__)
@app.route('/predict', methods=['POST'])
def classify_image():
# Check if an image file is included in the request
if 'file' not in request.files:
return jsonify({'error': 'No file provided'}), 400
# Get the image file from the request
file = request.files['file']
# Save the file temporarily
file_path = "temp_image.jpg"
file.save(file_path)
# Make a prediction
predicted_label = predict(file_path)
# Return the prediction as a JSON response
return jsonify({'prediction': predicted_label})
if __name__ == '__main__':
app.run(debug=True)
Step 6: Test the API
To test the API, you can use a tool like Postman or write a simple Python client. Here’s an example of how to test it using Python:
import requests
# Send an image to the API
url = "http://127.0.0.1:5000/predict"
files = {'file': open('example_image.jpg', 'rb')}
response = requests.post(url, files=files)
# Print the response
print(response.json())
Step 7: Deploy the API
To make the API accessible, you can deploy it on a cloud platform like Heroku, AWS, or Google Cloud. Use Docker to containerize the application for easy deployment.
Here’s a simple Dockerfile:
# Use the official Python image
FROM python:3.9-slim
# Set the working directory
WORKDIR /app
# Copy the application files
COPY . /app
# Install dependencies
RUN pip install tensorflow flask numpy pillow
# Expose the port
EXPOSE 5000
# Run the Flask app
CMD ["python", "app.py"]
Build and run the Docker container:
docker build -t image-classification-api .
docker run -p 5000:5000 image-classification-api
Challenges I Faced
- Model Size: Some pre-trained models are too large for lightweight applications. MobileNetV2 was a great choice for balancing accuracy and efficiency.
- Input Preprocessing: Ensuring the input image matches the model’s requirements was tricky at first but manageable with TensorFlow utilities.
Key Learnings - Pre-Trained Models Save Time: Using TensorFlow Hub models can save you weeks of training time.
- APIs Make AI Accessible: Flask is a simple yet powerful tool for serving machine learning models.
- Optimization Matters: For real-world applications, always choose models that balance accuracy and performance.
Final Thoughts
Building an image classification API is a great way to learn about deploying machine learning models. This project can be extended to include custom-trained models, additional endpoints, or even a front-end interface.
Next steps:
- Train a custom model for specific use cases.
- Add support for batch image classification.
- Deploy the API on a production server for real-world use.
If you have any questions or want to collaborate on similar projects, feel free to reach out!