Adds TransformNode to FuzzyFind Author Names

This commit is contained in:
quorploop 2025-12-23 17:53:37 +01:00
parent 64df8fb328
commit 72765532d3
11 changed files with 696 additions and 58 deletions

View file

@ -21,7 +21,23 @@ services:
- transform/.env
volumes:
- knack_data:/data
- models:/models
restart: unless-stopped
sqlitebrowser:
image: lscr.io/linuxserver/sqlitebrowser:latest
container_name: sqlitebrowser
environment:
- PUID=1000
- PGID=1000
- TZ=Etc/UTC
volumes:
- knack_data:/data
ports:
- "3000:3000" # noVNC web UI
- "3001:3001" # VNC server
restart: unless-stopped
volumes:
knack_data:
models:

View file

@ -1,7 +1,6 @@
FROM python:3.12-slim
RUN mkdir /app
RUN mkdir /data
RUN mkdir -p /app /data /models
# Install build dependencies
RUN apt-get update && apt-get install -y --no-install-recommends \
@ -11,9 +10,12 @@ RUN apt-get update && apt-get install -y --no-install-recommends \
libopenblas-dev \
liblapack-dev \
pkg-config \
curl \
jq \
&& rm -rf /var/lib/apt/lists/*
#COPY /data/knack.sqlite /data
ENV GLINER_MODEL_ID=urchade/gliner_multi-v2.1
ENV GLINER_MODEL_PATH=/models/gliner_multi-v2.1
WORKDIR /app
COPY requirements.txt .
@ -24,18 +26,21 @@ COPY .env .
RUN apt update -y
RUN apt install -y cron locales
COPY *.py .
# Ensure GLiNER helper scripts are available
COPY ensure_gliner_model.sh /usr/local/bin/ensure_gliner_model.sh
COPY entrypoint.sh /usr/local/bin/entrypoint.sh
RUN chmod +x /usr/local/bin/ensure_gliner_model.sh /usr/local/bin/entrypoint.sh
ENV PYTHONUNBUFFERED=1
ENV LANG=de_DE.UTF-8
ENV LC_ALL=de_DE.UTF-8
COPY *.py .
# Create cron job that runs every weekend (Sunday at 3 AM) 0 3 * * 0
# Testing every 30 Minutes */30 * * * *
RUN echo "0 3 * * 0 cd /app && /usr/local/bin/python main.py >> /proc/1/fd/1 2>&1" > /etc/cron.d/knack-transform
RUN echo "*/30 * * * * cd /app && /usr/local/bin/python main.py >> /proc/1/fd/1 2>&1" > /etc/cron.d/knack-transform
RUN chmod 0644 /etc/cron.d/knack-transform
RUN crontab /etc/cron.d/knack-transform
# Start cron in foreground
CMD ["cron", "-f"]
# Persist models between container runs
VOLUME /models
CMD ["/usr/local/bin/entrypoint.sh"]
#CMD ["python", "main.py"]

View file

@ -6,10 +6,15 @@ Data transformation pipeline for the Knack scraper project.
This folder contains the transformation logic that processes data from the SQLite database. It runs on a scheduled basis (every weekend) via cron.
The pipeline supports **parallel execution** of independent transform nodes, allowing you to leverage multi-core processors for faster data transformation.
## Structure
- `base.py` - Abstract base class for transform nodes
- `main.py` - Main entry point and pipeline orchestration
- `pipeline.py` - Parallel pipeline orchestration system
- `main.py` - Main entry point and pipeline execution
- `author_node.py` - NER-based author classification node
- `example_node.py` - Template for creating new nodes
- `Dockerfile` - Docker image configuration with cron setup
- `requirements.txt` - Python dependencies

View file

