Knack-Scraper/transform/author_node.py

263 lines
9.6 KiB
Python

"""Author classification transform node using NER."""
from base import TransformNode, TransformContext
import sqlite3
import pandas as pd
import logging
from concurrent.futures import ThreadPoolExecutor
from datetime import datetime
try:
from gliner import GLiNER
import torch
GLINER_AVAILABLE = True
except ImportError:
GLINER_AVAILABLE = False
logging.warning("GLiNER not available. Install with: pip install gliner")
logger = logging.getLogger("knack-transform")
class AuthorNode(TransformNode):
"""Transform node that extracts and classifies authors using NER.
Creates two tables:
- authors: stores unique authors with their type (Person, Organisation, etc.)
- post_authors: maps posts to their authors
"""
def __init__(self, model_name: str = "urchade/gliner_medium-v2.1",
threshold: float = 0.5,
max_workers: int = 64,
device: str = "cpu"):
"""Initialize the AuthorNode.
Args:
model_name: GLiNER model to use
threshold: Confidence threshold for entity predictions
max_workers: Number of parallel workers for prediction
device: Device to run model on ('cpu', 'cuda', 'mps')
"""
self.model_name = model_name
self.threshold = threshold
self.max_workers = max_workers
self.device = device
self.model = None
self.labels = ["Person", "Organisation", "Email", "Newspaper", "NGO"]
def _setup_model(self):
"""Initialize the NER model."""
if not GLINER_AVAILABLE:
raise ImportError("GLiNER is required for AuthorNode. Install with: pip install gliner")
logger.info(f"Loading GLiNER model: {self.model_name}")
if self.device == "cuda" and torch.cuda.is_available():
self.model = GLiNER.from_pretrained(
self.model_name,
max_length=255
).to('cuda', dtype=torch.float16)
elif self.device == "mps" and torch.backends.mps.is_available():
self.model = GLiNER.from_pretrained(
self.model_name,
max_length=255
).to('mps', dtype=torch.float16)
else:
self.model = GLiNER.from_pretrained(
self.model_name,
max_length=255
)
logger.info("Model loaded successfully")
def _predict(self, text_data: dict):
"""Predict entities for a single author text.
Args:
text_data: Dict with 'author' and 'id' keys
Returns:
Tuple of (predictions, post_id) or None
"""
if text_data is None or text_data.get('author') is None:
return None
predictions = self.model.predict_entities(
text_data['author'],
self.labels,
threshold=self.threshold
)
return predictions, text_data['id']
def _classify_authors(self, posts_df: pd.DataFrame):
"""Classify all authors in the posts dataframe.
Args:
posts_df: DataFrame with 'id' and 'author' columns
Returns:
List of dicts with 'text', 'label', 'id' keys
"""
if self.model is None:
self._setup_model()
# Prepare input data
authors_data = []
for idx, row in posts_df.iterrows():
if pd.notna(row['author']):
authors_data.append({
'author': row['author'],
'id': row['id']
})
logger.info(f"Classifying {len(authors_data)} authors")
results = []
with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
futures = [executor.submit(self._predict, data) for data in authors_data]
for future in futures:
result = future.result()
if result is not None:
predictions, post_id = result
for pred in predictions:
results.append({
'text': pred['text'],
'label': pred['label'],
'id': post_id
})
logger.info(f"Classification complete. Found {len(results)} author entities")
return results
def _create_tables(self, con: sqlite3.Connection):
"""Create authors and post_authors tables if they don't exist."""
logger.info("Creating authors tables")
con.execute("""
CREATE TABLE IF NOT EXISTS authors (
id INTEGER PRIMARY KEY,
name TEXT,
type TEXT,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)
""")
con.execute("""
CREATE TABLE IF NOT EXISTS post_authors (
post_id INTEGER,
author_id INTEGER,
PRIMARY KEY (post_id, author_id),
FOREIGN KEY (post_id) REFERENCES posts(id),
FOREIGN KEY (author_id) REFERENCES authors(id)
)
""")
con.commit()
def _store_authors(self, con: sqlite3.Connection, results: list):
"""Store classified authors and their mappings.
Args:
con: Database connection
results: List of classification results
"""
if not results:
logger.info("No authors to store")
return
# Convert results to DataFrame
results_df = pd.DataFrame(results)
# Get unique authors with their types
unique_authors = results_df[['text', 'label']].drop_duplicates()
unique_authors.columns = ['name', 'type']
# Get existing authors
existing_authors = pd.read_sql("SELECT id, name FROM authors", con)
# Find new authors to insert
if not existing_authors.empty:
new_authors = unique_authors[~unique_authors['name'].isin(existing_authors['name'])]
else:
new_authors = unique_authors
if not new_authors.empty:
logger.info(f"Inserting {len(new_authors)} new authors")
new_authors.to_sql('authors', con, if_exists='append', index=False)
# Get all authors with their IDs
all_authors = pd.read_sql("SELECT id, name FROM authors", con)
name_to_id = dict(zip(all_authors['name'], all_authors['id']))
# Create post_authors mappings
mappings = []
for _, row in results_df.iterrows():
author_id = name_to_id.get(row['text'])
if author_id:
mappings.append({
'post_id': row['id'],
'author_id': author_id
})
if mappings:
mappings_df = pd.DataFrame(mappings).drop_duplicates()
# Clear existing mappings for these posts (optional, depends on your strategy)
# post_ids = tuple(mappings_df['post_id'].unique())
# con.execute(f"DELETE FROM post_authors WHERE post_id IN ({','.join('?' * len(post_ids))})", post_ids)
logger.info(f"Creating {len(mappings_df)} post-author mappings")
mappings_df.to_sql('post_authors', con, if_exists='append', index=False)
# Mark posts as cleaned
processed_post_ids = mappings_df['post_id'].unique().tolist()
if processed_post_ids:
placeholders = ','.join('?' * len(processed_post_ids))
con.execute(f"UPDATE posts SET is_cleaned = 1 WHERE id IN ({placeholders})", processed_post_ids)
logger.info(f"Marked {len(processed_post_ids)} posts as cleaned")
con.commit()
logger.info("Authors and mappings stored successfully")
def run(self, con: sqlite3.Connection, context: TransformContext) -> TransformContext:
"""Execute the author classification transformation.
Args:
con: SQLite database connection
context: TransformContext containing posts dataframe
Returns:
TransformContext with classified authors dataframe
"""
logger.info("Starting AuthorNode transformation")
posts_df = context.get_dataframe()
# Ensure required columns exist
if 'author' not in posts_df.columns:
logger.warning("No 'author' column in dataframe. Skipping AuthorNode.")
return context
# Create tables
self._create_tables(con)
# Classify authors
results = self._classify_authors(posts_df)
# Store results
self._store_authors(con, results)
# Mark posts without author entities as cleaned too (no authors found)
processed_ids = set([r['id'] for r in results]) if results else set()
unprocessed_ids = [pid for pid in posts_df['id'].tolist() if pid not in processed_ids]
if unprocessed_ids:
placeholders = ','.join('?' * len(unprocessed_ids))
con.execute(f"UPDATE posts SET is_cleaned = 1 WHERE id IN ({placeholders})", unprocessed_ids)
con.commit()
logger.info(f"Marked {len(unprocessed_ids)} posts without author entities as cleaned")
# Return context with results
results_df = pd.DataFrame(results) if results else pd.DataFrame()
logger.info("AuthorNode transformation complete")
return TransformContext(results_df)