Adds TransformNode to FuzzyFind Author Names
This commit is contained in:
parent
64df8fb328
commit
72765532d3
11 changed files with 696 additions and 58 deletions
|
|
@ -21,7 +21,23 @@ services:
|
|||
- transform/.env
|
||||
volumes:
|
||||
- knack_data:/data
|
||||
- models:/models
|
||||
restart: unless-stopped
|
||||
|
||||
sqlitebrowser:
|
||||
image: lscr.io/linuxserver/sqlitebrowser:latest
|
||||
container_name: sqlitebrowser
|
||||
environment:
|
||||
- PUID=1000
|
||||
- PGID=1000
|
||||
- TZ=Etc/UTC
|
||||
volumes:
|
||||
- knack_data:/data
|
||||
ports:
|
||||
- "3000:3000" # noVNC web UI
|
||||
- "3001:3001" # VNC server
|
||||
restart: unless-stopped
|
||||
|
||||
volumes:
|
||||
knack_data:
|
||||
models:
|
||||
|
|
|
|||
|
|
@ -1,7 +1,6 @@
|
|||
FROM python:3.12-slim
|
||||
|
||||
RUN mkdir /app
|
||||
RUN mkdir /data
|
||||
RUN mkdir -p /app /data /models
|
||||
|
||||
# Install build dependencies
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
|
|
@ -11,9 +10,12 @@ RUN apt-get update && apt-get install -y --no-install-recommends \
|
|||
libopenblas-dev \
|
||||
liblapack-dev \
|
||||
pkg-config \
|
||||
curl \
|
||||
jq \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
#COPY /data/knack.sqlite /data
|
||||
ENV GLINER_MODEL_ID=urchade/gliner_multi-v2.1
|
||||
ENV GLINER_MODEL_PATH=/models/gliner_multi-v2.1
|
||||
|
||||
WORKDIR /app
|
||||
COPY requirements.txt .
|
||||
|
|
@ -24,18 +26,21 @@ COPY .env .
|
|||
RUN apt update -y
|
||||
RUN apt install -y cron locales
|
||||
|
||||
COPY *.py .
|
||||
# Ensure GLiNER helper scripts are available
|
||||
COPY ensure_gliner_model.sh /usr/local/bin/ensure_gliner_model.sh
|
||||
COPY entrypoint.sh /usr/local/bin/entrypoint.sh
|
||||
RUN chmod +x /usr/local/bin/ensure_gliner_model.sh /usr/local/bin/entrypoint.sh
|
||||
|
||||
ENV PYTHONUNBUFFERED=1
|
||||
ENV LANG=de_DE.UTF-8
|
||||
ENV LC_ALL=de_DE.UTF-8
|
||||
COPY *.py .
|
||||
|
||||
# Create cron job that runs every weekend (Sunday at 3 AM) 0 3 * * 0
|
||||
# Testing every 30 Minutes */30 * * * *
|
||||
RUN echo "0 3 * * 0 cd /app && /usr/local/bin/python main.py >> /proc/1/fd/1 2>&1" > /etc/cron.d/knack-transform
|
||||
RUN echo "*/30 * * * * cd /app && /usr/local/bin/python main.py >> /proc/1/fd/1 2>&1" > /etc/cron.d/knack-transform
|
||||
RUN chmod 0644 /etc/cron.d/knack-transform
|
||||
RUN crontab /etc/cron.d/knack-transform
|
||||
|
||||
# Start cron in foreground
|
||||
CMD ["cron", "-f"]
|
||||
# Persist models between container runs
|
||||
VOLUME /models
|
||||
|
||||
CMD ["/usr/local/bin/entrypoint.sh"]
|
||||
#CMD ["python", "main.py"]
|
||||
|
|
|
|||
|
|
@ -6,10 +6,15 @@ Data transformation pipeline for the Knack scraper project.
|
|||
|
||||
This folder contains the transformation logic that processes data from the SQLite database. It runs on a scheduled basis (every weekend) via cron.
|
||||
|
||||
The pipeline supports **parallel execution** of independent transform nodes, allowing you to leverage multi-core processors for faster data transformation.
|
||||
|
||||
## Structure
|
||||
|
||||
- `base.py` - Abstract base class for transform nodes
|
||||
- `main.py` - Main entry point and pipeline orchestration
|
||||
- `pipeline.py` - Parallel pipeline orchestration system
|
||||
- `main.py` - Main entry point and pipeline execution
|
||||
- `author_node.py` - NER-based author classification node
|
||||
- `example_node.py` - Template for creating new nodes
|
||||
- `Dockerfile` - Docker image configuration with cron setup
|
||||
- `requirements.txt` - Python dependencies
|
||||
|
||||
|
|
|
|||
|
|
@ -1,10 +1,13 @@
|
|||
"""Author classification transform node using NER."""
|
||||
from base import TransformNode, TransformContext
|
||||
import os
|
||||
import sqlite3
|
||||
import pandas as pd
|
||||
import logging
|
||||
import fuzzysearch
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from datetime import datetime
|
||||
|
||||
from pipeline import TransformContext
|
||||
from transform_node import TransformNode
|
||||
|
||||
try:
|
||||
from gliner import GLiNER
|
||||
|
|
@ -17,7 +20,7 @@ except ImportError:
|
|||
logger = logging.getLogger("knack-transform")
|
||||
|
||||
|
||||
class AuthorNode(TransformNode):
|
||||
class NerAuthorNode(TransformNode):
|
||||
"""Transform node that extracts and classifies authors using NER.
|
||||
|
||||
Creates two tables:
|
||||
|
|
@ -25,7 +28,8 @@ class AuthorNode(TransformNode):
|
|||
- post_authors: maps posts to their authors
|
||||
"""
|
||||
|
||||
def __init__(self, model_name: str = "urchade/gliner_medium-v2.1",
|
||||
def __init__(self, model_name: str = "urchade/gliner_multi-v2.1",
|
||||
model_path: str = None,
|
||||
threshold: float = 0.5,
|
||||
max_workers: int = 64,
|
||||
device: str = "cpu"):
|
||||
|
|
@ -33,11 +37,13 @@ class AuthorNode(TransformNode):
|
|||
|
||||
Args:
|
||||
model_name: GLiNER model to use
|
||||
model_path: Optional local path to a downloaded GLiNER model
|
||||
threshold: Confidence threshold for entity predictions
|
||||
max_workers: Number of parallel workers for prediction
|
||||
device: Device to run model on ('cpu', 'cuda', 'mps')
|
||||
"""
|
||||
self.model_name = model_name
|
||||
self.model_path = model_path or os.environ.get('GLINER_MODEL_PATH')
|
||||
self.threshold = threshold
|
||||
self.max_workers = max_workers
|
||||
self.device = device
|
||||
|
|
@ -49,21 +55,31 @@ class AuthorNode(TransformNode):
|
|||
if not GLINER_AVAILABLE:
|
||||
raise ImportError("GLiNER is required for AuthorNode. Install with: pip install gliner")
|
||||
|
||||
logger.info(f"Loading GLiNER model: {self.model_name}")
|
||||
model_source = None
|
||||
if self.model_path:
|
||||
if os.path.exists(self.model_path):
|
||||
model_source = self.model_path
|
||||
logger.info(f"Loading GLiNER model from local path: {self.model_path}")
|
||||
else:
|
||||
logger.warning(f"GLINER_MODEL_PATH '{self.model_path}' not found; falling back to hub model {self.model_name}")
|
||||
|
||||
if model_source is None:
|
||||
model_source = self.model_name
|
||||
logger.info(f"Loading GLiNER model from hub: {self.model_name}")
|
||||
|
||||
if self.device == "cuda" and torch.cuda.is_available():
|
||||
self.model = GLiNER.from_pretrained(
|
||||
self.model_name,
|
||||
model_source,
|
||||
max_length=255
|
||||
).to('cuda', dtype=torch.float16)
|
||||
elif self.device == "mps" and torch.backends.mps.is_available():
|
||||
self.model = GLiNER.from_pretrained(
|
||||
self.model_name,
|
||||
model_source,
|
||||
max_length=255
|
||||
).to('mps', dtype=torch.float16)
|
||||
else:
|
||||
self.model = GLiNER.from_pretrained(
|
||||
self.model_name,
|
||||
model_source,
|
||||
max_length=255
|
||||
)
|
||||
|
||||
|
|
@ -208,13 +224,6 @@ class AuthorNode(TransformNode):
|
|||
|
||||
logger.info(f"Creating {len(mappings_df)} post-author mappings")
|
||||
mappings_df.to_sql('post_authors', con, if_exists='append', index=False)
|
||||
|
||||
# Mark posts as cleaned
|
||||
processed_post_ids = mappings_df['post_id'].unique().tolist()
|
||||
if processed_post_ids:
|
||||
placeholders = ','.join('?' * len(processed_post_ids))
|
||||
con.execute(f"UPDATE posts SET is_cleaned = 1 WHERE id IN ({placeholders})", processed_post_ids)
|
||||
logger.info(f"Marked {len(processed_post_ids)} posts as cleaned")
|
||||
|
||||
con.commit()
|
||||
logger.info("Authors and mappings stored successfully")
|
||||
|
|
@ -247,17 +256,165 @@ class AuthorNode(TransformNode):
|
|||
# Store results
|
||||
self._store_authors(con, results)
|
||||
|
||||
# Mark posts without author entities as cleaned too (no authors found)
|
||||
processed_ids = set([r['id'] for r in results]) if results else set()
|
||||
unprocessed_ids = [pid for pid in posts_df['id'].tolist() if pid not in processed_ids]
|
||||
if unprocessed_ids:
|
||||
placeholders = ','.join('?' * len(unprocessed_ids))
|
||||
con.execute(f"UPDATE posts SET is_cleaned = 1 WHERE id IN ({placeholders})", unprocessed_ids)
|
||||
con.commit()
|
||||
logger.info(f"Marked {len(unprocessed_ids)} posts without author entities as cleaned")
|
||||
|
||||
# Return context with results
|
||||
results_df = pd.DataFrame(results) if results else pd.DataFrame()
|
||||
logger.info("AuthorNode transformation complete")
|
||||
|
||||
return TransformContext(results_df)
|
||||
|
||||
|
||||
class FuzzyAuthorNode(TransformNode):
|
||||
"""FuzzyAuthorNode
|
||||
|
||||
This Node takes in data and rules of authornames that have been classified already
|
||||
and uses those 'rule' to find more similar fields.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
max_l_dist: int = 1,):
|
||||
"""Initialize FuzzyAuthorNode.
|
||||
|
||||
Args:
|
||||
max_l_dist: The number of 'errors' that are allowed by the fuzzy search algorithm
|
||||
"""
|
||||
self.max_l_dist = max_l_dist
|
||||
logger.info(f"Initialized FuzzyAuthorNode with max_l_dist={max_l_dist}")
|
||||
|
||||
def _process_data(self, con: sqlite3.Connection, df: pd.DataFrame) -> pd.DataFrame:
|
||||
"""Process the input dataframe.
|
||||
|
||||
This is where your main transformation logic goes.
|
||||
|
||||
Args:
|
||||
con: Database connection
|
||||
df: Input dataframe from context
|
||||
|
||||
Returns:
|
||||
Processed dataframe
|
||||
"""
|
||||
logger.info(f"Processing {len(df)} rows")
|
||||
|
||||
# Retrieve all known authors from the authors table as 'rules'
|
||||
authors_df = pd.read_sql("SELECT id, name FROM authors", con)
|
||||
|
||||
if authors_df.empty:
|
||||
logger.warning("No authors found in database for fuzzy matching")
|
||||
return pd.DataFrame(columns=['post_id', 'author_id'])
|
||||
|
||||
# Get existing post-author mappings to avoid duplicates
|
||||
existing_mappings = pd.read_sql(
|
||||
"SELECT post_id, author_id FROM post_authors", con
|
||||
)
|
||||
existing_post_ids = set(existing_mappings['post_id'].unique())
|
||||
|
||||
logger.info(f"Found {len(authors_df)} known authors for fuzzy matching")
|
||||
logger.info(f"Found {len(existing_post_ids)} posts with existing author mappings")
|
||||
|
||||
# Filter to posts without author mappings and with non-null author field
|
||||
if 'author' not in df.columns or 'id' not in df.columns:
|
||||
logger.warning("Missing 'author' or 'id' column in input dataframe")
|
||||
return pd.DataFrame(columns=['post_id', 'author_id'])
|
||||
|
||||
posts_to_process = df[
|
||||
(df['id'].notna()) &
|
||||
(df['author'].notna()) &
|
||||
(~df['id'].isin(existing_post_ids))
|
||||
]
|
||||
|
||||
logger.info(f"Processing {len(posts_to_process)} posts for fuzzy matching")
|
||||
|
||||
# Perform fuzzy matching
|
||||
mappings = []
|
||||
for _, post_row in posts_to_process.iterrows():
|
||||
post_id = post_row['id']
|
||||
post_author = str(post_row['author'])
|
||||
|
||||
# Try to find matches against all known author names
|
||||
for _, author_row in authors_df.iterrows():
|
||||
author_id = author_row['id']
|
||||
author_name = str(author_row['name'])
|
||||
|
||||
# Use fuzzysearch to find matches with allowed errors
|
||||
matches = fuzzysearch.find_near_matches(
|
||||
author_name,
|
||||
post_author,
|
||||
max_l_dist=self.max_l_dist
|
||||
)
|
||||
|
||||
if matches:
|
||||
logger.debug(f"Found fuzzy match: '{author_name}' in '{post_author}' for post {post_id}")
|
||||
mappings.append({
|
||||
'post_id': post_id,
|
||||
'author_id': author_id
|
||||
})
|
||||
# Only take the first match per post to avoid multiple mappings
|
||||
break
|
||||
|
||||
# Create result dataframe
|
||||
result_df = pd.DataFrame(mappings, columns=['post_id', 'author_id']) if mappings else pd.DataFrame(columns=['post_id', 'author_id'])
|
||||
|
||||
logger.info(f"Processing complete. Found {len(result_df)} fuzzy matches")
|
||||
return result_df
|
||||
|
||||
def _store_results(self, con: sqlite3.Connection, df: pd.DataFrame):
|
||||
"""Store results back to the database.
|
||||
|
||||
Uses INSERT OR IGNORE to avoid inserting duplicates.
|
||||
|
||||
Args:
|
||||
con: Database connection
|
||||
df: Processed dataframe to store
|
||||
"""
|
||||
if df.empty:
|
||||
logger.info("No results to store")
|
||||
return
|
||||
|
||||
logger.info(f"Storing {len(df)} results")
|
||||
|
||||
# Use INSERT OR IGNORE to handle duplicates (respects PRIMARY KEY constraint)
|
||||
cursor = con.cursor()
|
||||
inserted_count = 0
|
||||
|
||||
for _, row in df.iterrows():
|
||||
cursor.execute(
|
||||
"INSERT OR IGNORE INTO post_authors (post_id, author_id) VALUES (?, ?)",
|
||||
(int(row['post_id']), int(row['author_id']))
|
||||
)
|
||||
if cursor.rowcount > 0:
|
||||
inserted_count += 1
|
||||
|
||||
con.commit()
|
||||
logger.info(f"Results stored successfully. Inserted {inserted_count} new mappings, skipped {len(df) - inserted_count} duplicates")
|
||||
|
||||
def run(self, con: sqlite3.Connection, context: TransformContext) -> TransformContext:
|
||||
"""Execute the transformation.
|
||||
|
||||
This is the main entry point called by the pipeline.
|
||||
|
||||
Args:
|
||||
con: SQLite database connection
|
||||
context: TransformContext containing input dataframe
|
||||
|
||||
Returns:
|
||||
TransformContext with processed dataframe
|
||||
"""
|
||||
logger.info("Starting FuzzyAuthorNode transformation")
|
||||
|
||||
# Get input dataframe from context
|
||||
input_df = context.get_dataframe()
|
||||
|
||||
# Validate input
|
||||
if input_df.empty:
|
||||
logger.warning("Empty dataframe provided to FuzzyAuthorNode")
|
||||
return context
|
||||
|
||||
# Process the data
|
||||
result_df = self._process_data(con, input_df)
|
||||
|
||||
# Store results
|
||||
self._store_results(con, result_df)
|
||||
|
||||
logger.info("FuzzyAuthorNode transformation complete")
|
||||
|
||||
# Return new context with results
|
||||
return TransformContext(result_df)
|
||||
|
|
|
|||
16
transform/ensure_gliner_model.sh
Normal file
16
transform/ensure_gliner_model.sh
Normal file
|
|
@ -0,0 +1,16 @@
|
|||
#!/usr/bin/env bash
|
||||
set -euo pipefail
|
||||
|
||||
if [ -d "$GLINER_MODEL_PATH" ] && find "$GLINER_MODEL_PATH" -type f | grep -q .; then
|
||||
echo "GLiNER model already present at $GLINER_MODEL_PATH"
|
||||
exit 0
|
||||
fi
|
||||
|
||||
echo "Downloading GLiNER model to $GLINER_MODEL_PATH"
|
||||
mkdir -p "$GLINER_MODEL_PATH"
|
||||
curl -sL "https://huggingface.co/api/models/${GLINER_MODEL_ID}" | jq -r '.siblings[].rfilename' | while read -r file; do
|
||||
target="${GLINER_MODEL_PATH}/${file}"
|
||||
mkdir -p "$(dirname "$target")"
|
||||
echo "Downloading ${file}"
|
||||
curl -sL "https://huggingface.co/${GLINER_MODEL_ID}/resolve/main/${file}" -o "$target"
|
||||
done
|
||||
8
transform/entrypoint.sh
Normal file
8
transform/entrypoint.sh
Normal file
|
|
@ -0,0 +1,8 @@
|
|||
#!/usr/bin/env bash
|
||||
set -euo pipefail
|
||||
|
||||
# Run model download with output to stdout/stderr
|
||||
/usr/local/bin/ensure_gliner_model.sh 2>&1
|
||||
|
||||
# Start cron in foreground with logging
|
||||
exec cron -f -L 2
|
||||
170
transform/example_node.py
Normal file
170
transform/example_node.py
Normal file
|
|
@ -0,0 +1,170 @@
|
|||
"""Example template node for the transform pipeline.
|
||||
|
||||
This is a template showing how to create new transform nodes.
|
||||
Copy this file and modify it for your specific transformation needs.
|
||||
"""
|
||||
from pipeline import TransformContext
|
||||
from transform_node import TransformNode
|
||||
import sqlite3
|
||||
import pandas as pd
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger("knack-transform")
|
||||
|
||||
|
||||
class ExampleNode(TransformNode):
|
||||
"""Example transform node template.
|
||||
|
||||
This node demonstrates the basic structure for creating
|
||||
new transformation nodes in the pipeline.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
param1: str = "default_value",
|
||||
param2: int = 42,
|
||||
device: str = "cpu"):
|
||||
"""Initialize the ExampleNode.
|
||||
|
||||
Args:
|
||||
param1: Example string parameter
|
||||
param2: Example integer parameter
|
||||
device: Device to use for computations ('cpu', 'cuda', 'mps')
|
||||
"""
|
||||
self.param1 = param1
|
||||
self.param2 = param2
|
||||
self.device = device
|
||||
logger.info(f"Initialized ExampleNode with param1={param1}, param2={param2}")
|
||||
|
||||
def _create_tables(self, con: sqlite3.Connection):
|
||||
"""Create any necessary tables in the database.
|
||||
|
||||
This is optional - only needed if your node creates new tables.
|
||||
"""
|
||||
logger.info("Creating example tables")
|
||||
|
||||
con.execute("""
|
||||
CREATE TABLE IF NOT EXISTS example_results (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
post_id INTEGER,
|
||||
result_value TEXT,
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
FOREIGN KEY (post_id) REFERENCES posts(id)
|
||||
)
|
||||
""")
|
||||
|
||||
con.commit()
|
||||
|
||||
def _process_data(self, df: pd.DataFrame) -> pd.DataFrame:
|
||||
"""Process the input dataframe.
|
||||
|
||||
This is where your main transformation logic goes.
|
||||
|
||||
Args:
|
||||
df: Input dataframe from context
|
||||
|
||||
Returns:
|
||||
Processed dataframe
|
||||
"""
|
||||
logger.info(f"Processing {len(df)} rows")
|
||||
|
||||
# Example: Add a new column based on existing data
|
||||
result_df = df.copy()
|
||||
result_df['processed'] = True
|
||||
result_df['example_value'] = result_df['id'].apply(lambda x: f"{self.param1}_{x}")
|
||||
|
||||
logger.info("Processing complete")
|
||||
return result_df
|
||||
|
||||
def _store_results(self, con: sqlite3.Connection, df: pd.DataFrame):
|
||||
"""Store results back to the database.
|
||||
|
||||
This is optional - only needed if you want to persist results.
|
||||
|
||||
Args:
|
||||
con: Database connection
|
||||
df: Processed dataframe to store
|
||||
"""
|
||||
if df.empty:
|
||||
logger.info("No results to store")
|
||||
return
|
||||
|
||||
logger.info(f"Storing {len(df)} results")
|
||||
|
||||
# Example: Store to database
|
||||
# df[['post_id', 'result_value']].to_sql(
|
||||
# 'example_results',
|
||||
# con,
|
||||
# if_exists='append',
|
||||
# index=False
|
||||
# )
|
||||
|
||||
con.commit()
|
||||
logger.info("Results stored successfully")
|
||||
|
||||
def run(self, con: sqlite3.Connection, context: TransformContext) -> TransformContext:
|
||||
"""Execute the transformation.
|
||||
|
||||
This is the main entry point called by the pipeline.
|
||||
|
||||
Args:
|
||||
con: SQLite database connection
|
||||
context: TransformContext containing input dataframe
|
||||
|
||||
Returns:
|
||||
TransformContext with processed dataframe
|
||||
"""
|
||||
logger.info("Starting ExampleNode transformation")
|
||||
|
||||
# Get input dataframe from context
|
||||
input_df = context.get_dataframe()
|
||||
|
||||
# Validate input
|
||||
if input_df.empty:
|
||||
logger.warning("Empty dataframe provided to ExampleNode")
|
||||
return context
|
||||
|
||||
# Create any necessary tables
|
||||
self._create_tables(con)
|
||||
|
||||
# Process the data
|
||||
result_df = self._process_data(input_df)
|
||||
|
||||
# Store results (optional)
|
||||
self._store_results(con, result_df)
|
||||
|
||||
logger.info("ExampleNode transformation complete")
|
||||
|
||||
# Return new context with results
|
||||
return TransformContext(result_df)
|
||||
|
||||
|
||||
# Example usage:
|
||||
if __name__ == "__main__":
|
||||
# This allows you to test your node independently
|
||||
import os
|
||||
os.chdir('/Users/linussilberstein/Documents/Knack-Scraper/transform')
|
||||
|
||||
from pipeline import TransformContext
|
||||
import sqlite3
|
||||
|
||||
# Create test data
|
||||
test_df = pd.DataFrame({
|
||||
'id': [1, 2, 3],
|
||||
'author': ['Test Author 1', 'Test Author 2', 'Test Author 3']
|
||||
})
|
||||
|
||||
# Create test database connection
|
||||
test_con = sqlite3.connect(':memory:')
|
||||
|
||||
# Create and run node
|
||||
node = ExampleNode(param1="test", param2=100)
|
||||
context = TransformContext(test_df)
|
||||
result_context = node.run(test_con, context)
|
||||
|
||||
# Check results
|
||||
result_df = result_context.get_dataframe()
|
||||
print("\nResult DataFrame:")
|
||||
print(result_df)
|
||||
|
||||
test_con.close()
|
||||
print("\n✓ ExampleNode test completed successfully!")
|
||||
|
|
@ -50,15 +50,14 @@ def main():
|
|||
logger.info("Transform pipeline skipped - no data available")
|
||||
return
|
||||
|
||||
# Import transform nodes
|
||||
from author_node import AuthorNode
|
||||
from base import TransformContext
|
||||
# Import transform components
|
||||
from pipeline import create_default_pipeline, TransformContext
|
||||
import pandas as pd
|
||||
|
||||
# Load posts data
|
||||
logger.info("Loading posts from database")
|
||||
sql = "SELECT id, author FROM posts WHERE author IS NOT NULL AND (is_cleaned IS NULL OR is_cleaned = 0) LIMIT ?"
|
||||
MAX_CLEANED_POSTS = os.environ.get("MAX_CLEANED_POSTS", 500)
|
||||
MAX_CLEANED_POSTS = os.environ.get("MAX_CLEANED_POSTS", 100)
|
||||
df = pd.read_sql(sql, con, params=[MAX_CLEANED_POSTS])
|
||||
logger.info(f"Loaded {len(df)} uncleaned posts with authors")
|
||||
|
||||
|
|
@ -66,15 +65,29 @@ def main():
|
|||
logger.info("No uncleaned posts found. Transform pipeline skipped.")
|
||||
return
|
||||
|
||||
# Create context and run author classification
|
||||
# Create initial context
|
||||
context = TransformContext(df)
|
||||
author_transform = AuthorNode(device=os.environ.get('COMPUTE_DEVICE', 'cpu')) # Change to "cuda" or "mps" if available
|
||||
result_context = author_transform.run(con, context)
|
||||
|
||||
# TODO: Create Node to compute Text Embeddings and UMAP.
|
||||
# TODO: Create Node to pre-compute data based on visuals to reduce load time.
|
||||
|
||||
logger.info("Transform pipeline completed successfully")
|
||||
# Create and run parallel pipeline
|
||||
device = os.environ.get('COMPUTE_DEVICE', 'cpu')
|
||||
max_workers = int(os.environ.get('MAX_WORKERS', 4))
|
||||
|
||||
pipeline = create_default_pipeline(device=device, max_workers=max_workers)
|
||||
results = pipeline.run(
|
||||
db_path=os.environ.get('DB_PATH', '/data/knack.sqlite'),
|
||||
initial_context=context,
|
||||
fail_fast=False # Continue even if some nodes fail
|
||||
)
|
||||
|
||||
logger.info(f"Pipeline completed. Processed {len(results)} node(s)")
|
||||
|
||||
# Mark all processed posts as cleaned
|
||||
post_ids = df['id'].tolist()
|
||||
if post_ids:
|
||||
placeholders = ','.join('?' * len(post_ids))
|
||||
con.execute(f"UPDATE posts SET is_cleaned = 1 WHERE id IN ({placeholders})", post_ids)
|
||||
con.commit()
|
||||
logger.info(f"Marked {len(post_ids)} posts as cleaned")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in transform pipeline: {e}", exc_info=True)
|
||||
|
|
|
|||
258
transform/pipeline.py
Normal file
258
transform/pipeline.py
Normal file
|
|
@ -0,0 +1,258 @@
|
|||
"""Parallel pipeline orchestration for transform nodes."""
|
||||
import logging
|
||||
import os
|
||||
import sqlite3
|
||||
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor, as_completed
|
||||
from typing import List, Dict, Optional
|
||||
|
||||
import pandas as pd
|
||||
import multiprocessing as mp
|
||||
|
||||
logger = logging.getLogger("knack-transform")
|
||||
|
||||
class TransformContext:
|
||||
"""Context object containing the dataframe for transformation."""
|
||||
|
||||
def __init__(self, df: pd.DataFrame):
|
||||
self.df = df
|
||||
|
||||
def get_dataframe(self) -> pd.DataFrame:
|
||||
"""Get the pandas dataframe from the context."""
|
||||
return self.df
|
||||
|
||||
class NodeConfig:
|
||||
"""Configuration for a transform node."""
|
||||
|
||||
def __init__(self,
|
||||
node_class: type,
|
||||
node_kwargs: Dict = None,
|
||||
dependencies: List[str] = None,
|
||||
name: str = None):
|
||||
"""Initialize node configuration.
|
||||
|
||||
Args:
|
||||
node_class: The TransformNode class to instantiate
|
||||
node_kwargs: Keyword arguments to pass to node constructor
|
||||
dependencies: List of node names that must complete before this one
|
||||
name: Optional name for the node (defaults to class name)
|
||||
"""
|
||||
self.node_class = node_class
|
||||
self.node_kwargs = node_kwargs or {}
|
||||
self.dependencies = dependencies or []
|
||||
self.name = name or node_class.__name__
|
||||
|
||||
class ParallelPipeline:
|
||||
"""Pipeline for executing transform nodes in parallel where possible.
|
||||
|
||||
The pipeline analyzes dependencies between nodes and executes
|
||||
independent nodes concurrently using multiprocessing or threading.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
max_workers: Optional[int] = None,
|
||||
use_processes: bool = False):
|
||||
"""Initialize the parallel pipeline.
|
||||
|
||||
Args:
|
||||
max_workers: Maximum number of parallel workers (defaults to CPU count)
|
||||
use_processes: If True, use ProcessPoolExecutor; if False, use ThreadPoolExecutor
|
||||
"""
|
||||
self.max_workers = max_workers or mp.cpu_count()
|
||||
self.use_processes = use_processes
|
||||
self.nodes: Dict[str, NodeConfig] = {}
|
||||
logger.info(f"Initialized ParallelPipeline with {self.max_workers} workers "
|
||||
f"({'processes' if use_processes else 'threads'})")
|
||||
|
||||
def add_node(self, config: NodeConfig):
|
||||
"""Add a node to the pipeline.
|
||||
|
||||
Args:
|
||||
config: NodeConfig with node details and dependencies
|
||||
"""
|
||||
self.nodes[config.name] = config
|
||||
logger.info(f"Added node '{config.name}' with dependencies: {config.dependencies}")
|
||||
|
||||
def _get_execution_stages(self) -> List[List[str]]:
|
||||
"""Determine execution stages based on dependencies.
|
||||
|
||||
Returns:
|
||||
List of stages, where each stage contains node names that can run in parallel
|
||||
"""
|
||||
stages = []
|
||||
completed = set()
|
||||
remaining = set(self.nodes.keys())
|
||||
|
||||
while remaining:
|
||||
# Find nodes whose dependencies are all completed
|
||||
ready = []
|
||||
for node_name in remaining:
|
||||
config = self.nodes[node_name]
|
||||
if all(dep in completed for dep in config.dependencies):
|
||||
ready.append(node_name)
|
||||
|
||||
if not ready:
|
||||
# Circular dependency or missing dependency
|
||||
raise ValueError(f"Cannot resolve dependencies. Remaining nodes: {remaining}")
|
||||
|
||||
stages.append(ready)
|
||||
completed.update(ready)
|
||||
remaining -= set(ready)
|
||||
|
||||
return stages
|
||||
|
||||
def _execute_node(self,
|
||||
node_name: str,
|
||||
db_path: str,
|
||||
context: TransformContext) -> tuple:
|
||||
"""Execute a single node.
|
||||
|
||||
Args:
|
||||
node_name: Name of the node to execute
|
||||
db_path: Path to the SQLite database
|
||||
context: TransformContext for the node
|
||||
|
||||
Returns:
|
||||
Tuple of (node_name, result_context, error)
|
||||
"""
|
||||
try:
|
||||
# Create fresh database connection (not shared across processes/threads)
|
||||
con = sqlite3.connect(db_path)
|
||||
|
||||
config = self.nodes[node_name]
|
||||
node = config.node_class(**config.node_kwargs)
|
||||
|
||||
logger.info(f"Executing node: {node_name}")
|
||||
result_context = node.run(con, context)
|
||||
|
||||
con.close()
|
||||
logger.info(f"Node '{node_name}' completed successfully")
|
||||
|
||||
return node_name, result_context, None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error executing node '{node_name}': {e}", exc_info=True)
|
||||
return node_name, None, str(e)
|
||||
|
||||
def run(self,
|
||||
db_path: str,
|
||||
initial_context: TransformContext,
|
||||
fail_fast: bool = False) -> Dict[str, TransformContext]:
|
||||
"""Execute the pipeline.
|
||||
|
||||
Args:
|
||||
db_path: Path to the SQLite database
|
||||
initial_context: Initial TransformContext for the pipeline
|
||||
fail_fast: If True, stop execution on first error
|
||||
|
||||
Returns:
|
||||
Dict mapping node names to their output TransformContext
|
||||
"""
|
||||
logger.info("Starting parallel pipeline execution")
|
||||
|
||||
stages = self._get_execution_stages()
|
||||
logger.info(f"Pipeline has {len(stages)} execution stage(s)")
|
||||
|
||||
results = {}
|
||||
contexts = {None: initial_context} # Track contexts from each node
|
||||
errors = []
|
||||
|
||||
ExecutorClass = ProcessPoolExecutor if self.use_processes else ThreadPoolExecutor
|
||||
|
||||
for stage_num, stage_nodes in enumerate(stages, 1):
|
||||
logger.info(f"Stage {stage_num}/{len(stages)}: Executing {len(stage_nodes)} node(s) in parallel: {stage_nodes}")
|
||||
|
||||
# For nodes in this stage, use the context from their dependencies
|
||||
# If multiple dependencies, we'll use the most recent one (or could merge)
|
||||
stage_futures = {}
|
||||
|
||||
with ExecutorClass(max_workers=min(self.max_workers, len(stage_nodes))) as executor:
|
||||
for node_name in stage_nodes:
|
||||
config = self.nodes[node_name]
|
||||
|
||||
# Get context from dependencies (use the last dependency's output)
|
||||
if config.dependencies:
|
||||
context = results.get(config.dependencies[-1], initial_context)
|
||||
else:
|
||||
context = initial_context
|
||||
|
||||
future = executor.submit(self._execute_node, node_name, db_path, context)
|
||||
stage_futures[future] = node_name
|
||||
|
||||
# Wait for all nodes in this stage to complete
|
||||
for future in as_completed(stage_futures):
|
||||
node_name = stage_futures[future]
|
||||
name, result_context, error = future.result()
|
||||
|
||||
if error:
|
||||
errors.append((name, error))
|
||||
if fail_fast:
|
||||
logger.error(f"Pipeline failed at node '{name}': {error}")
|
||||
raise RuntimeError(f"Node '{name}' failed: {error}")
|
||||
else:
|
||||
results[name] = result_context
|
||||
|
||||
if errors:
|
||||
logger.warning(f"Pipeline completed with {len(errors)} error(s)")
|
||||
for name, error in errors:
|
||||
logger.error(f" - {name}: {error}")
|
||||
else:
|
||||
logger.info("Pipeline completed successfully")
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def create_default_pipeline(device: str = "cpu",
|
||||
max_workers: Optional[int] = None) -> ParallelPipeline:
|
||||
"""Create a pipeline with default transform nodes.
|
||||
|
||||
Args:
|
||||
device: Device to use for compute-intensive nodes ('cpu', 'cuda', 'mps')
|
||||
max_workers: Maximum number of parallel workers
|
||||
|
||||
Returns:
|
||||
Configured ParallelPipeline
|
||||
"""
|
||||
from author_node import NerAuthorNode, FuzzyAuthorNode
|
||||
|
||||
pipeline = ParallelPipeline(max_workers=max_workers, use_processes=False)
|
||||
|
||||
# Add AuthorNode (no dependencies)
|
||||
pipeline.add_node(NodeConfig(
|
||||
node_class=NerAuthorNode,
|
||||
node_kwargs={
|
||||
'device': device,
|
||||
'model_path': os.environ.get('GLINER_MODEL_PATH')
|
||||
},
|
||||
dependencies=[],
|
||||
name='AuthorNode'
|
||||
))
|
||||
|
||||
pipeline.add_node(NodeConfig(
|
||||
node_class=FuzzyAuthorNode,
|
||||
node_kwargs={
|
||||
'max_l_dist': 1
|
||||
},
|
||||
dependencies=['AuthorNode'],
|
||||
name='FuzzyAuthorNode'
|
||||
))
|
||||
|
||||
# TODO: Create Node to compute Text Embeddings and UMAP.
|
||||
# TODO: Create Node to pre-compute data based on visuals to reduce load time.
|
||||
|
||||
# TODO: Add more nodes here as they are implemented
|
||||
# Example:
|
||||
# pipeline.add_node(NodeConfig(
|
||||
# node_class=EmbeddingNode,
|
||||
# node_kwargs={'device': device},
|
||||
# dependencies=[], # Runs after AuthorNode
|
||||
# name='EmbeddingNode'
|
||||
# ))
|
||||
|
||||
# pipeline.add_node(NodeConfig(
|
||||
# node_class=UMAPNode,
|
||||
# node_kwargs={'device': device},
|
||||
# dependencies=['EmbeddingNode'], # Runs after EmbeddingNode
|
||||
# name='UMAPNode'
|
||||
# ))
|
||||
|
||||
return pipeline
|
||||
|
|
@ -2,3 +2,4 @@ pandas
|
|||
python-dotenv
|
||||
gliner
|
||||
torch
|
||||
fuzzysearch
|
||||
|
|
@ -1,19 +1,8 @@
|
|||
"""Base transform node for data pipeline."""
|
||||
from abc import ABC, abstractmethod
|
||||
import sqlite3
|
||||
import pandas as pd
|
||||
|
||||
|
||||
class TransformContext:
|
||||
"""Context object containing the dataframe for transformation."""
|
||||
|
||||
def __init__(self, df: pd.DataFrame):
|
||||
self.df = df
|
||||
|
||||
def get_dataframe(self) -> pd.DataFrame:
|
||||
"""Get the pandas dataframe from the context."""
|
||||
return self.df
|
||||
|
||||
from pipeline import TransformContext
|
||||
|
||||
class TransformNode(ABC):
|
||||
"""Abstract base class for transformation nodes.
|
||||
Loading…
Add table
Add a link
Reference in a new issue