Knack-Scraper/transform/author_node.py
2025-12-24 17:58:23 +01:00

420 lines
15 KiB
Python

"""Author classification transform node using NER."""
import os
import sqlite3
import pandas as pd
import logging
import fuzzysearch
from concurrent.futures import ThreadPoolExecutor
from pipeline import TransformContext
from transform_node import TransformNode
logger = logging.getLogger("knack-transform")
try:
from gliner import GLiNER
import torch
GLINER_AVAILABLE = True
except ImportError:
GLINER_AVAILABLE = False
logging.warning("GLiNER not available. Install with: pip install gliner")
class NerAuthorNode(TransformNode):
"""Transform node that extracts and classifies authors using NER.
Creates two tables:
- authors: stores unique authors with their type (Person, Organisation, etc.)
- post_authors: maps posts to their authors
"""
def __init__(self, model_name: str = "urchade/gliner_multi-v2.1",
model_path: str = None,
threshold: float = 0.5,
max_workers: int = 64,
device: str = "cpu"):
"""Initialize the AuthorNode.
Args:
model_name: GLiNER model to use
model_path: Optional local path to a downloaded GLiNER model
threshold: Confidence threshold for entity predictions
max_workers: Number of parallel workers for prediction
device: Device to run model on ('cpu', 'cuda', 'mps')
"""
self.model_name = model_name
self.model_path = model_path or os.environ.get('GLINER_MODEL_PATH')
self.threshold = threshold
self.max_workers = max_workers
self.device = device
self.model = None
self.labels = ["Person", "Organisation", "Email", "Newspaper", "NGO"]
def _setup_model(self):
"""Initialize the NER model."""
if not GLINER_AVAILABLE:
raise ImportError("GLiNER is required for AuthorNode. Install with: pip install gliner")
model_source = None
if self.model_path:
if os.path.exists(self.model_path):
model_source = self.model_path
logger.info(f"Loading GLiNER model from local path: {self.model_path}")
else:
logger.warning(f"GLINER_MODEL_PATH '{self.model_path}' not found; falling back to hub model {self.model_name}")
if model_source is None:
model_source = self.model_name
logger.info(f"Loading GLiNER model from hub: {self.model_name}")
if self.device == "cuda" and torch.cuda.is_available():
self.model = GLiNER.from_pretrained(
model_source,
max_length=255
).to('cuda', dtype=torch.float16)
elif self.device == "mps" and torch.backends.mps.is_available():
self.model = GLiNER.from_pretrained(
model_source,
max_length=255
).to('mps', dtype=torch.float16)
else:
self.model = GLiNER.from_pretrained(
model_source,
max_length=255
)
logger.info("Model loaded successfully")
def _predict(self, text_data: dict):
"""Predict entities for a single author text.
Args:
text_data: Dict with 'author' and 'id' keys
Returns:
Tuple of (predictions, post_id) or None
"""
if text_data is None or text_data.get('author') is None:
return None
predictions = self.model.predict_entities(
text_data['author'],
self.labels,
threshold=self.threshold
)
return predictions, text_data['id']
def _classify_authors(self, posts_df: pd.DataFrame):
"""Classify all authors in the posts dataframe.
Args:
posts_df: DataFrame with 'id' and 'author' columns
Returns:
List of dicts with 'text', 'label', 'id' keys
"""
if self.model is None:
self._setup_model()
# Prepare input data
authors_data = []
for idx, row in posts_df.iterrows():
if pd.notna(row['author']):
authors_data.append({
'author': row['author'],
'id': row['id']
})
logger.info(f"Classifying {len(authors_data)} authors")
results = []
with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
futures = [executor.submit(self._predict, data) for data in authors_data]
for future in futures:
result = future.result()
if result is not None:
predictions, post_id = result
for pred in predictions:
results.append({
'text': pred['text'],
'label': pred['label'],
'id': post_id
})
logger.info(f"Classification complete. Found {len(results)} author entities")
return results
def _create_tables(self, con: sqlite3.Connection):
"""Create authors and post_authors tables if they don't exist."""
logger.info("Creating authors tables")
con.execute("""
CREATE TABLE IF NOT EXISTS authors (
id INTEGER PRIMARY KEY,
name TEXT,
type TEXT,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)
""")
con.execute("""
CREATE TABLE IF NOT EXISTS post_authors (
post_id INTEGER,
author_id INTEGER,
PRIMARY KEY (post_id, author_id),
FOREIGN KEY (post_id) REFERENCES posts(id),
FOREIGN KEY (author_id) REFERENCES authors(id)
)
""")
con.commit()
def _store_authors(self, con: sqlite3.Connection, results: list):
"""Store classified authors and their mappings.
Args:
con: Database connection
results: List of classification results
"""
if not results:
logger.info("No authors to store")
return
# Convert results to DataFrame
results_df = pd.DataFrame(results)
# Get unique authors with their types
unique_authors = results_df[['text', 'label']].drop_duplicates()
unique_authors.columns = ['name', 'type']
# Get existing authors
existing_authors = pd.read_sql("SELECT id, name FROM authors", con)
# Find new authors to insert
if not existing_authors.empty:
new_authors = unique_authors[~unique_authors['name'].isin(existing_authors['name'])]
else:
new_authors = unique_authors
if not new_authors.empty:
logger.info(f"Inserting {len(new_authors)} new authors")
new_authors.to_sql('authors', con, if_exists='append', index=False)
# Get all authors with their IDs
all_authors = pd.read_sql("SELECT id, name FROM authors", con)
name_to_id = dict(zip(all_authors['name'], all_authors['id']))
# Create post_authors mappings
mappings = []
for _, row in results_df.iterrows():
author_id = name_to_id.get(row['text'])
if author_id:
mappings.append({
'post_id': row['id'],
'author_id': author_id
})
if mappings:
mappings_df = pd.DataFrame(mappings).drop_duplicates()
# Clear existing mappings for these posts (optional, depends on your strategy)
# post_ids = tuple(mappings_df['post_id'].unique())
# con.execute(f"DELETE FROM post_authors WHERE post_id IN ({','.join('?' * len(post_ids))})", post_ids)
logger.info(f"Creating {len(mappings_df)} post-author mappings")
mappings_df.to_sql('post_authors', con, if_exists='append', index=False)
con.commit()
logger.info("Authors and mappings stored successfully")
def run(self, con: sqlite3.Connection, context: TransformContext) -> TransformContext:
"""Execute the author classification transformation.
Args:
con: SQLite database connection
context: TransformContext containing posts dataframe
Returns:
TransformContext with classified authors dataframe
"""
logger.info("Starting AuthorNode transformation")
posts_df = context.get_dataframe()
# Ensure required columns exist
if 'author' not in posts_df.columns:
logger.warning("No 'author' column in dataframe. Skipping AuthorNode.")
return context
# Create tables
self._create_tables(con)
# Classify authors
results = self._classify_authors(posts_df)
# Store results
self._store_authors(con, results)
# Return context with results
logger.info("AuthorNode transformation complete")
return TransformContext(posts_df)
class FuzzyAuthorNode(TransformNode):
"""FuzzyAuthorNode
This Node takes in data and rules of authornames that have been classified already
and uses those 'rule' to find more similar fields.
"""
def __init__(self,
max_l_dist: int = 1,):
"""Initialize FuzzyAuthorNode.
Args:
max_l_dist: The number of 'errors' that are allowed by the fuzzy search algorithm
"""
self.max_l_dist = max_l_dist
logger.info(f"Initialized FuzzyAuthorNode with max_l_dist={max_l_dist}")
def _process_data(self, con: sqlite3.Connection, df: pd.DataFrame) -> pd.DataFrame:
"""Process the input dataframe.
This is where your main transformation logic goes.
Args:
con: Database connection
df: Input dataframe from context
Returns:
Processed dataframe
"""
logger.info(f"Processing {len(df)} rows")
# Retrieve all known authors from the authors table as 'rules'
authors_df = pd.read_sql("SELECT id, name FROM authors", con)
if authors_df.empty:
logger.warning("No authors found in database for fuzzy matching")
return pd.DataFrame(columns=['post_id', 'author_id'])
# Get existing post-author mappings to avoid duplicates
existing_mappings = pd.read_sql(
"SELECT post_id, author_id FROM post_authors", con
)
existing_post_ids = set(existing_mappings['post_id'].unique())
logger.info(f"Found {len(authors_df)} known authors for fuzzy matching")
logger.info(f"Found {len(existing_post_ids)} posts with existing author mappings")
# Filter to posts without author mappings and with non-null author field
if 'author' not in df.columns or 'id' not in df.columns:
logger.warning("Missing 'author' or 'id' column in input dataframe")
return pd.DataFrame(columns=['post_id', 'author_id'])
posts_to_process = df[
(df['id'].notna()) &
(df['author'].notna()) &
(~df['id'].isin(existing_post_ids))
]
logger.info(f"Processing {len(posts_to_process)} posts for fuzzy matching")
# Perform fuzzy matching
mappings = []
for _, post_row in posts_to_process.iterrows():
post_id = post_row['id']
post_author = str(post_row['author'])
# Try to find matches against all known author names
for _, author_row in authors_df.iterrows():
author_id = author_row['id']
author_name = str(author_row['name'])
# for author names < than 2 characters I want a fault tolerance of 0!
l_dist = self.max_l_dist if len(author_name) > 2 else 0
# Use fuzzysearch to find matches with allowed errors
matches = fuzzysearch.find_near_matches(
author_name,
post_author,
max_l_dist=l_dist,
)
if matches:
logger.debug(f"Found fuzzy match: '{author_name}' in '{post_author}' for post {post_id}")
mappings.append({
'post_id': post_id,
'author_id': author_id
})
# Only take the first match per post to avoid multiple mappings
break
# Create result dataframe
result_df = pd.DataFrame(mappings, columns=['post_id', 'author_id']) if mappings else pd.DataFrame(columns=['post_id', 'author_id'])
logger.info(f"Processing complete. Found {len(result_df)} fuzzy matches")
return result_df
def _store_results(self, con: sqlite3.Connection, df: pd.DataFrame):
"""Store results back to the database.
Uses INSERT OR IGNORE to avoid inserting duplicates.
Args:
con: Database connection
df: Processed dataframe to store
"""
if df.empty:
logger.info("No results to store")
return
logger.info(f"Storing {len(df)} results")
# Use INSERT OR IGNORE to handle duplicates (respects PRIMARY KEY constraint)
cursor = con.cursor()
inserted_count = 0
for _, row in df.iterrows():
cursor.execute(
"INSERT OR IGNORE INTO post_authors (post_id, author_id) VALUES (?, ?)",
(int(row['post_id']), int(row['author_id']))
)
if cursor.rowcount > 0:
inserted_count += 1
con.commit()
logger.info(f"Results stored successfully. Inserted {inserted_count} new mappings, skipped {len(df) - inserted_count} duplicates")
def run(self, con: sqlite3.Connection, context: TransformContext) -> TransformContext:
"""Execute the transformation.
This is the main entry point called by the pipeline.
Args:
con: SQLite database connection
context: TransformContext containing input dataframe
Returns:
TransformContext with processed dataframe
"""
logger.info("Starting FuzzyAuthorNode transformation")
# Get input dataframe from context
input_df = context.get_dataframe()
# Validate input
if input_df.empty:
logger.warning("Empty dataframe provided to FuzzyAuthorNode")
return context
# Process the data
result_df = self._process_data(con, input_df)
# Store results
self._store_results(con, result_df)
logger.info("FuzzyAuthorNode transformation complete")
# Return new context with results
return TransformContext(input_df)