Skip to content

Graph Database (Apache AGE)

The graph module provides async access to Apache AGE, a PostgreSQL extension for graph database queries using Cypher.

Overview

The Graph class wraps a psycopg async connection pool and exposes high-level CRUD operations (create, delete, match, merge) that accept Pydantic models. The cypher module generates Cypher query templates from those models.

Basic Usage

from imbi_common import graph, models

# Open a graph connection pool
db = graph.Graph()
await db.open()

# Create a node
org = models.Organization(name="My Org", slug="my-org")
await db.create(org)

# Match nodes
orgs = await db.match(models.Organization, {"slug": "my-org"})

# Match all nodes of a type, ordered
teams = await db.match(models.Team, order_by="name")

# Upsert a node
await db.merge(org, match_on=["slug"])

# Delete a node
await db.delete(org)

# Close the connection pool
await db.close()

FastAPI Dependency Injection

Wire graph_lifespan into the application lifespan, then declare Pool as a route parameter to receive the injected Graph instance:

import fastapi
from imbi_common import lifespan, models
from imbi_common.graph import Pool, graph_lifespan

app = fastapi.FastAPI(
    lifespan=lifespan.Lifespan(graph_lifespan),
)


@app.get('/orgs/{slug}')
async def get_org(slug: str, db: Pool) -> models.Organization:
    results = await db.match(models.Organization, {"slug": slug})
    return results[0]

To run custom initialisation after the pool opens (e.g. schema setup), register a startup callback before creating the app:

from imbi_common import graph

async def on_graph_ready(db: graph.Graph) -> None:
    await graph.initialize()

graph.set_on_startup(on_graph_ready)

API Reference

Graph Client

Graph

Graph()

Wrapper around the PostgreSQL connection pool.

Supports both Apache AGE Cypher queries and pgvector similarity search against the embeddings table.

Source code in src/imbi_common/graph/client.py
def __init__(self) -> None:
    self.opened = False
    self.settings = settings.Postgres()
    self.pool = psycopg_pool.AsyncConnectionPool(
        conninfo=str(self.settings.url),
        min_size=self.settings.min_pool_size,
        max_size=self.settings.max_pool_size,
        configure=self._configure_connection,
        open=False,
    )

close async

close() -> None

Close the connection pool and release models.

Source code in src/imbi_common/graph/client.py
async def close(self) -> None:
    """Close the connection pool and release models."""
    from imbi_common.graph import embeddings

    await self.pool.close()
    embeddings.close()
    self.opened = False

create async

create(node: GraphModelT) -> GraphModelT

Create a node and its relationships in the graph.

Source code in src/imbi_common/graph/client.py
async def create(
    self,
    node: GraphModelT,
) -> GraphModelT:
    """Create a node and its relationships in the graph."""
    await self._execute_batch(cypher.create(node))
    if isinstance(node, models.Node):
        await self._auto_embed(node)
    return node

delete async

delete(node: GraphModel) -> None

Delete a node, its relationships, and embeddings.

The Cypher delete runs via AGE (requires autocommit), then embeddings are cleaned up on the same connection.

Source code in src/imbi_common/graph/client.py
async def delete(self, node: models.GraphModel) -> None:
    """Delete a node, its relationships, and embeddings.

    The Cypher delete runs via AGE (requires autocommit),
    then embeddings are cleaned up on the same connection.

    """
    if not self.opened:
        raise RuntimeError('Graph pool is not open')
    stmt = cypher.delete(node)
    async with self.pool.connection() as conn:
        await self._execute_on(
            conn,
            stmt.cypher,
            stmt.params,
        )
        if isinstance(node, models.Node):
            await self._delete_embeddings_where(
                conn,
                node_label=type(node).__name__,
                node_id=node.id,
            )

execute async

execute(
    query_template: str,
    params: dict[str, Any] | None = None,
    columns: list[str] | None = None,
) -> list[dict[str, typing.Any]]

Wrap a Cypher query in SQL and execute it.

Parameters in params are serialized via _cypher_param() using Cypher-compatible escaping and interpolated into query_template via sql.SQL.format().

The Cypher query is wrapped in AGE's cypher() function. columns defines the AS (...) clause — pass one name per value in the Cypher RETURN clause. Defaults to ['n'] for single-column returns.

