"""Rate limiting middleware for FastAPI.
Usage:
from app.rate_limit import RateLimitMiddleware
# In main.py:
app.add_middleware(RateLimitMiddleware, requests_per_minute=60)
"""
import logging
from fastapi import Request, status
from fastapi.responses import JSONResponse
from starlette.middleware.base import BaseHTTPMiddleware
from collections import defaultdict
from typing import Dict
import time
from app.dependencies import ensure_tenant_context
def _get_correlation_id(request: Request) -> str:
return getattr(getattr(request, "state", None), "correlation_id", None) or request.headers.get(
"x-correlation-id", ""
)
logger = logging.getLogger("kumiho.fastapi.rate_limit")
[docs]
class RateLimitMiddleware(BaseHTTPMiddleware):
"""
Simple in-memory rate limiting middleware.
Tracks requests per client IP and enforces rate limits.
For production, consider using Redis or a dedicated rate limiting service.
Args:
requests_per_minute: Maximum requests allowed per minute per client
burst_size: Maximum burst size (allows short bursts above the rate)
"""
def __init__(self, app, requests_per_minute: int = 60, burst_size: int = 10):
super().__init__(app)
self.requests_per_minute = requests_per_minute
self.burst_size = burst_size
# Dict[client_ip, List[timestamp]]
self.request_times: Dict[str, list] = defaultdict(list)
self.cleanup_interval = 300 # Clean up old entries every 5 minutes
self.last_cleanup = time.time()
def _cleanup_old_entries(self):
"""Remove entries older than 2 minutes to prevent memory bloat."""
now = time.time()
if now - self.last_cleanup < self.cleanup_interval:
return
cutoff_time = now - 120 # 2 minutes ago
for client_ip in list(self.request_times.keys()):
# Filter out old timestamps
self.request_times[client_ip] = [
ts for ts in self.request_times[client_ip]
if ts > cutoff_time
]
# Remove client if no recent requests
if not self.request_times[client_ip]:
del self.request_times[client_ip]
self.last_cleanup = now
def _get_client_ip(self, request: Request) -> str:
"""Extract client IP from request."""
# Check forwarded headers first (for proxies/load balancers)
forwarded = request.headers.get("X-Forwarded-For")
if forwarded:
return forwarded.split(",")[0].strip()
real_ip = request.headers.get("X-Real-IP")
if real_ip:
return real_ip
# Fallback to direct connection
return request.client.host if request.client else "unknown"
[docs]
async def dispatch(self, request: Request, call_next):
"""Process the request and enforce rate limiting."""
# Skip rate limiting for health checks and docs
if request.url.path in ["/", "/health", "/docs", "/redoc", "/openapi.json"]:
return await call_next(request)
try:
tenant_context = ensure_tenant_context()
tenant_id = tenant_context.tenant_id
except Exception:
# If we can't load tenant context (e.g. missing credentials),
# we fallback to a generic identifier so rate limiting still works per-IP.
tenant_id = "unknown"
client_ip = self._get_client_ip(request)
now = time.time()
auth_header = request.headers.get("Authorization", "")
user_bucket = "auth" if auth_header.startswith("Bearer ") else "anon"
rate_key = f"{tenant_id}:{user_bucket}:{client_ip}"
# Periodic cleanup
self._cleanup_old_entries()
# Get request times for this client
times = self.request_times[rate_key]
# Remove timestamps older than 1 minute
cutoff_time = now - 60
times[:] = [ts for ts in times if ts > cutoff_time]
# Check rate limit
if len(times) >= self.requests_per_minute:
# Calculate when the oldest request will expire
oldest = min(times)
retry_after = int(60 - (now - oldest)) + 1
logger.warning(
"Rate limit exceeded for tenant=%s bucket=%s ip=%s path=%s",
tenant_id,
user_bucket,
client_ip,
request.url.path,
)
correlation_id = _get_correlation_id(request)
return JSONResponse(
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
content={
"error": {
"code": "rate_limited",
"message": f"Rate limit exceeded. Maximum {self.requests_per_minute} requests per minute.",
"retryable": True,
"retry_after_ms": retry_after * 1000,
"details": {
"retry_after": retry_after,
"tenant_id": tenant_id,
"user_bucket": user_bucket,
},
},
"correlation_id": correlation_id,
# Transitional compatibility.
"detail": f"Rate limit exceeded. Maximum {self.requests_per_minute} requests per minute.",
"retry_after": retry_after,
"tenant_id": tenant_id,
"user_bucket": user_bucket,
},
headers={"Retry-After": str(retry_after), "x-correlation-id": correlation_id}
)
# Record this request
times.append(now)
# Process the request
response = await call_next(request)
# Add rate limit headers to response
response.headers["X-RateLimit-Limit"] = str(self.requests_per_minute)
response.headers["X-RateLimit-Remaining"] = str(self.requests_per_minute - len(times))
response.headers["X-RateLimit-Reset"] = str(int(min(times) + 60)) if times else str(int(now + 60))
response.headers["X-Kumiho-Tenant"] = tenant_id
response.headers["X-Kumiho-Rate-Bucket"] = user_bucket
return response