Source code for langchain_core.runnables.graph_mermaid

import base64
import re
from dataclasses import asdict
from typing import Dict, List, Optional

from langchain_core.runnables.graph import (
    CurveStyle,
    Edge,
    MermaidDrawMethod,
    Node,
    NodeStyles,
)


[docs] def draw_mermaid( nodes: Dict[str, Node], edges: List[Edge], *, first_node: Optional[str] = None, last_node: Optional[str] = None, with_styles: bool = True, curve_style: CurveStyle = CurveStyle.LINEAR, node_styles: NodeStyles = NodeStyles(), wrap_label_n_words: int = 9, ) -> str: """Draws a Mermaid graph using the provided graph data. Args: nodes (dict[str, str]): List of node ids. edges (List[Edge]): List of edges, object with a source, target and data. first_node (str, optional): Id of the first node. Defaults to None. last_node (str, optional): Id of the last node. Defaults to None. with_styles (bool, optional): Whether to include styles in the graph. Defaults to True. curve_style (CurveStyle, optional): Curve style for the edges. Defaults to CurveStyle.LINEAR. node_styles (NodeStyles, optional): Node colors for different types. Defaults to NodeStyles(). wrap_label_n_words (int, optional): Words to wrap the edge labels. Defaults to 9. Returns: str: Mermaid graph syntax. """ # Initialize Mermaid graph configuration mermaid_graph = ( ( f"%%{{init: {{'flowchart': {{'curve': '{curve_style.value}'" f"}}}}}}%%\ngraph TD;\n" ) if with_styles else "graph TD;\n" ) if with_styles: # Node formatting templates default_class_label = "default" format_dict = {default_class_label: "{0}({1})"} if first_node is not None: format_dict[first_node] = "{0}([{0}]):::first" if last_node is not None: format_dict[last_node] = "{0}([{0}]):::last" # Add nodes to the graph for key, node in nodes.items(): label = node.name.split(":")[-1] if node.metadata: label = ( f"{label}<hr/><small><em>" + "\n".join( f"{key} = {value}" for key, value in node.metadata.items() ) + "</em></small>" ) node_label = format_dict.get(key, format_dict[default_class_label]).format( _escape_node_label(key), label ) mermaid_graph += f"\t{node_label}\n" subgraph = "" # Add edges to the graph for edge in edges: src_prefix = edge.source.split(":")[0] if ":" in edge.source else None tgt_prefix = edge.target.split(":")[0] if ":" in edge.target else None # exit subgraph if source or target is not in the same subgraph if subgraph and (subgraph != src_prefix or subgraph != tgt_prefix): mermaid_graph += "\tend\n" subgraph = "" # enter subgraph if source and target are in the same subgraph if not subgraph and src_prefix and src_prefix == tgt_prefix: mermaid_graph += f"\tsubgraph {src_prefix}\n" subgraph = src_prefix source, target = edge.source, edge.target # Add BR every wrap_label_n_words words if edge.data is not None: edge_data = edge.data words = str(edge_data).split() # Split the string into words # Group words into chunks of wrap_label_n_words size if len(words) > wrap_label_n_words: edge_data = "&nbsp<br>&nbsp".join( " ".join(words[i : i + wrap_label_n_words]) for i in range(0, len(words), wrap_label_n_words) ) if edge.conditional: edge_label = f" -. &nbsp{edge_data}&nbsp .-> " else: edge_label = f" -- &nbsp{edge_data}&nbsp --> " else: if edge.conditional: edge_label = " -.-> " else: edge_label = " --> " mermaid_graph += ( f"\t{_escape_node_label(source)}{edge_label}" f"{_escape_node_label(target)};\n" ) if subgraph: mermaid_graph += "end\n" # Add custom styles for nodes if with_styles: mermaid_graph += _generate_mermaid_graph_styles(node_styles) return mermaid_graph
def _escape_node_label(node_label: str) -> str: """Escapes the node label for Mermaid syntax.""" return re.sub(r"[^a-zA-Z-_0-9]", "_", node_label) def _generate_mermaid_graph_styles(node_colors: NodeStyles) -> str: """Generates Mermaid graph styles for different node types.""" styles = "" for class_name, style in asdict(node_colors).items(): styles += f"\tclassDef {class_name} {style}\n" return styles
[docs] def draw_mermaid_png( mermaid_syntax: str, output_file_path: Optional[str] = None, draw_method: MermaidDrawMethod = MermaidDrawMethod.API, background_color: Optional[str] = "white", padding: int = 10, ) -> bytes: """Draws a Mermaid graph as PNG using provided syntax. Args: mermaid_syntax (str): Mermaid graph syntax. output_file_path (str, optional): Path to save the PNG image. Defaults to None. draw_method (MermaidDrawMethod, optional): Method to draw the graph. Defaults to MermaidDrawMethod.API. background_color (str, optional): Background color of the image. Defaults to "white". padding (int, optional): Padding around the image. Defaults to 10. Returns: bytes: PNG image bytes. Raises: ValueError: If an invalid draw method is provided. """ if draw_method == MermaidDrawMethod.PYPPETEER: import asyncio img_bytes = asyncio.run( _render_mermaid_using_pyppeteer( mermaid_syntax, output_file_path, background_color, padding ) ) elif draw_method == MermaidDrawMethod.API: img_bytes = _render_mermaid_using_api( mermaid_syntax, output_file_path, background_color ) else: supported_methods = ", ".join([m.value for m in MermaidDrawMethod]) raise ValueError( f"Invalid draw method: {draw_method}. " f"Supported draw methods are: {supported_methods}" ) return img_bytes
async def _render_mermaid_using_pyppeteer( mermaid_syntax: str, output_file_path: Optional[str] = None, background_color: Optional[str] = "white", padding: int = 10, device_scale_factor: int = 3, ) -> bytes: """Renders Mermaid graph using Pyppeteer.""" try: from pyppeteer import launch # type: ignore[import] except ImportError as e: raise ImportError( "Install Pyppeteer to use the Pyppeteer method: `pip install pyppeteer`." ) from e browser = await launch() page = await browser.newPage() # Setup Mermaid JS await page.goto("about:blank") await page.addScriptTag( {"url": "https://cdn.jsdelivr.net/npm/mermaid/dist/mermaid.min.js"} ) await page.evaluate( """() => { mermaid.initialize({startOnLoad:true}); }""" ) # Render SVG svg_code = await page.evaluate( """(mermaidGraph) => { return mermaid.mermaidAPI.render('mermaid', mermaidGraph); }""", mermaid_syntax, ) # Set the page background to white await page.evaluate( """(svg, background_color) => { document.body.innerHTML = svg; document.body.style.background = background_color; }""", svg_code["svg"], background_color, ) # Take a screenshot dimensions = await page.evaluate( """() => { const svgElement = document.querySelector('svg'); const rect = svgElement.getBoundingClientRect(); return { width: rect.width, height: rect.height }; }""" ) await page.setViewport( { "width": int(dimensions["width"] + padding), "height": int(dimensions["height"] + padding), "deviceScaleFactor": device_scale_factor, } ) img_bytes = await page.screenshot({"fullPage": False}) await browser.close() if output_file_path is not None: with open(output_file_path, "wb") as file: file.write(img_bytes) return img_bytes def _render_mermaid_using_api( mermaid_syntax: str, output_file_path: Optional[str] = None, background_color: Optional[str] = "white", ) -> bytes: """Renders Mermaid graph using the Mermaid.INK API.""" try: import requests # type: ignore[import] except ImportError as e: raise ImportError( "Install the `requests` module to use the Mermaid.INK API: " "`pip install requests`." ) from e # Use Mermaid API to render the image mermaid_syntax_encoded = base64.b64encode(mermaid_syntax.encode("utf8")).decode( "ascii" ) # Check if the background color is a hexadecimal color code using regex if background_color is not None: hex_color_pattern = re.compile(r"^#(?:[0-9a-fA-F]{3}){1,2}$") if not hex_color_pattern.match(background_color): background_color = f"!{background_color}" image_url = ( f"https://mermaid.ink/img/{mermaid_syntax_encoded}?bgColor={background_color}" ) response = requests.get(image_url) if response.status_code == 200: img_bytes = response.content if output_file_path is not None: with open(output_file_path, "wb") as file: file.write(response.content) return img_bytes else: raise ValueError( f"Failed to render the graph using the Mermaid.INK API. " f"Status code: {response.status_code}." )