- docker-compose: swap postgres image to pgvector/pgvector:pg16
- postgres/init.sql: CREATE EXTENSION vector; image_embedding table
- plugins/embeddings.py: embed_images + search_images_semantic
(sentence-transformers all-MiniLM-L6-v2, lazy-loaded, pgvector <=> cosine)
- plugins/images.py: register_image kicks off background embed worker
- seed.py: seed_embeddings writes 4 embeddings for the mock images
- README: semantic image search section + T3 note
- 11 tests across 4 files, all green:
test_embeddings_plugin.py (4): schema, ordering, idempotency, stub
test_embeddings_real_model.py (3): real MiniLM, acceptance queries
test_register_image_hook.py (2): manifest row, end-to-end hook
test_seed_embeddings.py (2): writes 4, idempotent
- Includes T3 consistency plugin skeleton (4 stub tools)
227 lines
8.3 KiB
Python
227 lines
8.3 KiB
Python
"""
|
|
embeddings plugin — pgvector-backed semantic image search.
|
|
|
|
Replaces the substring-only `search_images_by_caption` for cases where the
|
|
caller doesn't know the exact wording. Captions are encoded with a local
|
|
sentence-transformer model (all-MiniLM-L6-v2, 384 dims) and stored in the
|
|
`image_embedding` table. Queries are encoded with the same model and
|
|
matched against the corpus via pgvector's cosine distance operator (`<=>`).
|
|
|
|
Design notes:
|
|
- Model is lazy-loaded on first use (gated by `_get_model()`) and cached in
|
|
a module-global so we don't pay the ~80MB download twice.
|
|
- The plugin is intentionally side-effect free on import — no model is
|
|
downloaded until something actually calls `_get_model()`. This keeps
|
|
gateway startup fast and testable.
|
|
- `embed_images()` is idempotent: it only embeds rows that don't already
|
|
have an entry in `image_embedding`. Safe to re-run after adding new
|
|
manifest rows.
|
|
"""
|
|
import logging
|
|
import os
|
|
from typing import List, Optional
|
|
|
|
from server import REGISTRY, get_postgres
|
|
|
|
LOG = logging.getLogger(__name__)
|
|
|
|
EMBEDDING_DIM = 384
|
|
DEFAULT_MODEL = "sentence-transformers/all-MiniLM-L6-v2"
|
|
|
|
# Cached model handle. None until first use. Not thread-safe to assign but
|
|
# the underlying SentenceTransformer is internally thread-safe.
|
|
_model = None
|
|
|
|
|
|
# ─── DB helpers ─────────────────────────────────────────────────────────────
|
|
|
|
def _q_pg(sql: str, params=None, fetch: bool = True, pg_url: Optional[str] = None):
|
|
"""Run a query against Postgres. If pg_url is provided, use it directly
|
|
(for tests / scripts that need an out-of-band connection). Otherwise
|
|
use the gateway's shared connection."""
|
|
if pg_url is not None:
|
|
import psycopg2
|
|
conn = psycopg2.connect(pg_url)
|
|
try:
|
|
with conn.cursor() as cur:
|
|
cur.execute(sql, params or ())
|
|
if fetch and cur.description:
|
|
cols = [d[0] for d in cur.description]
|
|
return [dict(zip(cols, r)) for r in cur.fetchall()]
|
|
return []
|
|
finally:
|
|
conn.close()
|
|
conn = get_postgres()
|
|
try:
|
|
with conn.cursor() as cur:
|
|
cur.execute(sql, params or ())
|
|
if fetch and cur.description:
|
|
cols = [d[0] for d in cur.description]
|
|
return [dict(zip(cols, r)) for r in cur.fetchall()]
|
|
return []
|
|
finally:
|
|
conn.close()
|
|
|
|
|
|
def _exec_pg(sql: str, params=None, pg_url: Optional[str] = None):
|
|
"""Execute a write — commits the transaction, returns rowcount."""
|
|
if pg_url is not None:
|
|
import psycopg2
|
|
conn = psycopg2.connect(pg_url)
|
|
try:
|
|
with conn.cursor() as cur:
|
|
cur.execute(sql, params or ())
|
|
conn.commit()
|
|
finally:
|
|
conn.close()
|
|
return
|
|
conn = get_postgres()
|
|
try:
|
|
with conn.cursor() as cur:
|
|
cur.execute(sql, params or ())
|
|
conn.commit()
|
|
finally:
|
|
conn.close()
|
|
|
|
|
|
# ─── Model loader ──────────────────────────────────────────────────────────
|
|
|
|
def _get_model():
|
|
"""Lazy-load the sentence-transformers model. Cached after first call."""
|
|
global _model
|
|
if _model is None:
|
|
from sentence_transformers import SentenceTransformer
|
|
model_name = os.environ.get("EMBED_MODEL", DEFAULT_MODEL)
|
|
LOG.info(f"loading embedding model: {model_name}")
|
|
_model = SentenceTransformer(model_name)
|
|
return _model
|
|
|
|
|
|
def _encode(texts: List[str]) -> List[List[float]]:
|
|
"""Encode a list of texts → list of 384-dim vectors.
|
|
|
|
Tolerant of both numpy arrays (real SentenceTransformer) and plain
|
|
Python lists (test stubs).
|
|
"""
|
|
model = _get_model()
|
|
vectors = model.encode(texts, convert_to_numpy=True, show_progress_bar=False)
|
|
out = []
|
|
for v in vectors:
|
|
# v may be a numpy array or a plain list depending on the encoder
|
|
try:
|
|
out.append(v.tolist())
|
|
except AttributeError:
|
|
out.append(list(v))
|
|
return out
|
|
|
|
|
|
# ─── Internal search helper (used by tests + the tool handler) ─────────────
|
|
|
|
def _search_by_vector(query_vec: List[float], limit: int = 5, pg_url: Optional[str] = None):
|
|
"""Run the cosine-distance top-k query. Returns list of dicts with
|
|
image_id, entity_id, entity_type, caption, tags, era, object_key, score."""
|
|
sql = """
|
|
SELECT
|
|
m.image_id, m.entity_id, m.entity_type, m.caption,
|
|
m.tags, m.era, m.object_key,
|
|
(e.embedding <=> %s::vector) AS distance
|
|
FROM image_embedding e
|
|
JOIN image_manifest m USING (image_id)
|
|
ORDER BY e.embedding <=> %s::vector
|
|
LIMIT %s
|
|
"""
|
|
vec_str = "[" + ",".join(f"{x:.6f}" for x in query_vec) + "]"
|
|
return _q_pg(sql, (vec_str, vec_str, limit), pg_url=pg_url)
|
|
|
|
|
|
# ─── MCP tools ──────────────────────────────────────────────────────────────
|
|
|
|
def _do_embed_images(limit: int = 100, pg_url: Optional[str] = None) -> int:
|
|
"""Internal: compute and store embeddings for images that don't have one.
|
|
Returns the count of new embeddings written. Idempotent."""
|
|
# 1. Find manifest rows that don't have an embedding yet.
|
|
rows = _q_pg("""
|
|
SELECT m.image_id, m.caption
|
|
FROM image_manifest m
|
|
LEFT JOIN image_embedding e ON e.image_id = m.image_id
|
|
WHERE e.image_id IS NULL
|
|
ORDER BY m.uploaded_at DESC
|
|
LIMIT %s
|
|
""", (limit,), pg_url=pg_url)
|
|
if not rows:
|
|
return 0
|
|
# 2. Compute embeddings in one batch.
|
|
captions = [r["caption"] for r in rows]
|
|
vectors = _encode(captions)
|
|
# 3. Insert. ON CONFLICT keeps the call idempotent under races.
|
|
for r, vec in zip(rows, vectors):
|
|
vec_str = "[" + ",".join(f"{x:.6f}" for x in vec) + "]"
|
|
_exec_pg("""
|
|
INSERT INTO image_embedding (image_id, embedding)
|
|
VALUES (%s, %s::vector)
|
|
ON CONFLICT (image_id) DO UPDATE
|
|
SET embedding = EXCLUDED.embedding,
|
|
embedded_at = now()
|
|
""", (r["image_id"], vec_str), pg_url=pg_url)
|
|
return len(rows)
|
|
|
|
|
|
def _do_search_semantic(q: str, limit: int = 5, pg_url: Optional[str] = None) -> dict:
|
|
"""Internal: encode a query and return top-k results."""
|
|
query_vec = _encode([q])[0]
|
|
rows = _search_by_vector(query_vec, limit=limit, pg_url=pg_url)
|
|
out = []
|
|
for r in rows:
|
|
out.append({
|
|
"image_id": r["image_id"],
|
|
"entity_id": r["entity_id"],
|
|
"entity_type": r["entity_type"],
|
|
"caption": r["caption"],
|
|
"tags": r["tags"],
|
|
"era": r["era"],
|
|
"distance": float(r["distance"]),
|
|
})
|
|
return {"q": q, "count": len(out), "images": out}
|
|
|
|
|
|
@REGISTRY.tool(
|
|
name="embed_images",
|
|
description="Compute and store embeddings for images that don't have one yet. Idempotent. Returns the number of new embeddings written.",
|
|
input_schema={
|
|
"type": "object",
|
|
"properties": {
|
|
"limit": {
|
|
"type": "integer",
|
|
"default": 100,
|
|
"description": "Maximum number of new embeddings to compute in this call.",
|
|
},
|
|
},
|
|
},
|
|
)
|
|
def embed_images(args):
|
|
limit = int(args.get("limit", 100))
|
|
n = _do_embed_images(limit=limit)
|
|
return {"embedded": n}
|
|
|
|
|
|
@REGISTRY.tool(
|
|
name="search_images_semantic",
|
|
description="Find images whose captions are semantically closest to the query. Use this when the caller describes the image in their own words and exact-keyword search would miss.",
|
|
input_schema={
|
|
"type": "object",
|
|
"properties": {
|
|
"q": {"type": "string", "description": "Natural-language query, e.g. 'a noble lord with a scar'"},
|
|
"limit": {"type": "integer", "default": 5},
|
|
},
|
|
"required": ["q"],
|
|
},
|
|
)
|
|
def search_images_semantic(args):
|
|
q = args["q"]
|
|
limit = int(args.get("limit", 5))
|
|
return _do_search_semantic(q=q, limit=limit)
|
|
|
|
|
|
def register(registry):
|
|
pass
|