Adds TransformNode to FuzzyFind Author Names

This commit is contained in:
quorploop 2025-12-23 17:53:37 +01:00
parent 64df8fb328
commit 72765532d3
11 changed files with 696 additions and 58 deletions

View file

@ -1,10 +1,13 @@
"""Author classification transform node using NER."""
from base import TransformNode, TransformContext
import os
import sqlite3
import pandas as pd
import logging
import fuzzysearch
from concurrent.futures import ThreadPoolExecutor
from datetime import datetime
from pipeline import TransformContext
from transform_node import TransformNode
try:
from gliner import GLiNER
@ -17,7 +20,7 @@ except ImportError:
logger = logging.getLogger("knack-transform")
class AuthorNode(TransformNode):
class NerAuthorNode(TransformNode):
"""Transform node that extracts and classifies authors using NER.
Creates two tables:
@ -25,7 +28,8 @@ class AuthorNode(TransformNode):
- post_authors: maps posts to their authors
"""
def __init__(self, model_name: str = "urchade/gliner_medium-v2.1",
def __init__(self, model_name: str = "urchade/gliner_multi-v2.1",
model_path: str = None,
threshold: float = 0.5,
max_workers: int = 64,
device: str = "cpu"):
@ -33,11 +37,13 @@ class AuthorNode(TransformNode):
Args:
model_name: GLiNER model to use
model_path: Optional local path to a downloaded GLiNER model
threshold: Confidence threshold for entity predictions
max_workers: Number of parallel workers for prediction
device: Device to run model on ('cpu', 'cuda', 'mps')
"""
self.model_name = model_name
self.model_path = model_path or os.environ.get('GLINER_MODEL_PATH')
self.threshold = threshold
self.max_workers = max_workers
self.device = device
@ -49,21 +55,31 @@ class AuthorNode(TransformNode):
if not GLINER_AVAILABLE:
raise ImportError("GLiNER is required for AuthorNode. Install with: pip install gliner")
logger.info(f"Loading GLiNER model: {self.model_name}")
model_source = None
if self.model_path:
if os.path.exists(self.model_path):
model_source = self.model_path
logger.info(f"Loading GLiNER model from local path: {self.model_path}")
else:
logger.warning(f"GLINER_MODEL_PATH '{self.model_path}' not found; falling back to hub model {self.model_name}")
if model_source is None:
model_source = self.model_name
logger.info(f"Loading GLiNER model from hub: {self.model_name}")
if self.device == "cuda" and torch.cuda.is_available():
self.model = GLiNER.from_pretrained(
self.model_name,
model_source,
max_length=255
).to('cuda', dtype=torch.float16)
elif self.device == "mps" and torch.backends.mps.is_available():
self.model = GLiNER.from_pretrained(
self.model_name,
model_source,
max_length=255
).to('mps', dtype=torch.float16)
else:
self.model = GLiNER.from_pretrained(
self.model_name,
model_source,
max_length=255
)
@ -208,13 +224,6 @@ class AuthorNode(TransformNode):
logger.info(f"Creating {len(mappings_df)} post-author mappings")
mappings_df.to_sql('post_authors', con, if_exists='append', index=False)
# Mark posts as cleaned
processed_post_ids = mappings_df['post_id'].unique().tolist()
if processed_post_ids:
placeholders = ','.join('?' * len(processed_post_ids))
con.execute(f"UPDATE posts SET is_cleaned = 1 WHERE id IN ({placeholders})", processed_post_ids)
logger.info(f"Marked {len(processed_post_ids)} posts as cleaned")
con.commit()
logger.info("Authors and mappings stored successfully")
@ -247,17 +256,165 @@ class AuthorNode(TransformNode):
# Store results
self._store_authors(con, results)
# Mark posts without author entities as cleaned too (no authors found)
processed_ids = set([r['id'] for r in results]) if results else set()
unprocessed_ids = [pid for pid in posts_df['id'].tolist() if pid not in processed_ids]
if unprocessed_ids:
placeholders = ','.join('?' * len(unprocessed_ids))
con.execute(f"UPDATE posts SET is_cleaned = 1 WHERE id IN ({placeholders})", unprocessed_ids)
con.commit()
logger.info(f"Marked {len(unprocessed_ids)} posts without author entities as cleaned")
# Return context with results
results_df = pd.DataFrame(results) if results else pd.DataFrame()
logger.info("AuthorNode transformation complete")
return TransformContext(results_df)
class FuzzyAuthorNode(TransformNode):
"""FuzzyAuthorNode
This Node takes in data and rules of authornames that have been classified already
and uses those 'rule' to find more similar fields.
"""
def __init__(self,
max_l_dist: int = 1,):
"""Initialize FuzzyAuthorNode.
Args:
max_l_dist: The number of 'errors' that are allowed by the fuzzy search algorithm
"""
self.max_l_dist = max_l_dist
logger.info(f"Initialized FuzzyAuthorNode with max_l_dist={max_l_dist}")
def _process_data(self, con: sqlite3.Connection, df: pd.DataFrame) -> pd.DataFrame:
"""Process the input dataframe.
This is where your main transformation logic goes.
Args:
con: Database connection
df: Input dataframe from context
Returns:
Processed dataframe
"""
logger.info(f"Processing {len(df)} rows")
# Retrieve all known authors from the authors table as 'rules'
authors_df = pd.read_sql("SELECT id, name FROM authors", con)
if authors_df.empty:
logger.warning("No authors found in database for fuzzy matching")
return pd.DataFrame(columns=['post_id', 'author_id'])
# Get existing post-author mappings to avoid duplicates
existing_mappings = pd.read_sql(
"SELECT post_id, author_id FROM post_authors", con
)
existing_post_ids = set(existing_mappings['post_id'].unique())
logger.info(f"Found {len(authors_df)} known authors for fuzzy matching")
logger.info(f"Found {len(existing_post_ids)} posts with existing author mappings")
# Filter to posts without author mappings and with non-null author field
if 'author' not in df.columns or 'id' not in df.columns:
logger.warning("Missing 'author' or 'id' column in input dataframe")
return pd.DataFrame(columns=['post_id', 'author_id'])
posts_to_process = df[
(df['id'].notna()) &
(df['author'].notna()) &
(~df['id'].isin(existing_post_ids))
]
logger.info(f"Processing {len(posts_to_process)} posts for fuzzy matching")
# Perform fuzzy matching
mappings = []
for _, post_row in posts_to_process.iterrows():
post_id = post_row['id']
post_author = str(post_row['author'])
# Try to find matches against all known author names
for _, author_row in authors_df.iterrows():
author_id = author_row['id']
author_name = str(author_row['name'])
# Use fuzzysearch to find matches with allowed errors
matches = fuzzysearch.find_near_matches(
author_name,
post_author,
max_l_dist=self.max_l_dist
)
if matches:
logger.debug(f"Found fuzzy match: '{author_name}' in '{post_author}' for post {post_id}")
mappings.append({
'post_id': post_id,
'author_id': author_id
})
# Only take the first match per post to avoid multiple mappings
break
# Create result dataframe
result_df = pd.DataFrame(mappings, columns=['post_id', 'author_id']) if mappings else pd.DataFrame(columns=['post_id', 'author_id'])
logger.info(f"Processing complete. Found {len(result_df)} fuzzy matches")
return result_df
def _store_results(self, con: sqlite3.Connection, df: pd.DataFrame):
"""Store results back to the database.
Uses INSERT OR IGNORE to avoid inserting duplicates.
Args:
con: Database connection
df: Processed dataframe to store
"""
if df.empty:
logger.info("No results to store")
return
logger.info(f"Storing {len(df)} results")
# Use INSERT OR IGNORE to handle duplicates (respects PRIMARY KEY constraint)
cursor = con.cursor()
inserted_count = 0
for _, row in df.iterrows():
cursor.execute(
"INSERT OR IGNORE INTO post_authors (post_id, author_id) VALUES (?, ?)",
(int(row['post_id']), int(row['author_id']))
)
if cursor.rowcount > 0:
inserted_count += 1
con.commit()
logger.info(f"Results stored successfully. Inserted {inserted_count} new mappings, skipped {len(df) - inserted_count} duplicates")
def run(self, con: sqlite3.Connection, context: TransformContext) -> TransformContext:
"""Execute the transformation.
This is the main entry point called by the pipeline.
Args:
con: SQLite database connection
context: TransformContext containing input dataframe
Returns:
TransformContext with processed dataframe
"""
logger.info("Starting FuzzyAuthorNode transformation")
# Get input dataframe from context
input_df = context.get_dataframe()
# Validate input
if input_df.empty:
logger.warning("Empty dataframe provided to FuzzyAuthorNode")
return context
# Process the data
result_df = self._process_data(con, input_df)
# Store results
self._store_results(con, result_df)
logger.info("FuzzyAuthorNode transformation complete")
# Return new context with results
return TransformContext(result_df)