forked from lukaszett/Knack-Scraper
420 lines
15 KiB
Python
420 lines
15 KiB
Python
"""Author classification transform node using NER."""
|
|
import os
|
|
import sqlite3
|
|
import pandas as pd
|
|
import logging
|
|
import fuzzysearch
|
|
from concurrent.futures import ThreadPoolExecutor
|
|
|
|
from pipeline import TransformContext
|
|
from transform_node import TransformNode
|
|
|
|
logger = logging.getLogger("knack-transform")
|
|
|
|
try:
|
|
from gliner import GLiNER
|
|
import torch
|
|
GLINER_AVAILABLE = True
|
|
except ImportError:
|
|
GLINER_AVAILABLE = False
|
|
logging.warning("GLiNER not available. Install with: pip install gliner")
|
|
|
|
class NerAuthorNode(TransformNode):
|
|
"""Transform node that extracts and classifies authors using NER.
|
|
|
|
Creates two tables:
|
|
- authors: stores unique authors with their type (Person, Organisation, etc.)
|
|
- post_authors: maps posts to their authors
|
|
"""
|
|
|
|
def __init__(self, model_name: str = "urchade/gliner_multi-v2.1",
|
|
model_path: str = None,
|
|
threshold: float = 0.5,
|
|
max_workers: int = 64,
|
|
device: str = "cpu"):
|
|
"""Initialize the AuthorNode.
|
|
|
|
Args:
|
|
model_name: GLiNER model to use
|
|
model_path: Optional local path to a downloaded GLiNER model
|
|
threshold: Confidence threshold for entity predictions
|
|
max_workers: Number of parallel workers for prediction
|
|
device: Device to run model on ('cpu', 'cuda', 'mps')
|
|
"""
|
|
self.model_name = model_name
|
|
self.model_path = model_path or os.environ.get('GLINER_MODEL_PATH')
|
|
self.threshold = threshold
|
|
self.max_workers = max_workers
|
|
self.device = device
|
|
self.model = None
|
|
self.labels = ["Person", "Organisation", "Email", "Newspaper", "NGO"]
|
|
|
|
def _setup_model(self):
|
|
"""Initialize the NER model."""
|
|
if not GLINER_AVAILABLE:
|
|
raise ImportError("GLiNER is required for AuthorNode. Install with: pip install gliner")
|
|
|
|
model_source = None
|
|
if self.model_path:
|
|
if os.path.exists(self.model_path):
|
|
model_source = self.model_path
|
|
logger.info(f"Loading GLiNER model from local path: {self.model_path}")
|
|
else:
|
|
logger.warning(f"GLINER_MODEL_PATH '{self.model_path}' not found; falling back to hub model {self.model_name}")
|
|
|
|
if model_source is None:
|
|
model_source = self.model_name
|
|
logger.info(f"Loading GLiNER model from hub: {self.model_name}")
|
|
|
|
if self.device == "cuda" and torch.cuda.is_available():
|
|
self.model = GLiNER.from_pretrained(
|
|
model_source,
|
|
max_length=255
|
|
).to('cuda', dtype=torch.float16)
|
|
elif self.device == "mps" and torch.backends.mps.is_available():
|
|
self.model = GLiNER.from_pretrained(
|
|
model_source,
|
|
max_length=255
|
|
).to('mps', dtype=torch.float16)
|
|
else:
|
|
self.model = GLiNER.from_pretrained(
|
|
model_source,
|
|
max_length=255
|
|
)
|
|
|
|
logger.info("Model loaded successfully")
|
|
|
|
def _predict(self, text_data: dict):
|
|
"""Predict entities for a single author text.
|
|
|
|
Args:
|
|
text_data: Dict with 'author' and 'id' keys
|
|
|
|
Returns:
|
|
Tuple of (predictions, post_id) or None
|
|
"""
|
|
if text_data is None or text_data.get('author') is None:
|
|
return None
|
|
|
|
predictions = self.model.predict_entities(
|
|
text_data['author'],
|
|
self.labels,
|
|
threshold=self.threshold
|
|
)
|
|
return predictions, text_data['id']
|
|
|
|
def _classify_authors(self, posts_df: pd.DataFrame):
|
|
"""Classify all authors in the posts dataframe.
|
|
|
|
Args:
|
|
posts_df: DataFrame with 'id' and 'author' columns
|
|
|
|
Returns:
|
|
List of dicts with 'text', 'label', 'id' keys
|
|
"""
|
|
if self.model is None:
|
|
self._setup_model()
|
|
|
|
# Prepare input data
|
|
authors_data = []
|
|
for idx, row in posts_df.iterrows():
|
|
if pd.notna(row['author']):
|
|
authors_data.append({
|
|
'author': row['author'],
|
|
'id': row['id']
|
|
})
|
|
|
|
logger.info(f"Classifying {len(authors_data)} authors")
|
|
|
|
results = []
|
|
with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
|
|
futures = [executor.submit(self._predict, data) for data in authors_data]
|
|
|
|
for future in futures:
|
|
result = future.result()
|
|
if result is not None:
|
|
predictions, post_id = result
|
|
for pred in predictions:
|
|
results.append({
|
|
'text': pred['text'],
|
|
'label': pred['label'],
|
|
'id': post_id
|
|
})
|
|
|
|
logger.info(f"Classification complete. Found {len(results)} author entities")
|
|
return results
|
|
|
|
def _create_tables(self, con: sqlite3.Connection):
|
|
"""Create authors and post_authors tables if they don't exist."""
|
|
logger.info("Creating authors tables")
|
|
|
|
con.execute("""
|
|
CREATE TABLE IF NOT EXISTS authors (
|
|
id INTEGER PRIMARY KEY,
|
|
name TEXT,
|
|
type TEXT,
|
|
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
|
)
|
|
""")
|
|
|
|
con.execute("""
|
|
CREATE TABLE IF NOT EXISTS post_authors (
|
|
post_id INTEGER,
|
|
author_id INTEGER,
|
|
PRIMARY KEY (post_id, author_id),
|
|
FOREIGN KEY (post_id) REFERENCES posts(id),
|
|
FOREIGN KEY (author_id) REFERENCES authors(id)
|
|
)
|
|
""")
|
|
|
|
con.commit()
|
|
|
|
def _store_authors(self, con: sqlite3.Connection, results: list):
|
|
"""Store classified authors and their mappings.
|
|
|
|
Args:
|
|
con: Database connection
|
|
results: List of classification results
|
|
"""
|
|
if not results:
|
|
logger.info("No authors to store")
|
|
return
|
|
|
|
# Convert results to DataFrame
|
|
results_df = pd.DataFrame(results)
|
|
|
|
# Get unique authors with their types
|
|
unique_authors = results_df[['text', 'label']].drop_duplicates()
|
|
unique_authors.columns = ['name', 'type']
|
|
|
|
# Get existing authors
|
|
existing_authors = pd.read_sql("SELECT id, name FROM authors", con)
|
|
|
|
# Find new authors to insert
|
|
if not existing_authors.empty:
|
|
new_authors = unique_authors[~unique_authors['name'].isin(existing_authors['name'])]
|
|
else:
|
|
new_authors = unique_authors
|
|
|
|
if not new_authors.empty:
|
|
logger.info(f"Inserting {len(new_authors)} new authors")
|
|
new_authors.to_sql('authors', con, if_exists='append', index=False)
|
|
|
|
# Get all authors with their IDs
|
|
all_authors = pd.read_sql("SELECT id, name FROM authors", con)
|
|
name_to_id = dict(zip(all_authors['name'], all_authors['id']))
|
|
|
|
# Create post_authors mappings
|
|
mappings = []
|
|
for _, row in results_df.iterrows():
|
|
author_id = name_to_id.get(row['text'])
|
|
if author_id:
|
|
mappings.append({
|
|
'post_id': row['id'],
|
|
'author_id': author_id
|
|
})
|
|
|
|
if mappings:
|
|
mappings_df = pd.DataFrame(mappings).drop_duplicates()
|
|
|
|
# Clear existing mappings for these posts (optional, depends on your strategy)
|
|
# post_ids = tuple(mappings_df['post_id'].unique())
|
|
# con.execute(f"DELETE FROM post_authors WHERE post_id IN ({','.join('?' * len(post_ids))})", post_ids)
|
|
|
|
logger.info(f"Creating {len(mappings_df)} post-author mappings")
|
|
mappings_df.to_sql('post_authors', con, if_exists='append', index=False)
|
|
|
|
con.commit()
|
|
logger.info("Authors and mappings stored successfully")
|
|
|
|
def run(self, con: sqlite3.Connection, context: TransformContext) -> TransformContext:
|
|
"""Execute the author classification transformation.
|
|
|
|
Args:
|
|
con: SQLite database connection
|
|
context: TransformContext containing posts dataframe
|
|
|
|
Returns:
|
|
TransformContext with classified authors dataframe
|
|
"""
|
|
logger.info("Starting AuthorNode transformation")
|
|
|
|
posts_df = context.get_dataframe()
|
|
|
|
# Ensure required columns exist
|
|
if 'author' not in posts_df.columns:
|
|
logger.warning("No 'author' column in dataframe. Skipping AuthorNode.")
|
|
return context
|
|
|
|
# Create tables
|
|
self._create_tables(con)
|
|
|
|
# Classify authors
|
|
results = self._classify_authors(posts_df)
|
|
|
|
# Store results
|
|
self._store_authors(con, results)
|
|
|
|
# Return context with results
|
|
logger.info("AuthorNode transformation complete")
|
|
|
|
return TransformContext(posts_df)
|
|
|
|
|
|
class FuzzyAuthorNode(TransformNode):
|
|
"""FuzzyAuthorNode
|
|
|
|
This Node takes in data and rules of authornames that have been classified already
|
|
and uses those 'rule' to find more similar fields.
|
|
"""
|
|
|
|
def __init__(self,
|
|
max_l_dist: int = 1,):
|
|
"""Initialize FuzzyAuthorNode.
|
|
|
|
Args:
|
|
max_l_dist: The number of 'errors' that are allowed by the fuzzy search algorithm
|
|
"""
|
|
self.max_l_dist = max_l_dist
|
|
logger.info(f"Initialized FuzzyAuthorNode with max_l_dist={max_l_dist}")
|
|
|
|
def _process_data(self, con: sqlite3.Connection, df: pd.DataFrame) -> pd.DataFrame:
|
|
"""Process the input dataframe.
|
|
|
|
This is where your main transformation logic goes.
|
|
|
|
Args:
|
|
con: Database connection
|
|
df: Input dataframe from context
|
|
|
|
Returns:
|
|
Processed dataframe
|
|
"""
|
|
logger.info(f"Processing {len(df)} rows")
|
|
|
|
# Retrieve all known authors from the authors table as 'rules'
|
|
authors_df = pd.read_sql("SELECT id, name FROM authors", con)
|
|
|
|
if authors_df.empty:
|
|
logger.warning("No authors found in database for fuzzy matching")
|
|
return pd.DataFrame(columns=['post_id', 'author_id'])
|
|
|
|
# Get existing post-author mappings to avoid duplicates
|
|
existing_mappings = pd.read_sql(
|
|
"SELECT post_id, author_id FROM post_authors", con
|
|
)
|
|
existing_post_ids = set(existing_mappings['post_id'].unique())
|
|
|
|
logger.info(f"Found {len(authors_df)} known authors for fuzzy matching")
|
|
logger.info(f"Found {len(existing_post_ids)} posts with existing author mappings")
|
|
|
|
# Filter to posts without author mappings and with non-null author field
|
|
if 'author' not in df.columns or 'id' not in df.columns:
|
|
logger.warning("Missing 'author' or 'id' column in input dataframe")
|
|
return pd.DataFrame(columns=['post_id', 'author_id'])
|
|
|
|
posts_to_process = df[
|
|
(df['id'].notna()) &
|
|
(df['author'].notna()) &
|
|
(~df['id'].isin(existing_post_ids))
|
|
]
|
|
|
|
logger.info(f"Processing {len(posts_to_process)} posts for fuzzy matching")
|
|
|
|
# Perform fuzzy matching
|
|
mappings = []
|
|
for _, post_row in posts_to_process.iterrows():
|
|
post_id = post_row['id']
|
|
post_author = str(post_row['author'])
|
|
|
|
# Try to find matches against all known author names
|
|
for _, author_row in authors_df.iterrows():
|
|
author_id = author_row['id']
|
|
author_name = str(author_row['name'])
|
|
# for author names < than 2 characters I want a fault tolerance of 0!
|
|
l_dist = self.max_l_dist if len(author_name) > 2 else 0
|
|
|
|
# Use fuzzysearch to find matches with allowed errors
|
|
matches = fuzzysearch.find_near_matches(
|
|
author_name,
|
|
post_author,
|
|
max_l_dist=l_dist,
|
|
)
|
|
|
|
if matches:
|
|
logger.debug(f"Found fuzzy match: '{author_name}' in '{post_author}' for post {post_id}")
|
|
mappings.append({
|
|
'post_id': post_id,
|
|
'author_id': author_id
|
|
})
|
|
# Only take the first match per post to avoid multiple mappings
|
|
break
|
|
|
|
# Create result dataframe
|
|
result_df = pd.DataFrame(mappings, columns=['post_id', 'author_id']) if mappings else pd.DataFrame(columns=['post_id', 'author_id'])
|
|
|
|
logger.info(f"Processing complete. Found {len(result_df)} fuzzy matches")
|
|
return result_df
|
|
|
|
def _store_results(self, con: sqlite3.Connection, df: pd.DataFrame):
|
|
"""Store results back to the database.
|
|
|
|
Uses INSERT OR IGNORE to avoid inserting duplicates.
|
|
|
|
Args:
|
|
con: Database connection
|
|
df: Processed dataframe to store
|
|
"""
|
|
if df.empty:
|
|
logger.info("No results to store")
|
|
return
|
|
|
|
logger.info(f"Storing {len(df)} results")
|
|
|
|
# Use INSERT OR IGNORE to handle duplicates (respects PRIMARY KEY constraint)
|
|
cursor = con.cursor()
|
|
inserted_count = 0
|
|
|
|
for _, row in df.iterrows():
|
|
cursor.execute(
|
|
"INSERT OR IGNORE INTO post_authors (post_id, author_id) VALUES (?, ?)",
|
|
(int(row['post_id']), int(row['author_id']))
|
|
)
|
|
if cursor.rowcount > 0:
|
|
inserted_count += 1
|
|
|
|
con.commit()
|
|
logger.info(f"Results stored successfully. Inserted {inserted_count} new mappings, skipped {len(df) - inserted_count} duplicates")
|
|
|
|
def run(self, con: sqlite3.Connection, context: TransformContext) -> TransformContext:
|
|
"""Execute the transformation.
|
|
|
|
This is the main entry point called by the pipeline.
|
|
|
|
Args:
|
|
con: SQLite database connection
|
|
context: TransformContext containing input dataframe
|
|
|
|
Returns:
|
|
TransformContext with processed dataframe
|
|
"""
|
|
logger.info("Starting FuzzyAuthorNode transformation")
|
|
|
|
# Get input dataframe from context
|
|
input_df = context.get_dataframe()
|
|
|
|
# Validate input
|
|
if input_df.empty:
|
|
logger.warning("Empty dataframe provided to FuzzyAuthorNode")
|
|
return context
|
|
|
|
# Process the data
|
|
result_df = self._process_data(con, input_df)
|
|
|
|
# Store results
|
|
self._store_results(con, result_df)
|
|
|
|
logger.info("FuzzyAuthorNode transformation complete")
|
|
|
|
# Return new context with results
|
|
return TransformContext(input_df)
|