feat(ollama): Pull model on startup if not available in ollama
This commit is contained in:
@@ -35,6 +35,8 @@ class OllamaEmbeddingProvider(EmbeddingProvider):
|
||||
f"Initialized Ollama provider: {base_url} (model={model}, verify_ssl={verify_ssl})"
|
||||
)
|
||||
|
||||
self._check_model_is_loaded(autoload=True)
|
||||
|
||||
async def embed(self, text: str) -> list[float]:
|
||||
"""
|
||||
Generate embedding vector for text.
|
||||
@@ -80,6 +82,23 @@ class OllamaEmbeddingProvider(EmbeddingProvider):
|
||||
"""
|
||||
return self._dimension
|
||||
|
||||
def _check_model_is_loaded(self, autoload: bool = True):
|
||||
response = httpx.get(f"{self.base_url}/api/tags")
|
||||
response.raise_for_status()
|
||||
|
||||
models = [model["name"] for model in response.json().get("models", [])]
|
||||
logger.info("Ollama has following models pre-loaded: %s", models)
|
||||
|
||||
if (self.model not in models) and autoload:
|
||||
logger.warning(
|
||||
"Embedding model '%s' not yet available in ollama, attempting to pull now...",
|
||||
self.model,
|
||||
)
|
||||
response = httpx.post(
|
||||
f"{self.base_url}/api/pull", json={"model": self.model}
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
async def close(self):
|
||||
"""Close HTTP client."""
|
||||
await self.client.aclose()
|
||||
|
||||
@@ -27,7 +27,6 @@ async def temp_storage():
|
||||
yield storage
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_webhook(temp_storage):
|
||||
"""Test storing a webhook."""
|
||||
await temp_storage.store_webhook(webhook_id=123, preset_id="notes_sync")
|
||||
@@ -39,7 +38,6 @@ async def test_store_webhook(temp_storage):
|
||||
assert "created_at" in webhooks[0]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_webhook_duplicate(temp_storage):
|
||||
"""Test storing duplicate webhook replaces existing."""
|
||||
await temp_storage.store_webhook(webhook_id=123, preset_id="notes_sync")
|
||||
@@ -51,7 +49,6 @@ async def test_store_webhook_duplicate(temp_storage):
|
||||
assert webhooks[0]["preset_id"] == "calendar_sync"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_webhooks_by_preset(temp_storage):
|
||||
"""Test retrieving webhooks by preset."""
|
||||
await temp_storage.store_webhook(webhook_id=123, preset_id="notes_sync")
|
||||
@@ -68,14 +65,12 @@ async def test_get_webhooks_by_preset(temp_storage):
|
||||
assert 789 in calendar_webhooks
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_webhooks_by_preset_empty(temp_storage):
|
||||
"""Test retrieving webhooks for non-existent preset."""
|
||||
webhooks = await temp_storage.get_webhooks_by_preset("nonexistent")
|
||||
assert len(webhooks) == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_webhook(temp_storage):
|
||||
"""Test deleting a webhook."""
|
||||
await temp_storage.store_webhook(webhook_id=123, preset_id="notes_sync")
|
||||
@@ -89,14 +84,12 @@ async def test_delete_webhook(temp_storage):
|
||||
assert 456 in webhooks
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_webhook_nonexistent(temp_storage):
|
||||
"""Test deleting non-existent webhook."""
|
||||
deleted = await temp_storage.delete_webhook(webhook_id=999)
|
||||
assert deleted is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_all_webhooks(temp_storage):
|
||||
"""Test listing all webhooks."""
|
||||
await temp_storage.store_webhook(webhook_id=123, preset_id="notes_sync")
|
||||
@@ -119,14 +112,12 @@ async def test_list_all_webhooks(temp_storage):
|
||||
assert 789 in webhook_ids
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_all_webhooks_empty(temp_storage):
|
||||
"""Test listing webhooks when none exist."""
|
||||
webhooks = await temp_storage.list_all_webhooks()
|
||||
assert len(webhooks) == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_clear_preset_webhooks(temp_storage):
|
||||
"""Test clearing all webhooks for a preset."""
|
||||
await temp_storage.store_webhook(webhook_id=123, preset_id="notes_sync")
|
||||
@@ -146,14 +137,12 @@ async def test_clear_preset_webhooks(temp_storage):
|
||||
assert 789 in calendar_webhooks
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_clear_preset_webhooks_nonexistent(temp_storage):
|
||||
"""Test clearing webhooks for non-existent preset."""
|
||||
deleted_count = await temp_storage.clear_preset_webhooks("nonexistent")
|
||||
assert deleted_count == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_webhook_timestamps(temp_storage):
|
||||
"""Test that webhook timestamps are properly stored."""
|
||||
start_time = time.time()
|
||||
@@ -167,7 +156,6 @@ async def test_webhook_timestamps(temp_storage):
|
||||
assert start_time <= created_at <= end_time
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_storage_without_encryption_key():
|
||||
"""Test that storage can be initialized without encryption key."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
@@ -182,7 +170,6 @@ async def test_storage_without_encryption_key():
|
||||
assert 123 in webhooks
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multiple_presets_independence(temp_storage):
|
||||
"""Test that different presets maintain independent webhook lists."""
|
||||
presets = ["notes_sync", "calendar_sync", "deck_sync", "files_sync"]
|
||||
|
||||
Reference in New Issue
Block a user