slice 3.0: LLM provider abstraction (LLMProvider protocol + FakeProvider + OllamaCloudProvider; 16/16 tests)
This commit is contained in:
176
lore_engine_poc/llm.py
Normal file
176
lore_engine_poc/llm.py
Normal file
@@ -0,0 +1,176 @@
|
||||
"""Lore Engine POC — LLM provider abstraction (slice 3).
|
||||
|
||||
A thin wrapper around the LLM call surface we need for
|
||||
extraction. We define one ``LLMProvider`` Protocol with a
|
||||
single method (``chat``) and ship two implementations:
|
||||
|
||||
* :class:`FakeProvider` — canned responses for tests. The
|
||||
test code scripts ``(messages, response)`` pairs; the
|
||||
provider matches the incoming messages to the script and
|
||||
returns the canned response.
|
||||
* :class:`OllamaCloudProvider` — the real provider. Talks
|
||||
to ``https://ollama.com/api/chat`` over bearer-token auth
|
||||
using the ``urllib.request`` stdlib module (no new pip
|
||||
dependencies).
|
||||
|
||||
Why stdlib only and not LiteLLM:
|
||||
|
||||
* LiteLLM is great when you have many providers. We have
|
||||
one — Ollama Cloud — and one method to call.
|
||||
* The auto-classifier blocked earlier pip installs of
|
||||
agent-chosen packages (see slice 2.6+). ``urllib`` is
|
||||
already in the standard library.
|
||||
* The protocol stays uniform if we ever add a second
|
||||
provider (LiteLLM, anthropic, local vLLM): implement
|
||||
``chat(messages) -> str`` and slot in.
|
||||
|
||||
The provider is intentionally **stateless**: one call →
|
||||
one response. Stateful concerns (sessions, conversation
|
||||
history, retries) live in the caller — the extractor
|
||||
passes a single-message prompt and parses the single
|
||||
string response.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
import urllib.request
|
||||
from typing import Any, Callable, Optional, Protocol, runtime_checkable
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Protocol — the duck-typed contract
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class LLMProvider(Protocol):
|
||||
"""A single-call LLM provider.
|
||||
|
||||
``chat`` takes an OpenAI-style ``messages`` list and
|
||||
returns the assistant message content as a string. The
|
||||
caller is responsible for parsing; the provider does
|
||||
not interpret the response.
|
||||
"""
|
||||
|
||||
def chat(self, messages: list[dict], **opts: Any) -> str:
|
||||
...
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# FakeProvider — for tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class FakeProvider:
|
||||
"""Canned-response provider for tests.
|
||||
|
||||
The constructor takes a ``script``: a list of
|
||||
``(messages_match, response)`` pairs. When ``chat`` is
|
||||
called, the provider matches the incoming ``messages``
|
||||
to each ``messages_match`` (exact equality on the
|
||||
messages list — i.e. the whole list, not a substring)
|
||||
and returns the corresponding ``response``.
|
||||
|
||||
If no script entry matches, ``chat`` raises
|
||||
``AssertionError`` so test drift is loud. Every call
|
||||
is recorded on ``self.calls`` for assertions.
|
||||
|
||||
The ``script_messages`` and ``script_responses``
|
||||
attributes are alternative constructors for cases where
|
||||
a test wants to feed many calls with one response each.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
script: Optional[list[tuple[list[dict], str]]] = None,
|
||||
):
|
||||
self.script: list[tuple[list[dict], str]] = list(script or [])
|
||||
self.calls: list[list[dict]] = []
|
||||
|
||||
def chat(self, messages: list[dict], **opts: Any) -> str:
|
||||
self.calls.append(messages)
|
||||
for match, response in self.script:
|
||||
if match == messages:
|
||||
return response
|
||||
raise AssertionError(
|
||||
f"FakeProvider: no scripted response for messages={messages!r}"
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# OllamaCloudProvider — real
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class OllamaCloudProvider:
|
||||
"""Provider for Ollama Cloud (``https://ollama.com``).
|
||||
|
||||
Auth is a bearer token in the ``Authorization`` header
|
||||
(the ``$OLLAMA_API_KEY`` env var). The model defaults
|
||||
to ``minimax-m3:cloud`` (the user's chosen slug) but
|
||||
can be overridden via the ``$LORE_LLM_MODEL`` env var
|
||||
or the constructor's ``model`` kwarg.
|
||||
|
||||
The provider is *fail-loud*: any HTTP error, timeout, or
|
||||
non-JSON response bubbles up. The extractor (the only
|
||||
caller) catches and degrades to an empty result so the
|
||||
graph still loads. See :mod:`lore_engine_poc.extraction`.
|
||||
"""
|
||||
|
||||
ENDPOINT = "https://ollama.com/api/chat"
|
||||
DEFAULT_MODEL = "minimax-m3:cloud"
|
||||
DEFAULT_TIMEOUT = 60.0
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: Optional[str] = None,
|
||||
model: Optional[str] = None,
|
||||
timeout: Optional[float] = None,
|
||||
):
|
||||
self.api_key = (
|
||||
api_key
|
||||
if api_key is not None
|
||||
else os.environ.get("OLLAMA_API_KEY")
|
||||
)
|
||||
if not self.api_key:
|
||||
raise RuntimeError(
|
||||
"OllamaCloudProvider: $OLLAMA_API_KEY is not set. "
|
||||
"Either export the env var or pass api_key= explicitly."
|
||||
)
|
||||
self.model = (
|
||||
model
|
||||
if model is not None
|
||||
else os.environ.get("LORE_LLM_MODEL", self.DEFAULT_MODEL)
|
||||
)
|
||||
if timeout is not None:
|
||||
self.timeout = float(timeout)
|
||||
else:
|
||||
try:
|
||||
self.timeout = float(
|
||||
os.environ.get("LORE_LLM_TIMEOUT", self.DEFAULT_TIMEOUT)
|
||||
)
|
||||
except ValueError:
|
||||
self.timeout = self.DEFAULT_TIMEOUT
|
||||
|
||||
def chat(self, messages: list[dict], **opts: Any) -> str:
|
||||
body = json.dumps({
|
||||
"model": self.model,
|
||||
"messages": messages,
|
||||
"stream": False,
|
||||
}).encode()
|
||||
req = urllib.request.Request(
|
||||
self.ENDPOINT,
|
||||
data=body,
|
||||
headers={
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
)
|
||||
with urllib.request.urlopen(req, timeout=self.timeout) as resp:
|
||||
data = json.loads(resp.read())
|
||||
return data["message"]["content"]
|
||||
|
||||
|
||||
__all__ = ["LLMProvider", "FakeProvider", "OllamaCloudProvider"]
|
||||
0
tests/test_extraction/__init__.py
Normal file
0
tests/test_extraction/__init__.py
Normal file
202
tests/test_extraction/test_llm.py
Normal file
202
tests/test_extraction/test_llm.py
Normal file
@@ -0,0 +1,202 @@
|
||||
"""Tests for the LLM provider abstraction (slice 3.0).
|
||||
|
||||
Exercises ``LLMProvider`` as a Protocol (duck-typed), the
|
||||
``FakeProvider`` used by tests, and the ``OllamaCloudProvider``
|
||||
that talks to https://ollama.com/api/chat.
|
||||
|
||||
Tests for the real provider stub out ``urllib.request.urlopen``
|
||||
so no network is hit even when ``OLLAMA_API_KEY`` is set.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
|
||||
from lore_engine_poc.llm import (
|
||||
FakeProvider,
|
||||
OllamaCloudProvider,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# FakeProvider
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_fake_provider_returns_scripted_response():
|
||||
"""A scripted (messages, response) pair must round-trip."""
|
||||
p = FakeProvider(script=[
|
||||
([{"role": "user", "content": "hi"}], '[["a","b","c"]]'),
|
||||
])
|
||||
out = p.chat(messages=[{"role": "user", "content": "hi"}])
|
||||
assert out == '[["a","b","c"]]'
|
||||
assert len(p.calls) == 1
|
||||
|
||||
|
||||
def test_fake_provider_records_every_call():
|
||||
"""Useful for asserting prompt shape after extraction."""
|
||||
p = FakeProvider(script=[
|
||||
([{"role": "user", "content": "x"}], "ok"),
|
||||
([{"role": "user", "content": "y"}], "ok2"),
|
||||
])
|
||||
p.chat(messages=[{"role": "user", "content": "x"}])
|
||||
p.chat(messages=[{"role": "user", "content": "y"}])
|
||||
assert [c[0]["content"] for c in p.calls] == ["x", "y"]
|
||||
|
||||
|
||||
def test_fake_provider_raises_when_no_script_match():
|
||||
"""If the prompt isn't scripted, fail loud rather than
|
||||
silently returning a stale response."""
|
||||
p = FakeProvider(script=[
|
||||
([{"role": "user", "content": "x"}], "ok"),
|
||||
])
|
||||
with pytest.raises(AssertionError, match="no scripted response"):
|
||||
p.chat(messages=[{"role": "user", "content": "unexpected"}])
|
||||
|
||||
|
||||
def test_fake_provider_empty_script_always_raises():
|
||||
p = FakeProvider(script=[])
|
||||
with pytest.raises(AssertionError):
|
||||
p.chat(messages=[{"role": "user", "content": "x"}])
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# OllamaCloudProvider — construction
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_ollama_provider_raises_when_no_api_key(monkeypatch):
|
||||
"""No ``$OLLAMA_API_KEY`` and no explicit key → ``RuntimeError``."""
|
||||
monkeypatch.delenv("OLLAMA_API_KEY", raising=False)
|
||||
with pytest.raises(RuntimeError, match="OLLAMA_API_KEY"):
|
||||
OllamaCloudProvider()
|
||||
|
||||
|
||||
def test_ollama_provider_uses_explicit_api_key(monkeypatch):
|
||||
"""An explicit ``api_key`` arg overrides the env var."""
|
||||
monkeypatch.delenv("OLLAMA_API_KEY", raising=False)
|
||||
p = OllamaCloudProvider(api_key="explicit-test-key")
|
||||
assert p.api_key == "explicit-test-key"
|
||||
|
||||
|
||||
def test_ollama_provider_uses_env_var_when_no_explicit(monkeypatch):
|
||||
monkeypatch.setenv("OLLAMA_API_KEY", "env-test-key")
|
||||
p = OllamaCloudProvider()
|
||||
assert p.api_key == "env-test-key"
|
||||
|
||||
|
||||
def test_ollama_provider_default_model_is_minimax_m3_cloud(monkeypatch):
|
||||
"""The default model slug matches what ``ollama list`` showed."""
|
||||
monkeypatch.setenv("OLLAMA_API_KEY", "k")
|
||||
monkeypatch.delenv("LORE_LLM_MODEL", raising=False)
|
||||
p = OllamaCloudProvider()
|
||||
assert p.model == "minimax-m3:cloud"
|
||||
|
||||
|
||||
def test_ollama_provider_honors_lore_llm_model_env(monkeypatch):
|
||||
monkeypatch.setenv("OLLAMA_API_KEY", "k")
|
||||
monkeypatch.setenv("LORE_LLM_MODEL", "other-model:7b")
|
||||
p = OllamaCloudProvider()
|
||||
assert p.model == "other-model:7b"
|
||||
|
||||
|
||||
def test_ollama_provider_honors_lore_llm_timeout_env(monkeypatch):
|
||||
monkeypatch.setenv("OLLAMA_API_KEY", "k")
|
||||
monkeypatch.setenv("LORE_LLM_TIMEOUT", "5")
|
||||
p = OllamaCloudProvider()
|
||||
assert p.timeout == 5.0
|
||||
|
||||
|
||||
def test_ollama_provider_default_timeout_is_60(monkeypatch):
|
||||
monkeypatch.setenv("OLLAMA_API_KEY", "k")
|
||||
monkeypatch.delenv("LORE_LLM_TIMEOUT", raising=False)
|
||||
p = OllamaCloudProvider()
|
||||
assert p.timeout == 60.0
|
||||
|
||||
|
||||
def test_ollama_provider_endpoint_is_ollama_cloud(monkeypatch):
|
||||
monkeypatch.setenv("OLLAMA_API_KEY", "k")
|
||||
p = OllamaCloudProvider()
|
||||
assert p.ENDPOINT == "https://ollama.com/api/chat"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# OllamaCloudProvider.chat — wire format (urllib stubbed)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_response(payload: dict):
|
||||
"""Build a fake ``HTTPResponse``-like object for urlopen."""
|
||||
body = json.dumps(payload).encode()
|
||||
resp = mock.MagicMock()
|
||||
resp.read.return_value = body
|
||||
resp.__enter__ = lambda self: self
|
||||
resp.__exit__ = lambda self, *a: None
|
||||
return resp
|
||||
|
||||
|
||||
def test_ollama_provider_posts_correct_body_and_headers(monkeypatch):
|
||||
"""Wire format: bearer auth, model, messages, stream:false."""
|
||||
monkeypatch.setenv("OLLAMA_API_KEY", "k-test-1234")
|
||||
p = OllamaCloudProvider()
|
||||
|
||||
fake_resp = _make_response({
|
||||
"model": "minimax-m3:cloud",
|
||||
"message": {"role": "assistant", "content": "hello"},
|
||||
"done": True,
|
||||
})
|
||||
|
||||
with mock.patch("lore_engine_poc.llm.urllib.request.urlopen", return_value=fake_resp) as urlopen:
|
||||
out = p.chat(messages=[{"role": "user", "content": "ping"}])
|
||||
|
||||
# urlopen called once
|
||||
assert urlopen.call_count == 1
|
||||
# Inspect the request
|
||||
req = urlopen.call_args.args[0]
|
||||
assert req.full_url == "https://ollama.com/api/chat"
|
||||
assert req.headers["Authorization"] == "Bearer k-test-1234"
|
||||
assert req.headers["Content-type"] == "application/json"
|
||||
body = json.loads(req.data)
|
||||
assert body["model"] == "minimax-m3:cloud"
|
||||
assert body["messages"] == [{"role": "user", "content": "ping"}]
|
||||
assert body["stream"] is False
|
||||
# Returned content is the assistant message
|
||||
assert out == "hello"
|
||||
|
||||
|
||||
def test_ollama_provider_uses_override_model_in_request(monkeypatch):
|
||||
monkeypatch.setenv("OLLAMA_API_KEY", "k")
|
||||
p = OllamaCloudProvider(model="custom-model:7b")
|
||||
fake_resp = _make_response({
|
||||
"message": {"role": "assistant", "content": "ok"},
|
||||
"done": True,
|
||||
})
|
||||
with mock.patch("lore_engine_poc.llm.urllib.request.urlopen", return_value=fake_resp) as urlopen:
|
||||
p.chat(messages=[{"role": "user", "content": "x"}])
|
||||
body = json.loads(urlopen.call_args.args[0].data)
|
||||
assert body["model"] == "custom-model:7b"
|
||||
|
||||
|
||||
def test_ollama_provider_propagates_timeout(monkeypatch):
|
||||
"""``timeout=`` kwarg flows through to urlopen."""
|
||||
monkeypatch.setenv("OLLAMA_API_KEY", "k")
|
||||
p = OllamaCloudProvider(timeout=12.5)
|
||||
fake_resp = _make_response({"message": {"role": "assistant", "content": "ok"}, "done": True})
|
||||
with mock.patch("lore_engine_poc.llm.urllib.request.urlopen", return_value=fake_resp) as urlopen:
|
||||
p.chat(messages=[{"role": "user", "content": "x"}])
|
||||
assert urlopen.call_args.kwargs.get("timeout") == 12.5
|
||||
|
||||
|
||||
def test_ollama_provider_network_error_propagates(monkeypatch):
|
||||
"""A urllib error is not swallowed — the caller decides."""
|
||||
monkeypatch.setenv("OLLAMA_API_KEY", "k")
|
||||
p = OllamaCloudProvider()
|
||||
with mock.patch(
|
||||
"lore_engine_poc.llm.urllib.request.urlopen",
|
||||
side_effect=OSError("connection refused"),
|
||||
):
|
||||
with pytest.raises(OSError, match="connection refused"):
|
||||
p.chat(messages=[{"role": "user", "content": "x"}])
|
||||
Reference in New Issue
Block a user