feat: Add OpenAI provider support for embeddings and generation

Adds OpenAI provider to the unified provider architecture (ADR-015),
supporting:
- OpenAI API (api.openai.com)
- GitHub Models API (models.github.ai/inference)
- OpenAI-compatible endpoints (Fireworks, Together, etc.)

Features:
- Embedding support with text-embedding-3-small/large models
- Text generation via chat completions API
- Automatic retry with exponential backoff for rate limits
- Provider auto-detection in registry (priority after Bedrock)

Environment variables:
- OPENAI_API_KEY: API key (required)
- OPENAI_BASE_URL: Base URL override (optional)
- OPENAI_EMBEDDING_MODEL: Embedding model (default: text-embedding-3-small)
- OPENAI_GENERATION_MODEL: Generation model (default: gpt-4o-mini)

Also adds:
- Integration tests for RAG pipeline with MCP sampling
- MCP client sampling support for integration tests
- Ground truth Q&A pairs for Nextcloud User Manual

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
Chris Coutinho
2025-11-23 00:33:32 +01:00
parent 959cb8b21a
commit 208365cd3d
14 changed files with 1176 additions and 21 deletions
+38 -3
View File
@@ -217,6 +217,11 @@ class Settings:
ollama_embedding_model: str = "nomic-embed-text"
ollama_verify_ssl: bool = True
# OpenAI settings (for embeddings)
openai_api_key: Optional[str] = None
openai_base_url: Optional[str] = None
openai_embedding_model: str = "text-embedding-3-small"
# Document chunking settings (for vector embeddings)
document_chunk_size: int = 2048 # Characters per chunk
document_chunk_overlap: int = 200 # Overlapping characters between chunks
@@ -275,6 +280,29 @@ class Settings:
f"DOCUMENT_CHUNK_OVERLAP ({self.document_chunk_overlap}) cannot be negative."
)
def get_embedding_model_name(self) -> str:
"""
Get the active embedding model name based on provider priority.
Priority order (same as ProviderRegistry):
1. OpenAI - if OPENAI_API_KEY is set
2. Ollama - if OLLAMA_BASE_URL is set
3. Simple - fallback (returns "simple-384")
Returns:
Active embedding model name
"""
# Check OpenAI first (higher priority than Ollama in registry)
if self.openai_api_key:
return self.openai_embedding_model
# Check Ollama
if self.ollama_base_url:
return self.ollama_embedding_model
# Fallback to simple provider indicator
return "simple-384"
def get_collection_name(self) -> str:
"""
Get Qdrant collection name.
@@ -290,8 +318,9 @@ class Settings:
Format: {deployment-id}-{model-name}
Examples:
- "my-deployment-nomic-embed-text" (OTEL_SERVICE_NAME set)
- "mcp-container-all-minilm" (hostname fallback)
- "my-deployment-nomic-embed-text" (Ollama)
- "my-deployment-text-embedding-3-small" (OpenAI)
- "mcp-container-openai-text-embedding-3-small" (hostname fallback)
Returns:
Collection name string
@@ -311,7 +340,7 @@ class Settings:
# Sanitize deployment ID and model name
deployment_id = deployment_id.lower().replace(" ", "-").replace("_", "-")
model_name = self.ollama_embedding_model.replace("/", "-").replace(":", "-")
model_name = self.get_embedding_model_name().replace("/", "-").replace(":", "-")
return f"{deployment_id}-{model_name}"
@@ -371,6 +400,12 @@ def get_settings() -> Settings:
ollama_base_url=os.getenv("OLLAMA_BASE_URL"),
ollama_embedding_model=os.getenv("OLLAMA_EMBEDDING_MODEL", "nomic-embed-text"),
ollama_verify_ssl=os.getenv("OLLAMA_VERIFY_SSL", "true").lower() == "true",
# OpenAI settings
openai_api_key=os.getenv("OPENAI_API_KEY"),
openai_base_url=os.getenv("OPENAI_BASE_URL"),
openai_embedding_model=os.getenv(
"OPENAI_EMBEDDING_MODEL", "text-embedding-3-small"
),
# Document chunking settings
document_chunk_size=int(os.getenv("DOCUMENT_CHUNK_SIZE", "2048")),
document_chunk_overlap=int(os.getenv("DOCUMENT_CHUNK_OVERLAP", "200")),
@@ -4,12 +4,14 @@ from .anthropic import AnthropicProvider
from .base import Provider
from .bedrock import BedrockProvider
from .ollama import OllamaProvider
from .openai import OpenAIProvider
from .registry import get_provider, reset_provider
from .simple import SimpleProvider
__all__ = [
"Provider",
"OllamaProvider",
"OpenAIProvider",
"AnthropicProvider",
"SimpleProvider",
"BedrockProvider",
+227
View File
@@ -0,0 +1,227 @@
"""Unified OpenAI provider for embeddings and text generation.
Supports:
- OpenAI's standard API
- GitHub Models API (models.github.ai)
- Any OpenAI-compatible API via base_url override
"""
import logging
from openai import AsyncOpenAI
from .base import Provider
logger = logging.getLogger(__name__)
# Well-known embedding dimensions for OpenAI models
OPENAI_EMBEDDING_DIMENSIONS: dict[str, int] = {
"text-embedding-3-small": 1536,
"text-embedding-3-large": 3072,
"text-embedding-ada-002": 1536,
# GitHub Models API uses openai/ prefix
"openai/text-embedding-3-small": 1536,
"openai/text-embedding-3-large": 3072,
}
class OpenAIProvider(Provider):
"""
OpenAI provider supporting both embeddings and text generation.
Works with:
- OpenAI's standard API (api.openai.com)
- GitHub Models API (models.github.ai)
- Any OpenAI-compatible API (via base_url)
"""
def __init__(
self,
api_key: str,
base_url: str | None = None,
embedding_model: str | None = None,
generation_model: str | None = None,
timeout: float = 120.0,
):
"""
Initialize OpenAI provider.
Args:
api_key: OpenAI API key (or GITHUB_TOKEN for GitHub Models)
base_url: Base URL override (e.g., "https://models.github.ai/inference")
embedding_model: Model for embeddings (e.g., "text-embedding-3-small").
None disables embeddings.
generation_model: Model for text generation (e.g., "gpt-4o-mini").
None disables generation.
timeout: HTTP timeout in seconds (default: 120)
"""
self.embedding_model = embedding_model
self.generation_model = generation_model
self._dimension: int | None = None
# Initialize async client
self.client = AsyncOpenAI(
api_key=api_key,
base_url=base_url,
timeout=timeout,
)
# Try to get known dimension without API call
if embedding_model and embedding_model in OPENAI_EMBEDDING_DIMENSIONS:
self._dimension = OPENAI_EMBEDDING_DIMENSIONS[embedding_model]
logger.info(
f"Initialized OpenAI provider: base_url={base_url or 'default'} "
f"(embedding_model={embedding_model}, generation_model={generation_model}, "
f"dimension={self._dimension})"
)
@property
def supports_embeddings(self) -> bool:
"""Whether this provider supports embedding generation."""
return self.embedding_model is not None
@property
def supports_generation(self) -> bool:
"""Whether this provider supports text generation."""
return self.generation_model is not None
async def embed(self, text: str) -> list[float]:
"""
Generate embedding vector for text.
Args:
text: Input text to embed
Returns:
Vector embedding as list of floats
Raises:
NotImplementedError: If embeddings not enabled (no embedding_model)
"""
if not self.supports_embeddings:
raise NotImplementedError(
"Embedding not supported - no embedding_model configured"
)
response = await self.client.embeddings.create(
input=text,
model=self.embedding_model,
)
embedding = response.data[0].embedding
# Update dimension if not set
if self._dimension is None:
self._dimension = len(embedding)
logger.info(
f"Detected embedding dimension: {self._dimension} "
f"for model {self.embedding_model}"
)
return embedding
async def embed_batch(self, texts: list[str]) -> list[list[float]]:
"""
Generate embeddings for multiple texts using OpenAI's batch API.
OpenAI supports up to 2048 inputs per request.
Args:
texts: List of texts to embed
Returns:
List of vector embeddings
Raises:
NotImplementedError: If embeddings not enabled (no embedding_model)
"""
if not self.supports_embeddings:
raise NotImplementedError(
"Embedding not supported - no embedding_model configured"
)
if not texts:
return []
# OpenAI supports batches up to 2048, but use smaller batches for safety
batch_size = 100
all_embeddings: list[list[float]] = []
for i in range(0, len(texts), batch_size):
batch = texts[i : i + batch_size]
response = await self.client.embeddings.create(
input=batch,
model=self.embedding_model,
)
# Sort by index to maintain order
sorted_data = sorted(response.data, key=lambda x: x.index)
batch_embeddings = [item.embedding for item in sorted_data]
all_embeddings.extend(batch_embeddings)
# Update dimension if not set
if self._dimension is None and batch_embeddings:
self._dimension = len(batch_embeddings[0])
logger.info(
f"Detected embedding dimension: {self._dimension} "
f"for model {self.embedding_model}"
)
return all_embeddings
def get_dimension(self) -> int:
"""
Get embedding dimension.
Returns:
Vector dimension for the configured embedding model
Raises:
NotImplementedError: If embeddings not enabled (no embedding_model)
RuntimeError: If dimension not detected yet (call embed first)
"""
if not self.supports_embeddings:
raise NotImplementedError(
"Embedding not supported - no embedding_model configured"
)
if self._dimension is None:
raise RuntimeError(
f"Embedding dimension not detected yet for model {self.embedding_model}. "
"Call embed() first or use a known model."
)
return self._dimension
async def generate(self, prompt: str, max_tokens: int = 500) -> str:
"""
Generate text from a prompt.
Args:
prompt: The prompt to generate from
max_tokens: Maximum tokens to generate
Returns:
Generated text
Raises:
NotImplementedError: If generation not enabled (no generation_model)
"""
if not self.supports_generation:
raise NotImplementedError(
"Text generation not supported - no generation_model configured"
)
response = await self.client.chat.completions.create(
model=self.generation_model,
messages=[{"role": "user", "content": prompt}],
max_tokens=max_tokens,
temperature=0.7,
)
return response.choices[0].message.content or ""
async def close(self) -> None:
"""Close HTTP client."""
await self.client.close()
+38 -8
View File
@@ -6,6 +6,7 @@ import os
from .base import Provider
from .bedrock import BedrockProvider
from .ollama import OllamaProvider
from .openai import OpenAIProvider
from .simple import SimpleProvider
logger = logging.getLogger(__name__)
@@ -17,8 +18,9 @@ class ProviderRegistry:
Checks environment variables in priority order and creates appropriate provider:
1. Bedrock (AWS_REGION + BEDROCK_*_MODEL)
2. Ollama (OLLAMA_BASE_URL)
3. Simple (fallback for testing/development)
2. OpenAI (OPENAI_API_KEY)
3. Ollama (OLLAMA_BASE_URL)
4. Simple (fallback for testing/development)
"""
@staticmethod
@@ -28,8 +30,9 @@ class ProviderRegistry:
Priority order:
1. Bedrock - if AWS_REGION or BEDROCK_EMBEDDING_MODEL is set
2. Ollama - if OLLAMA_BASE_URL is set
3. Simple - fallback for testing/development
2. OpenAI - if OPENAI_API_KEY is set
3. Ollama - if OLLAMA_BASE_URL is set
4. Simple - fallback for testing/development
Returns:
Provider instance
@@ -42,6 +45,12 @@ class ProviderRegistry:
- BEDROCK_EMBEDDING_MODEL: Model ID for embeddings (e.g., "amazon.titan-embed-text-v2:0")
- BEDROCK_GENERATION_MODEL: Model ID for text generation (e.g., "anthropic.claude-3-sonnet-20240229-v1:0")
OpenAI:
- OPENAI_API_KEY: OpenAI API key (or GITHUB_TOKEN for GitHub Models)
- OPENAI_BASE_URL: Base URL override (e.g., "https://models.github.ai/inference")
- OPENAI_EMBEDDING_MODEL: Model for embeddings (default: "text-embedding-3-small")
- OPENAI_GENERATION_MODEL: Model for text generation (e.g., "gpt-4o-mini")
Ollama:
- OLLAMA_BASE_URL: Ollama API base URL (e.g., "http://localhost:11434")
- OLLAMA_EMBEDDING_MODEL: Model for embeddings (default: "nomic-embed-text")
@@ -70,7 +79,28 @@ class ProviderRegistry:
aws_secret_access_key=os.getenv("AWS_SECRET_ACCESS_KEY"),
)
# 2. Check for Ollama
# 2. Check for OpenAI
openai_api_key = os.getenv("OPENAI_API_KEY")
if openai_api_key:
base_url = os.getenv("OPENAI_BASE_URL")
embedding_model = os.getenv(
"OPENAI_EMBEDDING_MODEL", "text-embedding-3-small"
)
generation_model = os.getenv("OPENAI_GENERATION_MODEL")
logger.info(
f"Using OpenAI provider: base_url={base_url or 'default'}, "
f"embedding_model={embedding_model}, "
f"generation_model={generation_model}"
)
return OpenAIProvider(
api_key=openai_api_key,
base_url=base_url,
embedding_model=embedding_model,
generation_model=generation_model,
)
# 3. Check for Ollama (local LLM)
ollama_url = os.getenv("OLLAMA_BASE_URL")
if ollama_url:
embedding_model = os.getenv("OLLAMA_EMBEDDING_MODEL", "nomic-embed-text")
@@ -89,12 +119,12 @@ class ProviderRegistry:
verify_ssl=verify_ssl,
)
# 3. Fallback to Simple provider for development/testing
# 4. Fallback to Simple provider for development/testing
dimension = int(os.getenv("SIMPLE_EMBEDDING_DIMENSION", "384"))
logger.warning(
"No provider configured (AWS_REGION, OLLAMA_BASE_URL not set). "
"No provider configured (AWS_REGION, OPENAI_API_KEY, OLLAMA_BASE_URL not set). "
"Using SimpleProvider for testing/development. "
"For production, configure Bedrock or Ollama."
"For production, configure Bedrock, OpenAI, or Ollama."
)
return SimpleProvider(dimension=dimension)
+7 -5
View File
@@ -93,27 +93,29 @@ async def get_qdrant_client() -> AsyncQdrantClient:
# Validate dimension matches
if actual_dimension != expected_dimension:
embedding_model = settings.get_embedding_model_name()
raise ValueError(
f"Dimension mismatch for collection '{collection_name}':\n"
f" Expected: {expected_dimension} (from embedding model '{settings.ollama_embedding_model}')\n"
f" Expected: {expected_dimension} (from embedding model '{embedding_model}')\n"
f" Found: {actual_dimension}\n"
f"This usually means you changed the embedding model.\n"
f"Solutions:\n"
f" 1. Delete the old collection: Collection will be recreated with new dimensions\n"
f" 2. Set QDRANT_COLLECTION to use a different collection name\n"
f" 3. Revert OLLAMA_EMBEDDING_MODEL to the original model"
f" 3. Revert to the original embedding model"
)
logger.info(
f"Using existing Qdrant collection: {collection_name} "
f"(dimension={actual_dimension}, model={settings.ollama_embedding_model})"
f"(dimension={actual_dimension}, model={settings.get_embedding_model_name()})"
)
else:
# Collection doesn't exist - create it
embedding_model = settings.get_embedding_model_name()
logger.info(
f"Collection '{collection_name}' not found, creating with "
f"dimension={expected_dimension}, model={settings.ollama_embedding_model}..."
f"dimension={expected_dimension}, model={embedding_model}..."
)
await _qdrant_client.create_collection(
collection_name=collection_name,
@@ -134,7 +136,7 @@ async def get_qdrant_client() -> AsyncQdrantClient:
logger.info(
f"Created Qdrant collection: {collection_name}\n"
f" Dense vector dimension: {expected_dimension}\n"
f" Dense embedding model: {settings.ollama_embedding_model}\n"
f" Dense embedding model: {embedding_model}\n"
f" Sparse vectors: BM25 (for hybrid search)\n"
f" Distance: COSINE\n"
f"Background sync will index all documents with dense + sparse vectors."
+1
View File
@@ -39,6 +39,7 @@ dependencies = [
"pymupdf>=1.26.6",
"pymupdf4llm>=0.2.2",
"pymupdf-layout>=1.26.6",
"openai>=2.8.1",
]
classifiers = [
"Development Status :: 4 - Beta",
+7 -1
View File
@@ -114,6 +114,7 @@ async def create_mcp_client_session(
token: str | None = None,
client_name: str = "MCP",
elicitation_callback: Any = None,
sampling_callback: Any = None,
) -> AsyncGenerator[ClientSession, Any]:
"""
Factory function to create an MCP client session with proper lifecycle management.
@@ -133,6 +134,8 @@ async def create_mcp_client_session(
client_name: Client name for logging (e.g., "OAuth MCP (Playwright)")
elicitation_callback: Optional callback for handling elicitation requests.
Should match signature: async def callback(context: RequestContext, params: ElicitRequestParams) -> ElicitResult | ErrorData
sampling_callback: Optional callback for handling sampling (LLM generation) requests.
Should match signature: async def callback(context: RequestContext, params: CreateMessageRequestParams) -> CreateMessageResult | ErrorData
Yields:
Initialized MCP ClientSession
@@ -156,7 +159,10 @@ async def create_mcp_client_session(
_,
):
async with ClientSession(
read_stream, write_stream, elicitation_callback=elicitation_callback
read_stream,
write_stream,
elicitation_callback=elicitation_callback,
sampling_callback=sampling_callback,
) as session:
await session.initialize()
logger.info(f"{client_name} client session initialized successfully")
@@ -0,0 +1,37 @@
[
{
"id": "nc-manual-001",
"query": "What is two-factor authentication and how does it protect my Nextcloud account?",
"ground_truth": "Two-factor authentication (2FA) protects your Nextcloud account by requiring two different proofs of identity - something you know (like a password) and something you have (like a code from your phone). The first factor is typically a password, and the second can be a text message or code generated on your phone.",
"expected_topics": ["two-factor authentication", "2FA", "password", "security"],
"difficulty": "easy"
},
{
"id": "nc-manual-002",
"query": "How do file quotas work in Nextcloud when sharing files?",
"ground_truth": "When you share files with other users, the shared files count against the original share owner's quota. When you share a folder and allow others to upload files, all uploaded and edited files count against your quota. Re-shared files still count against the original share owner's quota. Deleted files in trash don't count against quotas until trash exceeds 50% of quota.",
"expected_topics": ["quota", "sharing", "files", "storage"],
"difficulty": "medium"
},
{
"id": "nc-manual-003",
"query": "How do I install the Nextcloud desktop sync client on Linux?",
"ground_truth": "Linux users must follow instructions on the download page to add the appropriate repository for their Linux distribution, install the signing key, and use their package managers to install the desktop sync client. Linux users also need a password manager enabled, such as GNOME Keyring or KWallet, so the sync client can login automatically.",
"expected_topics": ["Linux", "desktop client", "installation", "package manager", "GNOME Keyring", "KWallet"],
"difficulty": "medium"
},
{
"id": "nc-manual-004",
"query": "What are the system requirements for the Nextcloud desktop client on Windows?",
"ground_truth": "The Nextcloud desktop sync client requires Windows 10 or later, 64-bits only.",
"expected_topics": ["Windows", "system requirements", "desktop client"],
"difficulty": "easy"
},
{
"id": "nc-manual-005",
"query": "How do I use client applications with two-factor authentication enabled?",
"ground_truth": "Once you have enabled 2FA, your clients will no longer be able to connect with just your password unless they also support two-factor authentication. To solve this, you should generate device-specific passwords for them. This is managed through the connected browsers and devices settings.",
"expected_topics": ["2FA", "client applications", "device-specific passwords", "app passwords"],
"difficulty": "medium"
}
]
+94
View File
@@ -0,0 +1,94 @@
"""MCP sampling support for integration tests.
This module provides utilities to enable real LLM-based sampling in integration tests
using OpenAI or GitHub Models API.
"""
import logging
from typing import Any
from mcp import types
from mcp.client.session import ClientSession, RequestContext
from nextcloud_mcp_server.providers.openai import OpenAIProvider
logger = logging.getLogger(__name__)
def create_openai_sampling_callback(provider: OpenAIProvider):
"""Factory to create a sampling callback using OpenAI provider.
The callback conforms to MCP's SamplingFnT protocol and can be passed
to ClientSession for handling sampling requests from the server.
Args:
provider: OpenAIProvider instance configured with a generation model
Returns:
Async callback function for MCP sampling
Example:
```python
provider = OpenAIProvider(
api_key=os.getenv("OPENAI_API_KEY"),
base_url=os.getenv("OPENAI_BASE_URL"),
generation_model="gpt-4o-mini",
)
callback = create_openai_sampling_callback(provider)
async for session in create_mcp_client_session(
url="http://localhost:8000/mcp",
sampling_callback=callback,
):
# Session now supports sampling
pass
```
"""
async def sampling_callback(
context: RequestContext[ClientSession, Any],
params: types.CreateMessageRequestParams,
) -> types.CreateMessageResult | types.ErrorData:
"""Handle sampling requests using OpenAI provider."""
logger.debug(f"Sampling callback invoked with {len(params.messages)} messages")
# Extract messages and build prompt
messages_text = []
for msg in params.messages:
if hasattr(msg.content, "text"):
role_prefix = "User" if msg.role == "user" else "Assistant"
messages_text.append(f"{role_prefix}: {msg.content.text}")
prompt = "\n\n".join(messages_text)
# Add system prompt if provided
if params.systemPrompt:
prompt = f"System: {params.systemPrompt}\n\n{prompt}"
logger.debug(f"Generating response for prompt ({len(prompt)} chars)")
try:
# Generate response using OpenAI provider
# Note: temperature is hardcoded in the provider at 0.7
response = await provider.generate(
prompt=prompt,
max_tokens=params.maxTokens,
)
model_name = provider.generation_model or "unknown"
logger.info(f"Sampling completed: {len(response)} chars from {model_name}")
return types.CreateMessageResult(
role="assistant",
content=types.TextContent(type="text", text=response),
model=model_name,
stopReason="endTurn",
)
except Exception as e:
logger.error(f"OpenAI generation failed: {e}")
return types.ErrorData(
code=types.INTERNAL_ERROR,
message=f"OpenAI generation failed: {e!s}",
)
return sampling_callback
+300
View File
@@ -0,0 +1,300 @@
"""Integration tests for RAG pipeline with OpenAI/GitHub Models API.
These tests validate the complete semantic search and MCP sampling flow using:
1. OpenAI embeddings for semantic search
2. MCP sampling for answer generation
3. Pre-indexed Nextcloud User Manual as the knowledge base
Environment Variables:
OPENAI_API_KEY: OpenAI API key or GitHub token for models.github.ai
OPENAI_BASE_URL: Base URL override (e.g., "https://models.github.ai/inference")
OPENAI_EMBEDDING_MODEL: Embedding model (default: "text-embedding-3-small")
OPENAI_GENERATION_MODEL: Generation model for sampling (default: "gpt-4o-mini")
For GitHub CI, set:
OPENAI_API_KEY: ${{ secrets.GITHUB_TOKEN }}
OPENAI_BASE_URL: https://models.github.ai/inference
OPENAI_EMBEDDING_MODEL: openai/text-embedding-3-small
OPENAI_GENERATION_MODEL: openai/gpt-4o-mini
Prerequisites:
- Nextcloud User Manual indexed in Qdrant (via vector sync)
- VECTOR_SYNC_ENABLED=true on the MCP server
"""
import json
import os
from pathlib import Path
from typing import Any, AsyncGenerator
import pytest
from mcp import ClientSession
from nextcloud_mcp_server.providers.openai import OpenAIProvider
from tests.conftest import create_mcp_client_session
from tests.integration.sampling_support import create_openai_sampling_callback
# Skip all tests if OpenAI API key not configured
pytestmark = [
pytest.mark.integration,
pytest.mark.skipif(
not os.getenv("OPENAI_API_KEY"),
reason="OPENAI_API_KEY not set - skipping OpenAI RAG tests",
),
]
# Ground truth fixture path
FIXTURES_DIR = Path(__file__).parent / "fixtures"
GROUND_TRUTH_FILE = FIXTURES_DIR / "nextcloud_manual_ground_truth.json"
@pytest.fixture(scope="module")
def ground_truth_qa():
"""Load ground truth Q&A pairs for the Nextcloud manual."""
if not GROUND_TRUTH_FILE.exists():
pytest.skip(f"Ground truth file not found: {GROUND_TRUTH_FILE}")
with open(GROUND_TRUTH_FILE) as f:
return json.load(f)
@pytest.fixture(scope="module")
async def openai_provider():
"""OpenAI provider configured from environment (embeddings only)."""
api_key = os.getenv("OPENAI_API_KEY")
base_url = os.getenv("OPENAI_BASE_URL")
embedding_model = os.getenv("OPENAI_EMBEDDING_MODEL", "text-embedding-3-small")
provider = OpenAIProvider(
api_key=api_key,
base_url=base_url,
embedding_model=embedding_model,
generation_model=None, # Embeddings only
)
yield provider
await provider.close()
@pytest.fixture(scope="module")
async def openai_generation_provider():
"""OpenAI provider configured for text generation (for sampling callback)."""
api_key = os.getenv("OPENAI_API_KEY")
base_url = os.getenv("OPENAI_BASE_URL")
generation_model = os.getenv("OPENAI_GENERATION_MODEL", "gpt-4o-mini")
# For GitHub Models API, use the prefixed model name
if base_url and "models.github.ai" in base_url:
if not generation_model.startswith("openai/"):
generation_model = f"openai/{generation_model}"
provider = OpenAIProvider(
api_key=api_key,
base_url=base_url,
embedding_model=None, # Generation only
generation_model=generation_model,
)
yield provider
await provider.close()
@pytest.fixture(scope="module")
async def nc_mcp_client_with_sampling(
anyio_backend, openai_generation_provider
) -> AsyncGenerator[ClientSession, Any]:
"""MCP client with OpenAI-based sampling support.
This fixture creates an MCP client that can handle sampling requests
from the server using OpenAI for text generation.
"""
sampling_callback = create_openai_sampling_callback(openai_generation_provider)
async for session in create_mcp_client_session(
url="http://localhost:8000/mcp",
client_name="OpenAI Sampling MCP",
sampling_callback=sampling_callback,
):
yield session
async def test_openai_embeddings_work(openai_provider: OpenAIProvider):
"""Test that OpenAI embeddings can be generated."""
embedding = await openai_provider.embed("test query about Nextcloud")
assert isinstance(embedding, list)
assert len(embedding) > 0
assert all(isinstance(x, float) for x in embedding)
# OpenAI embedding dimensions: 1536 (small) or 3072 (large)
assert len(embedding) in [1536, 3072]
async def test_semantic_search_retrieval(nc_mcp_client, ground_truth_qa):
"""Test that semantic search retrieves relevant documents from the manual.
This tests the retrieval component of RAG - ensuring that queries
return relevant chunks from the indexed Nextcloud User Manual.
"""
# Use first query from ground truth
test_case = ground_truth_qa[0] # 2FA question
query = test_case["query"]
expected_topics = test_case["expected_topics"]
# Perform semantic search via MCP tool
result = await nc_mcp_client.call_tool(
"nc_semantic_search",
arguments={
"query": query,
"limit": 5,
"score_threshold": 0.0,
},
)
assert result.isError is False, f"Tool call failed: {result}"
data = result.structuredContent
# Verify we got results
assert data["success"] is True
assert data["total_found"] > 0, f"No results for query: {query}"
assert len(data["results"]) > 0
# Check that at least one result contains expected topic keywords
all_excerpts = " ".join([r["excerpt"].lower() for r in data["results"]])
topic_found = any(topic.lower() in all_excerpts for topic in expected_topics)
assert topic_found, (
f"Expected topics {expected_topics} not found in results for query: {query}"
)
async def test_semantic_search_answer_with_sampling(
nc_mcp_client_with_sampling, ground_truth_qa
):
"""Test semantic search with MCP sampling for answer generation.
This tests the full RAG pipeline:
1. Semantic search retrieves relevant documents
2. MCP sampling generates an answer from the retrieved context
3. OpenAI generates the answer via the sampling callback
Uses nc_mcp_client_with_sampling which has OpenAI-based sampling enabled.
"""
# Use the 2FA question - has clear expected answer
test_case = ground_truth_qa[0]
query = test_case["query"]
result = await nc_mcp_client_with_sampling.call_tool(
"nc_semantic_search_answer",
arguments={
"query": query,
"limit": 5,
"score_threshold": 0.0,
"max_answer_tokens": 300,
},
)
assert result.isError is False, f"Tool call failed: {result}"
data = result.structuredContent
# Verify response structure
assert data["success"] is True
assert "query" in data
assert "generated_answer" in data
assert "sources" in data
assert "search_method" in data
# Check for either successful sampling or graceful fallback
fallback_methods = {
"semantic_sampling_unsupported",
"semantic_sampling_user_declined",
"semantic_sampling_timeout",
"semantic_sampling_mcp_error",
"semantic_sampling_fallback",
}
if data["search_method"] in fallback_methods:
# Fallback mode - verify sources still returned
assert len(data["sources"]) > 0, "Expected sources even in fallback mode"
pytest.skip(
f"MCP sampling not available (method: {data['search_method']}), "
f"but retrieval succeeded with {len(data['sources'])} sources"
)
else:
# Successful sampling - verify answer quality
assert data["search_method"] == "semantic_sampling"
assert data["generated_answer"] is not None
assert len(data["generated_answer"]) > 50 # Non-trivial answer
# Check answer contains relevant content
answer_lower = data["generated_answer"].lower()
assert any(
keyword in answer_lower
for keyword in ["two-factor", "2fa", "authentication", "password"]
), f"Answer doesn't seem relevant to query: {data['generated_answer'][:200]}"
@pytest.mark.parametrize(
"qa_index,min_expected_results",
[
(0, 1), # 2FA question
(1, 1), # File quotas question
(2, 1), # Linux installation question
(3, 1), # Windows requirements question
(4, 1), # Client apps with 2FA question
],
)
async def test_retrieval_quality_all_queries(
nc_mcp_client, ground_truth_qa, qa_index, min_expected_results
):
"""Test retrieval quality for all ground truth queries.
Validates that each query returns at least the minimum expected
number of relevant results from the Nextcloud manual.
"""
if qa_index >= len(ground_truth_qa):
pytest.skip(f"Ground truth index {qa_index} not available")
test_case = ground_truth_qa[qa_index]
query = test_case["query"]
result = await nc_mcp_client.call_tool(
"nc_semantic_search",
arguments={
"query": query,
"limit": 5,
"score_threshold": 0.0,
},
)
assert result.isError is False
data = result.structuredContent
assert data["total_found"] >= min_expected_results, (
f"Query '{query}' returned {data['total_found']} results, "
f"expected at least {min_expected_results}"
)
async def test_no_results_for_unrelated_query(nc_mcp_client):
"""Test that completely unrelated queries return low/no scores.
The Nextcloud manual shouldn't have relevant content for
quantum physics queries.
"""
result = await nc_mcp_client.call_tool(
"nc_semantic_search",
arguments={
"query": "quantum entanglement hadron collider particle physics",
"limit": 5,
"score_threshold": 0.5, # Higher threshold to filter irrelevant
},
)
assert result.isError is False
data = result.structuredContent
# Should have few or no high-scoring results
# Low score threshold means we might get some results, but they should be low quality
if data["total_found"] > 0:
# If results exist, they should have low scores
max_score = max(r["score"] for r in data["results"])
assert max_score < 0.8, f"Unexpected high score {max_score} for unrelated query"
+26 -4
View File
@@ -3,8 +3,8 @@
DEPRECATED: This module is maintained for backward compatibility with RAG evaluation tests.
New code should use nextcloud_mcp_server.providers directly.
Supports Ollama (local), Anthropic (cloud), and Bedrock (AWS) providers for both ground truth
generation and evaluation.
Supports Ollama (local), Anthropic (cloud), Bedrock (AWS), and OpenAI (cloud) providers
for both ground truth generation and evaluation.
"""
import os
@@ -13,6 +13,7 @@ from nextcloud_mcp_server.providers import (
AnthropicProvider,
BedrockProvider,
OllamaProvider,
OpenAIProvider,
Provider,
)
@@ -25,11 +26,14 @@ def create_llm_provider(
anthropic_model: str | None = None,
bedrock_region: str | None = None,
bedrock_model: str | None = None,
openai_api_key: str | None = None,
openai_base_url: str | None = None,
openai_model: str | None = None,
) -> Provider:
"""Create an LLM provider from environment variables or arguments.
Args:
provider: Provider type ('ollama', 'anthropic', or 'bedrock').
provider: Provider type ('ollama', 'anthropic', 'bedrock', or 'openai').
Defaults to RAG_EVAL_PROVIDER env var or 'ollama'
ollama_base_url: Ollama base URL. Defaults to RAG_EVAL_OLLAMA_BASE_URL or 'http://localhost:11434'
ollama_model: Ollama model. Defaults to RAG_EVAL_OLLAMA_MODEL or 'llama3.2:1b'
@@ -38,6 +42,9 @@ def create_llm_provider(
bedrock_region: AWS region. Defaults to RAG_EVAL_BEDROCK_REGION or AWS_REGION env var
bedrock_model: Bedrock model ID. Defaults to RAG_EVAL_BEDROCK_MODEL or
'anthropic.claude-3-sonnet-20240229-v1:0'
openai_api_key: OpenAI API key. Defaults to OPENAI_API_KEY env var
openai_base_url: OpenAI base URL. Defaults to OPENAI_BASE_URL (for GitHub Models API)
openai_model: OpenAI model. Defaults to OPENAI_GENERATION_MODEL or 'gpt-4o-mini'
Returns:
Provider instance
@@ -83,7 +90,22 @@ def create_llm_provider(
region_name=region, embedding_model=None, generation_model=model
)
elif provider == "openai":
api_key = openai_api_key or os.environ.get("OPENAI_API_KEY")
if not api_key:
raise ValueError(
"OpenAI API key required. Set OPENAI_API_KEY environment variable."
)
base_url = openai_base_url or os.environ.get("OPENAI_BASE_URL")
model = openai_model or os.environ.get("OPENAI_GENERATION_MODEL", "gpt-4o-mini")
return OpenAIProvider(
api_key=api_key,
base_url=base_url,
embedding_model=None,
generation_model=model,
)
else:
raise ValueError(
f"Invalid provider: {provider}. Must be 'ollama', 'anthropic', or 'bedrock'."
f"Invalid provider: {provider}. Must be 'ollama', 'anthropic', 'bedrock', or 'openai'."
)
+292
View File
@@ -0,0 +1,292 @@
"""Unit tests for OpenAI provider."""
from unittest.mock import AsyncMock, MagicMock
import pytest
from nextcloud_mcp_server.providers.openai import (
OPENAI_EMBEDDING_DIMENSIONS,
OpenAIProvider,
)
@pytest.fixture
def mock_openai_client(mocker):
"""Mock OpenAI AsyncClient."""
mock_client = MagicMock()
mock_client.embeddings = MagicMock()
mock_client.chat = MagicMock()
mock_client.chat.completions = MagicMock()
mock_client.close = AsyncMock()
mocker.patch(
"nextcloud_mcp_server.providers.openai.AsyncOpenAI", return_value=mock_client
)
return mock_client
@pytest.mark.unit
async def test_openai_embedding(mock_openai_client):
"""Test OpenAI embedding with text-embedding-3-small."""
# Mock response
mock_embedding_data = MagicMock()
mock_embedding_data.embedding = [0.1, 0.2, 0.3]
mock_embedding_data.index = 0
mock_response = MagicMock()
mock_response.data = [mock_embedding_data]
mock_openai_client.embeddings.create = AsyncMock(return_value=mock_response)
# Create provider
provider = OpenAIProvider(
api_key="test-key",
embedding_model="text-embedding-3-small",
generation_model=None,
)
# Test embedding
embedding = await provider.embed("test text")
assert embedding == [0.1, 0.2, 0.3]
mock_openai_client.embeddings.create.assert_called_once_with(
input="test text",
model="text-embedding-3-small",
)
@pytest.mark.unit
async def test_openai_embedding_batch(mock_openai_client):
"""Test OpenAI batch embedding."""
# Mock response
mock_embedding_data_1 = MagicMock()
mock_embedding_data_1.embedding = [0.1, 0.2, 0.3]
mock_embedding_data_1.index = 0
mock_embedding_data_2 = MagicMock()
mock_embedding_data_2.embedding = [0.4, 0.5, 0.6]
mock_embedding_data_2.index = 1
mock_response = MagicMock()
mock_response.data = [mock_embedding_data_1, mock_embedding_data_2]
mock_openai_client.embeddings.create = AsyncMock(return_value=mock_response)
# Create provider
provider = OpenAIProvider(
api_key="test-key",
embedding_model="text-embedding-3-small",
generation_model=None,
)
# Test batch embedding
embeddings = await provider.embed_batch(["text1", "text2"])
assert len(embeddings) == 2
assert embeddings[0] == [0.1, 0.2, 0.3]
assert embeddings[1] == [0.4, 0.5, 0.6]
mock_openai_client.embeddings.create.assert_called_once_with(
input=["text1", "text2"],
model="text-embedding-3-small",
)
@pytest.mark.unit
async def test_openai_generation(mock_openai_client):
"""Test OpenAI text generation."""
# Mock response
mock_choice = MagicMock()
mock_choice.message.content = "Generated response"
mock_response = MagicMock()
mock_response.choices = [mock_choice]
mock_openai_client.chat.completions.create = AsyncMock(return_value=mock_response)
# Create provider
provider = OpenAIProvider(
api_key="test-key",
embedding_model=None,
generation_model="gpt-4o-mini",
)
# Test generation
text = await provider.generate("test prompt", max_tokens=100)
assert text == "Generated response"
mock_openai_client.chat.completions.create.assert_called_once_with(
model="gpt-4o-mini",
messages=[{"role": "user", "content": "test prompt"}],
max_tokens=100,
temperature=0.7,
)
@pytest.mark.unit
async def test_openai_both_capabilities(mock_openai_client):
"""Test OpenAI with both embedding and generation models."""
# Mock embedding response
mock_embedding_data = MagicMock()
mock_embedding_data.embedding = [0.1, 0.2]
mock_embedding_data.index = 0
mock_embed_response = MagicMock()
mock_embed_response.data = [mock_embedding_data]
mock_openai_client.embeddings.create = AsyncMock(return_value=mock_embed_response)
# Mock generation response
mock_choice = MagicMock()
mock_choice.message.content = "Response"
mock_gen_response = MagicMock()
mock_gen_response.choices = [mock_choice]
mock_openai_client.chat.completions.create = AsyncMock(
return_value=mock_gen_response
)
# Create provider with both models
provider = OpenAIProvider(
api_key="test-key",
embedding_model="text-embedding-3-small",
generation_model="gpt-4o-mini",
)
assert provider.supports_embeddings is True
assert provider.supports_generation is True
# Test both capabilities
embedding = await provider.embed("test")
assert embedding == [0.1, 0.2]
text = await provider.generate("test")
assert text == "Response"
@pytest.mark.unit
async def test_openai_no_embeddings():
"""Test OpenAI provider with no embedding model raises error."""
provider = OpenAIProvider(
api_key="test-key",
embedding_model=None,
generation_model="gpt-4o-mini",
)
assert provider.supports_embeddings is False
with pytest.raises(NotImplementedError, match="no embedding_model configured"):
await provider.embed("test")
with pytest.raises(NotImplementedError, match="no embedding_model configured"):
await provider.embed_batch(["test"])
with pytest.raises(NotImplementedError, match="no embedding_model configured"):
provider.get_dimension()
@pytest.mark.unit
async def test_openai_no_generation():
"""Test OpenAI provider with no generation model raises error."""
provider = OpenAIProvider(
api_key="test-key",
embedding_model="text-embedding-3-small",
generation_model=None,
)
assert provider.supports_generation is False
with pytest.raises(NotImplementedError, match="no generation_model configured"):
await provider.generate("test")
@pytest.mark.unit
async def test_openai_known_dimension():
"""Test dimension detection for known OpenAI models."""
provider = OpenAIProvider(
api_key="test-key",
embedding_model="text-embedding-3-small",
)
# Known model should have dimension set from lookup table
assert provider.get_dimension() == 1536
@pytest.mark.unit
async def test_openai_unknown_dimension_detected(mock_openai_client):
"""Test dimension detection for unknown model via API call."""
# Mock response with specific dimension
mock_embedding_data = MagicMock()
mock_embedding_data.embedding = [0.1] * 768
mock_embedding_data.index = 0
mock_response = MagicMock()
mock_response.data = [mock_embedding_data]
mock_openai_client.embeddings.create = AsyncMock(return_value=mock_response)
provider = OpenAIProvider(
api_key="test-key",
embedding_model="custom-embedding-model",
)
# Dimension not known yet for custom model
with pytest.raises(RuntimeError, match="not detected yet"):
provider.get_dimension()
# Detect dimension via embed call
await provider.embed("test")
# Now dimension should be available
assert provider.get_dimension() == 768
@pytest.mark.unit
async def test_openai_github_models_api(mock_openai_client):
"""Test OpenAI provider with GitHub Models API configuration."""
# Mock response
mock_embedding_data = MagicMock()
mock_embedding_data.embedding = [0.1, 0.2, 0.3]
mock_embedding_data.index = 0
mock_response = MagicMock()
mock_response.data = [mock_embedding_data]
mock_openai_client.embeddings.create = AsyncMock(return_value=mock_response)
# Create provider with GitHub Models configuration
provider = OpenAIProvider(
api_key="ghp_test_token",
base_url="https://models.github.ai/inference",
embedding_model="openai/text-embedding-3-small",
generation_model=None,
)
# Known dimension for GitHub Models prefixed model
assert (
provider.get_dimension()
== OPENAI_EMBEDDING_DIMENSIONS["openai/text-embedding-3-small"]
)
# Test embedding
embedding = await provider.embed("test text")
assert embedding == [0.1, 0.2, 0.3]
@pytest.mark.unit
async def test_openai_empty_batch():
"""Test OpenAI batch embedding with empty list."""
provider = OpenAIProvider(
api_key="test-key",
embedding_model="text-embedding-3-small",
)
embeddings = await provider.embed_batch([])
assert embeddings == []
@pytest.mark.unit
async def test_openai_close(mock_openai_client):
"""Test OpenAI client close."""
provider = OpenAIProvider(
api_key="test-key",
embedding_model="text-embedding-3-small",
)
await provider.close()
mock_openai_client.close.assert_called_once()
+86
View File
@@ -259,3 +259,89 @@ class TestChunkConfigValidation:
match="DOCUMENT_CHUNK_OVERLAP .* must be less than DOCUMENT_CHUNK_SIZE",
):
get_settings()
class TestEmbeddingModelName:
"""Test get_embedding_model_name() method."""
def test_openai_takes_priority(self):
"""Test that OpenAI model is returned when OPENAI_API_KEY is set."""
settings = Settings(
openai_api_key="test-key",
openai_embedding_model="text-embedding-3-large",
ollama_base_url="http://ollama:11434",
ollama_embedding_model="nomic-embed-text",
)
assert settings.get_embedding_model_name() == "text-embedding-3-large"
def test_ollama_used_when_no_openai(self):
"""Test that Ollama model is returned when no OpenAI configured."""
settings = Settings(
ollama_base_url="http://ollama:11434",
ollama_embedding_model="all-minilm",
)
assert settings.get_embedding_model_name() == "all-minilm"
def test_simple_fallback(self):
"""Test fallback to simple provider when nothing configured."""
settings = Settings()
assert settings.get_embedding_model_name() == "simple-384"
@patch.dict(
os.environ,
{
"OPENAI_API_KEY": "test-openai-key",
"OPENAI_EMBEDDING_MODEL": "openai/text-embedding-3-small",
},
clear=True,
)
def test_get_settings_openai_model(self):
"""Test get_settings() loads OpenAI embedding model."""
settings = get_settings()
assert settings.openai_api_key == "test-openai-key"
assert settings.openai_embedding_model == "openai/text-embedding-3-small"
assert settings.get_embedding_model_name() == "openai/text-embedding-3-small"
class TestCollectionNameWithProviders:
"""Test get_collection_name() with different providers."""
def test_collection_name_with_openai(self):
"""Test collection name uses OpenAI model when configured."""
settings = Settings(
openai_api_key="test-key",
openai_embedding_model="text-embedding-3-small",
otel_service_name="my-deployment",
)
assert settings.get_collection_name() == "my-deployment-text-embedding-3-small"
def test_collection_name_with_github_models(self):
"""Test collection name sanitizes GitHub Models prefix."""
settings = Settings(
openai_api_key="ghp_test",
openai_embedding_model="openai/text-embedding-3-small",
otel_service_name="my-deployment",
)
# Slashes should be replaced with dashes
assert (
settings.get_collection_name()
== "my-deployment-openai-text-embedding-3-small"
)
def test_collection_name_with_ollama(self):
"""Test collection name uses Ollama model when no OpenAI."""
settings = Settings(
ollama_base_url="http://ollama:11434",
ollama_embedding_model="nomic-embed-text",
otel_service_name="my-deployment",
)
assert settings.get_collection_name() == "my-deployment-nomic-embed-text"
def test_collection_name_explicit_override(self):
"""Test explicit QDRANT_COLLECTION overrides auto-generation."""
settings = Settings(
qdrant_collection="custom-collection",
openai_api_key="test-key",
openai_embedding_model="text-embedding-3-large",
)
assert settings.get_collection_name() == "custom-collection"
Generated
+21
View File
@@ -1951,6 +1951,7 @@ dependencies = [
{ name = "jinja2" },
{ name = "langchain-text-splitters" },
{ name = "mcp", extra = ["cli"] },
{ name = "openai" },
{ name = "opentelemetry-api" },
{ name = "opentelemetry-exporter-otlp-proto-grpc" },
{ name = "opentelemetry-instrumentation-asgi" },
@@ -1999,6 +2000,7 @@ requires-dist = [
{ name = "jinja2", specifier = ">=3.1.6" },
{ name = "langchain-text-splitters", specifier = ">=1.0.0" },
{ name = "mcp", extras = ["cli"], specifier = ">=1.22,<1.23" },
{ name = "openai", specifier = ">=2.8.1" },
{ name = "opentelemetry-api", specifier = ">=1.28.2" },
{ name = "opentelemetry-exporter-otlp-proto-grpc", specifier = ">=1.28.2" },
{ name = "opentelemetry-instrumentation-asgi", specifier = ">=0.49b2" },
@@ -2146,6 +2148,25 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/b6/ca/862b1e7a639460f0ca25fd5b6135fb42cf9deea86d398a92e44dfda2279d/onnxruntime-1.23.2-cp313-cp313t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:e2b9233c4947907fd1818d0e581c049c41ccc39b2856cc942ff6d26317cee145", size = 17394184, upload-time = "2025-10-22T03:47:08.127Z" },
]
[[package]]
name = "openai"
version = "2.8.1"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "anyio" },
{ name = "distro" },
{ name = "httpx" },
{ name = "jiter" },
{ name = "pydantic" },
{ name = "sniffio" },
{ name = "tqdm" },
{ name = "typing-extensions" },
]
sdist = { url = "https://files.pythonhosted.org/packages/d5/e4/42591e356f1d53c568418dc7e30dcda7be31dd5a4d570bca22acb0525862/openai-2.8.1.tar.gz", hash = "sha256:cb1b79eef6e809f6da326a7ef6038719e35aa944c42d081807bfa1be8060f15f", size = 602490, upload-time = "2025-11-17T22:39:59.549Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/55/4f/dbc0c124c40cb390508a82770fb9f6e3ed162560181a85089191a851c59a/openai-2.8.1-py3-none-any.whl", hash = "sha256:c6c3b5a04994734386e8dad3c00a393f56d3b68a27cd2e8acae91a59e4122463", size = 1022688, upload-time = "2025-11-17T22:39:57.675Z" },
]
[[package]]
name = "opentelemetry-api"
version = "1.38.0"