Implements Feature to cleanup authors freetext field

This commit is contained in:
quorploop 2025-12-21 21:18:05 +01:00
parent bcd210ce01
commit 64df8fb328
14 changed files with 804 additions and 310 deletions

263
transform/author_node.py Normal file
View file

@ -0,0 +1,263 @@
"""Author classification transform node using NER."""
from base import TransformNode, TransformContext
import sqlite3
import pandas as pd
import logging
from concurrent.futures import ThreadPoolExecutor
from datetime import datetime
try:
from gliner import GLiNER
import torch
GLINER_AVAILABLE = True
except ImportError:
GLINER_AVAILABLE = False
logging.warning("GLiNER not available. Install with: pip install gliner")
logger = logging.getLogger("knack-transform")
class AuthorNode(TransformNode):
"""Transform node that extracts and classifies authors using NER.
Creates two tables:
- authors: stores unique authors with their type (Person, Organisation, etc.)
- post_authors: maps posts to their authors
"""
def __init__(self, model_name: str = "urchade/gliner_medium-v2.1",
threshold: float = 0.5,
max_workers: int = 64,
device: str = "cpu"):
"""Initialize the AuthorNode.
Args:
model_name: GLiNER model to use
threshold: Confidence threshold for entity predictions
max_workers: Number of parallel workers for prediction
device: Device to run model on ('cpu', 'cuda', 'mps')
"""
self.model_name = model_name
self.threshold = threshold
self.max_workers = max_workers
self.device = device
self.model = None
self.labels = ["Person", "Organisation", "Email", "Newspaper", "NGO"]
def _setup_model(self):
"""Initialize the NER model."""
if not GLINER_AVAILABLE:
raise ImportError("GLiNER is required for AuthorNode. Install with: pip install gliner")
logger.info(f"Loading GLiNER model: {self.model_name}")
if self.device == "cuda" and torch.cuda.is_available():
self.model = GLiNER.from_pretrained(
self.model_name,
max_length=255
).to('cuda', dtype=torch.float16)
elif self.device == "mps" and torch.backends.mps.is_available():
self.model = GLiNER.from_pretrained(
self.model_name,
max_length=255
).to('mps', dtype=torch.float16)
else:
self.model = GLiNER.from_pretrained(
self.model_name,
max_length=255
)
logger.info("Model loaded successfully")
def _predict(self, text_data: dict):
"""Predict entities for a single author text.
Args:
text_data: Dict with 'author' and 'id' keys
Returns:
Tuple of (predictions, post_id) or None
"""
if text_data is None or text_data.get('author') is None:
return None
predictions = self.model.predict_entities(
text_data['author'],
self.labels,
threshold=self.threshold
)
return predictions, text_data['id']
def _classify_authors(self, posts_df: pd.DataFrame):
"""Classify all authors in the posts dataframe.
Args:
posts_df: DataFrame with 'id' and 'author' columns
Returns:
List of dicts with 'text', 'label', 'id' keys
"""
if self.model is None:
self._setup_model()
# Prepare input data
authors_data = []
for idx, row in posts_df.iterrows():
if pd.notna(row['author']):
authors_data.append({
'author': row['author'],
'id': row['id']
})
logger.info(f"Classifying {len(authors_data)} authors")
results = []
with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
futures = [executor.submit(self._predict, data) for data in authors_data]
for future in futures:
result = future.result()
if result is not None:
predictions, post_id = result
for pred in predictions:
results.append({
'text': pred['text'],
'label': pred['label'],
'id': post_id
})
logger.info(f"Classification complete. Found {len(results)} author entities")
return results
def _create_tables(self, con: sqlite3.Connection):
"""Create authors and post_authors tables if they don't exist."""
logger.info("Creating authors tables")
con.execute("""
CREATE TABLE IF NOT EXISTS authors (
id INTEGER PRIMARY KEY,
name TEXT,
type TEXT,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)
""")
con.execute("""
CREATE TABLE IF NOT EXISTS post_authors (
post_id INTEGER,
author_id INTEGER,
PRIMARY KEY (post_id, author_id),
FOREIGN KEY (post_id) REFERENCES posts(id),
FOREIGN KEY (author_id) REFERENCES authors(id)
)
""")
con.commit()
def _store_authors(self, con: sqlite3.Connection, results: list):
"""Store classified authors and their mappings.
Args:
con: Database connection
results: List of classification results
"""
if not results:
logger.info("No authors to store")
return
# Convert results to DataFrame
results_df = pd.DataFrame(results)
# Get unique authors with their types
unique_authors = results_df[['text', 'label']].drop_duplicates()
unique_authors.columns = ['name', 'type']
# Get existing authors
existing_authors = pd.read_sql("SELECT id, name FROM authors", con)
# Find new authors to insert
if not existing_authors.empty:
new_authors = unique_authors[~unique_authors['name'].isin(existing_authors['name'])]
else:
new_authors = unique_authors
if not new_authors.empty:
logger.info(f"Inserting {len(new_authors)} new authors")
new_authors.to_sql('authors', con, if_exists='append', index=False)
# Get all authors with their IDs
all_authors = pd.read_sql("SELECT id, name FROM authors", con)
name_to_id = dict(zip(all_authors['name'], all_authors['id']))
# Create post_authors mappings
mappings = []
for _, row in results_df.iterrows():
author_id = name_to_id.get(row['text'])
if author_id:
mappings.append({
'post_id': row['id'],
'author_id': author_id
})
if mappings:
mappings_df = pd.DataFrame(mappings).drop_duplicates()
# Clear existing mappings for these posts (optional, depends on your strategy)
# post_ids = tuple(mappings_df['post_id'].unique())
# con.execute(f"DELETE FROM post_authors WHERE post_id IN ({','.join('?' * len(post_ids))})", post_ids)
logger.info(f"Creating {len(mappings_df)} post-author mappings")
mappings_df.to_sql('post_authors', con, if_exists='append', index=False)
# Mark posts as cleaned
processed_post_ids = mappings_df['post_id'].unique().tolist()
if processed_post_ids:
placeholders = ','.join('?' * len(processed_post_ids))
con.execute(f"UPDATE posts SET is_cleaned = 1 WHERE id IN ({placeholders})", processed_post_ids)
logger.info(f"Marked {len(processed_post_ids)} posts as cleaned")
con.commit()
logger.info("Authors and mappings stored successfully")
def run(self, con: sqlite3.Connection, context: TransformContext) -> TransformContext:
"""Execute the author classification transformation.
Args:
con: SQLite database connection
context: TransformContext containing posts dataframe
Returns:
TransformContext with classified authors dataframe
"""
logger.info("Starting AuthorNode transformation")
posts_df = context.get_dataframe()
# Ensure required columns exist
if 'author' not in posts_df.columns:
logger.warning("No 'author' column in dataframe. Skipping AuthorNode.")
return context
# Create tables
self._create_tables(con)
# Classify authors
results = self._classify_authors(posts_df)
# Store results
self._store_authors(con, results)
# Mark posts without author entities as cleaned too (no authors found)
processed_ids = set([r['id'] for r in results]) if results else set()
unprocessed_ids = [pid for pid in posts_df['id'].tolist() if pid not in processed_ids]
if unprocessed_ids:
placeholders = ','.join('?' * len(unprocessed_ids))
con.execute(f"UPDATE posts SET is_cleaned = 1 WHERE id IN ({placeholders})", unprocessed_ids)
con.commit()
logger.info(f"Marked {len(unprocessed_ids)} posts without author entities as cleaned")
# Return context with results
results_df = pd.DataFrame(results) if results else pd.DataFrame()
logger.info("AuthorNode transformation complete")
return TransformContext(results_df)