@ -1,10 +1,13 @@
"""Author classification transform node using NER."""
from base import TransformNode, TransformContext
import os
import sqlite3
import pandas as pd
import logging
import fuzzysearch
from concurrent.futures import ThreadPoolExecutor
from datetime import datetime
from pipeline import TransformContext
from transform_node import TransformNode
try:
from gliner import GLiNER
@ -17,7 +20,7 @@ except ImportError:
logger = logging.getLogger("knack-transform")
class AuthorNode(TransformNode):
class NerAuthorNode(TransformNode):
"""Transform node that extracts and classifies authors using NER.
Creates two tables:
@ -25,7 +28,8 @@ class AuthorNode(TransformNode):
- post_authors: maps posts to their authors
"""
def __init__(self, model_name: str = "urchade/gliner_medium-v2.1",
def __init__(self, model_name: str = "urchade/gliner_multi-v2.1",
model_path: str = None,
threshold: float = 0.5,
max_workers: int = 64,
device: str = "cpu"):
@ -33,11 +37,13 @@ class AuthorNode(TransformNode):
Args:
model_name: GLiNER model to use
model_path: Optional local path to a downloaded GLiNER model
threshold: Confidence threshold for entity predictions
max_workers: Number of parallel workers for prediction
device: Device to run model on ('cpu', 'cuda', 'mps')
"""
self.model_name = model_name
self.model_path = model_path or os.environ.get('GLINER_MODEL_PATH')
self.threshold = threshold
self.max_workers = max_workers
self.device = device
@ -49,21 +55,31 @@ class AuthorNode(TransformNode):
if not GLINER_AVAILABLE:
raise ImportError("GLiNER is required for AuthorNode. Install with: pip install gliner")
logger.info(f"Loading GLiNER model: {self.model_name}")
model_source = None
if self.model_path:
if os.path.exists(self.model_path):
model_source = self.model_path
logger.info(f"Loading GLiNER model from local path: {self.model_path}")
else:
logger.warning(f"GLINER_MODEL_PATH '{self.model_path}' not found; falling back to hub model {self.model_name}")
if model_source is None:
model_source = self.model_name
logger.info(f"Loading GLiNER model from hub: {self.model_name}")
if self.device == "cuda" and torch.cuda.is_available():
self.model = GLiNER.from_pretrained(
self.model_name,
model_source,
max_length=255
).to('cuda', dtype=torch.float16)
elif self.device == "mps" and torch.backends.mps.is_available():
self.model = GLiNER.from_pretrained(
self.model_name,
model_source,
max_length=255
).to('mps', dtype=torch.float16)
else:
self.model = GLiNER.from_pretrained(
self.model_name,
model_source,
max_length=255
)
@ -209,13 +225,6 @@ class AuthorNode(TransformNode):
logger.info(f"Creating {len(mappings_df)} post-author mappings")
mappings_df.to_sql('post_authors', con, if_exists='append', index=False)
# Mark posts as cleaned
processed_post_ids = mappings_df['post_id'].unique().tolist()
if processed_post_ids:
placeholders = ','.join('?' * len(processed_post_ids))
con.execute(f"UPDATE posts SET is_cleaned = 1 WHERE id IN ({placeholders})", processed_post_ids)
logger.info(f"Marked {len(processed_post_ids)} posts as cleaned")
con.commit()
logger.info("Authors and mappings stored successfully")
@ -247,17 +256,165 @@ class AuthorNode(TransformNode):
# Store results
self._store_authors(con, results)
# Mark posts without author entities as cleaned too (no authors found)
processed_ids = set([r['id'] for r in results]) if results else set()
unprocessed_ids = [pid for pid in posts_df['id'].tolist() if pid not in processed_ids]
if unprocessed_ids:
placeholders = ','.join('?' * len(unprocessed_ids))
con.execute(f"UPDATE posts SET is_cleaned = 1 WHERE id IN ({placeholders})", unprocessed_ids)
con.commit()
logger.info(f"Marked {len(unprocessed_ids)} posts without author entities as cleaned")
# Return context with results
results_df = pd.DataFrame(results) if results else pd.DataFrame()
logger.info("AuthorNode transformation complete")
return TransformContext(results_df)
class FuzzyAuthorNode(TransformNode):
"""FuzzyAuthorNode
This Node takes in data and rules of authornames that have been classified already
and uses those 'rule' to find more similar fields.
"""
def __init__(self,
max_l_dist: int = 1,):
"""Initialize FuzzyAuthorNode.
Args:
max_l_dist: The number of 'errors' that are allowed by the fuzzy search algorithm
"""
self.max_l_dist = max_l_dist
logger.info(f"Initialized FuzzyAuthorNode with max_l_dist={max_l_dist}")
def _process_data(self, con: sqlite3.Connection, df: pd.DataFrame) -> pd.DataFrame:
"""Process the input dataframe.
This is where your main transformation logic goes.
Args:
con: Database connection
df: Input dataframe from context
Returns:
Processed dataframe
"""
logger.info(f"Processing {len(df)} rows")
# Retrieve all known authors from the authors table as 'rules'
authors_df = pd.read_sql("SELECT id, name FROM authors", con)
if authors_df.empty:
logger.warning("No authors found in database for fuzzy matching")
return pd.DataFrame(columns=['post_id', 'author_id'])
# Get existing post-author mappings to avoid duplicates
existing_mappings = pd.read_sql(
"SELECT post_id, author_id FROM post_authors", con
)
existing_post_ids = set(existing_mappings['post_id'].unique())
logger.info(f"Found {len(authors_df)} known authors for fuzzy matching")
logger.info(f"Found {len(existing_post_ids)} posts with existing author mappings")
# Filter to posts without author mappings and with non-null author field
if 'author' not in df.columns or 'id' not in df.columns:
logger.warning("Missing 'author' or 'id' column in input dataframe")
return pd.DataFrame(columns=['post_id', 'author_id'])
posts_to_process = df[
(df['id'].notna()) &
(df['author'].notna()) &
(~df['id'].isin(existing_post_ids))
]
logger.info(f"Processing {len(posts_to_process)} posts for fuzzy matching")
# Perform fuzzy matching
mappings = []
for _, post_row in posts_to_process.iterrows():
post_id = post_row['id']
post_author = str(post_row['author'])
# Try to find matches against all known author names
for _, author_row in authors_df.iterrows():
author_id = author_row['id']
author_name = str(author_row['name'])
# Use fuzzysearch to find matches with allowed errors
matches = fuzzysearch.find_near_matches(
author_name,
post_author,
max_l_dist=self.max_l_dist
)
if matches:
logger.debug(f"Found fuzzy match: '{author_name}' in '{post_author}' for post {post_id}")
mappings.append({
'post_id': post_id,
'author_id': author_id
})
# Only take the first match per post to avoid multiple mappings
break
# Create result dataframe
result_df = pd.DataFrame(mappings, columns=['post_id', 'author_id']) if mappings else pd.DataFrame(columns=['post_id', 'author_id'])
logger.info(f"Processing complete. Found {len(result_df)} fuzzy matches")
return result_df
def _store_results(self, con: sqlite3.Connection, df: pd.DataFrame):
"""Store results back to the database.
Uses INSERT OR IGNORE to avoid inserting duplicates.
Args:
con: Database connection
df: Processed dataframe to store
"""
if df.empty:
logger.info("No results to store")
return
logger.info(f"Storing {len(df)} results")
# Use INSERT OR IGNORE to handle duplicates (respects PRIMARY KEY constraint)
cursor = con.cursor()
inserted_count = 0
for _, row in df.iterrows():
cursor.execute(
"INSERT OR IGNORE INTO post_authors (post_id, author_id) VALUES (?, ?)",
(int(row['post_id']), int(row['author_id']))
)
if cursor.rowcount > 0:
inserted_count += 1
con.commit()
logger.info(f"Results stored successfully. Inserted {inserted_count} new mappings, skipped {len(df) - inserted_count} duplicates")
def run(self, con: sqlite3.Connection, context: TransformContext) -> TransformContext:
"""Execute the transformation.
This is the main entry point called by the pipeline.
Args:
con: SQLite database connection
context: TransformContext containing input dataframe
Returns:
TransformContext with processed dataframe
"""
logger.info("Starting FuzzyAuthorNode transformation")
# Get input dataframe from context
input_df = context.get_dataframe()
# Validate input
if input_df.empty:
logger.warning("Empty dataframe provided to FuzzyAuthorNode")
return context
# Process the data
result_df = self._process_data(con, input_df)
# Store results
self._store_results(con, result_df)
logger.info("FuzzyAuthorNode transformation complete")
# Return new context with results
return TransformContext(result_df)

View file

@ -0,0 +1,16 @@
#!/usr/bin/env bash
set -euo pipefail
if [ -d "$GLINER_MODEL_PATH" ] && find "$GLINER_MODEL_PATH" -type f | grep -q .; then
echo "GLiNER model already present at $GLINER_MODEL_PATH"
exit 0
fi
echo "Downloading GLiNER model to $GLINER_MODEL_PATH"
mkdir -p "$GLINER_MODEL_PATH"
curl -sL "https://huggingface.co/api/models/${GLINER_MODEL_ID}" | jq -r '.siblings[].rfilename' | while read -r file; do
target="${GLINER_MODEL_PATH}/${file}"
mkdir -p "$(dirname "$target")"
echo "Downloading ${file}"
curl -sL "https://huggingface.co/${GLINER_MODEL_ID}/resolve/main/${file}" -o "$target"
done

8
transform/entrypoint.sh Normal file
View file

@ -0,0 +1,8 @@
#!/usr/bin/env bash
set -euo pipefail
# Run model download with output to stdout/stderr
/usr/local/bin/ensure_gliner_model.sh 2>&1
# Start cron in foreground with logging
exec cron -f -L 2

170
transform/example_node.py Normal file
View file

@ -0,0 +1,170 @@
"""Example template node for the transform pipeline.
This is a template showing how to create new transform nodes.
Copy this file and modify it for your specific transformation needs.
"""
from pipeline import TransformContext
from transform_node import TransformNode
import sqlite3
import pandas as pd
import logging
logger = logging.getLogger("knack-transform")
class ExampleNode(TransformNode):
"""Example transform node template.
This node demonstrates the basic structure for creating
new transformation nodes in the pipeline.
"""
def __init__(self,
param1: str = "default_value",
param2: int = 42,
device: str = "cpu"):
"""Initialize the ExampleNode.
Args:
param1: Example string parameter
param2: Example integer parameter
device: Device to use for computations ('cpu', 'cuda', 'mps')
"""
self.param1 = param1
self.param2 = param2
self.device = device
logger.info(f"Initialized ExampleNode with param1={param1}, param2={param2}")
def _create_tables(self, con: sqlite3.Connection):
"""Create any necessary tables in the database.
This is optional - only needed if your node creates new tables.
"""
logger.info("Creating example tables")
con.execute("""
CREATE TABLE IF NOT EXISTS example_results (
id INTEGER PRIMARY KEY AUTOINCREMENT,
post_id INTEGER,
result_value TEXT,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
FOREIGN KEY (post_id) REFERENCES posts(id)
)
""")
con.commit()
def _process_data(self, df: pd.DataFrame) -> pd.DataFrame:
"""Process the input dataframe.
This is where your main transformation logic goes.
Args:
df: Input dataframe from context
Returns:
Processed dataframe
"""
logger.info(f"Processing {len(df)} rows")
# Example: Add a new column based on existing data
result_df = df.copy()
result_df['processed'] = True
result_df['example_value'] = result_df['id'].apply(lambda x: f"{self.param1}_{x}")
logger.info("Processing complete")
return result_df
def _store_results(self, con: sqlite3.Connection, df: pd.DataFrame):
"""Store results back to the database.
This is optional - only needed if you want to persist results.
Args:
con: Database connection
df: Processed dataframe to store
"""
if df.empty:
logger.info("No results to store")
return
logger.info(f"Storing {len(df)} results")
# Example: Store to database
# df[['post_id', 'result_value']].to_sql(
# 'example_results',
# con,
# if_exists='append',
# index=False
# )
con.commit()
logger.info("Results stored successfully")
def run(self, con: sqlite3.Connection, context: TransformContext) -> TransformContext:
"""Execute the transformation.
This is the main entry point called by the pipeline.
Args:
con: SQLite database connection
context: TransformContext containing input dataframe
Returns:
TransformContext with processed dataframe
"""
logger.info("Starting ExampleNode transformation")
# Get input dataframe from context
input_df = context.get_dataframe()
# Validate input
if input_df.empty:
logger.warning("Empty dataframe provided to ExampleNode")
return context
# Create any necessary tables
self._create_tables(con)
# Process the data
result_df = self._process_data(input_df)
# Store results (optional)
self._store_results(con, result_df)
logger.info("ExampleNode transformation complete")
# Return new context with results
return TransformContext(result_df)
# Example usage:
if __name__ == "__main__":
# This allows you to test your node independently
import os
os.chdir('/Users/linussilberstein/Documents/Knack-Scraper/transform')
from pipeline import TransformContext
import sqlite3
# Create test data
test_df = pd.DataFrame({
'id': [1, 2, 3],
'author': ['Test Author 1', 'Test Author 2', 'Test Author 3']
})
# Create test database connection
test_con = sqlite3.connect(':memory:')
# Create and run node
node = ExampleNode(param1="test", param2=100)
context = TransformContext(test_df)
result_context = node.run(test_con, context)
# Check results
result_df = result_context.get_dataframe()
print("\nResult DataFrame:")
print(result_df)
test_con.close()
print("\n✓ ExampleNode test completed successfully!")

View file

@ -50,15 +50,14 @@ def main():
logger.info("Transform pipeline skipped - no data available")
return
# Import transform nodes
from author_node import AuthorNode
from base import TransformContext
# Import transform components
from pipeline import create_default_pipeline, TransformContext
import pandas as pd
# Load posts data
logger.info("Loading posts from database")
sql = "SELECT id, author FROM posts WHERE author IS NOT NULL AND (is_cleaned IS NULL OR is_cleaned = 0) LIMIT ?"
MAX_CLEANED_POSTS = os.environ.get("MAX_CLEANED_POSTS", 500)
MAX_CLEANED_POSTS = os.environ.get("MAX_CLEANED_POSTS", 100)
df = pd.read_sql(sql, con, params=[MAX_CLEANED_POSTS])
logger.info(f"Loaded {len(df)} uncleaned posts with authors")
@ -66,15 +65,29 @@ def main():
logger.info("No uncleaned posts found. Transform pipeline skipped.")
return
# Create context and run author classification
# Create initial context
context = TransformContext(df)
author_transform = AuthorNode(device=os.environ.get('COMPUTE_DEVICE', 'cpu')) # Change to "cuda" or "mps" if available
result_context = author_transform.run(con, context)
# TODO: Create Node to compute Text Embeddings and UMAP.
# TODO: Create Node to pre-compute data based on visuals to reduce load time.
# Create and run parallel pipeline
device = os.environ.get('COMPUTE_DEVICE', 'cpu')
max_workers = int(os.environ.get('MAX_WORKERS', 4))
logger.info("Transform pipeline completed successfully")
pipeline = create_default_pipeline(device=device, max_workers=max_workers)
results = pipeline.run(
db_path=os.environ.get('DB_PATH', '/data/knack.sqlite'),
initial_context=context,
fail_fast=False # Continue even if some nodes fail
)
logger.info(f"Pipeline completed. Processed {len(results)} node(s)")
# Mark all processed posts as cleaned
post_ids = df['id'].tolist()
if post_ids:
placeholders = ','.join('?' * len(post_ids))
con.execute(f"UPDATE posts SET is_cleaned = 1 WHERE id IN ({placeholders})", post_ids)
con.commit()
logger.info(f"Marked {len(post_ids)} posts as cleaned")
except Exception as e:
logger.error(f"Error in transform pipeline: {e}", exc_info=True)

258
transform/pipeline.py Normal file
View file

@ -0,0 +1,258 @@
"""Parallel pipeline orchestration for transform nodes."""
import logging
import os
import sqlite3
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor, as_completed
from typing import List, Dict, Optional
import pandas as pd
import multiprocessing as mp
logger = logging.getLogger("knack-transform")
class TransformContext:
"""Context object containing the dataframe for transformation."""
def __init__(self, df: pd.DataFrame):
self.df = df
def get_dataframe(self) -> pd.DataFrame:
"""Get the pandas dataframe from the context."""
return self.df
class NodeConfig:
"""Configuration for a transform node."""
def __init__(self,
node_class: type,
node_kwargs: Dict = None,
dependencies: List[str] = None,
name: str = None):
"""Initialize node configuration.
Args:
node_class: The TransformNode class to instantiate
node_kwargs: Keyword arguments to pass to node constructor
dependencies: List of node names that must complete before this one
name: Optional name for the node (defaults to class name)
"""
self.node_class = node_class
self.node_kwargs = node_kwargs or {}
self.dependencies = dependencies or []
self.name = name or node_class.__name__
class ParallelPipeline:
"""Pipeline for executing transform nodes in parallel where possible.
The pipeline analyzes dependencies between nodes and executes
independent nodes concurrently using multiprocessing or threading.
"""
def __init__(self,
max_workers: Optional[int] = None,
use_processes: bool = False):
"""Initialize the parallel pipeline.
Args:
max_workers: Maximum number of parallel workers (defaults to CPU count)
use_processes: If True, use ProcessPoolExecutor; if False, use ThreadPoolExecutor
"""
self.max_workers = max_workers or mp.cpu_count()
self.use_processes = use_processes
self.nodes: Dict[str, NodeConfig] = {}
logger.info(f"Initialized ParallelPipeline with {self.max_workers} workers "
f"({'processes' if use_processes else 'threads'})")
def add_node(self, config: NodeConfig):
"""Add a node to the pipeline.
Args:
config: NodeConfig with node details and dependencies
"""
self.nodes[config.name] = config
logger.info(f"Added node '{config.name}' with dependencies: {config.dependencies}")
def _get_execution_stages(self) -> List[List[str]]:
"""Determine execution stages based on dependencies.
Returns:
List of stages, where each stage contains node names that can run in parallel
"""
stages = []
completed = set()
remaining = set(self.nodes.keys())
while remaining:
# Find nodes whose dependencies are all completed
ready = []
for node_name in remaining:
config = self.nodes[node_name]
if all(dep in completed for dep in config.dependencies):
ready.append(node_name)
if not ready:
# Circular dependency or missing dependency
raise ValueError(f"Cannot resolve dependencies. Remaining nodes: {remaining}")
stages.append(ready)
completed.update(ready)
remaining -= set(ready)
return stages
def _execute_node(self,
node_name: str,
db_path: str,
context: TransformContext) -> tuple:
"""Execute a single node.
Args:
node_name: Name of the node to execute
db_path: Path to the SQLite database
context: TransformContext for the node
Returns:
Tuple of (node_name, result_context, error)
"""
try:
# Create fresh database connection (not shared across processes/threads)
con = sqlite3.connect(db_path)
config = self.nodes[node_name]
node = config.node_class(**config.node_kwargs)
logger.info(f"Executing node: {node_name}")
result_context = node.run(con, context)
con.close()
logger.info(f"Node '{node_name}' completed successfully")
return node_name, result_context, None
except Exception as e:
logger.error(f"Error executing node '{node_name}': {e}", exc_info=True)
return node_name, None, str(e)
def run(self,
db_path: str,
initial_context: TransformContext,
fail_fast: bool = False) -> Dict[str, TransformContext]:
"""Execute the pipeline.
Args:
db_path: Path to the SQLite database
initial_context: Initial TransformContext for the pipeline
fail_fast: If True, stop execution on first error
Returns:
Dict mapping node names to their output TransformContext
"""
logger.info("Starting parallel pipeline execution")
stages = self._get_execution_stages()
logger.info(f"Pipeline has {len(stages)} execution stage(s)")
results = {}
contexts = {None: initial_context} # Track contexts from each node
errors = []
ExecutorClass = ProcessPoolExecutor if self.use_processes else ThreadPoolExecutor
for stage_num, stage_nodes in enumerate(stages, 1):
logger.info(f"Stage {stage_num}/{len(stages)}: Executing {len(stage_nodes)} node(s) in parallel: {stage_nodes}")
# For nodes in this stage, use the context from their dependencies
# If multiple dependencies, we'll use the most recent one (or could merge)
stage_futures = {}
with ExecutorClass(max_workers=min(self.max_workers, len(stage_nodes))) as executor:
for node_name in stage_nodes:
config = self.nodes[node_name]
# Get context from dependencies (use the last dependency's output)
if config.dependencies:
context = results.get(config.dependencies[-1], initial_context)
else:
context = initial_context
future = executor.submit(self._execute_node, node_name, db_path, context)
stage_futures[future] = node_name
# Wait for all nodes in this stage to complete
for future in as_completed(stage_futures):
node_name = stage_futures[future]
name, result_context, error = future.result()
if error:
errors.append((name, error))
if fail_fast:
logger.error(f"Pipeline failed at node '{name}': {error}")
raise RuntimeError(f"Node '{name}' failed: {error}")
else:
results[name] = result_context
if errors:
logger.warning(f"Pipeline completed with {len(errors)} error(s)")
for name, error in errors:
logger.error(f" - {name}: {error}")
else:
logger.info("Pipeline completed successfully")
return results
def create_default_pipeline(device: str = "cpu",
max_workers: Optional[int] = None) -> ParallelPipeline:
"""Create a pipeline with default transform nodes.
Args:
device: Device to use for compute-intensive nodes ('cpu', 'cuda', 'mps')
max_workers: Maximum number of parallel workers
Returns:
Configured ParallelPipeline
"""
from author_node import NerAuthorNode, FuzzyAuthorNode
pipeline = ParallelPipeline(max_workers=max_workers, use_processes=False)
# Add AuthorNode (no dependencies)
pipeline.add_node(NodeConfig(
node_class=NerAuthorNode,
node_kwargs={
'device': device,
'model_path': os.environ.get('GLINER_MODEL_PATH')
},
dependencies=[],
name='AuthorNode'
))
pipeline.add_node(NodeConfig(
node_class=FuzzyAuthorNode,
node_kwargs={
'max_l_dist': 1
},
dependencies=['AuthorNode'],
name='FuzzyAuthorNode'
))
# TODO: Create Node to compute Text Embeddings and UMAP.
# TODO: Create Node to pre-compute data based on visuals to reduce load time.
# TODO: Add more nodes here as they are implemented
# Example:
# pipeline.add_node(NodeConfig(
# node_class=EmbeddingNode,
# node_kwargs={'device': device},
# dependencies=[], # Runs after AuthorNode
# name='EmbeddingNode'
# ))
# pipeline.add_node(NodeConfig(
# node_class=UMAPNode,
# node_kwargs={'device': device},
# dependencies=['EmbeddingNode'], # Runs after EmbeddingNode
# name='UMAPNode'
# ))
return pipeline

View file

@ -2,3 +2,4 @@ pandas
python-dotenv
gliner
torch
fuzzysearch

View file

@ -1,19 +1,8 @@
"""Base transform node for data pipeline."""
from abc import ABC, abstractmethod
import sqlite3
import pandas as pd
class TransformContext:
"""Context object containing the dataframe for transformation."""
def __init__(self, df: pd.DataFrame):
self.df = df
def get_dataframe(self) -> pd.DataFrame:
"""Get the pandas dataframe from the context."""
return self.df
from pipeline import TransformContext
class TransformNode(ABC):
"""Abstract base class for transformation nodes.