"""FastAPI dependencies for Kumiho integration."""
import base64
import binascii
import json
import logging
import os
import time
import urllib.error
import urllib.parse
import urllib.request
from dataclasses import dataclass
from typing import Any, Dict, Optional, Tuple
from fastapi import Depends, Header, HTTPException, Request
import kumiho
from kumiho.discovery import DiscoveryError, DiscoveryManager
_SERVICE_TOKEN_ENV_VAR = "KUMIHO_SERVICE_TOKEN"
_CACHE_TTL_SECONDS = int(os.getenv("KUMIHO_SERVICE_TOKEN_CACHE_SECONDS", "30"))
_DEBUG_ENDPOINTS_ENV = "KUMIHO_DEBUG_ENDPOINTS"
_DEBUG_JWKS_ENV = "KUMIHO_DEBUG_JWKS"
_DEBUG_JWKS_TIMEOUT_SECONDS = float(os.getenv("KUMIHO_DEBUG_JWKS_TIMEOUT_SECONDS", "3"))
_DISCOVERY_TIMEOUT_SECONDS = float(os.getenv("KUMIHO_DISCOVERY_TIMEOUT_SECONDS", "10"))
_CONTROL_PLANE_URL = os.getenv("KUMIHO_CONTROL_PLANE_URL") or "https://control.kumiho.cloud"
[docs]
@dataclass
class TenantContext:
"""Cached tenant/control-plane metadata derived from the CP JWT."""
auth_token: str
expires_at: float
tenant_id: str
tenant_slug: Optional[str]
tenant_tier: Optional[str]
region_code: Optional[str]
neo4j_db_name: Optional[str]
guardrails: Dict[str, Any]
roles: Tuple[str, ...]
server_url: Optional[str]
grpc_authority: Optional[str]
[docs]
def preferred_target(self) -> Optional[str]:
"""Return the best known endpoint for the tenant."""
return self.grpc_authority or self.server_url
logger = logging.getLogger("kumiho.fastapi.dependencies")
def _derive_audit_identity_from_claims(
claims: Dict[str, Any],
*,
fallback_author: Optional[str] = None,
) -> Tuple[Optional[str], Optional[str]]:
"""Derive best-effort (author, username) for response audit fields.
This is for display/audit convenience only; it does not grant access.
"""
def _pick_string(*values: Any) -> Optional[str]:
for value in values:
if isinstance(value, str) and value.strip():
return value.strip()
return None
author = _pick_string(
claims.get("tenant_display_name"),
claims.get("tenant_name"),
claims.get("tenant_slug"),
claims.get("tenant_id"),
fallback_author,
)
username = _pick_string(
claims.get("email"),
claims.get("preferred_username"),
claims.get("username"),
claims.get("client_id"),
claims.get("sub"),
)
return author, username
def _env_truthy(name: str) -> bool:
value = os.getenv(name)
if not value:
return False
return value.strip().lower() in {"1", "true", "yes"}
def _decode_jwt_claims(token: str) -> Dict[str, Any]:
if not token or "." not in token:
return {}
try:
payload_b64 = token.split(".")[1]
padding = "=" * (-len(payload_b64) % 4)
decoded = base64.urlsafe_b64decode((payload_b64 + padding).encode("utf-8"))
claims = json.loads(decoded)
return claims if isinstance(claims, dict) else {}
except (binascii.Error, ValueError, json.JSONDecodeError):
return {}
def _discovery_refresh_deadline(record) -> Optional[float]:
try:
return record.cache_control.refresh_at.timestamp()
except Exception: # pragma: no cover - defensive
return None
def _build_discovery_url(base_url: str) -> str:
base = base_url.rstrip("/")
if base.endswith("/api/discovery/tenant"):
return base
if base.endswith("/api/discovery"):
return f"{base}/tenant"
if base.endswith("/api"):
return f"{base}/discovery/tenant"
return f"{base}/api/discovery/tenant"
def _post_json(
url: str,
payload: Dict[str, Any],
*,
timeout: float,
headers: Optional[Dict[str, str]] = None,
) -> Tuple[Optional[Dict[str, Any]], Optional[str]]:
body = json.dumps(payload).encode("utf-8")
req_headers = {
"Accept": "application/json",
"Content-Type": "application/json",
}
if headers:
req_headers.update(headers)
request = urllib.request.Request(url, data=body, headers=req_headers, method="POST")
try:
with urllib.request.urlopen(request, timeout=timeout) as response:
raw = response.read().decode("utf-8", errors="replace")
except urllib.error.HTTPError as exc:
try:
raw = exc.read().decode("utf-8", errors="replace")
except Exception:
raw = ""
snippet = raw[:200] if raw else ""
return None, f"HTTP {exc.code}: {snippet}"
except Exception as exc:
return None, str(exc)
try:
payload = json.loads(raw)
except (TypeError, ValueError) as exc:
return None, f"invalid json: {exc}"
if not isinstance(payload, dict):
return None, "invalid json: expected object"
return payload, None
def _resolve_routing_anonymous(
tenant_hint: str,
) -> Tuple[Optional[str], Optional[str], Optional[float]]:
if not tenant_hint:
return None, None, None
url = _build_discovery_url(_CONTROL_PLANE_URL)
payload = {"tenant_hint": tenant_hint}
response, error = _post_json(
url,
payload,
timeout=_DISCOVERY_TIMEOUT_SECONDS,
headers=None,
)
if error:
logger.warning("Discovery lookup failed; falling back to env target: %s", error)
return None, None, None
if response and "error" in response:
logger.warning("Discovery lookup failed; falling back to env target: %s", response.get("error"))
return None, None, None
region = response.get("region") if isinstance(response, dict) else None
if not isinstance(region, dict):
logger.warning("Discovery lookup failed; falling back to env target: invalid region payload")
return None, None, None
server_url = region.get("server_url") if isinstance(region.get("server_url"), str) else None
grpc_authority = (
region.get("grpc_authority") if isinstance(region.get("grpc_authority"), str) else None
)
if _env_truthy(_DEBUG_ENDPOINTS_ENV):
logger.warning(
"Discovery routing result (anonymous): tenant_hint=%s server_url=%s grpc_authority=%s",
tenant_hint,
server_url,
grpc_authority,
)
return server_url, grpc_authority, None
def _resolve_routing_from_discovery(
*,
id_token: Optional[str],
tenant_hint: Optional[str],
) -> Tuple[Optional[str], Optional[str], Optional[float]]:
if not id_token:
if tenant_hint:
if _env_truthy(_DEBUG_ENDPOINTS_ENV):
logger.warning(
"Discovery routing using tenant_hint without id_token (tenant_hint=%s)",
tenant_hint,
)
return _resolve_routing_anonymous(tenant_hint)
if _env_truthy(_DEBUG_ENDPOINTS_ENV):
logger.warning(
"Discovery routing skipped: no id_token (tenant_hint=%s)",
tenant_hint,
)
return None, None, None
manager = DiscoveryManager()
try:
record = manager.resolve(id_token=id_token, tenant_hint=tenant_hint)
deadline = _discovery_refresh_deadline(record)
if _env_truthy(_DEBUG_ENDPOINTS_ENV):
logger.warning(
"Discovery routing result: tenant_hint=%s server_url=%s grpc_authority=%s",
tenant_hint,
record.region.server_url,
record.region.grpc_authority,
)
return (
record.region.server_url,
record.region.grpc_authority,
deadline,
)
except DiscoveryError as exc: # pragma: no cover - defensive
logger.warning("Discovery lookup failed; falling back to env target: %s", exc)
except Exception: # pragma: no cover - defensive
logger.exception("Unexpected discovery failure; falling back to env target")
return None, None, None
def _infer_http_base(target: Optional[str]) -> Optional[str]:
if not target:
return None
raw = target.strip()
if not raw:
return None
scheme = ""
host = ""
port: Optional[int] = None
if "://" in raw:
parsed = urllib.parse.urlparse(raw)
scheme = parsed.scheme.lower()
host = parsed.hostname or ""
port = parsed.port
if not host:
return None
if scheme in {"grpcs", "https"}:
scheme = "https"
elif scheme in {"grpc", "http"}:
scheme = "http"
else:
scheme = "https" if port == 443 else "http"
else:
trimmed = raw.split("/", 1)[0]
if ":" in trimmed:
host, port_str = trimmed.rsplit(":", 1)
try:
port = int(port_str)
except ValueError:
return None
else:
host = trimmed
if not host:
return None
if port is None:
port = 8080
scheme = "https" if port == 443 else "http"
if port and port not in {80, 443}:
return f"{scheme}://{host}:{port}"
return f"{scheme}://{host}"
def _resolve_effective_target(target: Optional[str]) -> Tuple[str, str]:
if target:
return target, "context"
env_target = os.getenv("KUMIHO_SERVER_ENDPOINT") or os.getenv("KUMIHO_SERVER_ADDRESS")
if env_target:
return env_target, "env"
return "localhost:8080", "default"
def _fetch_json(url: str, timeout: float) -> Tuple[Optional[Dict[str, Any]], Optional[str]]:
try:
request = urllib.request.Request(url, headers={"Accept": "application/json"})
with urllib.request.urlopen(request, timeout=timeout) as response:
body = response.read().decode("utf-8", errors="replace")
try:
payload = json.loads(body)
except (TypeError, ValueError) as exc:
return None, f"invalid json: {exc}"
if not isinstance(payload, dict):
return None, "invalid json: expected object"
return payload, None
except urllib.error.HTTPError as exc:
try:
body = exc.read().decode("utf-8", errors="replace")
except Exception:
body = ""
snippet = body[:200] if body else ""
return None, f"HTTP {exc.code}: {snippet}"
except Exception as exc:
return None, str(exc)
def _fetch_jwks_diagnostics(base_url: str) -> Tuple[Optional[str], Optional[Tuple[str, ...]], Optional[str]]:
url = f"{base_url.rstrip('/')}/api/_control-plane/jwks"
payload, error = _fetch_json(url, _DEBUG_JWKS_TIMEOUT_SECONDS)
if error:
return None, None, error
jwks_url = payload.get("jwks_url")
kids = payload.get("kids")
if isinstance(kids, list):
kids = tuple(str(item) for item in kids)
elif kids is None:
kids = None
else:
kids = (str(kids),)
if "error" in payload:
return jwks_url if isinstance(jwks_url, str) else None, kids, str(payload["error"])
return jwks_url if isinstance(jwks_url, str) else None, kids, None
def _build_context_from_env(token: str) -> TenantContext:
tenant_id = os.getenv("KUMIHO_TENANT_ID")
if not tenant_id:
raise HTTPException(
status_code=500,
detail="KUMIHO_TENANT_ID must be set when using environment tokens.",
)
server_url = os.getenv("KUMIHO_SERVER_ENDPOINT") or os.getenv("KUMIHO_SERVER_ADDRESS")
if not server_url:
raise HTTPException(
status_code=500,
detail="KUMIHO_SERVER_ENDPOINT must be set when using environment tokens.",
)
now = time.time()
return TenantContext(
auth_token=token,
expires_at=now + max(_CACHE_TTL_SECONDS, 60),
tenant_id=tenant_id,
tenant_slug=os.getenv("KUMIHO_TENANT_SLUG"),
tenant_tier=os.getenv("KUMIHO_TENANT_TIER"),
region_code=os.getenv("KUMIHO_REGION_CODE"),
neo4j_db_name=os.getenv("KUMIHO_NEO4J_DB"),
guardrails={},
roles=(),
server_url=server_url,
grpc_authority=os.getenv("KUMIHO_SERVER_AUTHORITY"),
)
def _build_tenant_context(credentials: Dict[str, Any], tenant_id_override: Optional[str] = None) -> TenantContext:
now = time.time()
token = credentials.get("control_plane_token") or credentials.get("id_token")
if not isinstance(token, str):
raise HTTPException(status_code=401, detail="Control-plane token missing from credentials cache.")
cp_exp = credentials.get("cp_expires_at") or credentials.get("expires_at")
expiry = float(cp_exp) if isinstance(cp_exp, (int, float)) else now + _CACHE_TTL_SECONDS
claims = _decode_jwt_claims(token)
tenant_id = tenant_id_override or claims.get("tenant_id") or os.getenv("KUMIHO_TENANT_ID")
if not tenant_id:
raise HTTPException(status_code=500, detail="Tenant ID missing from control-plane token.")
tenant_slug = claims.get("tenant_slug")
tenant_hint = tenant_slug or tenant_id
id_token = credentials.get("id_token") if isinstance(credentials.get("id_token"), str) else None
# IMPORTANT:
# - If the caller supplied per-request credentials (service token in header), we must route
# per-tenant via discovery. Pinning to KUMIHO_SERVER_ENDPOINT would effectively make this
# FastAPI instance single-tenant / single-region and can break writes after regional rollouts.
# - If we're using environment credentials (no per-request creds), keep honoring env endpoints.
using_env_creds_only = "control_plane_token" not in credentials and "id_token" not in credentials
server_url = None
grpc_authority = None
discovery_deadline = None
if using_env_creds_only:
target = os.getenv("KUMIHO_SERVER_ENDPOINT") or os.getenv("KUMIHO_SERVER_ADDRESS")
if target:
server_url = target
if not server_url and not grpc_authority:
server_url, grpc_authority, discovery_deadline = _resolve_routing_from_discovery(
id_token=id_token,
tenant_hint=tenant_hint,
)
expiry_candidates = [expiry]
if discovery_deadline:
expiry_candidates.append(discovery_deadline)
guardrails_claim = claims.get("guardrails")
guardrails: Dict[str, Any] = guardrails_claim if isinstance(guardrails_claim, dict) else {}
roles_claim = claims.get("roles")
roles: Tuple[str, ...] = tuple(roles_claim) if isinstance(roles_claim, (list, tuple)) else tuple()
context = TenantContext(
auth_token=token,
expires_at=min(expiry_candidates) if expiry_candidates else expiry,
tenant_id=tenant_id,
tenant_slug=tenant_slug,
tenant_tier=claims.get("tenant_tier"),
region_code=claims.get("region_code"),
neo4j_db_name=claims.get("neo4j_db_name"),
guardrails=guardrails,
roles=roles,
server_url=server_url,
grpc_authority=grpc_authority,
)
# Only fall back to pinned env targets when we are using environment credentials.
# For per-request credentials, pinning to a single regional endpoint can route writes
# to a stale region during rollouts and produce opaque 500s.
if using_env_creds_only and not context.server_url:
env_target = os.getenv("KUMIHO_SERVER_ENDPOINT") or os.getenv("KUMIHO_SERVER_ADDRESS")
context.server_url = env_target
return context
[docs]
def ensure_tenant_context() -> TenantContext:
"""Return tenant context from environment service token."""
env_token = _get_service_token_from_env()
if env_token:
return _build_context_from_env(env_token)
raise HTTPException(
status_code=401,
detail=(
"No service token found. Set KUMIHO_SERVICE_TOKEN environment variable "
"or provide X-Kumiho-Token header."
),
)
def _get_service_token_from_env() -> Optional[str]:
"""Return the service token from KUMIHO_SERVICE_TOKEN environment variable."""
value = os.getenv(_SERVICE_TOKEN_ENV_VAR)
if value:
logger.debug("Found service token in %s", _SERVICE_TOKEN_ENV_VAR)
return value.strip()
return None
[docs]
def get_user_token(authorization: Optional[str] = Header(None)) -> Optional[str]:
"""
Extract Firebase user token from Authorization header (optional).
This is for end-user authentication that gets passed through to kumiho-server.
The service token (X-Kumiho-Token) is still required for API access.
Args:
authorization: Optional Authorization header with Bearer token
Returns:
The extracted token or None
"""
if authorization and authorization.startswith("Bearer "):
return authorization[7:]
return None
from starlette.concurrency import run_in_threadpool
[docs]
async def get_kumiho_client(
request: Request,
x_kumiho_token: Optional[str] = Header(None, alias="X-Kumiho-Token"),
x_tenant_id: Optional[str] = Header(None, alias="x-tenant-id"),
x_correlation_id: Optional[str] = Header(None, alias="x-correlation-id"),
x_idempotency_key: Optional[str] = Header(None, alias="x-idempotency-key"),
user_token: Optional[str] = Depends(get_user_token),
):
"""
Creates a Kumiho client instance for the current request using the provided token.
This is a per-request client - each API call gets its own client instance
configured with the client's service token. This enables multi-tenant operation
where different clients can use the same FastAPI deployment with their own tokens.
Args:
x_kumiho_token: The client's Kumiho service token from X-Kumiho-Token header
x_tenant_id: Optional tenant ID hint for anonymous users
user_token: Optional user ID token from Authorization: Bearer header
Returns:
A configured kumiho.Client instance
Raises:
HTTPException: 401 if token is missing or invalid
HTTPException: 500 if client creation fails
"""
service_token: Optional[str] = None
using_request_token = False
try:
# Determine the service token (control plane token)
# We prioritize X-Kumiho-Token, but fallback to Authorization: Bearer
# if X-Kumiho-Token is missing. This allows standard Bearer Auth in tools like n8n.
service_token = x_kumiho_token or user_token
using_request_token = bool(service_token)
if service_token:
# Use the provided service token
credentials = {"control_plane_token": service_token}
# If BOTH headers were provided, the Bearer token is treated as the ID token (user identity)
# while X-Kumiho-Token is the service token (tenant identity).
if x_kumiho_token and user_token:
credentials["id_token"] = user_token
context = _build_tenant_context(credentials, tenant_id_override=x_tenant_id)
else:
# Fallback to KUMIHO_SERVICE_TOKEN environment variable
context = ensure_tenant_context()
except HTTPException:
raise
except Exception as exc: # pragma: no cover - defensive
raise HTTPException(status_code=500, detail=f"Failed to load tenant context: {exc}") from exc
# Prefer header tenant_id if provided, otherwise context tenant_id
tenant_id = x_tenant_id or context.tenant_id
# NOTE: tenant_hint parameter doesn't actually add x-tenant-id in kumiho-python
# so we need to add it to metadata manually
metadata = [("x-tenant-id", tenant_id)]
if x_correlation_id:
metadata.append(("x-correlation-id", x_correlation_id))
if x_idempotency_key:
metadata.append(("x-idempotency-key", x_idempotency_key))
# Use the preferred target from the context (which includes env override)
target = context.preferred_target()
# Stash diagnostics for error handler/logging.
try:
request.state.kumiho_tenant_hint = context.tenant_slug or tenant_id
request.state.kumiho_target = target
except Exception:
pass
# Stash best-effort identity for response audit fields (author/username).
try:
claims = _decode_jwt_claims(context.auth_token)
fallback_author = context.tenant_slug or context.tenant_id
author, username = _derive_audit_identity_from_claims(
claims,
fallback_author=fallback_author,
)
request.state.kumiho_author = author
request.state.kumiho_username = username
except Exception:
pass
# If we still don't have a target, try environment fallback ONLY for env-credential mode.
# When a caller supplies a service token per-request, pinning to a fixed env endpoint can
# route writes to the wrong region during multi-region rollouts.
if not target and not using_request_token:
target = os.getenv("KUMIHO_SERVER_ENDPOINT") or os.getenv("KUMIHO_SERVER_ADDRESS")
# If we still don't have a target, it means discovery failed or was skipped.
# We should NOT let the SDK fall back to localhost:8080.
if not target:
tenant_hint = context.tenant_slug or context.tenant_id
if tenant_hint:
logger.info("Target not resolved yet, attempting last-minute discovery for %s", tenant_hint)
target, _, _ = await run_in_threadpool(
_resolve_routing_anonymous,
tenant_hint,
)
try:
request.state.kumiho_target = target
except Exception:
pass
if not target:
raise HTTPException(
status_code=503,
detail=(
"Kumiho routing failed: Could not resolve a regional server for this tenant. "
"Check if KUMIHO_CONTROL_PLANE_URL is accessible and the token is valid."
)
)
logger.info(
"Connecting to Kumiho Server at %s (tenant: %s)",
target,
context.tenant_slug or tenant_id,
)
if context.auth_token:
logger.info("Using auth token: present")
else:
logger.info("No auth token in context")
if _env_truthy(_DEBUG_ENDPOINTS_ENV):
logger.warning(
"Kumiho routing debug: server_url=%s grpc_authority=%s tenant_id=%s tenant_slug=%s use_discovery=%s",
context.server_url,
context.grpc_authority,
tenant_id,
context.tenant_slug,
bool(target is None),
)
if _env_truthy(_DEBUG_JWKS_ENV):
effective_target, target_source = _resolve_effective_target(target)
base_url = _infer_http_base(target or effective_target)
if not target:
logger.warning(
"Control plane JWKS diagnostics using fallback target: effective_target=%s source=%s",
effective_target,
target_source,
)
if base_url:
jwks_url, kids, error = await run_in_threadpool(_fetch_jwks_diagnostics, base_url)
if error:
logger.warning(
"Control plane JWKS diagnostics failed: base_url=%s error=%s",
base_url,
error,
)
else:
logger.info(
"Control plane JWKS diagnostics: base_url=%s jwks_url=%s kids=%s",
base_url,
jwks_url,
kids,
)
else:
logger.warning(
"Control plane JWKS diagnostics skipped: cannot derive base URL from target=%s",
target,
)
try:
# Run blocking connect in threadpool to avoid blocking the event loop
client = await run_in_threadpool(
kumiho.connect,
endpoint=target,
token=context.auth_token,
enable_auto_login=False,
use_discovery=bool(target is None),
default_metadata=metadata,
# tenant_hint is accepted but not actually used by kumiho-python
)
except Exception as exc: # pragma: no cover - upstream errors
raise HTTPException(
status_code=500,
detail=f"Failed to initialize Kumiho client: {exc}",
) from exc
# Set the client context for the duration of the request
# This allows kumiho.get_project() etc. to work without passing client explicitly
with kumiho.use_client(client):
yield client
[docs]
def require_user_token(token: Optional[str] = Header(None, alias="authorization")) -> str:
"""
Require a Firebase user token in the Authorization header.
Args:
token: Authorization header value
Returns:
The extracted Firebase token
Raises:
HTTPException: 401 if token is missing or invalid format
"""
if not token:
raise HTTPException(
status_code=401,
detail="Missing Authorization header"
)
if not token.startswith("Bearer "):
raise HTTPException(
status_code=401,
detail="Invalid Authorization header format. Expected 'Bearer <token>'"
)
return token[7:]