"""Author classification transform node using NER.""" import os import sqlite3 import pandas as pd import logging import fuzzysearch from concurrent.futures import ThreadPoolExecutor from pipeline import TransformContext from transform_node import TransformNode logger = logging.getLogger("knack-transform") try: from gliner import GLiNER import torch GLINER_AVAILABLE = True except ImportError: GLINER_AVAILABLE = False logging.warning("GLiNER not available. Install with: pip install gliner") class NerAuthorNode(TransformNode): """Transform node that extracts and classifies authors using NER. Creates two tables: - authors: stores unique authors with their type (Person, Organisation, etc.) - post_authors: maps posts to their authors """ def __init__(self, model_name: str = "urchade/gliner_multi-v2.1", model_path: str = None, threshold: float = 0.5, max_workers: int = 64, device: str = "cpu"): """Initialize the AuthorNode. Args: model_name: GLiNER model to use model_path: Optional local path to a downloaded GLiNER model threshold: Confidence threshold for entity predictions max_workers: Number of parallel workers for prediction device: Device to run model on ('cpu', 'cuda', 'mps') """ self.model_name = model_name self.model_path = model_path or os.environ.get('GLINER_MODEL_PATH') self.threshold = threshold self.max_workers = max_workers self.device = device self.model = None self.labels = ["Person", "Organisation", "Email", "Newspaper", "NGO"] def _setup_model(self): """Initialize the NER model.""" if not GLINER_AVAILABLE: raise ImportError("GLiNER is required for AuthorNode. Install with: pip install gliner") model_source = None if self.model_path: if os.path.exists(self.model_path): model_source = self.model_path logger.info(f"Loading GLiNER model from local path: {self.model_path}") else: logger.warning(f"GLINER_MODEL_PATH '{self.model_path}' not found; falling back to hub model {self.model_name}") if model_source is None: model_source = self.model_name logger.info(f"Loading GLiNER model from hub: {self.model_name}") if self.device == "cuda" and torch.cuda.is_available(): self.model = GLiNER.from_pretrained( model_source, max_length=255 ).to('cuda', dtype=torch.float16) elif self.device == "mps" and torch.backends.mps.is_available(): self.model = GLiNER.from_pretrained( model_source, max_length=255 ).to('mps', dtype=torch.float16) else: self.model = GLiNER.from_pretrained( model_source, max_length=255 ) logger.info("Model loaded successfully") def _predict(self, text_data: dict): """Predict entities for a single author text. Args: text_data: Dict with 'author' and 'id' keys Returns: Tuple of (predictions, post_id) or None """ if text_data is None or text_data.get('author') is None: return None predictions = self.model.predict_entities( text_data['author'], self.labels, threshold=self.threshold ) return predictions, text_data['id'] def _classify_authors(self, posts_df: pd.DataFrame): """Classify all authors in the posts dataframe. Args: posts_df: DataFrame with 'id' and 'author' columns Returns: List of dicts with 'text', 'label', 'id' keys """ if self.model is None: self._setup_model() # Prepare input data authors_data = [] for idx, row in posts_df.iterrows(): if pd.notna(row['author']): authors_data.append({ 'author': row['author'], 'id': row['id'] }) logger.info(f"Classifying {len(authors_data)} authors") results = [] with ThreadPoolExecutor(max_workers=self.max_workers) as executor: futures = [executor.submit(self._predict, data) for data in authors_data] for future in futures: result = future.result() if result is not None: predictions, post_id = result for pred in predictions: results.append({ 'text': pred['text'], 'label': pred['label'], 'id': post_id }) logger.info(f"Classification complete. Found {len(results)} author entities") return results def _create_tables(self, con: sqlite3.Connection): """Create authors and post_authors tables if they don't exist.""" logger.info("Creating authors tables") con.execute(""" CREATE TABLE IF NOT EXISTS authors ( id INTEGER PRIMARY KEY, name TEXT, type TEXT, created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ) """) con.execute(""" CREATE TABLE IF NOT EXISTS post_authors ( post_id INTEGER, author_id INTEGER, PRIMARY KEY (post_id, author_id), FOREIGN KEY (post_id) REFERENCES posts(id), FOREIGN KEY (author_id) REFERENCES authors(id) ) """) con.commit() def _store_authors(self, con: sqlite3.Connection, results: list): """Store classified authors and their mappings. Args: con: Database connection results: List of classification results """ if not results: logger.info("No authors to store") return # Convert results to DataFrame results_df = pd.DataFrame(results) # Get unique authors with their types unique_authors = results_df[['text', 'label']].drop_duplicates() unique_authors.columns = ['name', 'type'] # Get existing authors existing_authors = pd.read_sql("SELECT id, name FROM authors", con) # Find new authors to insert if not existing_authors.empty: new_authors = unique_authors[~unique_authors['name'].isin(existing_authors['name'])] else: new_authors = unique_authors if not new_authors.empty: logger.info(f"Inserting {len(new_authors)} new authors") new_authors.to_sql('authors', con, if_exists='append', index=False) # Get all authors with their IDs all_authors = pd.read_sql("SELECT id, name FROM authors", con) name_to_id = dict(zip(all_authors['name'], all_authors['id'])) # Create post_authors mappings mappings = [] for _, row in results_df.iterrows(): author_id = name_to_id.get(row['text']) if author_id: mappings.append({ 'post_id': row['id'], 'author_id': author_id }) if mappings: mappings_df = pd.DataFrame(mappings).drop_duplicates() # Clear existing mappings for these posts (optional, depends on your strategy) # post_ids = tuple(mappings_df['post_id'].unique()) # con.execute(f"DELETE FROM post_authors WHERE post_id IN ({','.join('?' * len(post_ids))})", post_ids) logger.info(f"Creating {len(mappings_df)} post-author mappings") mappings_df.to_sql('post_authors', con, if_exists='append', index=False) con.commit() logger.info("Authors and mappings stored successfully") def run(self, con: sqlite3.Connection, context: TransformContext) -> TransformContext: """Execute the author classification transformation. Args: con: SQLite database connection context: TransformContext containing posts dataframe Returns: TransformContext with classified authors dataframe """ logger.info("Starting AuthorNode transformation") posts_df = context.get_dataframe() # Ensure required columns exist if 'author' not in posts_df.columns: logger.warning("No 'author' column in dataframe. Skipping AuthorNode.") return context # Create tables self._create_tables(con) # Classify authors results = self._classify_authors(posts_df) # Store results self._store_authors(con, results) # Return context with results logger.info("AuthorNode transformation complete") return TransformContext(posts_df) class FuzzyAuthorNode(TransformNode): """FuzzyAuthorNode This Node takes in data and rules of authornames that have been classified already and uses those 'rule' to find more similar fields. """ def __init__(self, max_l_dist: int = 1,): """Initialize FuzzyAuthorNode. Args: max_l_dist: The number of 'errors' that are allowed by the fuzzy search algorithm """ self.max_l_dist = max_l_dist logger.info(f"Initialized FuzzyAuthorNode with max_l_dist={max_l_dist}") def _process_data(self, con: sqlite3.Connection, df: pd.DataFrame) -> pd.DataFrame: """Process the input dataframe. This is where your main transformation logic goes. Args: con: Database connection df: Input dataframe from context Returns: Processed dataframe """ logger.info(f"Processing {len(df)} rows") # Retrieve all known authors from the authors table as 'rules' authors_df = pd.read_sql("SELECT id, name FROM authors", con) if authors_df.empty: logger.warning("No authors found in database for fuzzy matching") return pd.DataFrame(columns=['post_id', 'author_id']) # Get existing post-author mappings to avoid duplicates existing_mappings = pd.read_sql( "SELECT post_id, author_id FROM post_authors", con ) existing_post_ids = set(existing_mappings['post_id'].unique()) logger.info(f"Found {len(authors_df)} known authors for fuzzy matching") logger.info(f"Found {len(existing_post_ids)} posts with existing author mappings") # Filter to posts without author mappings and with non-null author field if 'author' not in df.columns or 'id' not in df.columns: logger.warning("Missing 'author' or 'id' column in input dataframe") return pd.DataFrame(columns=['post_id', 'author_id']) posts_to_process = df[ (df['id'].notna()) & (df['author'].notna()) & (~df['id'].isin(existing_post_ids)) ] logger.info(f"Processing {len(posts_to_process)} posts for fuzzy matching") # Perform fuzzy matching mappings = [] for _, post_row in posts_to_process.iterrows(): post_id = post_row['id'] post_author = str(post_row['author']) # Try to find matches against all known author names for _, author_row in authors_df.iterrows(): author_id = author_row['id'] author_name = str(author_row['name']) # for author names < than 2 characters I want a fault tolerance of 0! l_dist = self.max_l_dist if len(author_name) > 2 else 0 # Use fuzzysearch to find matches with allowed errors matches = fuzzysearch.find_near_matches( author_name, post_author, max_l_dist=l_dist, ) if matches: logger.debug(f"Found fuzzy match: '{author_name}' in '{post_author}' for post {post_id}") mappings.append({ 'post_id': post_id, 'author_id': author_id }) # Only take the first match per post to avoid multiple mappings break # Create result dataframe result_df = pd.DataFrame(mappings, columns=['post_id', 'author_id']) if mappings else pd.DataFrame(columns=['post_id', 'author_id']) logger.info(f"Processing complete. Found {len(result_df)} fuzzy matches") return result_df def _store_results(self, con: sqlite3.Connection, df: pd.DataFrame): """Store results back to the database. Uses INSERT OR IGNORE to avoid inserting duplicates. Args: con: Database connection df: Processed dataframe to store """ if df.empty: logger.info("No results to store") return logger.info(f"Storing {len(df)} results") # Use INSERT OR IGNORE to handle duplicates (respects PRIMARY KEY constraint) cursor = con.cursor() inserted_count = 0 for _, row in df.iterrows(): cursor.execute( "INSERT OR IGNORE INTO post_authors (post_id, author_id) VALUES (?, ?)", (int(row['post_id']), int(row['author_id'])) ) if cursor.rowcount > 0: inserted_count += 1 con.commit() logger.info(f"Results stored successfully. Inserted {inserted_count} new mappings, skipped {len(df) - inserted_count} duplicates") def run(self, con: sqlite3.Connection, context: TransformContext) -> TransformContext: """Execute the transformation. This is the main entry point called by the pipeline. Args: con: SQLite database connection context: TransformContext containing input dataframe Returns: TransformContext with processed dataframe """ logger.info("Starting FuzzyAuthorNode transformation") # Get input dataframe from context input_df = context.get_dataframe() # Validate input if input_df.empty: logger.warning("Empty dataframe provided to FuzzyAuthorNode") return context # Process the data result_df = self._process_data(con, input_df) # Store results self._store_results(con, result_df) logger.info("FuzzyAuthorNode transformation complete") # Return new context with results return TransformContext(input_df)