Source code in src/imbi_common/graph/client.py
async def execute(
    self,
    query_template: str,
    params: dict[str, typing.Any] | None = None,
    columns: list[str] | None = None,
) -> list[dict[str, typing.Any]]:
    """Wrap a Cypher query in SQL and execute it.

    Parameters in *params* are serialized via
    ``_cypher_param()`` using Cypher-compatible escaping
    and interpolated into *query_template* via
    ``sql.SQL.format()``.

    The Cypher query is wrapped in AGE's ``cypher()``
    function.  *columns* defines the ``AS (...)`` clause
    — pass one name per value in the Cypher ``RETURN``
    clause.  Defaults to ``['n']`` for single-column
    returns.

    """
    if not self.opened:
        raise RuntimeError('Graph pool is not open')

    async with self.pool.connection() as conn:
        return await self._execute_on(
            conn,
            query_template,
            params,
            columns,
        )

match async

match(
    node_type: type[ModelT],
    params: dict[str, Any] | None = None,
    order_by: str | None = None,
) -> list[ModelT]

Match nodes and return model instances.

Deserialization prefers model_validate (so field validators run) and falls back to model_construct when validation fails.

Source code in src/imbi_common/graph/client.py
async def match(
    self,
    node_type: type[ModelT],
    params: dict[str, typing.Any] | None = None,
    order_by: str | None = None,
) -> list[ModelT]:
    """Match nodes and return model instances.

    Deserialization prefers ``model_validate`` (so field
    validators run) and falls back to ``model_construct``
    when validation fails.

    """
    stmt = cypher.match(node_type, params, order_by)
    raw_rows = await self.execute(
        stmt.cypher,
        stmt.params,
    )
    results: list[ModelT] = []
    for row in raw_rows:
        for value in row.values():
            props = parse_agtype(value)
            if isinstance(props, dict):
                results.append(
                    self._row_to_model(node_type, props),
                )
    return results

merge async

merge(
    node: GraphModelT, match_on: list[str] | None = None
) -> GraphModelT

Upsert a node and its relationships in the graph.

Source code in src/imbi_common/graph/client.py
async def merge(
    self,
    node: GraphModelT,
    match_on: list[str] | None = None,
) -> GraphModelT:
    """Upsert a node and its relationships in the graph."""
    await self._execute_batch(
        cypher.merge(node, match_on),
    )
    if isinstance(node, models.Node):
        await self._auto_embed(node)
    return node

open async

open() -> None

Open the connection pool.

Source code in src/imbi_common/graph/client.py
async def open(self) -> None:
    """Open the connection pool."""
    await self.pool.open()
    self.opened = True

search async

search(
    query: str,
    *,
    model_name: str = 'text',
    node_label: str | None = None,
    attribute: str | None = None,
    limit: int = 10,
    distance_threshold: float | None = None,
) -> list[SearchResult]

Search for nodes by semantic similarity.

Embeds query using the specified model, then performs a cosine similarity search against the embeddings table. Results are ordered by distance (ascending = most similar).

Source code in src/imbi_common/graph/client.py
async def search(
    self,
    query: str,
    *,
    model_name: str = 'text',
    node_label: str | None = None,
    attribute: str | None = None,
    limit: int = 10,
    distance_threshold: float | None = None,
) -> list[SearchResult]:
    """Search for nodes by semantic similarity.

    Embeds *query* using the specified model, then
    performs a cosine similarity search against the
    ``embeddings`` table.  Results are ordered by
    distance (ascending = most similar).

    """
    if not self.opened:
        raise RuntimeError('Graph pool is not open')
    from imbi_common.graph import embeddings

    vector = await embeddings.aembed_one(
        query,
        model_name,
    )
    vec = sql.Placeholder('vec')
    dims = embeddings.get_dimensions(model_name)
    distance = sql.SQL(
        '(embedding::vector({dims})) <=> ({vec}::vector({dims}))',
    ).format(
        dims=sql.Literal(dims),
        vec=vec,
    )
    params: dict[str, typing.Any] = {
        'vec': vector,
        'model': model_name,
        'limit': limit,
    }
    query_sql = sql.SQL(
        'SELECT node_label, node_id, attribute,'
        '       chunk_text,'
        '       {distance} AS distance'
        '  FROM public.embeddings'
        ' WHERE model_name = {model}'
    ).format(
        distance=distance,
        model=sql.Placeholder('model'),
    )
    if node_label is not None:
        params['label'] = node_label
        query_sql += sql.SQL(
            ' AND node_label = {label}',
        ).format(label=sql.Placeholder('label'))
    if attribute is not None:
        params['attribute'] = attribute
        query_sql += sql.SQL(
            ' AND attribute = {attribute}',
        ).format(attribute=sql.Placeholder('attribute'))
    if distance_threshold is not None:
        params['threshold'] = distance_threshold
        query_sql += sql.SQL(
            ' AND {distance} <= {threshold}',
        ).format(
            distance=distance,
            threshold=sql.Placeholder('threshold'),
        )
    query_sql += sql.SQL(
        ' ORDER BY {distance} LIMIT {limit}',
    ).format(
        distance=distance,
        limit=sql.Placeholder('limit'),
    )
    async with self.pool.connection() as conn:
        async with conn.cursor(
            row_factory=rows.dict_row,
        ) as cur:
            await cur.execute(query_sql, params)
            result_rows = await cur.fetchall()
    return [
        SearchResult(
            node_label=r['node_label'],
            node_id=r['node_id'],
            attribute=r['attribute'],
            chunk_text=r['chunk_text'],
            distance=r['distance'],
        )
        for r in result_rows
    ]

