Knack-Scraper/transform/pipeline.py
2025-12-24 17:58:23 +01:00

266 lines
9.5 KiB
Python

"""Parallel pipeline orchestration for transform nodes."""
import logging
import os
import sqlite3
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor, as_completed
from typing import List, Dict, Optional
import pandas as pd
import multiprocessing as mp
logger = logging.getLogger("knack-transform")
class TransformContext:
"""Context object containing the dataframe for transformation."""
# Possibly add a dict for the context to give more Information
def __init__(self, df: pd.DataFrame):
self.df = df
def get_dataframe(self) -> pd.DataFrame:
"""Get the pandas dataframe from the context."""
return self.df
class NodeConfig:
"""Configuration for a transform node."""
def __init__(self,
node_class: type,
node_kwargs: Dict = None,
dependencies: List[str] = None,
name: str = None):
"""Initialize node configuration.
Args:
node_class: The TransformNode class to instantiate
node_kwargs: Keyword arguments to pass to node constructor
dependencies: List of node names that must complete before this one
name: Optional name for the node (defaults to class name)
"""
self.node_class = node_class
self.node_kwargs = node_kwargs or {}
self.dependencies = dependencies or []
self.name = name or node_class.__name__
class ParallelPipeline:
"""Pipeline for executing transform nodes in parallel where possible.
The pipeline analyzes dependencies between nodes and executes
independent nodes concurrently using multiprocessing or threading.
"""
def __init__(self,
max_workers: Optional[int] = None,
use_processes: bool = False):
"""Initialize the parallel pipeline.
Args:
max_workers: Maximum number of parallel workers (defaults to CPU count)
use_processes: If True, use ProcessPoolExecutor; if False, use ThreadPoolExecutor
"""
self.max_workers = max_workers or mp.cpu_count()
self.use_processes = use_processes
self.nodes: Dict[str, NodeConfig] = {}
logger.info(f"Initialized ParallelPipeline with {self.max_workers} workers "
f"({'processes' if use_processes else 'threads'})")
def add_node(self, config: NodeConfig):
"""Add a node to the pipeline.
Args:
config: NodeConfig with node details and dependencies
"""
self.nodes[config.name] = config
logger.info(f"Added node '{config.name}' with dependencies: {config.dependencies}")
def _get_execution_stages(self) -> List[List[str]]:
"""Determine execution stages based on dependencies.
Returns:
List of stages, where each stage contains node names that can run in parallel
"""
stages = []
completed = set()
remaining = set(self.nodes.keys())
while remaining:
# Find nodes whose dependencies are all completed
ready = []
for node_name in remaining:
config = self.nodes[node_name]
if all(dep in completed for dep in config.dependencies):
ready.append(node_name)
if not ready:
# Circular dependency or missing dependency
raise ValueError(f"Cannot resolve dependencies. Remaining nodes: {remaining}")
stages.append(ready)
completed.update(ready)
remaining -= set(ready)
return stages
def _execute_node(self,
node_name: str,
db_path: str,
context: TransformContext) -> tuple:
"""Execute a single node.
Args:
node_name: Name of the node to execute
db_path: Path to the SQLite database
context: TransformContext for the node
Returns:
Tuple of (node_name, result_context, error)
"""
try:
# Create fresh database connection (not shared across processes/threads)
con = sqlite3.connect(db_path)
config = self.nodes[node_name]
node = config.node_class(**config.node_kwargs)
logger.info(f"Executing node: {node_name}")
result_context = node.run(con, context)
con.close()
logger.info(f"Node '{node_name}' completed successfully")
return node_name, result_context, None
except Exception as e:
logger.error(f"Error executing node '{node_name}': {e}", exc_info=True)
return node_name, None, str(e)
def run(self,
db_path: str,
initial_context: TransformContext,
fail_fast: bool = False) -> Dict[str, TransformContext]:
"""Execute the pipeline.
Args:
db_path: Path to the SQLite database
initial_context: Initial TransformContext for the pipeline
fail_fast: If True, stop execution on first error
Returns:
Dict mapping node names to their output TransformContext
"""
logger.info("Starting parallel pipeline execution")
stages = self._get_execution_stages()
logger.info(f"Pipeline has {len(stages)} execution stage(s)")
results = {}
errors = []
ExecutorClass = ProcessPoolExecutor if self.use_processes else ThreadPoolExecutor
for stage_num, stage_nodes in enumerate(stages, 1):
logger.info(f"Stage {stage_num}/{len(stages)}: Executing {len(stage_nodes)} node(s) in parallel: {stage_nodes}")
# For nodes in this stage, use the context from their dependencies
# If multiple dependencies, we'll use the most recent one (or could merge)
stage_futures = {}
with ExecutorClass(max_workers=min(self.max_workers, len(stage_nodes))) as executor:
for node_name in stage_nodes:
config = self.nodes[node_name]
# Get context from dependencies (use the last dependency's output)
if config.dependencies:
context = results.get(config.dependencies[-1], initial_context)
else:
context = initial_context
future = executor.submit(self._execute_node, node_name, db_path, context)
stage_futures[future] = node_name
# Wait for all nodes in this stage to complete
for future in as_completed(stage_futures):
node_name = stage_futures[future]
name, result_context, error = future.result()
if error:
errors.append((name, error))
if fail_fast:
logger.error(f"Pipeline failed at node '{name}': {error}")
raise RuntimeError(f"Node '{name}' failed: {error}")
else:
results[name] = result_context
if errors:
logger.warning(f"Pipeline completed with {len(errors)} error(s)")
for name, error in errors:
logger.error(f" - {name}: {error}")
else:
logger.info("Pipeline completed successfully")
return results
def create_default_pipeline(device: str = "cpu",
max_workers: Optional[int] = None) -> ParallelPipeline:
"""Create a pipeline with default transform nodes.
Args:
device: Device to use for compute-intensive nodes ('cpu', 'cuda', 'mps')
max_workers: Maximum number of parallel workers
Returns:
Configured ParallelPipeline
"""
from author_node import NerAuthorNode, FuzzyAuthorNode
from embeddings_node import TextEmbeddingNode, UmapNode
pipeline = ParallelPipeline(max_workers=max_workers, use_processes=False)
# Add AuthorNode (no dependencies)
pipeline.add_node(NodeConfig(
node_class=NerAuthorNode,
node_kwargs={
'device': device,
'model_path': os.environ.get('GLINER_MODEL_PATH')
},
dependencies=[],
name='AuthorNode'
))
pipeline.add_node(NodeConfig(
node_class=FuzzyAuthorNode,
node_kwargs={
'max_l_dist': 1
},
dependencies=['AuthorNode'],
name='FuzzyAuthorNode'
))
pipeline.add_node(NodeConfig(
node_class=TextEmbeddingNode,
node_kwargs={
'device': device,
'model_path': os.environ.get('MINILM_MODEL_PATH')
},
dependencies=[],
name='TextEmbeddingNode'
))
pipeline.add_node(NodeConfig(
node_class=UmapNode,
node_kwargs={},
dependencies=['TextEmbeddingNode'],
name='UmapNode'
))
# TODO: Create Node to compute Text Embeddings and UMAP.
# pipeline.add_node(NodeConfig(
# node_class=UMAPNode,
# node_kwargs={'device': device},
# dependencies=['EmbeddingNode'], # Runs after EmbeddingNode
# name='UMAPNode'
# ))
return pipeline