Source code for app.dependencies

"""FastAPI dependencies for Kumiho integration."""

import base64
import binascii
import json
import logging
import os
import time
import urllib.error
import urllib.parse
import urllib.request
from dataclasses import dataclass
from typing import Any, Dict, Optional, Tuple

from fastapi import Depends, Header, HTTPException, Request
import kumiho
from kumiho.discovery import DiscoveryError, DiscoveryManager

_SERVICE_TOKEN_ENV_VAR = "KUMIHO_SERVICE_TOKEN"
_CACHE_TTL_SECONDS = int(os.getenv("KUMIHO_SERVICE_TOKEN_CACHE_SECONDS", "30"))
_DEBUG_ENDPOINTS_ENV = "KUMIHO_DEBUG_ENDPOINTS"
_DEBUG_JWKS_ENV = "KUMIHO_DEBUG_JWKS"
_DEBUG_JWKS_TIMEOUT_SECONDS = float(os.getenv("KUMIHO_DEBUG_JWKS_TIMEOUT_SECONDS", "3"))
_DISCOVERY_TIMEOUT_SECONDS = float(os.getenv("KUMIHO_DISCOVERY_TIMEOUT_SECONDS", "10"))
_CONTROL_PLANE_URL = os.getenv("KUMIHO_CONTROL_PLANE_URL") or "https://control.kumiho.cloud"


