3f06e2ee77
Fixed 8 type checker errors across the codebase:
- vector/scanner.py: Handle None scroll results with null-safe iteration
- search/{bm25_hybrid,semantic}.py: Add None checks for result.payload
- auth/{unified_verifier,webhook_routes}.py: Assert non-None auth credentials
- client/webdav.py: Add None checks before int() conversions
- providers/openai.py: Assert embedding_model is not None
- search/algorithms.py: Explicitly type doc_types set and cast values
- observability/logging_config.py: Match parent class signature (log_data)
Also fixed test_create_tag_creates_system_tag to match WebDAV implementation
(was testing OCS API endpoint, now tests correct WebDAV endpoint with
Content-Location header).
Type checker: 0 errors (down from 8), 20 warnings (ignored)
Tests: All 192 unit tests passing
🤖 Generated with [Claude Code](https://claude.com/claude-code)
Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
272 lines
8.6 KiB
Python
272 lines
8.6 KiB
Python
"""Unified OpenAI provider for embeddings and text generation.
|
|
|
|
Supports:
|
|
- OpenAI's standard API
|
|
- GitHub Models API (models.github.ai)
|
|
- Any OpenAI-compatible API via base_url override
|
|
"""
|
|
|
|
import logging
|
|
from functools import wraps
|
|
|
|
import anyio
|
|
from openai import AsyncOpenAI, RateLimitError
|
|
|
|
from .base import Provider
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Rate limit retry configuration
|
|
MAX_RETRIES = 5
|
|
INITIAL_RETRY_DELAY = 2.0 # seconds
|
|
MAX_RETRY_DELAY = 60.0 # seconds
|
|
|
|
|
|
def retry_on_rate_limit(func):
|
|
"""Decorator to retry on OpenAI rate limit errors with exponential backoff."""
|
|
|
|
@wraps(func)
|
|
async def wrapper(*args, **kwargs):
|
|
retry_delay = INITIAL_RETRY_DELAY
|
|
last_error: Exception | None = None
|
|
|
|
for attempt in range(1, MAX_RETRIES + 1):
|
|
try:
|
|
return await func(*args, **kwargs)
|
|
except RateLimitError as e:
|
|
last_error = e
|
|
if attempt < MAX_RETRIES:
|
|
logger.warning(
|
|
f"Rate limit hit (attempt {attempt}/{MAX_RETRIES}), "
|
|
f"retrying in {retry_delay:.1f}s..."
|
|
)
|
|
await anyio.sleep(retry_delay)
|
|
retry_delay = min(retry_delay * 2, MAX_RETRY_DELAY)
|
|
|
|
logger.error(f"Rate limit exceeded after {MAX_RETRIES} attempts")
|
|
raise last_error # type: ignore[misc]
|
|
|
|
return wrapper
|
|
|
|
|
|
# Well-known embedding dimensions for OpenAI models
|
|
OPENAI_EMBEDDING_DIMENSIONS: dict[str, int] = {
|
|
"text-embedding-3-small": 1536,
|
|
"text-embedding-3-large": 3072,
|
|
"text-embedding-ada-002": 1536,
|
|
# GitHub Models API uses openai/ prefix
|
|
"openai/text-embedding-3-small": 1536,
|
|
"openai/text-embedding-3-large": 3072,
|
|
}
|
|
|
|
|
|
class OpenAIProvider(Provider):
|
|
"""
|
|
OpenAI provider supporting both embeddings and text generation.
|
|
|
|
Works with:
|
|
- OpenAI's standard API (api.openai.com)
|
|
- GitHub Models API (models.github.ai)
|
|
- Any OpenAI-compatible API (via base_url)
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
api_key: str,
|
|
base_url: str | None = None,
|
|
embedding_model: str | None = None,
|
|
generation_model: str | None = None,
|
|
timeout: float = 120.0,
|
|
):
|
|
"""
|
|
Initialize OpenAI provider.
|
|
|
|
Args:
|
|
api_key: OpenAI API key (or GITHUB_TOKEN for GitHub Models)
|
|
base_url: Base URL override (e.g., "https://models.github.ai/inference")
|
|
embedding_model: Model for embeddings (e.g., "text-embedding-3-small").
|
|
None disables embeddings.
|
|
generation_model: Model for text generation (e.g., "gpt-4o-mini").
|
|
None disables generation.
|
|
timeout: HTTP timeout in seconds (default: 120)
|
|
"""
|
|
self.embedding_model = embedding_model
|
|
self.generation_model = generation_model
|
|
self._dimension: int | None = None
|
|
|
|
# Initialize async client
|
|
self.client = AsyncOpenAI(
|
|
api_key=api_key,
|
|
base_url=base_url,
|
|
timeout=timeout,
|
|
)
|
|
|
|
# Try to get known dimension without API call
|
|
if embedding_model and embedding_model in OPENAI_EMBEDDING_DIMENSIONS:
|
|
self._dimension = OPENAI_EMBEDDING_DIMENSIONS[embedding_model]
|
|
|
|
logger.info(
|
|
f"Initialized OpenAI provider: base_url={base_url or 'default'} "
|
|
f"(embedding_model={embedding_model}, generation_model={generation_model}, "
|
|
f"dimension={self._dimension})"
|
|
)
|
|
|
|
@property
|
|
def supports_embeddings(self) -> bool:
|
|
"""Whether this provider supports embedding generation."""
|
|
return self.embedding_model is not None
|
|
|
|
@property
|
|
def supports_generation(self) -> bool:
|
|
"""Whether this provider supports text generation."""
|
|
return self.generation_model is not None
|
|
|
|
@retry_on_rate_limit
|
|
async def embed(self, text: str) -> list[float]:
|
|
"""
|
|
Generate embedding vector for text.
|
|
|
|
Args:
|
|
text: Input text to embed
|
|
|
|
Returns:
|
|
Vector embedding as list of floats
|
|
|
|
Raises:
|
|
NotImplementedError: If embeddings not enabled (no embedding_model)
|
|
"""
|
|
if not self.supports_embeddings:
|
|
raise NotImplementedError(
|
|
"Embedding not supported - no embedding_model configured"
|
|
)
|
|
|
|
assert self.embedding_model is not None # Type narrowing
|
|
response = await self.client.embeddings.create(
|
|
input=text,
|
|
model=self.embedding_model,
|
|
)
|
|
|
|
embedding = response.data[0].embedding
|
|
|
|
# Update dimension if not set
|
|
if self._dimension is None:
|
|
self._dimension = len(embedding)
|
|
logger.info(
|
|
f"Detected embedding dimension: {self._dimension} "
|
|
f"for model {self.embedding_model}"
|
|
)
|
|
|
|
return embedding
|
|
|
|
async def embed_batch(self, texts: list[str]) -> list[list[float]]:
|
|
"""
|
|
Generate embeddings for multiple texts using OpenAI's batch API.
|
|
|
|
OpenAI supports up to 2048 inputs per request.
|
|
|
|
Args:
|
|
texts: List of texts to embed
|
|
|
|
Returns:
|
|
List of vector embeddings
|
|
|
|
Raises:
|
|
NotImplementedError: If embeddings not enabled (no embedding_model)
|
|
"""
|
|
if not self.supports_embeddings:
|
|
raise NotImplementedError(
|
|
"Embedding not supported - no embedding_model configured"
|
|
)
|
|
|
|
if not texts:
|
|
return []
|
|
|
|
# OpenAI supports batches up to 2048, but use smaller batches for safety
|
|
batch_size = 100
|
|
all_embeddings: list[list[float]] = []
|
|
|
|
for i in range(0, len(texts), batch_size):
|
|
batch = texts[i : i + batch_size]
|
|
|
|
# Use helper method with retry logic for each batch
|
|
batch_embeddings = await self._embed_batch_request(batch)
|
|
all_embeddings.extend(batch_embeddings)
|
|
|
|
# Update dimension if not set
|
|
if self._dimension is None and batch_embeddings:
|
|
self._dimension = len(batch_embeddings[0])
|
|
logger.info(
|
|
f"Detected embedding dimension: {self._dimension} "
|
|
f"for model {self.embedding_model}"
|
|
)
|
|
|
|
return all_embeddings
|
|
|
|
@retry_on_rate_limit
|
|
async def _embed_batch_request(self, batch: list[str]) -> list[list[float]]:
|
|
"""Make a single batch embedding request with retry logic."""
|
|
assert self.embedding_model is not None # Type narrowing
|
|
response = await self.client.embeddings.create(
|
|
input=batch,
|
|
model=self.embedding_model,
|
|
)
|
|
# Sort by index to maintain order
|
|
sorted_data = sorted(response.data, key=lambda x: x.index)
|
|
return [item.embedding for item in sorted_data]
|
|
|
|
def get_dimension(self) -> int:
|
|
"""
|
|
Get embedding dimension.
|
|
|
|
Returns:
|
|
Vector dimension for the configured embedding model
|
|
|
|
Raises:
|
|
NotImplementedError: If embeddings not enabled (no embedding_model)
|
|
RuntimeError: If dimension not detected yet (call embed first)
|
|
"""
|
|
if not self.supports_embeddings:
|
|
raise NotImplementedError(
|
|
"Embedding not supported - no embedding_model configured"
|
|
)
|
|
|
|
if self._dimension is None:
|
|
raise RuntimeError(
|
|
f"Embedding dimension not detected yet for model {self.embedding_model}. "
|
|
"Call embed() first or use a known model."
|
|
)
|
|
return self._dimension
|
|
|
|
@retry_on_rate_limit
|
|
async def generate(self, prompt: str, max_tokens: int = 500) -> str:
|
|
"""
|
|
Generate text from a prompt.
|
|
|
|
Args:
|
|
prompt: The prompt to generate from
|
|
max_tokens: Maximum tokens to generate
|
|
|
|
Returns:
|
|
Generated text
|
|
|
|
Raises:
|
|
NotImplementedError: If generation not enabled (no generation_model)
|
|
"""
|
|
if not self.supports_generation:
|
|
raise NotImplementedError(
|
|
"Text generation not supported - no generation_model configured"
|
|
)
|
|
|
|
response = await self.client.chat.completions.create(
|
|
model=self.generation_model,
|
|
messages=[{"role": "user", "content": prompt}],
|
|
max_tokens=max_tokens,
|
|
temperature=0.7,
|
|
)
|
|
|
|
return response.choices[0].message.content or ""
|
|
|
|
async def close(self) -> None:
|
|
"""Close HTTP client."""
|
|
await self.client.close()
|