Source code for app.rate_limit

"""Rate limiting middleware for FastAPI.

Usage:
    from app.rate_limit import RateLimitMiddleware
    
    # In main.py:
    app.add_middleware(RateLimitMiddleware, requests_per_minute=60)
"""
import logging

from fastapi import Request, status
from fastapi.responses import JSONResponse
from starlette.middleware.base import BaseHTTPMiddleware
from collections import defaultdict
from typing import Dict
import time

from app.dependencies import ensure_tenant_context


def _get_correlation_id(request: Request) -> str:
    return getattr(getattr(request, "state", None), "correlation_id", None) or request.headers.get(
        "x-correlation-id", ""
    )

logger = logging.getLogger("kumiho.fastapi.rate_limit")


[docs] class RateLimitMiddleware(BaseHTTPMiddleware): """ Simple in-memory rate limiting middleware. Tracks requests per client IP and enforces rate limits. For production, consider using Redis or a dedicated rate limiting service. Args: requests_per_minute: Maximum requests allowed per minute per client burst_size: Maximum burst size (allows short bursts above the rate) """ def __init__(self, app, requests_per_minute: int = 60, burst_size: int = 10): super().__init__(app) self.requests_per_minute = requests_per_minute self.burst_size = burst_size # Dict[client_ip, List[timestamp]] self.request_times: Dict[str, list] = defaultdict(list) self.cleanup_interval = 300 # Clean up old entries every 5 minutes self.last_cleanup = time.time() def _cleanup_old_entries(self): """Remove entries older than 2 minutes to prevent memory bloat.""" now = time.time() if now - self.last_cleanup < self.cleanup_interval: return cutoff_time = now - 120 # 2 minutes ago for client_ip in list(self.request_times.keys()): # Filter out old timestamps self.request_times[client_ip] = [ ts for ts in self.request_times[client_ip] if ts > cutoff_time ] # Remove client if no recent requests if not self.request_times[client_ip]: del self.request_times[client_ip] self.last_cleanup = now def _get_client_ip(self, request: Request) -> str: """Extract client IP from request.""" # Check forwarded headers first (for proxies/load balancers) forwarded = request.headers.get("X-Forwarded-For") if forwarded: return forwarded.split(",")[0].strip() real_ip = request.headers.get("X-Real-IP") if real_ip: return real_ip # Fallback to direct connection return request.client.host if request.client else "unknown"
[docs] async def dispatch(self, request: Request, call_next): """Process the request and enforce rate limiting.""" # Skip rate limiting for health checks and docs if request.url.path in ["/", "/health", "/docs", "/redoc", "/openapi.json"]: return await call_next(request) try: tenant_context = ensure_tenant_context() tenant_id = tenant_context.tenant_id except Exception: # If we can't load tenant context (e.g. missing credentials), # we fallback to a generic identifier so rate limiting still works per-IP. tenant_id = "unknown" client_ip = self._get_client_ip(request) now = time.time() auth_header = request.headers.get("Authorization", "") user_bucket = "auth" if auth_header.startswith("Bearer ") else "anon" rate_key = f"{tenant_id}:{user_bucket}:{client_ip}" # Periodic cleanup self._cleanup_old_entries() # Get request times for this client times = self.request_times[rate_key] # Remove timestamps older than 1 minute cutoff_time = now - 60 times[:] = [ts for ts in times if ts > cutoff_time] # Check rate limit if len(times) >= self.requests_per_minute: # Calculate when the oldest request will expire oldest = min(times) retry_after = int(60 - (now - oldest)) + 1 logger.warning( "Rate limit exceeded for tenant=%s bucket=%s ip=%s path=%s", tenant_id, user_bucket, client_ip, request.url.path, ) correlation_id = _get_correlation_id(request) return JSONResponse( status_code=status.HTTP_429_TOO_MANY_REQUESTS, content={ "error": { "code": "rate_limited", "message": f"Rate limit exceeded. Maximum {self.requests_per_minute} requests per minute.", "retryable": True, "retry_after_ms": retry_after * 1000, "details": { "retry_after": retry_after, "tenant_id": tenant_id, "user_bucket": user_bucket, }, }, "correlation_id": correlation_id, # Transitional compatibility. "detail": f"Rate limit exceeded. Maximum {self.requests_per_minute} requests per minute.", "retry_after": retry_after, "tenant_id": tenant_id, "user_bucket": user_bucket, }, headers={"Retry-After": str(retry_after), "x-correlation-id": correlation_id} ) # Record this request times.append(now) # Process the request response = await call_next(request) # Add rate limit headers to response response.headers["X-RateLimit-Limit"] = str(self.requests_per_minute) response.headers["X-RateLimit-Remaining"] = str(self.requests_per_minute - len(times)) response.headers["X-RateLimit-Reset"] = str(int(min(times) + 60)) if times else str(int(now + 60)) response.headers["X-Kumiho-Tenant"] = tenant_id response.headers["X-Kumiho-Rate-Bucket"] = user_bucket return response