Adds TransformNode to FuzzyFind Author Names
This commit is contained in:
parent
64df8fb328
commit
72765532d3
11 changed files with 696 additions and 58 deletions
|
|
@ -1,10 +1,13 @@
|
|||
"""Author classification transform node using NER."""
|
||||
from base import TransformNode, TransformContext
|
||||
import os
|
||||
import sqlite3
|
||||
import pandas as pd
|
||||
import logging
|
||||
import fuzzysearch
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from datetime import datetime
|
||||
|
||||
from pipeline import TransformContext
|
||||
from transform_node import TransformNode
|
||||
|
||||
try:
|
||||
from gliner import GLiNER
|
||||
|
|
@ -17,7 +20,7 @@ except ImportError:
|
|||
logger = logging.getLogger("knack-transform")
|
||||
|
||||
|
||||
class AuthorNode(TransformNode):
|
||||
class NerAuthorNode(TransformNode):
|
||||
"""Transform node that extracts and classifies authors using NER.
|
||||
|
||||
Creates two tables:
|
||||
|
|
@ -25,7 +28,8 @@ class AuthorNode(TransformNode):
|
|||
- post_authors: maps posts to their authors
|
||||
"""
|
||||
|
||||
def __init__(self, model_name: str = "urchade/gliner_medium-v2.1",
|
||||
def __init__(self, model_name: str = "urchade/gliner_multi-v2.1",
|
||||
model_path: str = None,
|
||||
threshold: float = 0.5,
|
||||
max_workers: int = 64,
|
||||
device: str = "cpu"):
|
||||
|
|
@ -33,11 +37,13 @@ class AuthorNode(TransformNode):
|
|||
|
||||
Args:
|
||||
model_name: GLiNER model to use
|
||||
model_path: Optional local path to a downloaded GLiNER model
|
||||
threshold: Confidence threshold for entity predictions
|
||||
max_workers: Number of parallel workers for prediction
|
||||
device: Device to run model on ('cpu', 'cuda', 'mps')
|
||||
"""
|
||||
self.model_name = model_name
|
||||
self.model_path = model_path or os.environ.get('GLINER_MODEL_PATH')
|
||||
self.threshold = threshold
|
||||
self.max_workers = max_workers
|
||||
self.device = device
|
||||
|
|
@ -49,21 +55,31 @@ class AuthorNode(TransformNode):
|
|||
if not GLINER_AVAILABLE:
|
||||
raise ImportError("GLiNER is required for AuthorNode. Install with: pip install gliner")
|
||||
|
||||
logger.info(f"Loading GLiNER model: {self.model_name}")
|
||||
model_source = None
|
||||
if self.model_path:
|
||||
if os.path.exists(self.model_path):
|
||||
model_source = self.model_path
|
||||
logger.info(f"Loading GLiNER model from local path: {self.model_path}")
|
||||
else:
|
||||
logger.warning(f"GLINER_MODEL_PATH '{self.model_path}' not found; falling back to hub model {self.model_name}")
|
||||
|
||||
if model_source is None:
|
||||
model_source = self.model_name
|
||||
logger.info(f"Loading GLiNER model from hub: {self.model_name}")
|
||||
|
||||
if self.device == "cuda" and torch.cuda.is_available():
|
||||
self.model = GLiNER.from_pretrained(
|
||||
self.model_name,
|
||||
model_source,
|
||||
max_length=255
|
||||
).to('cuda', dtype=torch.float16)
|
||||
elif self.device == "mps" and torch.backends.mps.is_available():
|
||||
self.model = GLiNER.from_pretrained(
|
||||
self.model_name,
|
||||
model_source,
|
||||
max_length=255
|
||||
).to('mps', dtype=torch.float16)
|
||||
else:
|
||||
self.model = GLiNER.from_pretrained(
|
||||
self.model_name,
|
||||
model_source,
|
||||
max_length=255
|
||||
)
|
||||
|
||||
|
|
@ -208,13 +224,6 @@ class AuthorNode(TransformNode):
|
|||
|
||||
logger.info(f"Creating {len(mappings_df)} post-author mappings")
|
||||
mappings_df.to_sql('post_authors', con, if_exists='append', index=False)
|
||||
|
||||
# Mark posts as cleaned
|
||||
processed_post_ids = mappings_df['post_id'].unique().tolist()
|
||||
if processed_post_ids:
|
||||
placeholders = ','.join('?' * len(processed_post_ids))
|
||||
con.execute(f"UPDATE posts SET is_cleaned = 1 WHERE id IN ({placeholders})", processed_post_ids)
|
||||
logger.info(f"Marked {len(processed_post_ids)} posts as cleaned")
|
||||
|
||||
con.commit()
|
||||
logger.info("Authors and mappings stored successfully")
|
||||
|
|
@ -247,17 +256,165 @@ class AuthorNode(TransformNode):
|
|||
# Store results
|
||||
self._store_authors(con, results)
|
||||
|
||||
# Mark posts without author entities as cleaned too (no authors found)
|
||||
processed_ids = set([r['id'] for r in results]) if results else set()
|
||||
unprocessed_ids = [pid for pid in posts_df['id'].tolist() if pid not in processed_ids]
|
||||
if unprocessed_ids:
|
||||
placeholders = ','.join('?' * len(unprocessed_ids))
|
||||
con.execute(f"UPDATE posts SET is_cleaned = 1 WHERE id IN ({placeholders})", unprocessed_ids)
|
||||
con.commit()
|
||||
logger.info(f"Marked {len(unprocessed_ids)} posts without author entities as cleaned")
|
||||
|
||||
# Return context with results
|
||||
results_df = pd.DataFrame(results) if results else pd.DataFrame()
|
||||
logger.info("AuthorNode transformation complete")
|
||||
|
||||
return TransformContext(results_df)
|
||||
|
||||
|
||||
class FuzzyAuthorNode(TransformNode):
|
||||
"""FuzzyAuthorNode
|
||||
|
||||
This Node takes in data and rules of authornames that have been classified already
|
||||
and uses those 'rule' to find more similar fields.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
max_l_dist: int = 1,):
|
||||
"""Initialize FuzzyAuthorNode.
|
||||
|
||||
Args:
|
||||
max_l_dist: The number of 'errors' that are allowed by the fuzzy search algorithm
|
||||
"""
|
||||
self.max_l_dist = max_l_dist
|
||||
logger.info(f"Initialized FuzzyAuthorNode with max_l_dist={max_l_dist}")
|
||||
|
||||
def _process_data(self, con: sqlite3.Connection, df: pd.DataFrame) -> pd.DataFrame:
|
||||
"""Process the input dataframe.
|
||||
|
||||
This is where your main transformation logic goes.
|
||||
|
||||
Args:
|
||||
con: Database connection
|
||||
df: Input dataframe from context
|
||||
|
||||
Returns:
|
||||
Processed dataframe
|
||||
"""
|
||||
logger.info(f"Processing {len(df)} rows")
|
||||
|
||||
# Retrieve all known authors from the authors table as 'rules'
|
||||
authors_df = pd.read_sql("SELECT id, name FROM authors", con)
|
||||
|
||||
if authors_df.empty:
|
||||
logger.warning("No authors found in database for fuzzy matching")
|
||||
return pd.DataFrame(columns=['post_id', 'author_id'])
|
||||
|
||||
# Get existing post-author mappings to avoid duplicates
|
||||
existing_mappings = pd.read_sql(
|
||||
"SELECT post_id, author_id FROM post_authors", con
|
||||
)
|
||||
existing_post_ids = set(existing_mappings['post_id'].unique())
|
||||
|
||||
logger.info(f"Found {len(authors_df)} known authors for fuzzy matching")
|
||||
logger.info(f"Found {len(existing_post_ids)} posts with existing author mappings")
|
||||
|
||||
# Filter to posts without author mappings and with non-null author field
|
||||
if 'author' not in df.columns or 'id' not in df.columns:
|
||||
logger.warning("Missing 'author' or 'id' column in input dataframe")
|
||||
return pd.DataFrame(columns=['post_id', 'author_id'])
|
||||
|
||||
posts_to_process = df[
|
||||
(df['id'].notna()) &
|
||||
(df['author'].notna()) &
|
||||
(~df['id'].isin(existing_post_ids))
|
||||
]
|
||||
|
||||
logger.info(f"Processing {len(posts_to_process)} posts for fuzzy matching")
|
||||
|
||||
# Perform fuzzy matching
|
||||
mappings = []
|
||||
for _, post_row in posts_to_process.iterrows():
|
||||
post_id = post_row['id']
|
||||
post_author = str(post_row['author'])
|
||||
|
||||
# Try to find matches against all known author names
|
||||
for _, author_row in authors_df.iterrows():
|
||||
author_id = author_row['id']
|
||||
author_name = str(author_row['name'])
|
||||
|
||||
# Use fuzzysearch to find matches with allowed errors
|
||||
matches = fuzzysearch.find_near_matches(
|
||||
author_name,
|
||||
post_author,
|
||||
max_l_dist=self.max_l_dist
|
||||
)
|
||||
|
||||
if matches:
|
||||
logger.debug(f"Found fuzzy match: '{author_name}' in '{post_author}' for post {post_id}")
|
||||
mappings.append({
|
||||
'post_id': post_id,
|
||||
'author_id': author_id
|
||||
})
|
||||
# Only take the first match per post to avoid multiple mappings
|
||||
break
|
||||
|
||||
# Create result dataframe
|
||||
result_df = pd.DataFrame(mappings, columns=['post_id', 'author_id']) if mappings else pd.DataFrame(columns=['post_id', 'author_id'])
|
||||
|
||||
logger.info(f"Processing complete. Found {len(result_df)} fuzzy matches")
|
||||
return result_df
|
||||
|
||||
def _store_results(self, con: sqlite3.Connection, df: pd.DataFrame):
|
||||
"""Store results back to the database.
|
||||
|
||||
Uses INSERT OR IGNORE to avoid inserting duplicates.
|
||||
|
||||
Args:
|
||||
con: Database connection
|
||||
df: Processed dataframe to store
|
||||
"""
|
||||
if df.empty:
|
||||
logger.info("No results to store")
|
||||
return
|
||||
|
||||
logger.info(f"Storing {len(df)} results")
|
||||
|
||||
# Use INSERT OR IGNORE to handle duplicates (respects PRIMARY KEY constraint)
|
||||
cursor = con.cursor()
|
||||
inserted_count = 0
|
||||
|
||||
for _, row in df.iterrows():
|
||||
cursor.execute(
|
||||
"INSERT OR IGNORE INTO post_authors (post_id, author_id) VALUES (?, ?)",
|
||||
(int(row['post_id']), int(row['author_id']))
|
||||
)
|
||||
if cursor.rowcount > 0:
|
||||
inserted_count += 1
|
||||
|
||||
con.commit()
|
||||
logger.info(f"Results stored successfully. Inserted {inserted_count} new mappings, skipped {len(df) - inserted_count} duplicates")
|
||||
|
||||
def run(self, con: sqlite3.Connection, context: TransformContext) -> TransformContext:
|
||||
"""Execute the transformation.
|
||||
|
||||
This is the main entry point called by the pipeline.
|
||||
|
||||
Args:
|
||||
con: SQLite database connection
|
||||
context: TransformContext containing input dataframe
|
||||
|
||||
Returns:
|
||||
TransformContext with processed dataframe
|
||||
"""
|
||||
logger.info("Starting FuzzyAuthorNode transformation")
|
||||
|
||||
# Get input dataframe from context
|
||||
input_df = context.get_dataframe()
|
||||
|
||||
# Validate input
|
||||
if input_df.empty:
|
||||
logger.warning("Empty dataframe provided to FuzzyAuthorNode")
|
||||
return context
|
||||
|
||||
# Process the data
|
||||
result_df = self._process_data(con, input_df)
|
||||
|
||||
# Store results
|
||||
self._store_results(con, result_df)
|
||||
|
||||
logger.info("FuzzyAuthorNode transformation complete")
|
||||
|
||||
# Return new context with results
|
||||
return TransformContext(result_df)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue