"""MCP sampling support for integration tests. This module provides utilities to enable real LLM-based sampling in integration tests using OpenAI or GitHub Models API. """ import logging from typing import Any from mcp import types from mcp.client.session import ClientSession, RequestContext from nextcloud_mcp_server.providers.openai import OpenAIProvider logger = logging.getLogger(__name__) def create_openai_sampling_callback(provider: OpenAIProvider): """Factory to create a sampling callback using OpenAI provider. The callback conforms to MCP's SamplingFnT protocol and can be passed to ClientSession for handling sampling requests from the server. Args: provider: OpenAIProvider instance configured with a generation model Returns: Async callback function for MCP sampling Example: ```python provider = OpenAIProvider( api_key=os.getenv("OPENAI_API_KEY"), base_url=os.getenv("OPENAI_BASE_URL"), generation_model="gpt-4o-mini", ) callback = create_openai_sampling_callback(provider) async for session in create_mcp_client_session( url="http://localhost:8000/mcp", sampling_callback=callback, ): # Session now supports sampling pass ``` """ async def sampling_callback( context: RequestContext[ClientSession, Any], params: types.CreateMessageRequestParams, ) -> types.CreateMessageResult | types.ErrorData: """Handle sampling requests using OpenAI provider.""" logger.debug(f"Sampling callback invoked with {len(params.messages)} messages") # Extract messages and build prompt messages_text = [] for msg in params.messages: if hasattr(msg.content, "text"): role_prefix = "User" if msg.role == "user" else "Assistant" messages_text.append(f"{role_prefix}: {msg.content.text}") prompt = "\n\n".join(messages_text) # Add system prompt if provided if params.systemPrompt: prompt = f"System: {params.systemPrompt}\n\n{prompt}" logger.debug(f"Generating response for prompt ({len(prompt)} chars)") try: # Generate response using OpenAI provider # Note: temperature is hardcoded in the provider at 0.7 response = await provider.generate( prompt=prompt, max_tokens=params.maxTokens, ) model_name = provider.generation_model or "unknown" logger.info(f"Sampling completed: {len(response)} chars from {model_name}") return types.CreateMessageResult( role="assistant", content=types.TextContent(type="text", text=response), model=model_name, stopReason="endTurn", ) except Exception as e: logger.error(f"OpenAI generation failed: {e}") return types.ErrorData( code=types.INTERNAL_ERROR, message=f"OpenAI generation failed: {e!s}", ) return sampling_callback