feat: Use Ollama native batch API in embed_batch()

- Switch from sequential loop to /api/embed batch endpoint
- Use 'input' array parameter instead of individual 'prompt' requests
- Process in chunks of 32 to avoid quality degradation (issue #6262)
- Reduces HTTP overhead: 128 texts = 4 requests instead of 128
- Maintains backward compatibility with embed() for single embeddings

Ref: ollama/ollama#6262
This commit is contained in:
Chris Coutinho
2025-11-20 16:50:13 +01:00
parent ec2c274cd9
commit 25ef33de7f
+21 -8
View File
@@ -92,14 +92,21 @@ class OllamaProvider(Provider):
response.raise_for_status() response.raise_for_status()
return response.json()["embedding"] return response.json()["embedding"]
async def embed_batch(self, texts: list[str]) -> list[list[float]]: async def embed_batch(
self, texts: list[str], batch_size: int = 32
) -> list[list[float]]:
""" """
Generate embeddings for multiple texts (batched requests). Generate embeddings for multiple texts using Ollama's batch API.
Note: Ollama doesn't have native batch API, so we send requests sequentially. Uses /api/embed endpoint with array input for efficient batch processing.
Conservative batch size (32) prevents quality degradation observed in
Ollama issue #6262 with larger batches.
Note: Ollama processes batches serially, not in parallel.
Args: Args:
texts: List of texts to embed texts: List of texts to embed
batch_size: Maximum texts per batch (default: 32)
Returns: Returns:
List of vector embeddings List of vector embeddings
@@ -112,11 +119,17 @@ class OllamaProvider(Provider):
"Embedding not supported - no embedding_model configured" "Embedding not supported - no embedding_model configured"
) )
embeddings = [] all_embeddings = []
for text in texts: for i in range(0, len(texts), batch_size):
embedding = await self.embed(text) batch = texts[i : i + batch_size]
embeddings.append(embedding) response = await self.client.post(
return embeddings f"{self.base_url}/api/embed",
json={"model": self.embedding_model, "input": batch},
)
response.raise_for_status()
all_embeddings.extend(response.json()["embeddings"])
return all_embeddings
async def _detect_dimension(self): async def _detect_dimension(self):
""" """