"""Parallel pipeline orchestration for transform nodes.""" import logging import os import sqlite3 from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor, as_completed from typing import List, Dict, Optional import pandas as pd import multiprocessing as mp logger = logging.getLogger("knack-transform") class TransformContext: """Context object containing the dataframe for transformation.""" def __init__(self, df: pd.DataFrame): self.df = df def get_dataframe(self) -> pd.DataFrame: """Get the pandas dataframe from the context.""" return self.df class NodeConfig: """Configuration for a transform node.""" def __init__(self, node_class: type, node_kwargs: Dict = None, dependencies: List[str] = None, name: str = None): """Initialize node configuration. Args: node_class: The TransformNode class to instantiate node_kwargs: Keyword arguments to pass to node constructor dependencies: List of node names that must complete before this one name: Optional name for the node (defaults to class name) """ self.node_class = node_class self.node_kwargs = node_kwargs or {} self.dependencies = dependencies or [] self.name = name or node_class.__name__ class ParallelPipeline: """Pipeline for executing transform nodes in parallel where possible. The pipeline analyzes dependencies between nodes and executes independent nodes concurrently using multiprocessing or threading. """ def __init__(self, max_workers: Optional[int] = None, use_processes: bool = False): """Initialize the parallel pipeline. Args: max_workers: Maximum number of parallel workers (defaults to CPU count) use_processes: If True, use ProcessPoolExecutor; if False, use ThreadPoolExecutor """ self.max_workers = max_workers or mp.cpu_count() self.use_processes = use_processes self.nodes: Dict[str, NodeConfig] = {} logger.info(f"Initialized ParallelPipeline with {self.max_workers} workers " f"({'processes' if use_processes else 'threads'})") def add_node(self, config: NodeConfig): """Add a node to the pipeline. Args: config: NodeConfig with node details and dependencies """ self.nodes[config.name] = config logger.info(f"Added node '{config.name}' with dependencies: {config.dependencies}") def _get_execution_stages(self) -> List[List[str]]: """Determine execution stages based on dependencies. Returns: List of stages, where each stage contains node names that can run in parallel """ stages = [] completed = set() remaining = set(self.nodes.keys()) while remaining: # Find nodes whose dependencies are all completed ready = [] for node_name in remaining: config = self.nodes[node_name] if all(dep in completed for dep in config.dependencies): ready.append(node_name) if not ready: # Circular dependency or missing dependency raise ValueError(f"Cannot resolve dependencies. Remaining nodes: {remaining}") stages.append(ready) completed.update(ready) remaining -= set(ready) return stages def _execute_node(self, node_name: str, db_path: str, context: TransformContext) -> tuple: """Execute a single node. Args: node_name: Name of the node to execute db_path: Path to the SQLite database context: TransformContext for the node Returns: Tuple of (node_name, result_context, error) """ try: # Create fresh database connection (not shared across processes/threads) con = sqlite3.connect(db_path) config = self.nodes[node_name] node = config.node_class(**config.node_kwargs) logger.info(f"Executing node: {node_name}") result_context = node.run(con, context) con.close() logger.info(f"Node '{node_name}' completed successfully") return node_name, result_context, None except Exception as e: logger.error(f"Error executing node '{node_name}': {e}", exc_info=True) return node_name, None, str(e) def run(self, db_path: str, initial_context: TransformContext, fail_fast: bool = False) -> Dict[str, TransformContext]: """Execute the pipeline. Args: db_path: Path to the SQLite database initial_context: Initial TransformContext for the pipeline fail_fast: If True, stop execution on first error Returns: Dict mapping node names to their output TransformContext """ logger.info("Starting parallel pipeline execution") stages = self._get_execution_stages() logger.info(f"Pipeline has {len(stages)} execution stage(s)") results = {} contexts = {None: initial_context} # Track contexts from each node errors = [] ExecutorClass = ProcessPoolExecutor if self.use_processes else ThreadPoolExecutor for stage_num, stage_nodes in enumerate(stages, 1): logger.info(f"Stage {stage_num}/{len(stages)}: Executing {len(stage_nodes)} node(s) in parallel: {stage_nodes}") # For nodes in this stage, use the context from their dependencies # If multiple dependencies, we'll use the most recent one (or could merge) stage_futures = {} with ExecutorClass(max_workers=min(self.max_workers, len(stage_nodes))) as executor: for node_name in stage_nodes: config = self.nodes[node_name] # Get context from dependencies (use the last dependency's output) if config.dependencies: context = results.get(config.dependencies[-1], initial_context) else: context = initial_context future = executor.submit(self._execute_node, node_name, db_path, context) stage_futures[future] = node_name # Wait for all nodes in this stage to complete for future in as_completed(stage_futures): node_name = stage_futures[future] name, result_context, error = future.result() if error: errors.append((name, error)) if fail_fast: logger.error(f"Pipeline failed at node '{name}': {error}") raise RuntimeError(f"Node '{name}' failed: {error}") else: results[name] = result_context if errors: logger.warning(f"Pipeline completed with {len(errors)} error(s)") for name, error in errors: logger.error(f" - {name}: {error}") else: logger.info("Pipeline completed successfully") return results def create_default_pipeline(device: str = "cpu", max_workers: Optional[int] = None) -> ParallelPipeline: """Create a pipeline with default transform nodes. Args: device: Device to use for compute-intensive nodes ('cpu', 'cuda', 'mps') max_workers: Maximum number of parallel workers Returns: Configured ParallelPipeline """ from author_node import NerAuthorNode, FuzzyAuthorNode pipeline = ParallelPipeline(max_workers=max_workers, use_processes=False) # Add AuthorNode (no dependencies) pipeline.add_node(NodeConfig( node_class=NerAuthorNode, node_kwargs={ 'device': device, 'model_path': os.environ.get('GLINER_MODEL_PATH') }, dependencies=[], name='AuthorNode' )) pipeline.add_node(NodeConfig( node_class=FuzzyAuthorNode, node_kwargs={ 'max_l_dist': 1 }, dependencies=['AuthorNode'], name='FuzzyAuthorNode' )) # TODO: Create Node to compute Text Embeddings and UMAP. # TODO: Create Node to pre-compute data based on visuals to reduce load time. # TODO: Add more nodes here as they are implemented # Example: # pipeline.add_node(NodeConfig( # node_class=EmbeddingNode, # node_kwargs={'device': device}, # dependencies=[], # Runs after AuthorNode # name='EmbeddingNode' # )) # pipeline.add_node(NodeConfig( # node_class=UMAPNode, # node_kwargs={'device': device}, # dependencies=['EmbeddingNode'], # Runs after EmbeddingNode # name='UMAPNode' # )) return pipeline