feat(ollama): Pull model on startup if not available in ollama

This commit is contained in:
Chris Coutinho
2025-11-12 00:37:26 +01:00
parent 0eae33a918
commit 7e93097137
2 changed files with 19 additions and 13 deletions
@@ -35,6 +35,8 @@ class OllamaEmbeddingProvider(EmbeddingProvider):
f"Initialized Ollama provider: {base_url} (model={model}, verify_ssl={verify_ssl})"
)
self._check_model_is_loaded(autoload=True)
async def embed(self, text: str) -> list[float]:
"""
Generate embedding vector for text.
@@ -80,6 +82,23 @@ class OllamaEmbeddingProvider(EmbeddingProvider):
"""
return self._dimension
def _check_model_is_loaded(self, autoload: bool = True):
response = httpx.get(f"{self.base_url}/api/tags")
response.raise_for_status()
models = [model["name"] for model in response.json().get("models", [])]
logger.info("Ollama has following models pre-loaded: %s", models)
if (self.model not in models) and autoload:
logger.warning(
"Embedding model '%s' not yet available in ollama, attempting to pull now...",
self.model,
)
response = httpx.post(
f"{self.base_url}/api/pull", json={"model": self.model}
)
response.raise_for_status()
async def close(self):
"""Close HTTP client."""
await self.client.aclose()
-13
View File
@@ -27,7 +27,6 @@ async def temp_storage():
yield storage
@pytest.mark.asyncio
async def test_store_webhook(temp_storage):
"""Test storing a webhook."""
await temp_storage.store_webhook(webhook_id=123, preset_id="notes_sync")
@@ -39,7 +38,6 @@ async def test_store_webhook(temp_storage):
assert "created_at" in webhooks[0]
@pytest.mark.asyncio
async def test_store_webhook_duplicate(temp_storage):
"""Test storing duplicate webhook replaces existing."""
await temp_storage.store_webhook(webhook_id=123, preset_id="notes_sync")
@@ -51,7 +49,6 @@ async def test_store_webhook_duplicate(temp_storage):
assert webhooks[0]["preset_id"] == "calendar_sync"
@pytest.mark.asyncio
async def test_get_webhooks_by_preset(temp_storage):
"""Test retrieving webhooks by preset."""
await temp_storage.store_webhook(webhook_id=123, preset_id="notes_sync")
@@ -68,14 +65,12 @@ async def test_get_webhooks_by_preset(temp_storage):
assert 789 in calendar_webhooks
@pytest.mark.asyncio
async def test_get_webhooks_by_preset_empty(temp_storage):
"""Test retrieving webhooks for non-existent preset."""
webhooks = await temp_storage.get_webhooks_by_preset("nonexistent")
assert len(webhooks) == 0
@pytest.mark.asyncio
async def test_delete_webhook(temp_storage):
"""Test deleting a webhook."""
await temp_storage.store_webhook(webhook_id=123, preset_id="notes_sync")
@@ -89,14 +84,12 @@ async def test_delete_webhook(temp_storage):
assert 456 in webhooks
@pytest.mark.asyncio
async def test_delete_webhook_nonexistent(temp_storage):
"""Test deleting non-existent webhook."""
deleted = await temp_storage.delete_webhook(webhook_id=999)
assert deleted is False
@pytest.mark.asyncio
async def test_list_all_webhooks(temp_storage):
"""Test listing all webhooks."""
await temp_storage.store_webhook(webhook_id=123, preset_id="notes_sync")
@@ -119,14 +112,12 @@ async def test_list_all_webhooks(temp_storage):
assert 789 in webhook_ids
@pytest.mark.asyncio
async def test_list_all_webhooks_empty(temp_storage):
"""Test listing webhooks when none exist."""
webhooks = await temp_storage.list_all_webhooks()
assert len(webhooks) == 0
@pytest.mark.asyncio
async def test_clear_preset_webhooks(temp_storage):
"""Test clearing all webhooks for a preset."""
await temp_storage.store_webhook(webhook_id=123, preset_id="notes_sync")
@@ -146,14 +137,12 @@ async def test_clear_preset_webhooks(temp_storage):
assert 789 in calendar_webhooks
@pytest.mark.asyncio
async def test_clear_preset_webhooks_nonexistent(temp_storage):
"""Test clearing webhooks for non-existent preset."""
deleted_count = await temp_storage.clear_preset_webhooks("nonexistent")
assert deleted_count == 0
@pytest.mark.asyncio
async def test_webhook_timestamps(temp_storage):
"""Test that webhook timestamps are properly stored."""
start_time = time.time()
@@ -167,7 +156,6 @@ async def test_webhook_timestamps(temp_storage):
assert start_time <= created_at <= end_time
@pytest.mark.asyncio
async def test_storage_without_encryption_key():
"""Test that storage can be initialized without encryption key."""
with tempfile.TemporaryDirectory() as tmpdir:
@@ -182,7 +170,6 @@ async def test_storage_without_encryption_key():
assert 123 in webhooks
@pytest.mark.asyncio
async def test_multiple_presets_independence(temp_storage):
"""Test that different presets maintain independent webhook lists."""
presets = ["notes_sync", "calendar_sync", "deck_sync", "files_sync"]