[docs] @dataclass class TenantContext: """Cached tenant/control-plane metadata derived from the CP JWT.""" auth_token: str expires_at: float tenant_id: str tenant_slug: Optional[str] tenant_tier: Optional[str] region_code: Optional[str] neo4j_db_name: Optional[str] guardrails: Dict[str, Any] roles: Tuple[str, ...] server_url: Optional[str] grpc_authority: Optional[str]
[docs] def preferred_target(self) -> Optional[str]: """Return the best known endpoint for the tenant.""" return self.grpc_authority or self.server_url
logger = logging.getLogger("kumiho.fastapi.dependencies") def _derive_audit_identity_from_claims( claims: Dict[str, Any], *, fallback_author: Optional[str] = None, ) -> Tuple[Optional[str], Optional[str]]: """Derive best-effort (author, username) for response audit fields. This is for display/audit convenience only; it does not grant access. """ def _pick_string(*values: Any) -> Optional[str]: for value in values: if isinstance(value, str) and value.strip(): return value.strip() return None author = _pick_string( claims.get("tenant_display_name"), claims.get("tenant_name"), claims.get("tenant_slug"), claims.get("tenant_id"), fallback_author, ) username = _pick_string( claims.get("email"), claims.get("preferred_username"), claims.get("username"), claims.get("client_id"), claims.get("sub"), ) return author, username def _env_truthy(name: str) -> bool: value = os.getenv(name) if not value: return False return value.strip().lower() in {"1", "true", "yes"} def _decode_jwt_claims(token: str) -> Dict[str, Any]: if not token or "." not in token: return {} try: payload_b64 = token.split(".")[1] padding = "=" * (-len(payload_b64) % 4) decoded = base64.urlsafe_b64decode((payload_b64 + padding).encode("utf-8")) claims = json.loads(decoded) return claims if isinstance(claims, dict) else {} except (binascii.Error, ValueError, json.JSONDecodeError): return {} def _discovery_refresh_deadline(record) -> Optional[float]: try: return record.cache_control.refresh_at.timestamp() except Exception: # pragma: no cover - defensive return None def _build_discovery_url(base_url: str) -> str: base = base_url.rstrip("/") if base.endswith("/api/discovery/tenant"): return base if base.endswith("/api/discovery"): return f"{base}/tenant" if base.endswith("/api"): return f"{base}/discovery/tenant" return f"{base}/api/discovery/tenant" def _post_json( url: str, payload: Dict[str, Any], *, timeout: float, headers: Optional[Dict[str, str]] = None, ) -> Tuple[Optional[Dict[str, Any]], Optional[str]]: body = json.dumps(payload).encode("utf-8") req_headers = { "Accept": "application/json", "Content-Type": "application/json", } if headers: req_headers.update(headers) request = urllib.request.Request(url, data=body, headers=req_headers, method="POST") try: with urllib.request.urlopen(request, timeout=timeout) as response: raw = response.read().decode("utf-8", errors="replace") except urllib.error.HTTPError as exc: try: raw = exc.read().decode("utf-8", errors="replace") except Exception: raw = "" snippet = raw[:200] if raw else "" return None, f"HTTP {exc.code}: {snippet}" except Exception as exc: return None, str(exc) try: payload = json.loads(raw) except (TypeError, ValueError) as exc: return None, f"invalid json: {exc}" if not isinstance(payload, dict): return None, "invalid json: expected object" return payload, None def _resolve_routing_anonymous( tenant_hint: str, ) -> Tuple[Optional[str], Optional[str], Optional[float]]: if not tenant_hint: return None, None, None url = _build_discovery_url(_CONTROL_PLANE_URL) payload = {"tenant_hint": tenant_hint} response, error = _post_json( url, payload, timeout=_DISCOVERY_TIMEOUT_SECONDS, headers=None, ) if error: logger.warning("Discovery lookup failed; falling back to env target: %s", error) return None, None, None if response and "error" in response: logger.warning("Discovery lookup failed; falling back to env target: %s", response.get("error")) return None, None, None region = response.get("region") if isinstance(response, dict) else None if not isinstance(region, dict): logger.warning("Discovery lookup failed; falling back to env target: invalid region payload") return None, None, None server_url = region.get("server_url") if isinstance(region.get("server_url"), str) else None grpc_authority = ( region.get("grpc_authority") if isinstance(region.get("grpc_authority"), str) else None ) if _env_truthy(_DEBUG_ENDPOINTS_ENV): logger.warning( "Discovery routing result (anonymous): tenant_hint=%s server_url=%s grpc_authority=%s", tenant_hint, server_url, grpc_authority, ) return server_url, grpc_authority, None def _resolve_routing_from_discovery( *, id_token: Optional[str], tenant_hint: Optional[str], ) -> Tuple[Optional[str], Optional[str], Optional[float]]: if not id_token: if tenant_hint: if _env_truthy(_DEBUG_ENDPOINTS_ENV): logger.warning( "Discovery routing using tenant_hint without id_token (tenant_hint=%s)", tenant_hint, ) return _resolve_routing_anonymous(tenant_hint) if _env_truthy(_DEBUG_ENDPOINTS_ENV): logger.warning( "Discovery routing skipped: no id_token (tenant_hint=%s)", tenant_hint, ) return None, None, None manager = DiscoveryManager() try: record = manager.resolve(id_token=id_token, tenant_hint=tenant_hint) deadline = _discovery_refresh_deadline(record) if _env_truthy(_DEBUG_ENDPOINTS_ENV): logger.warning( "Discovery routing result: tenant_hint=%s server_url=%s grpc_authority=%s", tenant_hint, record.region.server_url, record.region.grpc_authority, ) return ( record.region.server_url, record.region.grpc_authority, deadline, ) except DiscoveryError as exc: # pragma: no cover - defensive logger.warning("Discovery lookup failed; falling back to env target: %s", exc) except Exception: # pragma: no cover - defensive logger.exception("Unexpected discovery failure; falling back to env target") return None, None, None def _infer_http_base(target: Optional[str]) -> Optional[str]: if not target: return None raw = target.strip() if not raw: return None scheme = "" host = "" port: Optional[int] = None if "://" in raw: parsed = urllib.parse.urlparse(raw) scheme = parsed.scheme.lower() host = parsed.hostname or "" port = parsed.port if not host: return None if scheme in {"grpcs", "https"}: scheme = "https" elif scheme in {"grpc", "http"}: scheme = "http" else: scheme = "https" if port == 443 else "http" else: trimmed = raw.split("/", 1)[0] if ":" in trimmed: host, port_str = trimmed.rsplit(":", 1) try: port = int(port_str) except ValueError: return None else: host = trimmed if not host: return None if port is None: port = 8080 scheme = "https" if port == 443 else "http" if port and port not in {80, 443}: return f"{scheme}://{host}:{port}" return f"{scheme}://{host}" def _resolve_effective_target(target: Optional[str]) -> Tuple[str, str]: if target: return target, "context" env_target = os.getenv("KUMIHO_SERVER_ENDPOINT") or os.getenv("KUMIHO_SERVER_ADDRESS") if env_target: return env_target, "env" return "localhost:8080", "default" def _fetch_json(url: str, timeout: float) -> Tuple[Optional[Dict[str, Any]], Optional[str]]: try: request = urllib.request.Request(url, headers={"Accept": "application/json"}) with urllib.request.urlopen(request, timeout=timeout) as response: body = response.read().decode("utf-8", errors="replace") try: payload = json.loads(body) except (TypeError, ValueError) as exc: return None, f"invalid json: {exc}" if not isinstance(payload, dict): return None, "invalid json: expected object" return payload, None except urllib.error.HTTPError as exc: try: body = exc.read().decode("utf-8", errors="replace") except Exception: body = "" snippet = body[:200] if body else "" return None, f"HTTP {exc.code}: {snippet}" except Exception as exc: return None, str(exc) def _fetch_jwks_diagnostics(base_url: str) -> Tuple[Optional[str], Optional[Tuple[str, ...]], Optional[str]]: url = f"{base_url.rstrip('/')}/api/_control-plane/jwks" payload, error = _fetch_json(url, _DEBUG_JWKS_TIMEOUT_SECONDS) if error: return None, None, error jwks_url = payload.get("jwks_url") kids = payload.get("kids") if isinstance(kids, list): kids = tuple(str(item) for item in kids) elif kids is None: kids = None else: kids = (str(kids),) if "error" in payload: return jwks_url if isinstance(jwks_url, str) else None, kids, str(payload["error"]) return jwks_url if isinstance(jwks_url, str) else None, kids, None def _build_context_from_env(token: str) -> TenantContext: tenant_id = os.getenv("KUMIHO_TENANT_ID") if not tenant_id: raise HTTPException( status_code=500, detail="KUMIHO_TENANT_ID must be set when using environment tokens.", ) server_url = os.getenv("KUMIHO_SERVER_ENDPOINT") or os.getenv("KUMIHO_SERVER_ADDRESS") if not server_url: raise HTTPException( status_code=500, detail="KUMIHO_SERVER_ENDPOINT must be set when using environment tokens.", ) now = time.time() return TenantContext( auth_token=token, expires_at=now + max(_CACHE_TTL_SECONDS, 60), tenant_id=tenant_id, tenant_slug=os.getenv("KUMIHO_TENANT_SLUG"), tenant_tier=os.getenv("KUMIHO_TENANT_TIER"), region_code=os.getenv("KUMIHO_REGION_CODE"), neo4j_db_name=os.getenv("KUMIHO_NEO4J_DB"), guardrails={}, roles=(), server_url=server_url, grpc_authority=os.getenv("KUMIHO_SERVER_AUTHORITY"), ) def _build_tenant_context(credentials: Dict[str, Any], tenant_id_override: Optional[str] = None) -> TenantContext: now = time.time() token = credentials.get("control_plane_token") or credentials.get("id_token") if not isinstance(token, str): raise HTTPException(status_code=401, detail="Control-plane token missing from credentials cache.") cp_exp = credentials.get("cp_expires_at") or credentials.get("expires_at") expiry = float(cp_exp) if isinstance(cp_exp, (int, float)) else now + _CACHE_TTL_SECONDS claims = _decode_jwt_claims(token) tenant_id = tenant_id_override or claims.get("tenant_id") or os.getenv("KUMIHO_TENANT_ID") if not tenant_id: raise HTTPException(status_code=500, detail="Tenant ID missing from control-plane token.") tenant_slug = claims.get("tenant_slug") tenant_hint = tenant_slug or tenant_id id_token = credentials.get("id_token") if isinstance(credentials.get("id_token"), str) else None # IMPORTANT: # - If the caller supplied per-request credentials (service token in header), we must route # per-tenant via discovery. Pinning to KUMIHO_SERVER_ENDPOINT would effectively make this # FastAPI instance single-tenant / single-region and can break writes after regional rollouts. # - If we're using environment credentials (no per-request creds), keep honoring env endpoints. using_env_creds_only = "control_plane_token" not in credentials and "id_token" not in credentials server_url = None grpc_authority = None discovery_deadline = None if using_env_creds_only: target = os.getenv("KUMIHO_SERVER_ENDPOINT") or os.getenv("KUMIHO_SERVER_ADDRESS") if target: server_url = target if not server_url and not grpc_authority: server_url, grpc_authority, discovery_deadline = _resolve_routing_from_discovery( id_token=id_token, tenant_hint=tenant_hint, ) expiry_candidates = [expiry] if discovery_deadline: expiry_candidates.append(discovery_deadline) guardrails_claim = claims.get("guardrails") guardrails: Dict[str, Any] = guardrails_claim if isinstance(guardrails_claim, dict) else {} roles_claim = claims.get("roles") roles: Tuple[str, ...] = tuple(roles_claim) if isinstance(roles_claim, (list, tuple)) else tuple() context = TenantContext( auth_token=token, expires_at=min(expiry_candidates) if expiry_candidates else expiry, tenant_id=tenant_id, tenant_slug=tenant_slug, tenant_tier=claims.get("tenant_tier"), region_code=claims.get("region_code"), neo4j_db_name=claims.get("neo4j_db_name"), guardrails=guardrails, roles=roles, server_url=server_url, grpc_authority=grpc_authority, ) # Only fall back to pinned env targets when we are using environment credentials. # For per-request credentials, pinning to a single regional endpoint can route writes # to a stale region during rollouts and produce opaque 500s. if using_env_creds_only and not context.server_url: env_target = os.getenv("KUMIHO_SERVER_ENDPOINT") or os.getenv("KUMIHO_SERVER_ADDRESS") context.server_url = env_target return context
[docs] def ensure_tenant_context() -> TenantContext: """Return tenant context from environment service token.""" env_token = _get_service_token_from_env() if env_token: return _build_context_from_env(env_token) raise HTTPException( status_code=401, detail=( "No service token found. Set KUMIHO_SERVICE_TOKEN environment variable " "or provide X-Kumiho-Token header." ), )
def _get_service_token_from_env() -> Optional[str]: """Return the service token from KUMIHO_SERVICE_TOKEN environment variable.""" value = os.getenv(_SERVICE_TOKEN_ENV_VAR) if value: logger.debug("Found service token in %s", _SERVICE_TOKEN_ENV_VAR) return value.strip() return None
[docs] def get_user_token(authorization: Optional[str] = Header(None)) -> Optional[str]: """ Extract Firebase user token from Authorization header (optional). This is for end-user authentication that gets passed through to kumiho-server. The service token (X-Kumiho-Token) is still required for API access. Args: authorization: Optional Authorization header with Bearer token Returns: The extracted token or None """ if authorization and authorization.startswith("Bearer "): return authorization[7:] return None
from starlette.concurrency import run_in_threadpool
[docs] async def get_kumiho_client( request: Request, x_kumiho_token: Optional[str] = Header(None, alias="X-Kumiho-Token"), x_tenant_id: Optional[str] = Header(None, alias="x-tenant-id"), x_correlation_id: Optional[str] = Header(None, alias="x-correlation-id"), x_idempotency_key: Optional[str] = Header(None, alias="x-idempotency-key"), user_token: Optional[str] = Depends(get_user_token), ): """ Creates a Kumiho client instance for the current request using the provided token. This is a per-request client - each API call gets its own client instance configured with the client's service token. This enables multi-tenant operation where different clients can use the same FastAPI deployment with their own tokens. Args: x_kumiho_token: The client's Kumiho service token from X-Kumiho-Token header x_tenant_id: Optional tenant ID hint for anonymous users user_token: Optional user ID token from Authorization: Bearer header Returns: A configured kumiho.Client instance Raises: HTTPException: 401 if token is missing or invalid HTTPException: 500 if client creation fails """ service_token: Optional[str] = None using_request_token = False try: # Determine the service token (control plane token) # We prioritize X-Kumiho-Token, but fallback to Authorization: Bearer # if X-Kumiho-Token is missing. This allows standard Bearer Auth in tools like n8n. service_token = x_kumiho_token or user_token using_request_token = bool(service_token) if service_token: # Use the provided service token credentials = {"control_plane_token": service_token} # If BOTH headers were provided, the Bearer token is treated as the ID token (user identity) # while X-Kumiho-Token is the service token (tenant identity). if x_kumiho_token and user_token: credentials["id_token"] = user_token context = _build_tenant_context(credentials, tenant_id_override=x_tenant_id) else: # Fallback to KUMIHO_SERVICE_TOKEN environment variable context = ensure_tenant_context() except HTTPException: raise except Exception as exc: # pragma: no cover - defensive raise HTTPException(status_code=500, detail=f"Failed to load tenant context: {exc}") from exc # Prefer header tenant_id if provided, otherwise context tenant_id tenant_id = x_tenant_id or context.tenant_id # NOTE: tenant_hint parameter doesn't actually add x-tenant-id in kumiho-python # so we need to add it to metadata manually metadata = [("x-tenant-id", tenant_id)] if x_correlation_id: metadata.append(("x-correlation-id", x_correlation_id)) if x_idempotency_key: metadata.append(("x-idempotency-key", x_idempotency_key)) # Use the preferred target from the context (which includes env override) target = context.preferred_target() # Stash diagnostics for error handler/logging. try: request.state.kumiho_tenant_hint = context.tenant_slug or tenant_id request.state.kumiho_target = target except Exception: pass # Stash best-effort identity for response audit fields (author/username). try: claims = _decode_jwt_claims(context.auth_token) fallback_author = context.tenant_slug or context.tenant_id author, username = _derive_audit_identity_from_claims( claims, fallback_author=fallback_author, ) request.state.kumiho_author = author request.state.kumiho_username = username except Exception: pass # If we still don't have a target, try environment fallback ONLY for env-credential mode. # When a caller supplies a service token per-request, pinning to a fixed env endpoint can # route writes to the wrong region during multi-region rollouts. if not target and not using_request_token: target = os.getenv("KUMIHO_SERVER_ENDPOINT") or os.getenv("KUMIHO_SERVER_ADDRESS") # If we still don't have a target, it means discovery failed or was skipped. # We should NOT let the SDK fall back to localhost:8080. if not target: tenant_hint = context.tenant_slug or context.tenant_id if tenant_hint: logger.info("Target not resolved yet, attempting last-minute discovery for %s", tenant_hint) target, _, _ = await run_in_threadpool( _resolve_routing_anonymous, tenant_hint, ) try: request.state.kumiho_target = target except Exception: pass if not target: raise HTTPException( status_code=503, detail=( "Kumiho routing failed: Could not resolve a regional server for this tenant. " "Check if KUMIHO_CONTROL_PLANE_URL is accessible and the token is valid." ) ) logger.info( "Connecting to Kumiho Server at %s (tenant: %s)", target, context.tenant_slug or tenant_id, ) if context.auth_token: logger.info("Using auth token: present") else: logger.info("No auth token in context") if _env_truthy(_DEBUG_ENDPOINTS_ENV): logger.warning( "Kumiho routing debug: server_url=%s grpc_authority=%s tenant_id=%s tenant_slug=%s use_discovery=%s", context.server_url, context.grpc_authority, tenant_id, context.tenant_slug, bool(target is None), ) if _env_truthy(_DEBUG_JWKS_ENV): effective_target, target_source = _resolve_effective_target(target) base_url = _infer_http_base(target or effective_target) if not target: logger.warning( "Control plane JWKS diagnostics using fallback target: effective_target=%s source=%s", effective_target, target_source, ) if base_url: jwks_url, kids, error = await run_in_threadpool(_fetch_jwks_diagnostics, base_url) if error: logger.warning( "Control plane JWKS diagnostics failed: base_url=%s error=%s", base_url, error, ) else: logger.info( "Control plane JWKS diagnostics: base_url=%s jwks_url=%s kids=%s", base_url, jwks_url, kids, ) else: logger.warning( "Control plane JWKS diagnostics skipped: cannot derive base URL from target=%s", target, ) try: # Run blocking connect in threadpool to avoid blocking the event loop client = await run_in_threadpool( kumiho.connect, endpoint=target, token=context.auth_token, enable_auto_login=False, use_discovery=bool(target is None), default_metadata=metadata, # tenant_hint is accepted but not actually used by kumiho-python ) except Exception as exc: # pragma: no cover - upstream errors raise HTTPException( status_code=500, detail=f"Failed to initialize Kumiho client: {exc}", ) from exc # Set the client context for the duration of the request # This allows kumiho.get_project() etc. to work without passing client explicitly with kumiho.use_client(client): yield client
[docs] def require_user_token(token: Optional[str] = Header(None, alias="authorization")) -> str: """ Require a Firebase user token in the Authorization header. Args: token: Authorization header value Returns: The extracted Firebase token Raises: HTTPException: 401 if token is missing or invalid format """ if not token: raise HTTPException( status_code=401, detail="Missing Authorization header" ) if not token.startswith("Bearer "): raise HTTPException( status_code=401, detail="Invalid Authorization header format. Expected 'Bearer <token>'" ) return token[7:]