forked from lukaszett/Knack-Scraper
258 lines
9.3 KiB
Python
258 lines
9.3 KiB
Python
"""Parallel pipeline orchestration for transform nodes."""
|
|
import logging
|
|
import os
|
|
import sqlite3
|
|
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor, as_completed
|
|
from typing import List, Dict, Optional
|
|
|
|
import pandas as pd
|
|
import multiprocessing as mp
|
|
|
|
logger = logging.getLogger("knack-transform")
|
|
|
|
class TransformContext:
|
|
"""Context object containing the dataframe for transformation."""
|
|
|
|
def __init__(self, df: pd.DataFrame):
|
|
self.df = df
|
|
|
|
def get_dataframe(self) -> pd.DataFrame:
|
|
"""Get the pandas dataframe from the context."""
|
|
return self.df
|
|
|
|
class NodeConfig:
|
|
"""Configuration for a transform node."""
|
|
|
|
def __init__(self,
|
|
node_class: type,
|
|
node_kwargs: Dict = None,
|
|
dependencies: List[str] = None,
|
|
name: str = None):
|
|
"""Initialize node configuration.
|
|
|
|
Args:
|
|
node_class: The TransformNode class to instantiate
|
|
node_kwargs: Keyword arguments to pass to node constructor
|
|
dependencies: List of node names that must complete before this one
|
|
name: Optional name for the node (defaults to class name)
|
|
"""
|
|
self.node_class = node_class
|
|
self.node_kwargs = node_kwargs or {}
|
|
self.dependencies = dependencies or []
|
|
self.name = name or node_class.__name__
|
|
|
|
class ParallelPipeline:
|
|
"""Pipeline for executing transform nodes in parallel where possible.
|
|
|
|
The pipeline analyzes dependencies between nodes and executes
|
|
independent nodes concurrently using multiprocessing or threading.
|
|
"""
|
|
|
|
def __init__(self,
|
|
max_workers: Optional[int] = None,
|
|
use_processes: bool = False):
|
|
"""Initialize the parallel pipeline.
|
|
|
|
Args:
|
|
max_workers: Maximum number of parallel workers (defaults to CPU count)
|
|
use_processes: If True, use ProcessPoolExecutor; if False, use ThreadPoolExecutor
|
|
"""
|
|
self.max_workers = max_workers or mp.cpu_count()
|
|
self.use_processes = use_processes
|
|
self.nodes: Dict[str, NodeConfig] = {}
|
|
logger.info(f"Initialized ParallelPipeline with {self.max_workers} workers "
|
|
f"({'processes' if use_processes else 'threads'})")
|
|
|
|
def add_node(self, config: NodeConfig):
|
|
"""Add a node to the pipeline.
|
|
|
|
Args:
|
|
config: NodeConfig with node details and dependencies
|
|
"""
|
|
self.nodes[config.name] = config
|
|
logger.info(f"Added node '{config.name}' with dependencies: {config.dependencies}")
|
|
|
|
def _get_execution_stages(self) -> List[List[str]]:
|
|
"""Determine execution stages based on dependencies.
|
|
|
|
Returns:
|
|
List of stages, where each stage contains node names that can run in parallel
|
|
"""
|
|
stages = []
|
|
completed = set()
|
|
remaining = set(self.nodes.keys())
|
|
|
|
while remaining:
|
|
# Find nodes whose dependencies are all completed
|
|
ready = []
|
|
for node_name in remaining:
|
|
config = self.nodes[node_name]
|
|
if all(dep in completed for dep in config.dependencies):
|
|
ready.append(node_name)
|
|
|
|
if not ready:
|
|
# Circular dependency or missing dependency
|
|
raise ValueError(f"Cannot resolve dependencies. Remaining nodes: {remaining}")
|
|
|
|
stages.append(ready)
|
|
completed.update(ready)
|
|
remaining -= set(ready)
|
|
|
|
return stages
|
|
|
|
def _execute_node(self,
|
|
node_name: str,
|
|
db_path: str,
|
|
context: TransformContext) -> tuple:
|
|
"""Execute a single node.
|
|
|
|
Args:
|
|
node_name: Name of the node to execute
|
|
db_path: Path to the SQLite database
|
|
context: TransformContext for the node
|
|
|
|
Returns:
|
|
Tuple of (node_name, result_context, error)
|
|
"""
|
|
try:
|
|
# Create fresh database connection (not shared across processes/threads)
|
|
con = sqlite3.connect(db_path)
|
|
|
|
config = self.nodes[node_name]
|
|
node = config.node_class(**config.node_kwargs)
|
|
|
|
logger.info(f"Executing node: {node_name}")
|
|
result_context = node.run(con, context)
|
|
|
|
con.close()
|
|
logger.info(f"Node '{node_name}' completed successfully")
|
|
|
|
return node_name, result_context, None
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error executing node '{node_name}': {e}", exc_info=True)
|
|
return node_name, None, str(e)
|
|
|
|
def run(self,
|
|
db_path: str,
|
|
initial_context: TransformContext,
|
|
fail_fast: bool = False) -> Dict[str, TransformContext]:
|
|
"""Execute the pipeline.
|
|
|
|
Args:
|
|
db_path: Path to the SQLite database
|
|
initial_context: Initial TransformContext for the pipeline
|
|
fail_fast: If True, stop execution on first error
|
|
|
|
Returns:
|
|
Dict mapping node names to their output TransformContext
|
|
"""
|
|
logger.info("Starting parallel pipeline execution")
|
|
|
|
stages = self._get_execution_stages()
|
|
logger.info(f"Pipeline has {len(stages)} execution stage(s)")
|
|
|
|
results = {}
|
|
contexts = {None: initial_context} # Track contexts from each node
|
|
errors = []
|
|
|
|
ExecutorClass = ProcessPoolExecutor if self.use_processes else ThreadPoolExecutor
|
|
|
|
for stage_num, stage_nodes in enumerate(stages, 1):
|
|
logger.info(f"Stage {stage_num}/{len(stages)}: Executing {len(stage_nodes)} node(s) in parallel: {stage_nodes}")
|
|
|
|
# For nodes in this stage, use the context from their dependencies
|
|
# If multiple dependencies, we'll use the most recent one (or could merge)
|
|
stage_futures = {}
|
|
|
|
with ExecutorClass(max_workers=min(self.max_workers, len(stage_nodes))) as executor:
|
|
for node_name in stage_nodes:
|
|
config = self.nodes[node_name]
|
|
|
|
# Get context from dependencies (use the last dependency's output)
|
|
if config.dependencies:
|
|
context = results.get(config.dependencies[-1], initial_context)
|
|
else:
|
|
context = initial_context
|
|
|
|
future = executor.submit(self._execute_node, node_name, db_path, context)
|
|
stage_futures[future] = node_name
|
|
|
|
# Wait for all nodes in this stage to complete
|
|
for future in as_completed(stage_futures):
|
|
node_name = stage_futures[future]
|
|
name, result_context, error = future.result()
|
|
|
|
if error:
|
|
errors.append((name, error))
|
|
if fail_fast:
|
|
logger.error(f"Pipeline failed at node '{name}': {error}")
|
|
raise RuntimeError(f"Node '{name}' failed: {error}")
|
|
else:
|
|
results[name] = result_context
|
|
|
|
if errors:
|
|
logger.warning(f"Pipeline completed with {len(errors)} error(s)")
|
|
for name, error in errors:
|
|
logger.error(f" - {name}: {error}")
|
|
else:
|
|
logger.info("Pipeline completed successfully")
|
|
|
|
return results
|
|
|
|
|
|
def create_default_pipeline(device: str = "cpu",
|
|
max_workers: Optional[int] = None) -> ParallelPipeline:
|
|
"""Create a pipeline with default transform nodes.
|
|
|
|
Args:
|
|
device: Device to use for compute-intensive nodes ('cpu', 'cuda', 'mps')
|
|
max_workers: Maximum number of parallel workers
|
|
|
|
Returns:
|
|
Configured ParallelPipeline
|
|
"""
|
|
from author_node import NerAuthorNode, FuzzyAuthorNode
|
|
|
|
pipeline = ParallelPipeline(max_workers=max_workers, use_processes=False)
|
|
|
|
# Add AuthorNode (no dependencies)
|
|
pipeline.add_node(NodeConfig(
|
|
node_class=NerAuthorNode,
|
|
node_kwargs={
|
|
'device': device,
|
|
'model_path': os.environ.get('GLINER_MODEL_PATH')
|
|
},
|
|
dependencies=[],
|
|
name='AuthorNode'
|
|
))
|
|
|
|
pipeline.add_node(NodeConfig(
|
|
node_class=FuzzyAuthorNode,
|
|
node_kwargs={
|
|
'max_l_dist': 1
|
|
},
|
|
dependencies=['AuthorNode'],
|
|
name='FuzzyAuthorNode'
|
|
))
|
|
|
|
# TODO: Create Node to compute Text Embeddings and UMAP.
|
|
# TODO: Create Node to pre-compute data based on visuals to reduce load time.
|
|
|
|
# TODO: Add more nodes here as they are implemented
|
|
# Example:
|
|
# pipeline.add_node(NodeConfig(
|
|
# node_class=EmbeddingNode,
|
|
# node_kwargs={'device': device},
|
|
# dependencies=[], # Runs after AuthorNode
|
|
# name='EmbeddingNode'
|
|
# ))
|
|
|
|
# pipeline.add_node(NodeConfig(
|
|
# node_class=UMAPNode,
|
|
# node_kwargs={'device': device},
|
|
# dependencies=['EmbeddingNode'], # Runs after EmbeddingNode
|
|
# name='UMAPNode'
|
|
# ))
|
|
|
|
return pipeline
|