"""Author classification transform node using NER.""" from base import TransformNode, TransformContext import sqlite3 import pandas as pd import logging from concurrent.futures import ThreadPoolExecutor from datetime import datetime try: from gliner import GLiNER import torch GLINER_AVAILABLE = True except ImportError: GLINER_AVAILABLE = False logging.warning("GLiNER not available. Install with: pip install gliner") logger = logging.getLogger("knack-transform") class AuthorNode(TransformNode): """Transform node that extracts and classifies authors using NER. Creates two tables: - authors: stores unique authors with their type (Person, Organisation, etc.) - post_authors: maps posts to their authors """ def __init__(self, model_name: str = "urchade/gliner_medium-v2.1", threshold: float = 0.5, max_workers: int = 64, device: str = "cpu"): """Initialize the AuthorNode. Args: model_name: GLiNER model to use threshold: Confidence threshold for entity predictions max_workers: Number of parallel workers for prediction device: Device to run model on ('cpu', 'cuda', 'mps') """ self.model_name = model_name self.threshold = threshold self.max_workers = max_workers self.device = device self.model = None self.labels = ["Person", "Organisation", "Email", "Newspaper", "NGO"] def _setup_model(self): """Initialize the NER model.""" if not GLINER_AVAILABLE: raise ImportError("GLiNER is required for AuthorNode. Install with: pip install gliner") logger.info(f"Loading GLiNER model: {self.model_name}") if self.device == "cuda" and torch.cuda.is_available(): self.model = GLiNER.from_pretrained( self.model_name, max_length=255 ).to('cuda', dtype=torch.float16) elif self.device == "mps" and torch.backends.mps.is_available(): self.model = GLiNER.from_pretrained( self.model_name, max_length=255 ).to('mps', dtype=torch.float16) else: self.model = GLiNER.from_pretrained( self.model_name, max_length=255 ) logger.info("Model loaded successfully") def _predict(self, text_data: dict): """Predict entities for a single author text. Args: text_data: Dict with 'author' and 'id' keys Returns: Tuple of (predictions, post_id) or None """ if text_data is None or text_data.get('author') is None: return None predictions = self.model.predict_entities( text_data['author'], self.labels, threshold=self.threshold ) return predictions, text_data['id'] def _classify_authors(self, posts_df: pd.DataFrame): """Classify all authors in the posts dataframe. Args: posts_df: DataFrame with 'id' and 'author' columns Returns: List of dicts with 'text', 'label', 'id' keys """ if self.model is None: self._setup_model() # Prepare input data authors_data = [] for idx, row in posts_df.iterrows(): if pd.notna(row['author']): authors_data.append({ 'author': row['author'], 'id': row['id'] }) logger.info(f"Classifying {len(authors_data)} authors") results = [] with ThreadPoolExecutor(max_workers=self.max_workers) as executor: futures = [executor.submit(self._predict, data) for data in authors_data] for future in futures: result = future.result() if result is not None: predictions, post_id = result for pred in predictions: results.append({ 'text': pred['text'], 'label': pred['label'], 'id': post_id }) logger.info(f"Classification complete. Found {len(results)} author entities") return results def _create_tables(self, con: sqlite3.Connection): """Create authors and post_authors tables if they don't exist.""" logger.info("Creating authors tables") con.execute(""" CREATE TABLE IF NOT EXISTS authors ( id INTEGER PRIMARY KEY, name TEXT, type TEXT, created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ) """) con.execute(""" CREATE TABLE IF NOT EXISTS post_authors ( post_id INTEGER, author_id INTEGER, PRIMARY KEY (post_id, author_id), FOREIGN KEY (post_id) REFERENCES posts(id), FOREIGN KEY (author_id) REFERENCES authors(id) ) """) con.commit() def _store_authors(self, con: sqlite3.Connection, results: list): """Store classified authors and their mappings. Args: con: Database connection results: List of classification results """ if not results: logger.info("No authors to store") return # Convert results to DataFrame results_df = pd.DataFrame(results) # Get unique authors with their types unique_authors = results_df[['text', 'label']].drop_duplicates() unique_authors.columns = ['name', 'type'] # Get existing authors existing_authors = pd.read_sql("SELECT id, name FROM authors", con) # Find new authors to insert if not existing_authors.empty: new_authors = unique_authors[~unique_authors['name'].isin(existing_authors['name'])] else: new_authors = unique_authors if not new_authors.empty: logger.info(f"Inserting {len(new_authors)} new authors") new_authors.to_sql('authors', con, if_exists='append', index=False) # Get all authors with their IDs all_authors = pd.read_sql("SELECT id, name FROM authors", con) name_to_id = dict(zip(all_authors['name'], all_authors['id'])) # Create post_authors mappings mappings = [] for _, row in results_df.iterrows(): author_id = name_to_id.get(row['text']) if author_id: mappings.append({ 'post_id': row['id'], 'author_id': author_id }) if mappings: mappings_df = pd.DataFrame(mappings).drop_duplicates() # Clear existing mappings for these posts (optional, depends on your strategy) # post_ids = tuple(mappings_df['post_id'].unique()) # con.execute(f"DELETE FROM post_authors WHERE post_id IN ({','.join('?' * len(post_ids))})", post_ids) logger.info(f"Creating {len(mappings_df)} post-author mappings") mappings_df.to_sql('post_authors', con, if_exists='append', index=False) # Mark posts as cleaned processed_post_ids = mappings_df['post_id'].unique().tolist() if processed_post_ids: placeholders = ','.join('?' * len(processed_post_ids)) con.execute(f"UPDATE posts SET is_cleaned = 1 WHERE id IN ({placeholders})", processed_post_ids) logger.info(f"Marked {len(processed_post_ids)} posts as cleaned") con.commit() logger.info("Authors and mappings stored successfully") def run(self, con: sqlite3.Connection, context: TransformContext) -> TransformContext: """Execute the author classification transformation. Args: con: SQLite database connection context: TransformContext containing posts dataframe Returns: TransformContext with classified authors dataframe """ logger.info("Starting AuthorNode transformation") posts_df = context.get_dataframe() # Ensure required columns exist if 'author' not in posts_df.columns: logger.warning("No 'author' column in dataframe. Skipping AuthorNode.") return context # Create tables self._create_tables(con) # Classify authors results = self._classify_authors(posts_df) # Store results self._store_authors(con, results) # Mark posts without author entities as cleaned too (no authors found) processed_ids = set([r['id'] for r in results]) if results else set() unprocessed_ids = [pid for pid in posts_df['id'].tolist() if pid not in processed_ids] if unprocessed_ids: placeholders = ','.join('?' * len(unprocessed_ids)) con.execute(f"UPDATE posts SET is_cleaned = 1 WHERE id IN ({placeholders})", unprocessed_ids) con.commit() logger.info(f"Marked {len(unprocessed_ids)} posts without author entities as cleaned") # Return context with results results_df = pd.DataFrame(results) if results else pd.DataFrame() logger.info("AuthorNode transformation complete") return TransformContext(results_df)