From 72765532d39613b65cfec72441b381b8329c594b Mon Sep 17 00:00:00 2001 From: quorploop <> Date: Tue, 23 Dec 2025 17:53:37 +0100 Subject: [PATCH] Adds TransformNode to FuzzyFind Author Names --- docker-compose.yml | 16 ++ transform/Dockerfile | 25 ++- transform/README.md | 7 +- transform/author_node.py | 205 +++++++++++++++--- transform/ensure_gliner_model.sh | 16 ++ transform/entrypoint.sh | 8 + transform/example_node.py | 170 +++++++++++++++ transform/main.py | 35 ++- transform/pipeline.py | 258 +++++++++++++++++++++++ transform/requirements.txt | 1 + transform/{base.py => transform_node.py} | 13 +- 11 files changed, 696 insertions(+), 58 deletions(-) create mode 100644 transform/ensure_gliner_model.sh create mode 100644 transform/entrypoint.sh create mode 100644 transform/example_node.py create mode 100644 transform/pipeline.py rename transform/{base.py => transform_node.py} (70%) diff --git a/docker-compose.yml b/docker-compose.yml index 5c5c4e7..4ab3b8c 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -21,7 +21,23 @@ services: - transform/.env volumes: - knack_data:/data + - models:/models + restart: unless-stopped + + sqlitebrowser: + image: lscr.io/linuxserver/sqlitebrowser:latest + container_name: sqlitebrowser + environment: + - PUID=1000 + - PGID=1000 + - TZ=Etc/UTC + volumes: + - knack_data:/data + ports: + - "3000:3000" # noVNC web UI + - "3001:3001" # VNC server restart: unless-stopped volumes: knack_data: + models: diff --git a/transform/Dockerfile b/transform/Dockerfile index 4c72480..682af4f 100644 --- a/transform/Dockerfile +++ b/transform/Dockerfile @@ -1,7 +1,6 @@ FROM python:3.12-slim -RUN mkdir /app -RUN mkdir /data +RUN mkdir -p /app /data /models # Install build dependencies RUN apt-get update && apt-get install -y --no-install-recommends \ @@ -11,9 +10,12 @@ RUN apt-get update && apt-get install -y --no-install-recommends \ libopenblas-dev \ liblapack-dev \ pkg-config \ + curl \ + jq \ && rm -rf /var/lib/apt/lists/* -#COPY /data/knack.sqlite /data +ENV GLINER_MODEL_ID=urchade/gliner_multi-v2.1 +ENV GLINER_MODEL_PATH=/models/gliner_multi-v2.1 WORKDIR /app COPY requirements.txt . @@ -24,18 +26,21 @@ COPY .env . RUN apt update -y RUN apt install -y cron locales -COPY *.py . +# Ensure GLiNER helper scripts are available +COPY ensure_gliner_model.sh /usr/local/bin/ensure_gliner_model.sh +COPY entrypoint.sh /usr/local/bin/entrypoint.sh +RUN chmod +x /usr/local/bin/ensure_gliner_model.sh /usr/local/bin/entrypoint.sh -ENV PYTHONUNBUFFERED=1 -ENV LANG=de_DE.UTF-8 -ENV LC_ALL=de_DE.UTF-8 +COPY *.py . # Create cron job that runs every weekend (Sunday at 3 AM) 0 3 * * 0 # Testing every 30 Minutes */30 * * * * -RUN echo "0 3 * * 0 cd /app && /usr/local/bin/python main.py >> /proc/1/fd/1 2>&1" > /etc/cron.d/knack-transform +RUN echo "*/30 * * * * cd /app && /usr/local/bin/python main.py >> /proc/1/fd/1 2>&1" > /etc/cron.d/knack-transform RUN chmod 0644 /etc/cron.d/knack-transform RUN crontab /etc/cron.d/knack-transform -# Start cron in foreground -CMD ["cron", "-f"] +# Persist models between container runs +VOLUME /models + +CMD ["/usr/local/bin/entrypoint.sh"] #CMD ["python", "main.py"] diff --git a/transform/README.md b/transform/README.md index 44ddeb1..9e3665a 100644 --- a/transform/README.md +++ b/transform/README.md @@ -6,10 +6,15 @@ Data transformation pipeline for the Knack scraper project. This folder contains the transformation logic that processes data from the SQLite database. It runs on a scheduled basis (every weekend) via cron. +The pipeline supports **parallel execution** of independent transform nodes, allowing you to leverage multi-core processors for faster data transformation. + ## Structure - `base.py` - Abstract base class for transform nodes -- `main.py` - Main entry point and pipeline orchestration +- `pipeline.py` - Parallel pipeline orchestration system +- `main.py` - Main entry point and pipeline execution +- `author_node.py` - NER-based author classification node +- `example_node.py` - Template for creating new nodes - `Dockerfile` - Docker image configuration with cron setup - `requirements.txt` - Python dependencies diff --git a/transform/author_node.py b/transform/author_node.py index 23d3365..719a191 100644 --- a/transform/author_node.py +++ b/transform/author_node.py @@ -1,10 +1,13 @@ """Author classification transform node using NER.""" -from base import TransformNode, TransformContext +import os import sqlite3 import pandas as pd import logging +import fuzzysearch from concurrent.futures import ThreadPoolExecutor -from datetime import datetime + +from pipeline import TransformContext +from transform_node import TransformNode try: from gliner import GLiNER @@ -17,7 +20,7 @@ except ImportError: logger = logging.getLogger("knack-transform") -class AuthorNode(TransformNode): +class NerAuthorNode(TransformNode): """Transform node that extracts and classifies authors using NER. Creates two tables: @@ -25,7 +28,8 @@ class AuthorNode(TransformNode): - post_authors: maps posts to their authors """ - def __init__(self, model_name: str = "urchade/gliner_medium-v2.1", + def __init__(self, model_name: str = "urchade/gliner_multi-v2.1", + model_path: str = None, threshold: float = 0.5, max_workers: int = 64, device: str = "cpu"): @@ -33,11 +37,13 @@ class AuthorNode(TransformNode): Args: model_name: GLiNER model to use + model_path: Optional local path to a downloaded GLiNER model threshold: Confidence threshold for entity predictions max_workers: Number of parallel workers for prediction device: Device to run model on ('cpu', 'cuda', 'mps') """ self.model_name = model_name + self.model_path = model_path or os.environ.get('GLINER_MODEL_PATH') self.threshold = threshold self.max_workers = max_workers self.device = device @@ -49,21 +55,31 @@ class AuthorNode(TransformNode): if not GLINER_AVAILABLE: raise ImportError("GLiNER is required for AuthorNode. Install with: pip install gliner") - logger.info(f"Loading GLiNER model: {self.model_name}") + model_source = None + if self.model_path: + if os.path.exists(self.model_path): + model_source = self.model_path + logger.info(f"Loading GLiNER model from local path: {self.model_path}") + else: + logger.warning(f"GLINER_MODEL_PATH '{self.model_path}' not found; falling back to hub model {self.model_name}") + + if model_source is None: + model_source = self.model_name + logger.info(f"Loading GLiNER model from hub: {self.model_name}") if self.device == "cuda" and torch.cuda.is_available(): self.model = GLiNER.from_pretrained( - self.model_name, + model_source, max_length=255 ).to('cuda', dtype=torch.float16) elif self.device == "mps" and torch.backends.mps.is_available(): self.model = GLiNER.from_pretrained( - self.model_name, + model_source, max_length=255 ).to('mps', dtype=torch.float16) else: self.model = GLiNER.from_pretrained( - self.model_name, + model_source, max_length=255 ) @@ -208,13 +224,6 @@ class AuthorNode(TransformNode): logger.info(f"Creating {len(mappings_df)} post-author mappings") mappings_df.to_sql('post_authors', con, if_exists='append', index=False) - - # Mark posts as cleaned - processed_post_ids = mappings_df['post_id'].unique().tolist() - if processed_post_ids: - placeholders = ','.join('?' * len(processed_post_ids)) - con.execute(f"UPDATE posts SET is_cleaned = 1 WHERE id IN ({placeholders})", processed_post_ids) - logger.info(f"Marked {len(processed_post_ids)} posts as cleaned") con.commit() logger.info("Authors and mappings stored successfully") @@ -247,17 +256,165 @@ class AuthorNode(TransformNode): # Store results self._store_authors(con, results) - # Mark posts without author entities as cleaned too (no authors found) - processed_ids = set([r['id'] for r in results]) if results else set() - unprocessed_ids = [pid for pid in posts_df['id'].tolist() if pid not in processed_ids] - if unprocessed_ids: - placeholders = ','.join('?' * len(unprocessed_ids)) - con.execute(f"UPDATE posts SET is_cleaned = 1 WHERE id IN ({placeholders})", unprocessed_ids) - con.commit() - logger.info(f"Marked {len(unprocessed_ids)} posts without author entities as cleaned") - # Return context with results results_df = pd.DataFrame(results) if results else pd.DataFrame() logger.info("AuthorNode transformation complete") return TransformContext(results_df) + + +class FuzzyAuthorNode(TransformNode): + """FuzzyAuthorNode + + This Node takes in data and rules of authornames that have been classified already + and uses those 'rule' to find more similar fields. + """ + + def __init__(self, + max_l_dist: int = 1,): + """Initialize FuzzyAuthorNode. + + Args: + max_l_dist: The number of 'errors' that are allowed by the fuzzy search algorithm + """ + self.max_l_dist = max_l_dist + logger.info(f"Initialized FuzzyAuthorNode with max_l_dist={max_l_dist}") + + def _process_data(self, con: sqlite3.Connection, df: pd.DataFrame) -> pd.DataFrame: + """Process the input dataframe. + + This is where your main transformation logic goes. + + Args: + con: Database connection + df: Input dataframe from context + + Returns: + Processed dataframe + """ + logger.info(f"Processing {len(df)} rows") + + # Retrieve all known authors from the authors table as 'rules' + authors_df = pd.read_sql("SELECT id, name FROM authors", con) + + if authors_df.empty: + logger.warning("No authors found in database for fuzzy matching") + return pd.DataFrame(columns=['post_id', 'author_id']) + + # Get existing post-author mappings to avoid duplicates + existing_mappings = pd.read_sql( + "SELECT post_id, author_id FROM post_authors", con + ) + existing_post_ids = set(existing_mappings['post_id'].unique()) + + logger.info(f"Found {len(authors_df)} known authors for fuzzy matching") + logger.info(f"Found {len(existing_post_ids)} posts with existing author mappings") + + # Filter to posts without author mappings and with non-null author field + if 'author' not in df.columns or 'id' not in df.columns: + logger.warning("Missing 'author' or 'id' column in input dataframe") + return pd.DataFrame(columns=['post_id', 'author_id']) + + posts_to_process = df[ + (df['id'].notna()) & + (df['author'].notna()) & + (~df['id'].isin(existing_post_ids)) + ] + + logger.info(f"Processing {len(posts_to_process)} posts for fuzzy matching") + + # Perform fuzzy matching + mappings = [] + for _, post_row in posts_to_process.iterrows(): + post_id = post_row['id'] + post_author = str(post_row['author']) + + # Try to find matches against all known author names + for _, author_row in authors_df.iterrows(): + author_id = author_row['id'] + author_name = str(author_row['name']) + + # Use fuzzysearch to find matches with allowed errors + matches = fuzzysearch.find_near_matches( + author_name, + post_author, + max_l_dist=self.max_l_dist + ) + + if matches: + logger.debug(f"Found fuzzy match: '{author_name}' in '{post_author}' for post {post_id}") + mappings.append({ + 'post_id': post_id, + 'author_id': author_id + }) + # Only take the first match per post to avoid multiple mappings + break + + # Create result dataframe + result_df = pd.DataFrame(mappings, columns=['post_id', 'author_id']) if mappings else pd.DataFrame(columns=['post_id', 'author_id']) + + logger.info(f"Processing complete. Found {len(result_df)} fuzzy matches") + return result_df + + def _store_results(self, con: sqlite3.Connection, df: pd.DataFrame): + """Store results back to the database. + + Uses INSERT OR IGNORE to avoid inserting duplicates. + + Args: + con: Database connection + df: Processed dataframe to store + """ + if df.empty: + logger.info("No results to store") + return + + logger.info(f"Storing {len(df)} results") + + # Use INSERT OR IGNORE to handle duplicates (respects PRIMARY KEY constraint) + cursor = con.cursor() + inserted_count = 0 + + for _, row in df.iterrows(): + cursor.execute( + "INSERT OR IGNORE INTO post_authors (post_id, author_id) VALUES (?, ?)", + (int(row['post_id']), int(row['author_id'])) + ) + if cursor.rowcount > 0: + inserted_count += 1 + + con.commit() + logger.info(f"Results stored successfully. Inserted {inserted_count} new mappings, skipped {len(df) - inserted_count} duplicates") + + def run(self, con: sqlite3.Connection, context: TransformContext) -> TransformContext: + """Execute the transformation. + + This is the main entry point called by the pipeline. + + Args: + con: SQLite database connection + context: TransformContext containing input dataframe + + Returns: + TransformContext with processed dataframe + """ + logger.info("Starting FuzzyAuthorNode transformation") + + # Get input dataframe from context + input_df = context.get_dataframe() + + # Validate input + if input_df.empty: + logger.warning("Empty dataframe provided to FuzzyAuthorNode") + return context + + # Process the data + result_df = self._process_data(con, input_df) + + # Store results + self._store_results(con, result_df) + + logger.info("FuzzyAuthorNode transformation complete") + + # Return new context with results + return TransformContext(result_df) diff --git a/transform/ensure_gliner_model.sh b/transform/ensure_gliner_model.sh new file mode 100644 index 0000000..4df8215 --- /dev/null +++ b/transform/ensure_gliner_model.sh @@ -0,0 +1,16 @@ +#!/usr/bin/env bash +set -euo pipefail + +if [ -d "$GLINER_MODEL_PATH" ] && find "$GLINER_MODEL_PATH" -type f | grep -q .; then + echo "GLiNER model already present at $GLINER_MODEL_PATH" + exit 0 +fi + +echo "Downloading GLiNER model to $GLINER_MODEL_PATH" +mkdir -p "$GLINER_MODEL_PATH" +curl -sL "https://huggingface.co/api/models/${GLINER_MODEL_ID}" | jq -r '.siblings[].rfilename' | while read -r file; do + target="${GLINER_MODEL_PATH}/${file}" + mkdir -p "$(dirname "$target")" + echo "Downloading ${file}" + curl -sL "https://huggingface.co/${GLINER_MODEL_ID}/resolve/main/${file}" -o "$target" +done diff --git a/transform/entrypoint.sh b/transform/entrypoint.sh new file mode 100644 index 0000000..8beab84 --- /dev/null +++ b/transform/entrypoint.sh @@ -0,0 +1,8 @@ +#!/usr/bin/env bash +set -euo pipefail + +# Run model download with output to stdout/stderr +/usr/local/bin/ensure_gliner_model.sh 2>&1 + +# Start cron in foreground with logging +exec cron -f -L 2 diff --git a/transform/example_node.py b/transform/example_node.py new file mode 100644 index 0000000..69900d1 --- /dev/null +++ b/transform/example_node.py @@ -0,0 +1,170 @@ +"""Example template node for the transform pipeline. + +This is a template showing how to create new transform nodes. +Copy this file and modify it for your specific transformation needs. +""" +from pipeline import TransformContext +from transform_node import TransformNode +import sqlite3 +import pandas as pd +import logging + +logger = logging.getLogger("knack-transform") + + +class ExampleNode(TransformNode): + """Example transform node template. + + This node demonstrates the basic structure for creating + new transformation nodes in the pipeline. + """ + + def __init__(self, + param1: str = "default_value", + param2: int = 42, + device: str = "cpu"): + """Initialize the ExampleNode. + + Args: + param1: Example string parameter + param2: Example integer parameter + device: Device to use for computations ('cpu', 'cuda', 'mps') + """ + self.param1 = param1 + self.param2 = param2 + self.device = device + logger.info(f"Initialized ExampleNode with param1={param1}, param2={param2}") + + def _create_tables(self, con: sqlite3.Connection): + """Create any necessary tables in the database. + + This is optional - only needed if your node creates new tables. + """ + logger.info("Creating example tables") + + con.execute(""" + CREATE TABLE IF NOT EXISTS example_results ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + post_id INTEGER, + result_value TEXT, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + FOREIGN KEY (post_id) REFERENCES posts(id) + ) + """) + + con.commit() + + def _process_data(self, df: pd.DataFrame) -> pd.DataFrame: + """Process the input dataframe. + + This is where your main transformation logic goes. + + Args: + df: Input dataframe from context + + Returns: + Processed dataframe + """ + logger.info(f"Processing {len(df)} rows") + + # Example: Add a new column based on existing data + result_df = df.copy() + result_df['processed'] = True + result_df['example_value'] = result_df['id'].apply(lambda x: f"{self.param1}_{x}") + + logger.info("Processing complete") + return result_df + + def _store_results(self, con: sqlite3.Connection, df: pd.DataFrame): + """Store results back to the database. + + This is optional - only needed if you want to persist results. + + Args: + con: Database connection + df: Processed dataframe to store + """ + if df.empty: + logger.info("No results to store") + return + + logger.info(f"Storing {len(df)} results") + + # Example: Store to database + # df[['post_id', 'result_value']].to_sql( + # 'example_results', + # con, + # if_exists='append', + # index=False + # ) + + con.commit() + logger.info("Results stored successfully") + + def run(self, con: sqlite3.Connection, context: TransformContext) -> TransformContext: + """Execute the transformation. + + This is the main entry point called by the pipeline. + + Args: + con: SQLite database connection + context: TransformContext containing input dataframe + + Returns: + TransformContext with processed dataframe + """ + logger.info("Starting ExampleNode transformation") + + # Get input dataframe from context + input_df = context.get_dataframe() + + # Validate input + if input_df.empty: + logger.warning("Empty dataframe provided to ExampleNode") + return context + + # Create any necessary tables + self._create_tables(con) + + # Process the data + result_df = self._process_data(input_df) + + # Store results (optional) + self._store_results(con, result_df) + + logger.info("ExampleNode transformation complete") + + # Return new context with results + return TransformContext(result_df) + + +# Example usage: +if __name__ == "__main__": + # This allows you to test your node independently + import os + os.chdir('/Users/linussilberstein/Documents/Knack-Scraper/transform') + + from pipeline import TransformContext + import sqlite3 + + # Create test data + test_df = pd.DataFrame({ + 'id': [1, 2, 3], + 'author': ['Test Author 1', 'Test Author 2', 'Test Author 3'] + }) + + # Create test database connection + test_con = sqlite3.connect(':memory:') + + # Create and run node + node = ExampleNode(param1="test", param2=100) + context = TransformContext(test_df) + result_context = node.run(test_con, context) + + # Check results + result_df = result_context.get_dataframe() + print("\nResult DataFrame:") + print(result_df) + + test_con.close() + print("\n✓ ExampleNode test completed successfully!") diff --git a/transform/main.py b/transform/main.py index 29b9a38..d07d905 100644 --- a/transform/main.py +++ b/transform/main.py @@ -50,15 +50,14 @@ def main(): logger.info("Transform pipeline skipped - no data available") return - # Import transform nodes - from author_node import AuthorNode - from base import TransformContext + # Import transform components + from pipeline import create_default_pipeline, TransformContext import pandas as pd # Load posts data logger.info("Loading posts from database") sql = "SELECT id, author FROM posts WHERE author IS NOT NULL AND (is_cleaned IS NULL OR is_cleaned = 0) LIMIT ?" - MAX_CLEANED_POSTS = os.environ.get("MAX_CLEANED_POSTS", 500) + MAX_CLEANED_POSTS = os.environ.get("MAX_CLEANED_POSTS", 100) df = pd.read_sql(sql, con, params=[MAX_CLEANED_POSTS]) logger.info(f"Loaded {len(df)} uncleaned posts with authors") @@ -66,15 +65,29 @@ def main(): logger.info("No uncleaned posts found. Transform pipeline skipped.") return - # Create context and run author classification + # Create initial context context = TransformContext(df) - author_transform = AuthorNode(device=os.environ.get('COMPUTE_DEVICE', 'cpu')) # Change to "cuda" or "mps" if available - result_context = author_transform.run(con, context) - - # TODO: Create Node to compute Text Embeddings and UMAP. - # TODO: Create Node to pre-compute data based on visuals to reduce load time. - logger.info("Transform pipeline completed successfully") + # Create and run parallel pipeline + device = os.environ.get('COMPUTE_DEVICE', 'cpu') + max_workers = int(os.environ.get('MAX_WORKERS', 4)) + + pipeline = create_default_pipeline(device=device, max_workers=max_workers) + results = pipeline.run( + db_path=os.environ.get('DB_PATH', '/data/knack.sqlite'), + initial_context=context, + fail_fast=False # Continue even if some nodes fail + ) + + logger.info(f"Pipeline completed. Processed {len(results)} node(s)") + + # Mark all processed posts as cleaned + post_ids = df['id'].tolist() + if post_ids: + placeholders = ','.join('?' * len(post_ids)) + con.execute(f"UPDATE posts SET is_cleaned = 1 WHERE id IN ({placeholders})", post_ids) + con.commit() + logger.info(f"Marked {len(post_ids)} posts as cleaned") except Exception as e: logger.error(f"Error in transform pipeline: {e}", exc_info=True) diff --git a/transform/pipeline.py b/transform/pipeline.py new file mode 100644 index 0000000..1a97f1f --- /dev/null +++ b/transform/pipeline.py @@ -0,0 +1,258 @@ +"""Parallel pipeline orchestration for transform nodes.""" +import logging +import os +import sqlite3 +from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor, as_completed +from typing import List, Dict, Optional + +import pandas as pd +import multiprocessing as mp + +logger = logging.getLogger("knack-transform") + +class TransformContext: + """Context object containing the dataframe for transformation.""" + + def __init__(self, df: pd.DataFrame): + self.df = df + + def get_dataframe(self) -> pd.DataFrame: + """Get the pandas dataframe from the context.""" + return self.df + +class NodeConfig: + """Configuration for a transform node.""" + + def __init__(self, + node_class: type, + node_kwargs: Dict = None, + dependencies: List[str] = None, + name: str = None): + """Initialize node configuration. + + Args: + node_class: The TransformNode class to instantiate + node_kwargs: Keyword arguments to pass to node constructor + dependencies: List of node names that must complete before this one + name: Optional name for the node (defaults to class name) + """ + self.node_class = node_class + self.node_kwargs = node_kwargs or {} + self.dependencies = dependencies or [] + self.name = name or node_class.__name__ + +class ParallelPipeline: + """Pipeline for executing transform nodes in parallel where possible. + + The pipeline analyzes dependencies between nodes and executes + independent nodes concurrently using multiprocessing or threading. + """ + + def __init__(self, + max_workers: Optional[int] = None, + use_processes: bool = False): + """Initialize the parallel pipeline. + + Args: + max_workers: Maximum number of parallel workers (defaults to CPU count) + use_processes: If True, use ProcessPoolExecutor; if False, use ThreadPoolExecutor + """ + self.max_workers = max_workers or mp.cpu_count() + self.use_processes = use_processes + self.nodes: Dict[str, NodeConfig] = {} + logger.info(f"Initialized ParallelPipeline with {self.max_workers} workers " + f"({'processes' if use_processes else 'threads'})") + + def add_node(self, config: NodeConfig): + """Add a node to the pipeline. + + Args: + config: NodeConfig with node details and dependencies + """ + self.nodes[config.name] = config + logger.info(f"Added node '{config.name}' with dependencies: {config.dependencies}") + + def _get_execution_stages(self) -> List[List[str]]: + """Determine execution stages based on dependencies. + + Returns: + List of stages, where each stage contains node names that can run in parallel + """ + stages = [] + completed = set() + remaining = set(self.nodes.keys()) + + while remaining: + # Find nodes whose dependencies are all completed + ready = [] + for node_name in remaining: + config = self.nodes[node_name] + if all(dep in completed for dep in config.dependencies): + ready.append(node_name) + + if not ready: + # Circular dependency or missing dependency + raise ValueError(f"Cannot resolve dependencies. Remaining nodes: {remaining}") + + stages.append(ready) + completed.update(ready) + remaining -= set(ready) + + return stages + + def _execute_node(self, + node_name: str, + db_path: str, + context: TransformContext) -> tuple: + """Execute a single node. + + Args: + node_name: Name of the node to execute + db_path: Path to the SQLite database + context: TransformContext for the node + + Returns: + Tuple of (node_name, result_context, error) + """ + try: + # Create fresh database connection (not shared across processes/threads) + con = sqlite3.connect(db_path) + + config = self.nodes[node_name] + node = config.node_class(**config.node_kwargs) + + logger.info(f"Executing node: {node_name}") + result_context = node.run(con, context) + + con.close() + logger.info(f"Node '{node_name}' completed successfully") + + return node_name, result_context, None + + except Exception as e: + logger.error(f"Error executing node '{node_name}': {e}", exc_info=True) + return node_name, None, str(e) + + def run(self, + db_path: str, + initial_context: TransformContext, + fail_fast: bool = False) -> Dict[str, TransformContext]: + """Execute the pipeline. + + Args: + db_path: Path to the SQLite database + initial_context: Initial TransformContext for the pipeline + fail_fast: If True, stop execution on first error + + Returns: + Dict mapping node names to their output TransformContext + """ + logger.info("Starting parallel pipeline execution") + + stages = self._get_execution_stages() + logger.info(f"Pipeline has {len(stages)} execution stage(s)") + + results = {} + contexts = {None: initial_context} # Track contexts from each node + errors = [] + + ExecutorClass = ProcessPoolExecutor if self.use_processes else ThreadPoolExecutor + + for stage_num, stage_nodes in enumerate(stages, 1): + logger.info(f"Stage {stage_num}/{len(stages)}: Executing {len(stage_nodes)} node(s) in parallel: {stage_nodes}") + + # For nodes in this stage, use the context from their dependencies + # If multiple dependencies, we'll use the most recent one (or could merge) + stage_futures = {} + + with ExecutorClass(max_workers=min(self.max_workers, len(stage_nodes))) as executor: + for node_name in stage_nodes: + config = self.nodes[node_name] + + # Get context from dependencies (use the last dependency's output) + if config.dependencies: + context = results.get(config.dependencies[-1], initial_context) + else: + context = initial_context + + future = executor.submit(self._execute_node, node_name, db_path, context) + stage_futures[future] = node_name + + # Wait for all nodes in this stage to complete + for future in as_completed(stage_futures): + node_name = stage_futures[future] + name, result_context, error = future.result() + + if error: + errors.append((name, error)) + if fail_fast: + logger.error(f"Pipeline failed at node '{name}': {error}") + raise RuntimeError(f"Node '{name}' failed: {error}") + else: + results[name] = result_context + + if errors: + logger.warning(f"Pipeline completed with {len(errors)} error(s)") + for name, error in errors: + logger.error(f" - {name}: {error}") + else: + logger.info("Pipeline completed successfully") + + return results + + +def create_default_pipeline(device: str = "cpu", + max_workers: Optional[int] = None) -> ParallelPipeline: + """Create a pipeline with default transform nodes. + + Args: + device: Device to use for compute-intensive nodes ('cpu', 'cuda', 'mps') + max_workers: Maximum number of parallel workers + + Returns: + Configured ParallelPipeline + """ + from author_node import NerAuthorNode, FuzzyAuthorNode + + pipeline = ParallelPipeline(max_workers=max_workers, use_processes=False) + + # Add AuthorNode (no dependencies) + pipeline.add_node(NodeConfig( + node_class=NerAuthorNode, + node_kwargs={ + 'device': device, + 'model_path': os.environ.get('GLINER_MODEL_PATH') + }, + dependencies=[], + name='AuthorNode' + )) + + pipeline.add_node(NodeConfig( + node_class=FuzzyAuthorNode, + node_kwargs={ + 'max_l_dist': 1 + }, + dependencies=['AuthorNode'], + name='FuzzyAuthorNode' + )) + + # TODO: Create Node to compute Text Embeddings and UMAP. + # TODO: Create Node to pre-compute data based on visuals to reduce load time. + + # TODO: Add more nodes here as they are implemented + # Example: + # pipeline.add_node(NodeConfig( + # node_class=EmbeddingNode, + # node_kwargs={'device': device}, + # dependencies=[], # Runs after AuthorNode + # name='EmbeddingNode' + # )) + + # pipeline.add_node(NodeConfig( + # node_class=UMAPNode, + # node_kwargs={'device': device}, + # dependencies=['EmbeddingNode'], # Runs after EmbeddingNode + # name='UMAPNode' + # )) + + return pipeline diff --git a/transform/requirements.txt b/transform/requirements.txt index c95bd6d..e210d05 100644 --- a/transform/requirements.txt +++ b/transform/requirements.txt @@ -2,3 +2,4 @@ pandas python-dotenv gliner torch +fuzzysearch \ No newline at end of file diff --git a/transform/base.py b/transform/transform_node.py similarity index 70% rename from transform/base.py rename to transform/transform_node.py index 59a4f31..54e6bed 100644 --- a/transform/base.py +++ b/transform/transform_node.py @@ -1,19 +1,8 @@ """Base transform node for data pipeline.""" from abc import ABC, abstractmethod import sqlite3 -import pandas as pd - - -class TransformContext: - """Context object containing the dataframe for transformation.""" - - def __init__(self, df: pd.DataFrame): - self.df = df - - def get_dataframe(self) -> pd.DataFrame: - """Get the pandas dataframe from the context.""" - return self.df +from pipeline import TransformContext class TransformNode(ABC): """Abstract base class for transformation nodes.