Graph Database (Apache AGE)
The graph module provides async access to Apache AGE, a PostgreSQL extension
for graph database queries using Cypher.
Overview
The Graph class wraps a psycopg async connection pool and exposes
high-level CRUD operations (create, delete, match, merge) that
accept Pydantic models. The cypher module generates Cypher query
templates from those models.
Basic Usage
from imbi_common import graph, models
# Open a graph connection pool
db = graph.Graph()
await db.open()
# Create a node
org = models.Organization(name="My Org", slug="my-org")
await db.create(org)
# Match nodes
orgs = await db.match(models.Organization, {"slug": "my-org"})
# Match all nodes of a type, ordered
teams = await db.match(models.Team, order_by="name")
# Upsert a node
await db.merge(org, match_on=["slug"])
# Delete a node
await db.delete(org)
# Close the connection pool
await db.close()
FastAPI Dependency Injection
Wire graph_lifespan into the application lifespan, then declare Pool
as a route parameter to receive the injected Graph instance:
import fastapi
from imbi_common import lifespan, models
from imbi_common.graph import Pool, graph_lifespan
app = fastapi.FastAPI(
lifespan=lifespan.Lifespan(graph_lifespan),
)
@app.get('/orgs/{slug}')
async def get_org(slug: str, db: Pool) -> models.Organization:
results = await db.match(models.Organization, {"slug": slug})
return results[0]
To run custom initialisation after the pool opens (e.g. schema setup),
register a startup callback before creating the app:
from imbi_common import graph
async def on_graph_ready(db: graph.Graph) -> None:
await graph.initialize()
graph.set_on_startup(on_graph_ready)
API Reference
Graph Client
Graph
Wrapper around the PostgreSQL connection pool.
Supports both Apache AGE Cypher queries and pgvector
similarity search against the embeddings table.
Source code in src/imbi_common/graph/client.py
| def __init__(self) -> None:
self.opened = False
self.settings = settings.Postgres()
self.pool = psycopg_pool.AsyncConnectionPool(
conninfo=str(self.settings.url),
min_size=self.settings.min_pool_size,
max_size=self.settings.max_pool_size,
configure=self._configure_connection,
open=False,
)
|
close
async
Close the connection pool and release models.
Source code in src/imbi_common/graph/client.py
| async def close(self) -> None:
"""Close the connection pool and release models."""
from imbi_common.graph import embeddings
await self.pool.close()
embeddings.close()
self.opened = False
|
create
async
create(node: GraphModelT) -> GraphModelT
Create a node and its relationships in the graph.
Source code in src/imbi_common/graph/client.py
| async def create(
self,
node: GraphModelT,
) -> GraphModelT:
"""Create a node and its relationships in the graph."""
await self._execute_batch(cypher.create(node))
if isinstance(node, models.Node):
await self._auto_embed(node)
return node
|
delete
async
delete(node: GraphModel) -> None
Delete a node, its relationships, and embeddings.
The Cypher delete runs via AGE (requires autocommit),
then embeddings are cleaned up on the same connection.
Source code in src/imbi_common/graph/client.py
| async def delete(self, node: models.GraphModel) -> None:
"""Delete a node, its relationships, and embeddings.
The Cypher delete runs via AGE (requires autocommit),
then embeddings are cleaned up on the same connection.
"""
if not self.opened:
raise RuntimeError('Graph pool is not open')
stmt = cypher.delete(node)
async with self.pool.connection() as conn:
await self._execute_on(
conn,
stmt.cypher,
stmt.params,
)
if isinstance(node, models.Node):
await self._delete_embeddings_where(
conn,
node_label=type(node).__name__,
node_id=node.id,
)
|
execute
async
execute(
query_template: str,
params: dict[str, Any] | None = None,
columns: list[str] | None = None,
) -> list[dict[str, typing.Any]]
Wrap a Cypher query in SQL and execute it.
Parameters in params are serialized via
_cypher_param() using Cypher-compatible escaping
and interpolated into query_template via
sql.SQL.format().
The Cypher query is wrapped in AGE's cypher()
function. columns defines the AS (...) clause
— pass one name per value in the Cypher RETURN
clause. Defaults to ['n'] for single-column
returns.
Source code in src/imbi_common/graph/client.py
| async def execute(
self,
query_template: str,
params: dict[str, typing.Any] | None = None,
columns: list[str] | None = None,
) -> list[dict[str, typing.Any]]:
"""Wrap a Cypher query in SQL and execute it.
Parameters in *params* are serialized via
``_cypher_param()`` using Cypher-compatible escaping
and interpolated into *query_template* via
``sql.SQL.format()``.
The Cypher query is wrapped in AGE's ``cypher()``
function. *columns* defines the ``AS (...)`` clause
— pass one name per value in the Cypher ``RETURN``
clause. Defaults to ``['n']`` for single-column
returns.
"""
if not self.opened:
raise RuntimeError('Graph pool is not open')
async with self.pool.connection() as conn:
return await self._execute_on(
conn,
query_template,
params,
columns,
)
|
match
async
match(
node_type: type[ModelT],
params: dict[str, Any] | None = None,
order_by: str | None = None,
) -> list[ModelT]
Match nodes and return model instances.
Deserialization prefers model_validate (so field
validators run) and falls back to model_construct
when validation fails.
Source code in src/imbi_common/graph/client.py
| async def match(
self,
node_type: type[ModelT],
params: dict[str, typing.Any] | None = None,
order_by: str | None = None,
) -> list[ModelT]:
"""Match nodes and return model instances.
Deserialization prefers ``model_validate`` (so field
validators run) and falls back to ``model_construct``
when validation fails.
"""
stmt = cypher.match(node_type, params, order_by)
raw_rows = await self.execute(
stmt.cypher,
stmt.params,
)
results: list[ModelT] = []
for row in raw_rows:
for value in row.values():
props = parse_agtype(value)
if isinstance(props, dict):
results.append(
self._row_to_model(node_type, props),
)
return results
|
merge
async
merge(
node: GraphModelT, match_on: list[str] | None = None
) -> GraphModelT
Upsert a node and its relationships in the graph.
Source code in src/imbi_common/graph/client.py
| async def merge(
self,
node: GraphModelT,
match_on: list[str] | None = None,
) -> GraphModelT:
"""Upsert a node and its relationships in the graph."""
await self._execute_batch(
cypher.merge(node, match_on),
)
if isinstance(node, models.Node):
await self._auto_embed(node)
return node
|
open
async
Open the connection pool.
Source code in src/imbi_common/graph/client.py
| async def open(self) -> None:
"""Open the connection pool."""
await self.pool.open()
self.opened = True
|
search
async
search(
query: str,
*,
model_name: str = 'text',
node_label: str | None = None,
attribute: str | None = None,
limit: int = 10,
distance_threshold: float | None = None,
) -> list[SearchResult]
Search for nodes by semantic similarity.
Embeds query using the specified model, then
performs a cosine similarity search against the
embeddings table. Results are ordered by
distance (ascending = most similar).
Source code in src/imbi_common/graph/client.py
| async def search(
self,
query: str,
*,
model_name: str = 'text',
node_label: str | None = None,
attribute: str | None = None,
limit: int = 10,
distance_threshold: float | None = None,
) -> list[SearchResult]:
"""Search for nodes by semantic similarity.
Embeds *query* using the specified model, then
performs a cosine similarity search against the
``embeddings`` table. Results are ordered by
distance (ascending = most similar).
"""
if not self.opened:
raise RuntimeError('Graph pool is not open')
from imbi_common.graph import embeddings
vector = await embeddings.aembed_one(
query,
model_name,
)
vec = sql.Placeholder('vec')
dims = embeddings.get_dimensions(model_name)
distance = sql.SQL(
'(embedding::vector({dims})) <=> ({vec}::vector({dims}))',
).format(
dims=sql.Literal(dims),
vec=vec,
)
params: dict[str, typing.Any] = {
'vec': vector,
'model': model_name,
'limit': limit,
}
query_sql = sql.SQL(
'SELECT node_label, node_id, attribute,'
' chunk_text,'
' {distance} AS distance'
' FROM public.embeddings'
' WHERE model_name = {model}'
).format(
distance=distance,
model=sql.Placeholder('model'),
)
if node_label is not None:
params['label'] = node_label
query_sql += sql.SQL(
' AND node_label = {label}',
).format(label=sql.Placeholder('label'))
if attribute is not None:
params['attribute'] = attribute
query_sql += sql.SQL(
' AND attribute = {attribute}',
).format(attribute=sql.Placeholder('attribute'))
if distance_threshold is not None:
params['threshold'] = distance_threshold
query_sql += sql.SQL(
' AND {distance} <= {threshold}',
).format(
distance=distance,
threshold=sql.Placeholder('threshold'),
)
query_sql += sql.SQL(
' ORDER BY {distance} LIMIT {limit}',
).format(
distance=distance,
limit=sql.Placeholder('limit'),
)
async with self.pool.connection() as conn:
async with conn.cursor(
row_factory=rows.dict_row,
) as cur:
await cur.execute(query_sql, params)
result_rows = await cur.fetchall()
return [
SearchResult(
node_label=r['node_label'],
node_id=r['node_id'],
attribute=r['attribute'],
chunk_text=r['chunk_text'],
distance=r['distance'],
)
for r in result_rows
]
|
search_nodes
async
search_nodes(
node_type: type[ModelT],
query: str,
*,
model_name: str = 'text',
limit: int = 10,
) -> list[ModelT]
Search and return full node instances.
Combines vector search with graph node retrieval.
Results are deduplicated by id (multiple chunks
from the same node may match).
Source code in src/imbi_common/graph/client.py
| async def search_nodes(
self,
node_type: type[ModelT],
query: str,
*,
model_name: str = 'text',
limit: int = 10,
) -> list[ModelT]:
"""Search and return full node instances.
Combines vector search with graph node retrieval.
Results are deduplicated by ``id`` (multiple chunks
from the same node may match).
"""
# Over-fetch embedding rows so deduplication still
# yields enough distinct nodes for the requested limit.
chunk_multiplier = 5
results = await self.search(
query,
model_name=model_name,
node_label=node_type.__name__,
limit=limit * chunk_multiplier,
)
node_ids = list(
dict.fromkeys(r.node_id for r in results),
)[:limit]
if not node_ids:
return []
label = node_type.__name__
id_list = ', '.join(f'{{{f"id{i}"}}}' for i in range(len(node_ids)))
cypher_q = f'MATCH (n:{label}) WHERE n.id IN [{id_list}] RETURN n'
params: dict[str, typing.Any] = {
f'id{i}': nid for i, nid in enumerate(node_ids)
}
raw_rows = await self.execute(cypher_q, params)
# Re-order to match the ranking from search()
by_id: dict[str, ModelT] = {}
for row in raw_rows:
for value in row.values():
props = parse_agtype(value)
if isinstance(props, dict):
nid = props.get('id')
node = self._row_to_model(node_type, props)
if nid is not None:
by_id[nid] = node
return [by_id[nid] for nid in node_ids if nid in by_id]
|
graph_lifespan
async
graph_lifespan() -> abc.AsyncIterator[Graph]
Source code in src/imbi_common/graph/__init__.py
| @contextlib.asynccontextmanager
async def graph_lifespan() -> abc.AsyncIterator[Graph]:
await initialize()
graph = Graph()
await graph.open()
try:
if _on_startup is not None:
await _on_startup(graph)
yield graph
finally:
await graph.close()
|
Pool
module-attribute
Pool = Annotated[Graph, Depends(_inject_graph)]
Cypher Query Generation
Statement
Bases: NamedTuple
A Cypher query template paired with its parameter values.
create
create(node: GraphModel) -> list[Statement]
Generate CREATE statements for node and its edges.
Returns a list where the first entry creates the node and
subsequent entries create each relationship.
Source code in src/imbi_common/graph/cypher.py
| def create(node: models.GraphModel) -> list[Statement]:
"""Generate ``CREATE`` statements for *node* and its edges.
Returns a list where the first entry creates the node and
subsequent entries create each relationship.
"""
props = _node_properties(node)
cypher = f'CREATE (n:{_label(node)} {_props_template(props)}) RETURN n'
statements = [Statement(cypher=cypher, params=props)]
statements.extend(_edge_statements(node))
return statements
|
delete
delete(node: GraphModel) -> Statement
Generate a DETACH DELETE statement for node.
Source code in src/imbi_common/graph/cypher.py
| def delete(node: models.GraphModel) -> Statement:
"""Generate a ``DETACH DELETE`` statement for *node*."""
key, val = _identity(node)
return Statement(
cypher=(
f'MATCH (n:{_label(node)} {{{{{key}: {{key}}}}}}) '
f'DETACH DELETE n RETURN n'
),
params={'key': val},
)
|
match
match(
node_type: type[BaseModel],
params: dict[str, Any] | None = None,
order_by: str | None = None,
) -> Statement
Generate a MATCH statement for node_type.
When params is provided the matched nodes are filtered by
those properties; otherwise all nodes of the label are returned.
order_by, when given, appends ORDER BY n.<field>.
Source code in src/imbi_common/graph/cypher.py
| def match(
node_type: type[pydantic.BaseModel],
params: dict[str, typing.Any] | None = None,
order_by: str | None = None,
) -> Statement:
"""Generate a ``MATCH`` statement for *node_type*.
When *params* is provided the matched nodes are filtered by
those properties; otherwise all nodes of the label are returned.
*order_by*, when given, appends ``ORDER BY n.<field>``.
"""
params = dict(params) if params else {}
label = _label(node_type)
edge_names = {n for n, _, _ in _edge_fields(node_type)}
scalar = set(node_type.model_fields) - edge_names
if params:
bad = [k for k in params if k not in scalar]
if bad:
raise ValueError(f'Unknown field(s) for {label}: {", ".join(bad)}')
cypher = f'MATCH (n:{label} {_props_template(params)}) RETURN n'
else:
cypher = f'MATCH (n:{label}) RETURN n'
if order_by:
if order_by not in scalar:
raise ValueError(f'Unknown order_by field for {label}: {order_by}')
cypher += f' ORDER BY n.{order_by}'
return Statement(cypher=cypher, params=params)
|
merge
merge(
node: GraphModel, match_on: list[str] | None = None
) -> list[Statement]
Generate MERGE statements for node and its edges.
match_on lists the property names used to identify the node
for the MERGE clause. Defaults to ['slug'] for
Node subclasses (stable business key) and ['id'] for
plain GraphModel subclasses. All other non-None scalar
properties appear in the SET clause.
id and created_at use COALESCE so they are written
on first creation but preserved on subsequent merges (Apache
AGE does not support ON CREATE SET / ON MATCH SET).
Properties whose value is None are omitted so that existing
graph values are preserved rather than being deleted.
Source code in src/imbi_common/graph/cypher.py
| def merge(
node: models.GraphModel,
match_on: list[str] | None = None,
) -> list[Statement]:
"""Generate ``MERGE`` statements for *node* and its edges.
*match_on* lists the property names used to identify the node
for the ``MERGE`` clause. Defaults to ``['slug']`` for
``Node`` subclasses (stable business key) and ``['id']`` for
plain ``GraphModel`` subclasses. All other non-None scalar
properties appear in the ``SET`` clause.
``id`` and ``created_at`` use ``COALESCE`` so they are written
on first creation but preserved on subsequent merges (Apache
AGE does not support ``ON CREATE SET`` / ``ON MATCH SET``).
Properties whose value is ``None`` are omitted so that existing
graph values are preserved rather than being deleted.
"""
if match_on is None:
match_on = [_identity(node)[0]]
if not match_on:
raise ValueError('match_on must contain at least one key')
props = _node_properties(node)
bad = [k for k in match_on if k not in props]
if bad:
raise ValueError(
f'Unknown merge key(s) for {_label(node)}: {", ".join(bad)}'
)
match_props = {k: props[k] for k in match_on}
set_props = {
k: v for k, v in props.items() if k not in match_on and v is not None
}
cypher = f'MERGE (n:{_label(node)} {_props_template(match_props)})'
# Build SET assignments. ``id`` and ``created_at`` use
# COALESCE so the first MERGE persists them but subsequent
# merges preserve the original values (Apache AGE lacks
# ``ON CREATE SET`` / ``ON MATCH SET``).
once_only = {'id', 'created_at'}
assignments: list[str] = []
for k in set_props:
if k in once_only:
assignments.append(
f'n.{k} = coalesce(n.{k}, {{{k}}})',
)
else:
assignments.append(f'n.{k} = {{{k}}}')
if assignments:
cypher += ' SET ' + ', '.join(assignments)
cypher += ' RETURN n'
statements = [Statement(cypher=cypher, params=props)]
statements.extend(_edge_statements(node, verb='MERGE'))
return statements
|