search_nodes async

search_nodes(
    node_type: type[ModelT],
    query: str,
    *,
    model_name: str = 'text',
    limit: int = 10,
) -> list[ModelT]

Search and return full node instances.

Combines vector search with graph node retrieval. Results are deduplicated by id (multiple chunks from the same node may match).

Source code in src/imbi_common/graph/client.py
async def search_nodes(
    self,
    node_type: type[ModelT],
    query: str,
    *,
    model_name: str = 'text',
    limit: int = 10,
) -> list[ModelT]:
    """Search and return full node instances.

    Combines vector search with graph node retrieval.
    Results are deduplicated by ``id`` (multiple chunks
    from the same node may match).

    """
    # Over-fetch embedding rows so deduplication still
    # yields enough distinct nodes for the requested limit.
    chunk_multiplier = 5
    results = await self.search(
        query,
        model_name=model_name,
        node_label=node_type.__name__,
        limit=limit * chunk_multiplier,
    )
    node_ids = list(
        dict.fromkeys(r.node_id for r in results),
    )[:limit]
    if not node_ids:
        return []
    label = node_type.__name__
    id_list = ', '.join(f'{{{f"id{i}"}}}' for i in range(len(node_ids)))
    cypher_q = f'MATCH (n:{label}) WHERE n.id IN [{id_list}] RETURN n'
    params: dict[str, typing.Any] = {
        f'id{i}': nid for i, nid in enumerate(node_ids)
    }
    raw_rows = await self.execute(cypher_q, params)
    # Re-order to match the ranking from search()
    by_id: dict[str, ModelT] = {}
    for row in raw_rows:
        for value in row.values():
            props = parse_agtype(value)
            if isinstance(props, dict):
                nid = props.get('id')
                node = self._row_to_model(node_type, props)
                if nid is not None:
                    by_id[nid] = node
    return [by_id[nid] for nid in node_ids if nid in by_id]

graph_lifespan async

graph_lifespan() -> abc.AsyncIterator[Graph]
Source code in src/imbi_common/graph/__init__.py
@contextlib.asynccontextmanager
async def graph_lifespan() -> abc.AsyncIterator[Graph]:
    await initialize()
    graph = Graph()
    await graph.open()
    try:
        if _on_startup is not None:
            await _on_startup(graph)
        yield graph
    finally:
        await graph.close()

Pool module-attribute

Pool = Annotated[Graph, Depends(_inject_graph)]

Cypher Query Generation

Statement

Bases: NamedTuple

A Cypher query template paired with its parameter values.

create

create(node: GraphModel) -> list[Statement]

Generate CREATE statements for node and its edges.

Returns a list where the first entry creates the node and subsequent entries create each relationship.

Source code in src/imbi_common/graph/cypher.py
def create(node: models.GraphModel) -> list[Statement]:
    """Generate ``CREATE`` statements for *node* and its edges.

    Returns a list where the first entry creates the node and
    subsequent entries create each relationship.

    """
    props = _node_properties(node)
    cypher = f'CREATE (n:{_label(node)} {_props_template(props)}) RETURN n'
    statements = [Statement(cypher=cypher, params=props)]
    statements.extend(_edge_statements(node))
    return statements

delete

delete(node: GraphModel) -> Statement

Generate a DETACH DELETE statement for node.

Source code in src/imbi_common/graph/cypher.py
def delete(node: models.GraphModel) -> Statement:
    """Generate a ``DETACH DELETE`` statement for *node*."""
    key, val = _identity(node)
    return Statement(
        cypher=(
            f'MATCH (n:{_label(node)} {{{{{key}: {{key}}}}}}) '
            f'DETACH DELETE n RETURN n'
        ),
        params={'key': val},
    )

