Building APIs with FastAPI
Why FastAPI?
FastAPI is a modern Python web framework designed for building APIs. It was created by Sebastián Ramírez in 2018 and has rapidly become the go-to framework for serving ML models in production.
Key Advantages
| Feature | Description | Why It Matters for ML |
|---|---|---|
| Async support | Built on Starlette, supports async/await | Handle many concurrent prediction requests |
| Type hints | Uses Python type annotations natively | Self-documenting code, IDE autocompletion |
| Pydantic validation | Automatic request/response validation | Reject invalid model inputs before inference |
| Auto-generated docs | Swagger UI + ReDoc out of the box | Clients can explore and test your API instantly |
| High performance | One of the fastest Python frameworks | Low latency for real-time predictions |
| Standards-based | Built on OpenAPI and JSON Schema | Easy integration with any client or tool |
- WSGI (Web Server Gateway Interface): Synchronous — one request at a time per worker (used by Flask)
- ASGI (Asynchronous Server Gateway Interface): Asynchronous — handles many concurrent requests in a single worker (used by FastAPI)
For ML APIs receiving many simultaneous prediction requests, ASGI significantly improves throughput.
Installation and Setup
Installing Dependencies
pip install fastapi uvicorn pydantic
pip install scikit-learn joblib numpy pandas
Project Structure
A well-organized ML API project follows this structure:
ml-api/
├── app/
│ ├── __init__.py
│ ├── main.py # FastAPI application entry point
│ ├── models/
│ │ ├── __init__.py
│ │ └── schemas.py # Pydantic request/response models
│ ├── routers/
│ │ ├── __init__.py
│ │ └── predictions.py # Prediction route handlers
│ ├── services/
│ │ ├── __init__.py
│ │ └── ml_service.py # Model loading and inference logic
│ └── core/
│ ├── __init__.py
│ └── config.py # Configuration settings
├── models/
│ └── model_v1.joblib # Serialized ML model
├── requirements.txt
└── README.md
Your First FastAPI Application
Minimal Example
from fastapi import FastAPI
app = FastAPI(
title="ML Prediction API",
description="API for serving machine learning predictions",
version="1.0.0",
)
@app.get("/")
def root():
return {"message": "ML Prediction API is running"}
@app.get("/health")
def health_check():
return {"status": "healthy"}
Run the application:
uvicorn app.main:app --reload --host 0.0.0.0 --port 8000
This gives you:
- API available at
http://localhost:8000 - Swagger UI at
http://localhost:8000/docs - ReDoc at
http://localhost:8000/redoc
Pydantic Models for Request/Response Validation
Pydantic is the foundation of FastAPI's data validation. You define Python classes with type annotations, and Pydantic automatically validates incoming data.
Defining Input Schemas
from pydantic import BaseModel, Field
from typing import Optional, List
from enum import Enum
class LoanPurpose(str, Enum):
home = "home"
car = "car"
education = "education"
personal = "personal"
class PredictionInput(BaseModel):
"""Input features for loan approval prediction."""
age: int = Field(
...,
ge=18,
le=120,
description="Applicant age in years",
example=35,
)
income: float = Field(
...,
gt=0,
description="Annual income in USD",
example=55000.0,
)
credit_score: int = Field(
...,
ge=300,
le=850,
description="Credit score (FICO)",
example=720,
)
employment_years: float = Field(
...,
ge=0,
description="Years of employment",
example=8.5,
)
loan_amount: float = Field(
...,
gt=0,
description="Requested loan amount in USD",
example=25000.0,
)
loan_purpose: LoanPurpose = Field(
...,
description="Purpose of the loan",
example="home",
)
class Config:
json_schema_extra = {
"example": {
"age": 35,
"income": 55000.0,
"credit_score": 720,
"employment_years": 8.5,
"loan_amount": 25000.0,
"loan_purpose": "home",
}
}
Defining Output Schemas
from datetime import datetime
class PredictionOutput(BaseModel):
"""Prediction result from the ML model."""
prediction: str = Field(..., description="Predicted class label")
probability: float = Field(
..., ge=0, le=1, description="Prediction confidence"
)
model_version: str = Field(..., description="Model version used")
timestamp: datetime = Field(
default_factory=datetime.utcnow,
description="Prediction timestamp",
)
class ErrorResponse(BaseModel):
"""Standard error response."""
error_code: str
message: str
details: Optional[List[str]] = None
| Constraint | Usage | Example |
|---|---|---|
... (Ellipsis) | Required field | Field(...) |
default= | Default value | Field(default=0.5) |
ge=, gt= | Greater than (or equal) | Field(ge=0) |
le=, lt= | Less than (or equal) | Field(le=100) |
min_length= | Min string length | Field(min_length=1) |
max_length= | Max string length | Field(max_length=255) |
regex= | Pattern matching | Field(regex=r"^[a-z]+$") |
Loading and Serving an ML Model
The ML Service
Create a service class that loads the model once at startup and reuses it for every request:
import joblib
import numpy as np
from pathlib import Path
class MLService:
"""Handles model loading and inference."""
def __init__(self):
self.model = None
self.model_version = "unknown"
self.feature_names = [
"age", "income", "credit_score",
"employment_years", "loan_amount",
]
def load_model(self, model_path: str):
"""Load a serialized model from disk."""
path = Path(model_path)
if not path.exists():
raise FileNotFoundError(f"Model not found: {model_path}")
self.model = joblib.load(path)
self.model_version = path.stem
return self
def predict(self, features: dict) -> dict:
"""Run inference on input features."""
if self.model is None:
raise RuntimeError("Model not loaded")
feature_array = np.array([[
features["age"],
features["income"],
features["credit_score"],
features["employment_years"],
features["loan_amount"],
]])
prediction = self.model.predict(feature_array)[0]
probabilities = self.model.predict_proba(feature_array)[0]
return {
"prediction": "approved" if prediction == 1 else "denied",
"probability": float(max(probabilities)),
"model_version": self.model_version,
}
ml_service = MLService()
Wiring It Into FastAPI with Lifespan Events
from contextlib import asynccontextmanager
from fastapi import FastAPI
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Load model at startup, clean up at shutdown."""
ml_service.load_model("models/model_v1.joblib")
print(f"Model loaded: {ml_service.model_version}")
yield
print("Shutting down, releasing resources...")
app = FastAPI(
title="ML Prediction API",
version="1.0.0",
lifespan=lifespan,
)
Never load the model inside a request handler. Deserializing a model from disk on every request adds massive latency. Use the lifespan event or a global singleton to load it once at startup.
Creating the Prediction Endpoint
Complete Prediction Route
from fastapi import FastAPI, HTTPException
from datetime import datetime
@app.post(
"/api/v1/predict",
response_model=PredictionOutput,
summary="Get a loan approval prediction",
tags=["Predictions"],
)
def predict(input_data: PredictionInput):
"""
Submit loan application features and receive
an approval/denial prediction with confidence score.
"""
try:
features = input_data.model_dump(exclude={"loan_purpose"})
result = ml_service.predict(features)
return PredictionOutput(
prediction=result["prediction"],
probability=result["probability"],
model_version=result["model_version"],
timestamp=datetime.utcnow(),
)
except RuntimeError as e:
raise HTTPException(
status_code=503,
detail=f"Model not available: {str(e)}",
)
except Exception as e:
raise HTTPException(
status_code=500,
detail=f"Prediction failed: {str(e)}",
)
Health Check Endpoint
@app.get("/health", tags=["System"])
def health_check():
"""Check if the API and model are ready."""
model_loaded = ml_service.model is not None
return {
"status": "healthy" if model_loaded else "degraded",
"model_loaded": model_loaded,
"model_version": ml_service.model_version,
"timestamp": datetime.utcnow().isoformat(),
}
Dependency Injection
FastAPI's dependency injection system lets you share logic across endpoints cleanly. This is useful for authentication, database connections, or model access.
from fastapi import Depends, Header, HTTPException
async def verify_api_key(x_api_key: str = Header(...)):
"""Validate the API key from request headers."""
valid_keys = {"sk_live_abc123", "sk_live_def456"}
if x_api_key not in valid_keys:
raise HTTPException(
status_code=401,
detail="Invalid API key",
)
return x_api_key
@app.post("/api/v1/predict", dependencies=[Depends(verify_api_key)])
def predict(input_data: PredictionInput):
# Only reached if API key is valid
...
Error Handling
Custom Exception Handlers
from fastapi import Request
from fastapi.responses import JSONResponse
class ModelNotLoadedError(Exception):
pass
class PredictionError(Exception):
def __init__(self, detail: str):
self.detail = detail
@app.exception_handler(ModelNotLoadedError)
async def model_not_loaded_handler(request: Request, exc: ModelNotLoadedError):
return JSONResponse(
status_code=503,
content={
"error_code": "MODEL_NOT_LOADED",
"message": "The ML model is not available. Please try again later.",
},
)
@app.exception_handler(PredictionError)
async def prediction_error_handler(request: Request, exc: PredictionError):
return JSONResponse(
status_code=500,
content={
"error_code": "PREDICTION_FAILED",
"message": exc.detail,
},
)
Middleware
Middleware runs before every request and after every response. It's perfect for logging, timing, and adding headers.
Request Timing Middleware
import time
from starlette.middleware.base import BaseHTTPMiddleware
class TimingMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next):
start_time = time.perf_counter()
response = await call_next(request)
duration_ms = (time.perf_counter() - start_time) * 1000
response.headers["X-Response-Time-Ms"] = f"{duration_ms:.2f}"
return response
app.add_middleware(TimingMiddleware)
CORS Middleware
from fastapi.middleware.cors import CORSMiddleware
app.add_middleware(
CORSMiddleware,
allow_origins=[
"http://localhost:3000",
"https://myapp.example.com",
],
allow_credentials=True,
allow_methods=["GET", "POST"],
allow_headers=["*"],
)
Async vs Sync Endpoints
FastAPI supports both synchronous and asynchronous endpoints. The choice depends on what your endpoint does.
| Scenario | Use | Why |
|---|---|---|
| ML inference (CPU-bound) | def (sync) | scikit-learn is not async; FastAPI runs it in a thread pool |
| Database queries (I/O-bound) | async def | Non-blocking I/O, better concurrency |
| File operations | async def with aiofiles | Doesn't block the event loop |
| External API calls | async def with httpx | Concurrent HTTP requests |
# Sync — FastAPI runs this in a thread pool automatically
@app.post("/api/v1/predict")
def predict_sync(input_data: PredictionInput):
result = ml_service.predict(input_data.model_dump())
return result
# Async — runs on the event loop, don't do CPU-heavy work here
@app.post("/api/v1/predict-async")
async def predict_async(input_data: PredictionInput):
import asyncio
loop = asyncio.get_event_loop()
result = await loop.run_in_executor(
None, ml_service.predict, input_data.model_dump()
)
return result
If you define async def but then call a blocking function (like joblib.load() or model.predict()), you will block the event loop and freeze all other requests. Use def (sync) for CPU-bound ML inference, or explicitly run it in an executor.
File Upload Endpoint
For models that process images, audio, or documents, you need file upload support.
from fastapi import UploadFile, File
import io
from PIL import Image
@app.post("/api/v1/predict/image", tags=["Predictions"])
async def predict_image(
file: UploadFile = File(..., description="Image file for classification"),
):
if file.content_type not in ["image/jpeg", "image/png"]:
raise HTTPException(
status_code=400,
detail="Only JPEG and PNG images are supported",
)
contents = await file.read()
image = Image.open(io.BytesIO(contents))
# Preprocess and predict (simplified)
result = image_model.predict(image)
return {
"filename": file.filename,
"prediction": result["label"],
"confidence": result["confidence"],
}
Batch Prediction Endpoint
For efficiency, allow clients to submit multiple inputs in a single request.
from typing import List
class BatchInput(BaseModel):
inputs: List[PredictionInput] = Field(
..., min_length=1, max_length=100,
description="List of prediction inputs (max 100)",
)
class BatchOutput(BaseModel):
predictions: List[PredictionOutput]
total: int
processing_time_ms: float
@app.post("/api/v1/predict/batch", response_model=BatchOutput, tags=["Predictions"])
def predict_batch(batch: BatchInput):
start = time.perf_counter()
results = []
for item in batch.inputs:
features = item.model_dump(exclude={"loan_purpose"})
result = ml_service.predict(features)
results.append(PredictionOutput(
prediction=result["prediction"],
probability=result["probability"],
model_version=result["model_version"],
))
duration = (time.perf_counter() - start) * 1000
return BatchOutput(
predictions=results,
total=len(results),
processing_time_ms=round(duration, 2),
)
Complete Application — Putting It All Together
from contextlib import asynccontextmanager
from datetime import datetime
from fastapi import FastAPI, HTTPException, Depends, Header
from fastapi.middleware.cors import CORSMiddleware
import time
import joblib
import numpy as np
from pydantic import BaseModel, Field
from typing import Optional
# --- Schemas ---
class PredictionInput(BaseModel):
age: int = Field(..., ge=18, le=120)
income: float = Field(..., gt=0)
credit_score: int = Field(..., ge=300, le=850)
employment_years: float = Field(..., ge=0)
loan_amount: float = Field(..., gt=0)
class PredictionOutput(BaseModel):
prediction: str
probability: float
model_version: str
timestamp: datetime = Field(default_factory=datetime.utcnow)
# --- ML Service ---
class MLService:
def __init__(self):
self.model = None
self.version = "unknown"
def load(self, path: str):
self.model = joblib.load(path)
self.version = "v1.0"
def predict(self, features: dict) -> dict:
arr = np.array([[
features["age"], features["income"],
features["credit_score"], features["employment_years"],
features["loan_amount"],
]])
pred = self.model.predict(arr)[0]
proba = self.model.predict_proba(arr)[0]
return {
"prediction": "approved" if pred == 1 else "denied",
"probability": float(max(proba)),
"model_version": self.version,
}
ml = MLService()
# --- Lifespan ---
@asynccontextmanager
async def lifespan(app: FastAPI):
ml.load("models/model_v1.joblib")
yield
# --- App ---
app = FastAPI(title="Loan Prediction API", version="1.0.0", lifespan=lifespan)
app.add_middleware(
CORSMiddleware,
allow_origins=["http://localhost:3000"],
allow_methods=["*"],
allow_headers=["*"],
)
@app.get("/health", tags=["System"])
def health():
return {"status": "healthy", "model": ml.version}
@app.post("/api/v1/predict", response_model=PredictionOutput, tags=["Predictions"])
def predict(data: PredictionInput):
try:
result = ml.predict(data.model_dump())
return PredictionOutput(**result)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
Run it:
uvicorn app.main:app --reload --port 8000
Test it:
curl -X POST http://localhost:8000/api/v1/predict \
-H "Content-Type: application/json" \
-d '{"age": 35, "income": 55000, "credit_score": 720, "employment_years": 8, "loan_amount": 25000}'
FastAPI Request Lifecycle — Summary
Summary
| Topic | Key Takeaway |
|---|---|
| FastAPI | Modern, fast, type-safe Python framework for APIs |
| Pydantic | Automatic validation of input/output with clear error messages |
| Lifespan events | Load model once at startup, not per request |
| Dependencies | Reusable logic for auth, model access, etc. |
| Middleware | Cross-cutting concerns (CORS, timing, logging) |
| Sync vs Async | Use def for CPU-bound ML inference |
| File uploads | UploadFile for image/document prediction APIs |
| Batch predictions | Process multiple inputs in a single request |
FastAPI Quick Reference
| Action | Code |
|---|---|
| Create app | app = FastAPI(title="...", version="...") |
| GET endpoint | @app.get("/path") |
| POST endpoint | @app.post("/path", response_model=Schema) |
| Run server | uvicorn app.main:app --reload |
| Access docs | http://localhost:8000/docs |
| Validate input | Define a BaseModel subclass |
| Add middleware | app.add_middleware(MiddlewareClass, ...) |
| Dependency injection | @app.post("/", dependencies=[Depends(fn)]) |
| Raise HTTP error | raise HTTPException(status_code=400, detail="...") |