refactor: Move background tasks to server lifespan and deprecate SSE transport

- Move scanner/processor tasks from FastMCP session lifespan to Starlette
  server lifespan (correct architecture: background tasks run once at
  server level, not per-session)
- Change default CLI transport from SSE to streamable-http
- Remove SSE transport option from CLI (SSE is deprecated)
- Remove SSE client session factory from test fixtures
- Add tracing instrumentation to BM25 hybrid search operations for
  better observability

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
Chris Coutinho
2025-11-23 04:02:30 +01:00
parent 2ab8dad6a5
commit fafeaf3d83
5 changed files with 314 additions and 405 deletions
+155 -244
View File
@@ -555,15 +555,15 @@ async def load_oauth_client_credentials(
@asynccontextmanager @asynccontextmanager
async def app_lifespan_basic(server: FastMCP) -> AsyncIterator[AppContext]: async def app_lifespan_basic(server: FastMCP) -> AsyncIterator[AppContext]:
""" """
Manage application lifecycle for BasicAuth mode. Manage application lifecycle for BasicAuth mode (FastMCP session lifespan).
Creates a single Nextcloud client with basic authentication Creates a single Nextcloud client with basic authentication
that is shared across all requests. that is shared across all requests within a session.
If vector sync is enabled (VECTOR_SYNC_ENABLED=true), also starts Note: Background tasks (scanner, processor) are started at server level
background tasks for automatic document indexing (ADR-007). in starlette_lifespan, not here. This lifespan runs per-session.
""" """
logger.info("Starting MCP server in BasicAuth mode") logger.info("Starting MCP session in BasicAuth mode")
logger.info("Creating Nextcloud client with BasicAuth") logger.info("Creating Nextcloud client with BasicAuth")
client = NextcloudClient.from_env() client = NextcloudClient.from_env()
@@ -579,91 +579,12 @@ async def app_lifespan_basic(server: FastMCP) -> AsyncIterator[AppContext]:
# Initialize document processors # Initialize document processors
initialize_document_processors() initialize_document_processors()
settings = get_settings() # Yield client context - scanner runs at server level (starlette_lifespan)
try:
# Check if vector sync is enabled yield AppContext(client=client, storage=storage)
if settings.vector_sync_enabled: finally:
logger.info("Vector sync enabled - starting background tasks") logger.info("Shutting down BasicAuth session")
await client.close()
# Get username from environment for BasicAuth mode
username = os.getenv("NEXTCLOUD_USERNAME")
if not username:
raise ValueError(
"NEXTCLOUD_USERNAME is required for vector sync in BasicAuth mode"
)
# Initialize Qdrant collection before starting background tasks
logger.info("Initializing Qdrant collection...")
from nextcloud_mcp_server.vector.qdrant_client import get_qdrant_client
try:
await get_qdrant_client() # Triggers collection creation if needed
logger.info("Qdrant collection ready")
except Exception as e:
logger.error(f"Failed to initialize Qdrant collection: {e}")
raise RuntimeError(
f"Cannot start vector sync - Qdrant initialization failed: {e}"
) from e
# Initialize shared state
send_stream, receive_stream = anyio.create_memory_object_stream(
max_buffer_size=settings.vector_sync_queue_max_size
)
shutdown_event = anyio.Event()
scanner_wake_event = anyio.Event()
# Start background tasks using anyio TaskGroup
async with anyio.create_task_group() as tg:
# Start scanner task
await tg.start(
scanner_task,
send_stream,
shutdown_event,
scanner_wake_event,
client,
username,
)
# Start processor pool (each gets a cloned receive stream)
for i in range(settings.vector_sync_processor_workers):
await tg.start(
processor_task,
i,
receive_stream.clone(),
shutdown_event,
client,
username,
)
logger.info(
f"Background sync tasks started: 1 scanner + {settings.vector_sync_processor_workers} processors"
)
# Yield with background tasks running
try:
yield AppContext(
client=client,
storage=storage,
document_send_stream=send_stream,
document_receive_stream=receive_stream,
shutdown_event=shutdown_event,
scanner_wake_event=scanner_wake_event,
)
finally:
# Shutdown signal
logger.info("Shutting down background sync tasks")
shutdown_event.set()
# TaskGroup automatically cancels all tasks on exit
logger.info("Background sync tasks stopped")
await client.close()
else:
# No vector sync - simple lifecycle
try:
yield AppContext(client=client, storage=storage)
finally:
logger.info("Shutting down BasicAuth mode")
await client.close()
async def setup_oauth_config(): async def setup_oauth_config():
@@ -979,7 +900,7 @@ async def setup_oauth_config():
) )
def get_app(transport: str = "sse", enabled_apps: list[str] | None = None): def get_app(transport: str = "streamable-http", enabled_apps: list[str] | None = None):
# Initialize observability (logging will be configured by uvicorn) # Initialize observability (logging will be configured by uvicorn)
settings = get_settings() settings = get_settings()
@@ -1197,180 +1118,170 @@ def get_app(transport: str = "sse", enabled_apps: list[str] | None = None):
"Dynamic tool filtering enabled for OAuth mode (JWT and Bearer tokens)" "Dynamic tool filtering enabled for OAuth mode (JWT and Bearer tokens)"
) )
if transport == "sse": mcp_app = mcp.streamable_http_app()
mcp_app = mcp.sse_app()
starlette_lifespan = None
elif transport in ("http", "streamable-http"):
mcp_app = mcp.streamable_http_app()
@asynccontextmanager @asynccontextmanager
async def starlette_lifespan(app: Starlette): async def starlette_lifespan(app: Starlette):
# Set OAuth context for OAuth login routes (ADR-004) # Set OAuth context for OAuth login routes (ADR-004)
if oauth_enabled: if oauth_enabled:
# Prepare OAuth config from setup_oauth_config closure variables # Prepare OAuth config from setup_oauth_config closure variables
mcp_server_url = os.getenv( mcp_server_url = os.getenv(
"NEXTCLOUD_MCP_SERVER_URL", "http://localhost:8000" "NEXTCLOUD_MCP_SERVER_URL", "http://localhost:8000"
) )
nextcloud_resource_uri = os.getenv( nextcloud_resource_uri = os.getenv("NEXTCLOUD_RESOURCE_URI", nextcloud_host)
"NEXTCLOUD_RESOURCE_URI", nextcloud_host discovery_url = os.getenv(
) "OIDC_DISCOVERY_URL",
discovery_url = os.getenv( f"{nextcloud_host}/.well-known/openid-configuration",
"OIDC_DISCOVERY_URL", )
f"{nextcloud_host}/.well-known/openid-configuration", scopes = os.getenv("NEXTCLOUD_OIDC_SCOPES", "")
)
scopes = os.getenv("NEXTCLOUD_OIDC_SCOPES", "")
oauth_context_dict = { oauth_context_dict = {
"storage": refresh_token_storage, "storage": refresh_token_storage,
"oauth_client": oauth_client, "oauth_client": oauth_client,
"token_verifier": token_verifier, # For querying IdP userinfo endpoint "token_verifier": token_verifier, # For querying IdP userinfo endpoint
"config": { "config": {
"mcp_server_url": mcp_server_url, "mcp_server_url": mcp_server_url,
"discovery_url": discovery_url, "discovery_url": discovery_url,
"client_id": client_id, # From setup_oauth_config (DCR or static) "client_id": client_id, # From setup_oauth_config (DCR or static)
"client_secret": client_secret, # From setup_oauth_config (DCR or static) "client_secret": client_secret, # From setup_oauth_config (DCR or static)
"scopes": scopes, "scopes": scopes,
"nextcloud_host": nextcloud_host, "nextcloud_host": nextcloud_host,
"nextcloud_resource_uri": nextcloud_resource_uri, "nextcloud_resource_uri": nextcloud_resource_uri,
"oauth_provider": oauth_provider, "oauth_provider": oauth_provider,
}, },
} }
app.state.oauth_context = oauth_context_dict app.state.oauth_context = oauth_context_dict
# Also set oauth_context on browser_app for session authentication # Also set oauth_context on browser_app for session authentication
# browser_app is in the same function scope (defined later in create_app) # browser_app is in the same function scope (defined later in create_app)
# We need to find it in the mounted routes # We need to find it in the mounted routes
for route in app.routes: for route in app.routes:
if isinstance(route, Mount) and route.path == "/app": if isinstance(route, Mount) and route.path == "/app":
route.app.state.oauth_context = oauth_context_dict route.app.state.oauth_context = oauth_context_dict
logger.info( logger.info(
"OAuth context shared with browser_app for session auth" "OAuth context shared with browser_app for session auth"
)
break
logger.info(
f"OAuth context initialized for login routes (client_id={client_id[:16]}...)"
)
else:
# BasicAuth mode - share storage with browser_app for webhook management
from nextcloud_mcp_server.auth.storage import RefreshTokenStorage
storage = RefreshTokenStorage.from_env()
await storage.initialize()
app.state.storage = storage
# Also share with browser_app for webhook routes
for route in app.routes:
if isinstance(route, Mount) and route.path == "/app":
route.app.state.storage = storage
logger.info(
"Storage shared with browser_app for webhook management"
)
break
# Start background vector sync tasks for BasicAuth mode (ADR-007)
# For streamable-http transport, FastMCP lifespan isn't automatically triggered
# so we manually start background tasks here if vector sync is enabled
import anyio as anyio_module
settings = get_settings()
if not oauth_enabled and settings.vector_sync_enabled:
logger.info("Starting background vector sync tasks for BasicAuth mode")
# Get username from environment
username = os.getenv("NEXTCLOUD_USERNAME")
if not username:
raise ValueError(
"NEXTCLOUD_USERNAME required for vector sync in BasicAuth mode"
) )
break
# Get Nextcloud client from MCP app context logger.info(
# Create client since we're outside FastMCP lifespan f"OAuth context initialized for login routes (client_id={client_id[:16]}...)"
client = NextcloudClient.from_env() )
else:
# BasicAuth mode - share storage with browser_app for webhook management
from nextcloud_mcp_server.auth.storage import RefreshTokenStorage
# Initialize Qdrant collection before starting background tasks storage = RefreshTokenStorage.from_env()
logger.info("Initializing Qdrant collection...") await storage.initialize()
from nextcloud_mcp_server.vector.qdrant_client import get_qdrant_client
try: app.state.storage = storage
await get_qdrant_client() # Triggers collection creation if needed
logger.info("Qdrant collection ready")
except Exception as e:
logger.error(f"Failed to initialize Qdrant collection: {e}")
raise RuntimeError(
f"Cannot start vector sync - Qdrant initialization failed: {e}"
) from e
# Initialize shared state # Also share with browser_app for webhook routes
send_stream, receive_stream = anyio_module.create_memory_object_stream( for route in app.routes:
max_buffer_size=settings.vector_sync_queue_max_size if isinstance(route, Mount) and route.path == "/app":
route.app.state.storage = storage
logger.info(
"Storage shared with browser_app for webhook management"
)
break
# Start background vector sync tasks for BasicAuth mode (ADR-007)
# Scanner runs at server-level (once), not per-session
import anyio as anyio_module
settings = get_settings()
if not oauth_enabled and settings.vector_sync_enabled:
logger.info("Starting background vector sync tasks for BasicAuth mode")
# Get username from environment
username = os.getenv("NEXTCLOUD_USERNAME")
if not username:
raise ValueError(
"NEXTCLOUD_USERNAME required for vector sync in BasicAuth mode"
) )
shutdown_event = anyio_module.Event()
scanner_wake_event = anyio_module.Event()
# Store in app state for access from routes (ADR-007) # Create client for vector sync (server-level, not per-session)
app.state.document_send_stream = send_stream client = NextcloudClient.from_env()
app.state.document_receive_stream = receive_stream
app.state.shutdown_event = shutdown_event
app.state.scanner_wake_event = scanner_wake_event
# Also share with browser_app for /app route # Initialize Qdrant collection before starting background tasks
for route in app.routes: logger.info("Initializing Qdrant collection...")
if isinstance(route, Mount) and route.path == "/app": from nextcloud_mcp_server.vector.qdrant_client import get_qdrant_client
route.app.state.document_send_stream = send_stream
route.app.state.document_receive_stream = receive_stream
route.app.state.shutdown_event = shutdown_event
route.app.state.scanner_wake_event = scanner_wake_event
logger.info(
"Vector sync state shared with browser_app for /app"
)
break
# Start background tasks using anyio TaskGroup try:
async with anyio_module.create_task_group() as tg: await get_qdrant_client() # Triggers collection creation if needed
# Start scanner task logger.info("Qdrant collection ready")
except Exception as e:
logger.error(f"Failed to initialize Qdrant collection: {e}")
raise RuntimeError(
f"Cannot start vector sync - Qdrant initialization failed: {e}"
) from e
# Initialize shared state
send_stream, receive_stream = anyio_module.create_memory_object_stream(
max_buffer_size=settings.vector_sync_queue_max_size
)
shutdown_event = anyio_module.Event()
scanner_wake_event = anyio_module.Event()
# Store in app state for access from routes (ADR-007)
app.state.document_send_stream = send_stream
app.state.document_receive_stream = receive_stream
app.state.shutdown_event = shutdown_event
app.state.scanner_wake_event = scanner_wake_event
# Also share with browser_app for /app route
for route in app.routes:
if isinstance(route, Mount) and route.path == "/app":
route.app.state.document_send_stream = send_stream
route.app.state.document_receive_stream = receive_stream
route.app.state.shutdown_event = shutdown_event
route.app.state.scanner_wake_event = scanner_wake_event
logger.info("Vector sync state shared with browser_app for /app")
break
# Start background tasks using anyio TaskGroup
async with anyio_module.create_task_group() as tg:
# Start scanner task
await tg.start(
scanner_task,
send_stream,
shutdown_event,
scanner_wake_event,
client,
username,
)
# Start processor pool (each gets a cloned receive stream)
for i in range(settings.vector_sync_processor_workers):
await tg.start( await tg.start(
scanner_task, processor_task,
send_stream, i,
receive_stream.clone(),
shutdown_event, shutdown_event,
scanner_wake_event,
client, client,
username, username,
) )
# Start processor pool (each gets a cloned receive stream) logger.info(
for i in range(settings.vector_sync_processor_workers): f"Background sync tasks started: 1 scanner + "
await tg.start( f"{settings.vector_sync_processor_workers} processors"
processor_task, )
i,
receive_stream.clone(),
shutdown_event,
client,
username,
)
logger.info( # Run MCP session manager and yield
f"Background sync tasks started: 1 scanner + "
f"{settings.vector_sync_processor_workers} processors"
)
# Run MCP session manager and yield
async with AsyncExitStack() as stack:
await stack.enter_async_context(mcp.session_manager.run())
try:
yield
finally:
# Shutdown signal
logger.info("Shutting down background sync tasks")
shutdown_event.set()
await client.close()
# TaskGroup automatically cancels all tasks on exit
else:
# No vector sync - just run MCP session manager
async with AsyncExitStack() as stack: async with AsyncExitStack() as stack:
await stack.enter_async_context(mcp.session_manager.run()) await stack.enter_async_context(mcp.session_manager.run())
yield try:
yield
finally:
# Shutdown signal
logger.info("Shutting down background sync tasks")
shutdown_event.set()
await client.close()
# TaskGroup automatically cancels all tasks on exit
else:
# No vector sync - just run MCP session manager
async with AsyncExitStack() as stack:
await stack.enter_async_context(mcp.session_manager.run())
yield
# Health check endpoints for Kubernetes probes # Health check endpoints for Kubernetes probes
def health_live(request): def health_live(request):
+74 -37
View File
@@ -22,6 +22,7 @@ from starlette.requests import Request
from starlette.responses import HTMLResponse, JSONResponse from starlette.responses import HTMLResponse, JSONResponse
from nextcloud_mcp_server.config import get_settings from nextcloud_mcp_server.config import get_settings
from nextcloud_mcp_server.observability.tracing import trace_operation
from nextcloud_mcp_server.search import ( from nextcloud_mcp_server.search import (
BM25HybridSearchAlgorithm, BM25HybridSearchAlgorithm,
SemanticSearchAlgorithm, SemanticSearchAlgorithm,
@@ -139,7 +140,10 @@ async def vector_visualization_search(request: Request) -> JSONResponse:
_get_authenticated_client_for_userinfo, _get_authenticated_client_for_userinfo,
) )
async with await _get_authenticated_client_for_userinfo(request) as nc_client: # noqa: F841 with trace_operation("vector_viz.get_auth_client"):
auth_client_ctx = await _get_authenticated_client_for_userinfo(request)
async with auth_client_ctx as nc_client: # noqa: F841
# Create search algorithm (no client needed - verification removed) # Create search algorithm (no client needed - verification removed)
if algorithm == "semantic": if algorithm == "semantic":
search_algo = SemanticSearchAlgorithm(score_threshold=score_threshold) search_algo = SemanticSearchAlgorithm(score_threshold=score_threshold)
@@ -159,24 +163,40 @@ async def vector_visualization_search(request: Request) -> JSONResponse:
all_results = [] all_results = []
if doc_types is None or len(doc_types) == 0: if doc_types is None or len(doc_types) == 0:
# Cross-app search - search all indexed types # Cross-app search - search all indexed types
unverified_results = await search_algo.search( with trace_operation(
query=query, "vector_viz.search_execute",
user_id=username, attributes={
limit=limit * 2, # Buffer for verification filtering "search.algorithm": algorithm,
doc_type=None, # Search all types "search.limit": limit * 2,
score_threshold=score_threshold, "search.doc_type": "all",
) },
all_results.extend(unverified_results) ):
else:
# Search each document type and combine
for doc_type in doc_types:
unverified_results = await search_algo.search( unverified_results = await search_algo.search(
query=query, query=query,
user_id=username, user_id=username,
limit=limit * 2, # Buffer for verification filtering limit=limit * 2, # Buffer for verification filtering
doc_type=doc_type, doc_type=None, # Search all types
score_threshold=score_threshold, score_threshold=score_threshold,
) )
all_results.extend(unverified_results)
else:
# Search each document type and combine
for doc_type in doc_types:
with trace_operation(
"vector_viz.search_execute",
attributes={
"search.algorithm": algorithm,
"search.limit": limit * 2,
"search.doc_type": doc_type,
},
):
unverified_results = await search_algo.search(
query=query,
user_id=username,
limit=limit * 2, # Buffer for verification filtering
doc_type=doc_type,
score_threshold=score_threshold,
)
all_results.extend(unverified_results) all_results.extend(unverified_results)
# Sort by score before verification # Sort by score before verification
all_results.sort(key=lambda r: r.score, reverse=True) all_results.sort(key=lambda r: r.score, reverse=True)
@@ -190,22 +210,26 @@ async def vector_visualization_search(request: Request) -> JSONResponse:
# Store original scores and normalize for visualization # Store original scores and normalize for visualization
# (best result = 1.0, worst result = 0.0 within THIS result set) # (best result = 1.0, worst result = 0.0 within THIS result set)
# This makes visual encoding meaningful regardless of RRF normalization # This makes visual encoding meaningful regardless of RRF normalization
if search_results: with trace_operation(
scores = [r.score for r in search_results] "vector_viz.score_normalize",
min_score, max_score = min(scores), max(scores) attributes={"normalize.num_results": len(search_results)},
score_range = max_score - min_score if max_score > min_score else 1.0 ):
if search_results:
scores = [r.score for r in search_results]
min_score, max_score = min(scores), max(scores)
score_range = max_score - min_score if max_score > min_score else 1.0
logger.info( logger.info(
f"Normalizing scores for viz: original range [{min_score:.3f}, {max_score:.3f}] " f"Normalizing scores for viz: original range [{min_score:.3f}, {max_score:.3f}] "
f"→ [0.0, 1.0]" f"→ [0.0, 1.0]"
) )
# Store original score and rescale to 0-1 for visualization # Store original score and rescale to 0-1 for visualization
for r in search_results: for r in search_results:
# Store original score before normalization # Store original score before normalization
r.original_score = r.score r.original_score = r.score
# Rescale for visual encoding # Rescale for visual encoding
r.score = (r.score - min_score) / score_range r.score = (r.score - min_score) / score_range
if not search_results: if not search_results:
return JSONResponse( return JSONResponse(
@@ -220,7 +244,9 @@ async def vector_visualization_search(request: Request) -> JSONResponse:
# Fetch vectors for specific matching chunks from Qdrant using batch retrieve # Fetch vectors for specific matching chunks from Qdrant using batch retrieve
vector_fetch_start = time.perf_counter() vector_fetch_start = time.perf_counter()
qdrant_client = await get_qdrant_client()
with trace_operation("vector_viz.get_qdrant_client"):
qdrant_client = await get_qdrant_client()
chunk_vectors_map = {} # Map (doc_id, chunk_start, chunk_end) -> vector chunk_vectors_map = {} # Map (doc_id, chunk_start, chunk_end) -> vector
@@ -231,12 +257,16 @@ async def vector_visualization_search(request: Request) -> JSONResponse:
if point_ids: if point_ids:
# Single batch retrieve call instead of N sequential scroll calls # Single batch retrieve call instead of N sequential scroll calls
# This is ~50x faster for 50 results (1 HTTP request vs 50) # This is ~50x faster for 50 results (1 HTTP request vs 50)
points_response = await qdrant_client.retrieve( with trace_operation(
collection_name=settings.get_collection_name(), "vector_viz.vector_retrieve",
ids=point_ids, attributes={"retrieve.num_points": len(point_ids)},
with_vectors=["dense"], ):
with_payload=["doc_id", "chunk_start_offset", "chunk_end_offset"], points_response = await qdrant_client.retrieve(
) collection_name=settings.get_collection_name(),
ids=point_ids,
with_vectors=["dense"],
with_payload=["doc_id", "chunk_start_offset", "chunk_end_offset"],
)
# Build chunk_vectors_map from batch response # Build chunk_vectors_map from batch response
for point in points_response: for point in points_response:
@@ -367,9 +397,16 @@ async def vector_visualization_search(request: Request) -> JSONResponse:
import anyio import anyio
coords_3d, pca = await anyio.to_thread.run_sync( # type: ignore[attr-defined] with trace_operation(
lambda: _compute_pca(all_vectors_normalized) "vector_viz.pca_compute",
) attributes={
"pca.num_vectors": len(all_vectors_normalized),
"pca.embedding_dim": embedding_dim,
},
):
coords_3d, pca = await anyio.to_thread.run_sync( # type: ignore[attr-defined]
lambda: _compute_pca(all_vectors_normalized)
)
pca_duration = time.perf_counter() - pca_start pca_duration = time.perf_counter() - pca_start
# After fit, these attributes are guaranteed to be set # After fit, these attributes are guaranteed to be set
+2 -2
View File
@@ -29,9 +29,9 @@ from .app import get_app
@click.option( @click.option(
"--transport", "--transport",
"-t", "-t",
default="sse", default="streamable-http",
show_default=True, show_default=True,
type=click.Choice(["sse", "streamable-http", "http"]), type=click.Choice(["streamable-http", "http"]),
help="MCP transport protocol", help="MCP transport protocol",
) )
@click.option( @click.option(
+82 -67
View File
@@ -9,6 +9,7 @@ from qdrant_client.models import FieldCondition, Filter, MatchValue
from nextcloud_mcp_server.config import get_settings from nextcloud_mcp_server.config import get_settings
from nextcloud_mcp_server.embedding import get_bm25_service, get_embedding_service from nextcloud_mcp_server.embedding import get_bm25_service, get_embedding_service
from nextcloud_mcp_server.observability.metrics import record_qdrant_operation from nextcloud_mcp_server.observability.metrics import record_qdrant_operation
from nextcloud_mcp_server.observability.tracing import trace_operation
from nextcloud_mcp_server.search.algorithms import SearchAlgorithm, SearchResult from nextcloud_mcp_server.search.algorithms import SearchAlgorithm, SearchResult
from nextcloud_mcp_server.vector.placeholder import get_placeholder_filter from nextcloud_mcp_server.vector.placeholder import get_placeholder_filter
from nextcloud_mcp_server.vector.qdrant_client import get_qdrant_client from nextcloud_mcp_server.vector.qdrant_client import get_qdrant_client
@@ -99,15 +100,19 @@ class BM25HybridSearchAlgorithm(SearchAlgorithm):
) )
# Generate dense embedding for semantic search # Generate dense embedding for semantic search
embedding_service = get_embedding_service() with trace_operation("search.get_embedding_service"):
dense_embedding = await embedding_service.embed(query) embedding_service = get_embedding_service()
with trace_operation("search.dense_embedding"):
dense_embedding = await embedding_service.embed(query)
# Store for reuse by callers (e.g., viz_routes PCA visualization) # Store for reuse by callers (e.g., viz_routes PCA visualization)
self.query_embedding = dense_embedding self.query_embedding = dense_embedding
logger.debug(f"Generated dense embedding (dimension={len(dense_embedding)})") logger.debug(f"Generated dense embedding (dimension={len(dense_embedding)})")
# Generate sparse embedding for BM25 keyword search # Generate sparse embedding for BM25 keyword search
bm25_service = get_bm25_service() with trace_operation("search.get_bm25_service"):
sparse_embedding = await bm25_service.encode_async(query) bm25_service = get_bm25_service()
with trace_operation("search.sparse_embedding_bm25"):
sparse_embedding = await bm25_service.encode_async(query)
logger.debug( logger.debug(
f"Generated sparse embedding " f"Generated sparse embedding "
f"({len(sparse_embedding['indices'])} non-zero terms)" f"({len(sparse_embedding['indices'])} non-zero terms)"
@@ -134,38 +139,44 @@ class BM25HybridSearchAlgorithm(SearchAlgorithm):
query_filter = Filter(must=filter_conditions) query_filter = Filter(must=filter_conditions)
# Execute hybrid search with Qdrant native RRF fusion # Execute hybrid search with Qdrant native RRF fusion
qdrant_client = await get_qdrant_client() with trace_operation("search.get_qdrant_client"):
qdrant_client = await get_qdrant_client()
try: try:
# Use prefetch to run both dense and sparse searches # Use prefetch to run both dense and sparse searches
# Qdrant will automatically merge results using RRF # Qdrant will automatically merge results using RRF
search_response = await qdrant_client.query_points( with trace_operation(
collection_name=settings.get_collection_name(), "search.qdrant_query",
prefetch=[ attributes={"query.limit": limit * 2, "query.fusion": self.fusion_name},
# Dense semantic search ):
models.Prefetch( search_response = await qdrant_client.query_points(
query=dense_embedding, collection_name=settings.get_collection_name(),
using="dense", prefetch=[
limit=limit * 2, # Get extra for deduplication # Dense semantic search
filter=query_filter, models.Prefetch(
), query=dense_embedding,
# Sparse BM25 search using="dense",
models.Prefetch( limit=limit * 2, # Get extra for deduplication
query=models.SparseVector( filter=query_filter,
indices=sparse_embedding["indices"],
values=sparse_embedding["values"],
), ),
using="sparse", # Sparse BM25 search
limit=limit * 2, # Get extra for deduplication models.Prefetch(
filter=query_filter, query=models.SparseVector(
), indices=sparse_embedding["indices"],
], values=sparse_embedding["values"],
# Fusion query (RRF or DBSF based on initialization) ),
query=models.FusionQuery(fusion=self.fusion), using="sparse",
limit=limit * 2, # Get extra for deduplication limit=limit * 2, # Get extra for deduplication
score_threshold=score_threshold, filter=query_filter,
with_payload=True, ),
with_vectors=False, # Don't return vectors to save bandwidth ],
) # Fusion query (RRF or DBSF based on initialization)
query=models.FusionQuery(fusion=self.fusion),
limit=limit * 2, # Get extra for deduplication
score_threshold=score_threshold,
with_payload=True,
with_vectors=False, # Don't return vectors to save bandwidth
)
record_qdrant_operation("search", "success") record_qdrant_operation("search", "success")
except Exception: except Exception:
record_qdrant_operation("search", "error") record_qdrant_operation("search", "error")
@@ -185,47 +196,51 @@ class BM25HybridSearchAlgorithm(SearchAlgorithm):
# Deduplicate by (doc_id, doc_type, chunk_start, chunk_end) # Deduplicate by (doc_id, doc_type, chunk_start, chunk_end)
# This allows multiple chunks from same doc, but removes duplicate chunks # This allows multiple chunks from same doc, but removes duplicate chunks
seen_chunks = set() with trace_operation(
results = [] "search.deduplicate",
attributes={"dedupe.num_points": len(search_response.points)},
):
seen_chunks = set()
results = []
for result in search_response.points: for result in search_response.points:
# doc_id can be int (notes) or str (files - file paths) # doc_id can be int (notes) or str (files - file paths)
doc_id = result.payload["doc_id"] doc_id = result.payload["doc_id"]
doc_type = result.payload.get("doc_type", "note") doc_type = result.payload.get("doc_type", "note")
chunk_start = result.payload.get("chunk_start_offset") chunk_start = result.payload.get("chunk_start_offset")
chunk_end = result.payload.get("chunk_end_offset") chunk_end = result.payload.get("chunk_end_offset")
chunk_key = (doc_id, doc_type, chunk_start, chunk_end) chunk_key = (doc_id, doc_type, chunk_start, chunk_end)
# Skip if we've already seen this exact chunk # Skip if we've already seen this exact chunk
if chunk_key in seen_chunks: if chunk_key in seen_chunks:
continue continue
seen_chunks.add(chunk_key) seen_chunks.add(chunk_key)
# Return unverified results (verification happens at output stage) # Return unverified results (verification happens at output stage)
results.append( results.append(
SearchResult( SearchResult(
id=doc_id, id=doc_id,
doc_type=doc_type, doc_type=doc_type,
title=result.payload.get("title", "Untitled"), title=result.payload.get("title", "Untitled"),
excerpt=result.payload.get("excerpt", ""), excerpt=result.payload.get("excerpt", ""),
score=result.score, # Fusion score (RRF or DBSF) score=result.score, # Fusion score (RRF or DBSF)
metadata={ metadata={
"chunk_index": result.payload.get("chunk_index"), "chunk_index": result.payload.get("chunk_index"),
"total_chunks": result.payload.get("total_chunks"), "total_chunks": result.payload.get("total_chunks"),
"search_method": f"bm25_hybrid_{self.fusion_name}", "search_method": f"bm25_hybrid_{self.fusion_name}",
}, },
chunk_start_offset=result.payload.get("chunk_start_offset"), chunk_start_offset=result.payload.get("chunk_start_offset"),
chunk_end_offset=result.payload.get("chunk_end_offset"), chunk_end_offset=result.payload.get("chunk_end_offset"),
page_number=result.payload.get("page_number"), page_number=result.payload.get("page_number"),
chunk_index=result.payload.get("chunk_index", 0), chunk_index=result.payload.get("chunk_index", 0),
total_chunks=result.payload.get("total_chunks", 1), total_chunks=result.payload.get("total_chunks", 1),
point_id=str(result.id), # Qdrant point ID for batch retrieval point_id=str(result.id), # Qdrant point ID for batch retrieval
)
) )
)
if len(results) >= limit: if len(results) >= limit:
break break
logger.info(f"Returning {len(results)} unverified results after deduplication") logger.info(f"Returning {len(results)} unverified results after deduplication")
if results: if results:
+1 -55
View File
@@ -9,7 +9,6 @@ import pytest
from httpx import HTTPStatusError from httpx import HTTPStatusError
from mcp import ClientSession from mcp import ClientSession
from mcp.client.session import RequestContext from mcp.client.session import RequestContext
from mcp.client.sse import sse_client
from mcp.client.streamable_http import streamablehttp_client from mcp.client.streamable_http import streamablehttp_client
from mcp.types import ElicitRequestParams, ElicitResult, ErrorData from mcp.types import ElicitRequestParams, ElicitResult, ErrorData
@@ -172,51 +171,6 @@ async def create_mcp_client_session(
logger.debug(f"{client_name} client session cleaned up successfully") logger.debug(f"{client_name} client session cleaned up successfully")
async def create_mcp_client_session_sse(
url: str,
token: str | None = None,
client_name: str = "MCP",
elicitation_callback: Any = None,
) -> AsyncGenerator[ClientSession, Any]:
"""
Factory function to create an MCP client session using SSE transport.
Similar to create_mcp_client_session but uses SSE transport instead of streamable-http.
Uses native async context managers to ensure correct LIFO cleanup order.
Args:
url: MCP server URL (e.g., "http://localhost:8000/sse")
token: Optional OAuth access token for Bearer authentication
client_name: Client name for logging (e.g., "Basic MCP (SSE)")
elicitation_callback: Optional callback for handling elicitation requests
Yields:
Initialized MCP ClientSession
Note:
SSE transport is being deprecated in favor of streamable-http.
This function exists for compatibility testing only.
"""
logger.info(f"Creating SSE client for {client_name}")
# Prepare headers with OAuth token if provided
headers = {"Authorization": f"Bearer {token}"} if token else None
# Use native async with - Python ensures LIFO cleanup
# Cleanup order will be: ClientSession.__aexit__ -> sse_client.__aexit__
# Note: sse_client yields only (read_stream, write_stream), not 3 values like streamablehttp_client
async with sse_client(url, headers=headers) as (read_stream, write_stream):
async with ClientSession(
read_stream, write_stream, elicitation_callback=elicitation_callback
) as session:
await session.initialize()
logger.info(f"{client_name} client session initialized successfully")
yield session
# Cleanup happens automatically in LIFO order - no exception suppression needed
logger.debug(f"{client_name} client session cleaned up successfully")
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
async def nc_client(anyio_backend) -> AsyncGenerator[NextcloudClient, Any]: async def nc_client(anyio_backend) -> AsyncGenerator[NextcloudClient, Any]:
""" """
@@ -255,18 +209,10 @@ async def nc_client(anyio_backend) -> AsyncGenerator[NextcloudClient, Any]:
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
async def nc_mcp_client(anyio_backend) -> AsyncGenerator[ClientSession, Any]: async def nc_mcp_client(anyio_backend) -> AsyncGenerator[ClientSession, Any]:
""" """
Fixture to create an MCP client session for integration tests using SSE transport. Fixture to create an MCP client session for integration tests using streamable-http.
Uses anyio pytest plugin for proper async fixture handling. Uses anyio pytest plugin for proper async fixture handling.
Note: SSE transport is being deprecated. This fixture uses SSE for compatibility testing.
""" """
# async for session in create_mcp_client_session_sse(
# url="http://localhost:8000/sse", client_name="Basic MCP (SSE)"
# ):
# yield session
async for session in create_mcp_client_session( async for session in create_mcp_client_session(
url="http://localhost:8000/mcp", url="http://localhost:8000/mcp",
client_name="Basic MCP (HTTP)", client_name="Basic MCP (HTTP)",