match

match(
    node_type: type[BaseModel],
    params: dict[str, Any] | None = None,
    order_by: str | None = None,
) -> Statement

Generate a MATCH statement for node_type.

When params is provided the matched nodes are filtered by those properties; otherwise all nodes of the label are returned.

order_by, when given, appends ORDER BY n.<field>.

Source code in src/imbi_common/graph/cypher.py
def match(
    node_type: type[pydantic.BaseModel],
    params: dict[str, typing.Any] | None = None,
    order_by: str | None = None,
) -> Statement:
    """Generate a ``MATCH`` statement for *node_type*.

    When *params* is provided the matched nodes are filtered by
    those properties; otherwise all nodes of the label are returned.

    *order_by*, when given, appends ``ORDER BY n.<field>``.

    """
    params = dict(params) if params else {}
    label = _label(node_type)
    edge_names = {n for n, _, _ in _edge_fields(node_type)}
    scalar = set(node_type.model_fields) - edge_names
    if params:
        bad = [k for k in params if k not in scalar]
        if bad:
            raise ValueError(f'Unknown field(s) for {label}: {", ".join(bad)}')
        cypher = f'MATCH (n:{label} {_props_template(params)}) RETURN n'
    else:
        cypher = f'MATCH (n:{label}) RETURN n'
    if order_by:
        if order_by not in scalar:
            raise ValueError(f'Unknown order_by field for {label}: {order_by}')
        cypher += f' ORDER BY n.{order_by}'
    return Statement(cypher=cypher, params=params)

merge

merge(
    node: GraphModel, match_on: list[str] | None = None
) -> list[Statement]

Generate MERGE statements for node and its edges.

match_on lists the property names used to identify the node for the MERGE clause. Defaults to ['slug'] for Node subclasses (stable business key) and ['id'] for plain GraphModel subclasses. All other non-None scalar properties appear in the SET clause.

id and created_at use COALESCE so they are written on first creation but preserved on subsequent merges (Apache AGE does not support ON CREATE SET / ON MATCH SET).

Properties whose value is None are omitted so that existing graph values are preserved rather than being deleted.

Source code in src/imbi_common/graph/cypher.py
def merge(
    node: models.GraphModel,
    match_on: list[str] | None = None,
) -> list[Statement]:
    """Generate ``MERGE`` statements for *node* and its edges.

    *match_on* lists the property names used to identify the node
    for the ``MERGE`` clause.  Defaults to ``['slug']`` for
    ``Node`` subclasses (stable business key) and ``['id']`` for
    plain ``GraphModel`` subclasses.  All other non-None scalar
    properties appear in the ``SET`` clause.

    ``id`` and ``created_at`` use ``COALESCE`` so they are written
    on first creation but preserved on subsequent merges (Apache
    AGE does not support ``ON CREATE SET`` / ``ON MATCH SET``).

    Properties whose value is ``None`` are omitted so that existing
    graph values are preserved rather than being deleted.

    """
    if match_on is None:
        match_on = [_identity(node)[0]]
    if not match_on:
        raise ValueError('match_on must contain at least one key')
    props = _node_properties(node)

    bad = [k for k in match_on if k not in props]
    if bad:
        raise ValueError(
            f'Unknown merge key(s) for {_label(node)}: {", ".join(bad)}'
        )
    match_props = {k: props[k] for k in match_on}
    set_props = {
        k: v for k, v in props.items() if k not in match_on and v is not None
    }

    cypher = f'MERGE (n:{_label(node)} {_props_template(match_props)})'

    # Build SET assignments.  ``id`` and ``created_at`` use
    # COALESCE so the first MERGE persists them but subsequent
    # merges preserve the original values (Apache AGE lacks
    # ``ON CREATE SET`` / ``ON MATCH SET``).
    once_only = {'id', 'created_at'}
    assignments: list[str] = []
    for k in set_props:
        if k in once_only:
            assignments.append(
                f'n.{k} = coalesce(n.{k}, {{{k}}})',
            )
        else:
            assignments.append(f'n.{k} = {{{k}}}')
    if assignments:
        cypher += ' SET ' + ', '.join(assignments)
    cypher += ' RETURN n'

    statements = [Statement(cypher=cypher, params=props)]
    statements.extend(_edge_statements(node, verb='MERGE'))
    return statements