""" Observability middleware for the Nextcloud MCP Server. This module provides Starlette middleware that automatically instruments HTTP requests with: - Prometheus metrics (request count, latency, in-flight requests) - OpenTelemetry distributed tracing - Request/response timing and error tracking """ import logging import time from typing import Callable from starlette.middleware.base import BaseHTTPMiddleware from starlette.requests import Request from starlette.responses import Response from nextcloud_mcp_server.observability.metrics import ( http_request_duration_seconds, http_requests_in_progress, http_requests_total, ) from nextcloud_mcp_server.observability.tracing import ( add_span_attribute, trace_operation, ) logger = logging.getLogger(__name__) class ObservabilityMiddleware(BaseHTTPMiddleware): """ Starlette middleware for automatic HTTP request instrumentation. This middleware: - Records Prometheus metrics for each request (RED metrics) - Creates OpenTelemetry spans for distributed tracing - Tracks request timing and errors - Handles in-flight request counting """ async def dispatch( self, request: Request, call_next: Callable, ) -> Response: """ Process HTTP request with observability instrumentation. Args: request: Starlette request object call_next: Next middleware or route handler Returns: Response from downstream handler """ # Extract request details method = request.method path = request.url.path endpoint = self._get_endpoint_label(path) # Increment in-flight requests counter http_requests_in_progress.labels(method=method, endpoint=endpoint).inc() # Record start time start_time = time.time() # Skip tracing for health/metrics/polling endpoints to reduce noise should_trace = not ( path.startswith("/health/") or path == "/metrics" or path == "/app/vector-sync/status" ) try: if should_trace: # Create span for request (OpenTelemetry auto-instrumentation will create parent span) with trace_operation( f"HTTP {method} {endpoint}", attributes={ "http.method": method, "http.path": path, "http.scheme": request.url.scheme, "http.host": request.url.hostname, }, ): # Process request response = await call_next(request) # Add response status to span add_span_attribute("http.status_code", response.status_code) # Record metrics duration = time.time() - start_time self._record_request_metrics( method=method, endpoint=endpoint, status_code=response.status_code, duration=duration, ) return response else: # No tracing for health/metrics endpoints, but still record metrics response = await call_next(request) # Record metrics duration = time.time() - start_time self._record_request_metrics( method=method, endpoint=endpoint, status_code=response.status_code, duration=duration, ) return response except Exception: # Record error metrics duration = time.time() - start_time self._record_request_metrics( method=method, endpoint=endpoint, status_code=500, # Internal server error duration=duration, ) logger.error( f"Request failed: {method} {path}", exc_info=True, extra={ "method": method, "path": path, "duration_seconds": duration, }, ) # Re-raise exception to be handled by error middleware raise finally: # Decrement in-flight requests counter http_requests_in_progress.labels(method=method, endpoint=endpoint).dec() def _get_endpoint_label(self, path: str) -> str: """ Get endpoint label for metrics, normalizing dynamic path segments. This prevents metric cardinality explosion by grouping similar paths. Args: path: Request path Returns: Normalized endpoint label """ # Health check endpoints if path.startswith("/health/"): return "/health/*" # Metrics endpoint if path == "/metrics": return "/metrics" # MCP protocol endpoints if path == "/sse" or path.startswith("/sse/"): return "/sse" if path == "/messages" or path.startswith("/messages/"): return "/messages" # OAuth/OIDC endpoints if path.startswith("/oauth/"): return "/oauth/*" if path.startswith("/oidc/"): return "/oidc/*" # Catch-all for other paths return path def _record_request_metrics( self, method: str, endpoint: str, status_code: int, duration: float, ) -> None: """ Record Prometheus metrics for an HTTP request. Args: method: HTTP method endpoint: Normalized endpoint label status_code: HTTP status code duration: Request duration in seconds """ # Record request count http_requests_total.labels( method=method, endpoint=endpoint, status_code=str(status_code), ).inc() # Record request duration http_request_duration_seconds.labels( method=method, endpoint=endpoint, ).observe(duration) # Log slow requests (>1 second) if duration > 1.0: logger.warning( f"Slow request: {method} {endpoint} took {duration:.3f}s", extra={ "method": method, "endpoint": endpoint, "status_code": status_code, "duration_seconds": duration, }, )