Knack-Scraper/transform/embeddings_node.py
2026-01-18 15:43:35 +01:00

536 lines
19 KiB
Python

"""Classes of Transformernodes that have to do with
text processing.
- TextEmbeddingNode calculates text embeddings
- UmapNode calculates xy coordinates on those vector embeddings
- SimilarityNode calculates top n similar posts based on those embeddings
using the spectral distance.
"""
from pipeline import TransformContext
from transform_node import TransformNode
import sqlite3
import pandas as pd
import logging
import os
import numpy as np
import sys
import pickle
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
logger = logging.getLogger("knack-transform")
try:
from sentence_transformers import SentenceTransformer
import torch
GTE_AVAILABLE = True
except ImportError:
GTE_AVAILABLE = False
logging.warning("GTE not available. Install with pip!")
try:
import umap
UMAP_AVAILABLE = True
except ImportError:
UMAP_AVAILABLE = False
logging.warning("UMAP not available. Install with pip install umap-learn!")
class TextEmbeddingNode(TransformNode):
"""Calculates vector embeddings based on a dataframe
of posts.
"""
def __init__(self,
model_name: str = "thenlper/gte-small",
model_path: str = None,
device: str = "cpu"):
"""Initialize the ExampleNode.
Args:
model_name: Name of the ML Model to calculate text embeddings
model_path: Optional local path to a downloaded embedding model
device: Device to use for computations ('cpu', 'cuda', 'mps')
"""
self.model_name = model_name
self.model_path = model_path or os.environ.get('GTE_MODEL_PATH')
self.device = device
self.model = None
logger.info(f"Initialized TextEmbeddingNode with model_name={model_name}, model_path={model_path}, device={device}")
def _setup_model(self):
"""Init the Text Embedding Model."""
if not GTE_AVAILABLE:
raise ImportError("GTE is required for TextEmbeddingNode. Please install.")
model_source = None
if self.model_path:
if os.path.exists(self.model_path):
model_source = self.model_path
logger.info(f"Loading GTE model from local path: {self.model_path}")
else:
logger.warning(f"GTE_MODEL_PATH '{self.model_path}' not found; Falling back to hub model {self.model_name}")
if model_source is None:
model_source = self.model_name
logger.info(f"Loading GTE model from the hub: {self.model_name}")
if self.device == "cuda" and torch.cuda.is_available():
self.model = SentenceTransformer(model_source).to('cuda', dtype=torch.float16)
elif self.device == "mps" and torch.backends.mps.is_available():
self.model = SentenceTransformer(model_source).to('mps', dtype=torch.float16)
else:
self.model = SentenceTransformer(model_source)
def _process_data(self, df: pd.DataFrame) -> pd.DataFrame:
"""Process the input dataframe.
Calculates an embedding as a np.array.
Also pickles that array to prepare it to
storage in the database.
Args:
df: Input dataframe from context
Returns:
Processed dataframe
"""
logger.info(f"Processing {len(df)} rows")
if self.model is None:
self._setup_model()
# Example: Add a new column based on existing data
result_df = df.copy()
result_df['embedding'] = df['text'].apply(lambda x: self.model.encode(x, convert_to_numpy=True))
logger.info("Processing complete")
return result_df
def _store_results(self, con: sqlite3.Connection, df: pd.DataFrame):
"""Store results back to the database using batch updates."""
if df.empty:
logger.info("No results to store")
return
logger.info(f"Storing {len(df)} results")
# Convert numpy arrays to bytes for BLOB storage
updates = [(row['embedding'], row['id']) for _, row in df.iterrows()]
con.executemany(
"UPDATE posts SET embedding = ? WHERE id = ?",
updates
)
con.commit()
logger.info("Results stored successfully")
def run(self, con: sqlite3.Connection, context: TransformContext) -> TransformContext:
"""Execute the transformation.
This is the main entry point called by the pipeline.
Args:
con: SQLite database connection
context: TransformContext containing input dataframe
Returns:
TransformContext with processed dataframe
"""
logger.info("Starting TextEmbeddingNode transformation")
# Get input dataframe from context
input_df = context.get_dataframe()
# Validate input
if input_df.empty:
logger.warning("Empty dataframe provided to TextEmbeddingNdode")
return context
if 'text' not in input_df.columns:
logger.warning("No 'text' column in context dataframe. Skipping TextEmbeddingNode")
return context
# Process the data
result_df = self._process_data(input_df)
# Store results (optional)
self._store_results(con, result_df)
logger.info("TextEmbeddingNode transformation complete")
# Return new context with results
return TransformContext(result_df)
class UmapNode(TransformNode):
"""Calculates 2D coordinates from embeddings using UMAP dimensionality reduction.
This node takes text embeddings and reduces them to 2D coordinates
for visualization purposes.
"""
def __init__(self,
n_neighbors: int = 10,
min_dist: float = 0.1,
n_components: int = 3,
metric: str = "cosine",
random_state: int = 42,
model_path: str = None):
"""Initialize the UmapNode.
Args:
n_neighbors: Number of neighbors to consider for UMAP (default: 15)
min_dist: Minimum distance between points in low-dimensional space (default: 0.1)
n_components: Number of dimensions to reduce to (default: 2)
metric: Distance metric to use (default: 'cosine')
random_state: Random seed for reproducibility (default: 42)
model_path: Path to save/load the fitted UMAP model (default: None, uses 'umap_model.pkl')
"""
self.n_neighbors = n_neighbors
self.min_dist = min_dist
self.n_components = n_components
self.metric = metric
self.random_state = random_state
self.model_path = model_path or os.environ.get('UMAP_MODEL_PATH')
self.reducer = None
logger.info(f"Initialized UmapNode with n_neighbors={n_neighbors}, min_dist={min_dist}, "
f"n_components={n_components}, metric={metric}, random_state={random_state}, "
f"model_path={self.model_path}")
def _process_data(self, df: pd.DataFrame) -> pd.DataFrame:
"""Process the input dataframe.
Retrieves embeddings from BLOB storage, converts them back to numpy arrays,
and applies UMAP dimensionality reduction to create 2D coordinates.
Args:
df: Input dataframe from context
Returns:
Processed dataframe with umap_x and umap_y columns
"""
logger.info(f"Processing {len(df)} rows")
if not UMAP_AVAILABLE:
raise ImportError("UMAP is required for UmapNode. Install with: pip install umap-learn")
result_df = df.copy()
# Convert BLOB embeddings back to numpy arrays
if 'embedding' not in result_df.columns:
logger.error("No 'embedding' column found in dataframe")
raise ValueError("Input dataframe must contain 'embedding' column")
logger.info("Converting embeddings from BLOB to numpy arrays")
result_df['embedding'] = result_df['embedding'].apply(
lambda x: np.frombuffer(x, dtype=np.float32) if x is not None else None
)
# Filter out rows with None embeddings
valid_rows = result_df['embedding'].notna()
if not valid_rows.any():
logger.error("No valid embeddings found in dataframe")
raise ValueError("No valid embeddings to process")
logger.info(f"Found {valid_rows.sum()} valid embeddings out of {len(result_df)} rows")
# Stack embeddings into a matrix
embeddings_matrix = np.vstack(result_df.loc[valid_rows, 'embedding'].values)
logger.info(f"Embeddings matrix shape: {embeddings_matrix.shape}")
# Check if a saved UMAP model exists
if self.model_path and os.path.exists(self.model_path):
logger.info(f"Loading existing UMAP model from {self.model_path}")
try:
with open(self.model_path, 'rb') as f:
self.reducer = pickle.load(f)
logger.info("UMAP model loaded successfully")
umap_coords = self.reducer.transform(embeddings_matrix)
logger.info(f"UMAP transformation complete using existing model. Output shape: {umap_coords.shape}")
except Exception as e:
logger.warning(f"Failed to load UMAP model from {self.model_path}: {e}")
logger.info("Falling back to fitting a new model")
self.reducer = None
# If no saved model or loading failed, fit a new model
if self.reducer is None:
logger.info("Fitting new UMAP reducer...")
self.reducer = umap.UMAP(
n_neighbors=self.n_neighbors,
min_dist=self.min_dist,
n_components=self.n_components,
metric=self.metric,
random_state=self.random_state
)
umap_coords = self.reducer.fit_transform(embeddings_matrix)
logger.info(f"UMAP transformation complete. Output shape: {umap_coords.shape}")
# Save the fitted model
try:
umap_folder = '/'.join(self.model_path.split('/')[:1])
os.mkdir(umap_folder)
with open(self.model_path, 'wb') as f:
pickle.dump(self.reducer, f)
logger.info(f"UMAP model saved to {self.model_path}")
except Exception as e:
logger.error(f"Failed to save UMAP model to {self.model_path}: {e}")
# Add UMAP coordinates to dataframe
result_df.loc[valid_rows, 'umap_x'] = umap_coords[:, 0]
result_df.loc[valid_rows, 'umap_y'] = umap_coords[:, 1]
result_df.loc[valid_rows, 'umap_z'] = umap_coords[:, 2]
# Fill NaN for invalid rows
result_df['umap_x'] = result_df['umap_x'].fillna(value=0)
result_df['umap_y'] = result_df['umap_y'].fillna(value=0)
result_df['umap_z'] = result_df['umap_z'].fillna(value=0)
logger.info("Processing complete")
return result_df
def _store_results(self, con: sqlite3.Connection, df: pd.DataFrame):
"""Store UMAP coordinates back to the database.
Args:
con: Database connection
df: Processed dataframe with umap_x and umap_y columns
"""
if df.empty:
logger.info("No results to store")
return
logger.info(f"Storing {len(df)} results")
# Batch update UMAP coordinates
updates = [
(row['umap_x'], row['umap_y'], row['umap_z'], row['id'])
for _, row in df.iterrows()
if pd.notna(row.get('umap_x')) and pd.notna(row.get('umap_y')) and pd.notna(row.get('umap_z'))
]
if updates:
con.executemany(
"UPDATE posts SET umap_x = ?, umap_y = ?, umap_z = ? WHERE id = ?",
updates
)
con.commit()
logger.info(f"Stored {len(updates)} UMAP coordinate pairs successfully")
else:
logger.warning("No valid UMAP coordinates to store")
def run(self, con: sqlite3.Connection, context: TransformContext) -> TransformContext:
"""Execute the transformation.
This is the main entry point called by the pipeline.
Args:
con: SQLite database connection
context: TransformContext containing input dataframe
Returns:
TransformContext with processed dataframe
"""
logger.info("Starting ExampleNode transformation")
# Get input dataframe from context
input_df = context.get_dataframe()
# Validate input
if input_df.empty:
logger.warning("Empty dataframe provided to ExampleNode")
return context
# Process the data
result_df = self._process_data(input_df)
# Store results (optional)
self._store_results(con, result_df)
logger.info("ExampleNode transformation complete")
# Return new context with results
return TransformContext(result_df)
class SimilarityNode(TransformNode):
"""Example transform node template.
This node demonstrates the basic structure for creating
new transformation nodes in the pipeline.
"""
def __init__(self,
param1: str = "default_value",
param2: int = 42,
device: str = "cpu"):
"""Initialize the ExampleNode.
Args:
param1: Example string parameter
param2: Example integer parameter
device: Device to use for computations ('cpu', 'cuda', 'mps')
"""
self.param1 = param1
self.param2 = param2
self.device = device
logger.info(f"Initialized ExampleNode with param1={param1}, param2={param2}")
def _create_tables(self, con: sqlite3.Connection):
"""Create any necessary tables in the database.
This is optional - only needed if your node creates new tables.
"""
logger.info("Creating example tables")
con.execute("""
CREATE TABLE IF NOT EXISTS example_results (
id INTEGER PRIMARY KEY AUTOINCREMENT,
post_id INTEGER,
result_value TEXT,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
FOREIGN KEY (post_id) REFERENCES posts(id)
)
""")
con.commit()
def _process_data(self, df: pd.DataFrame) -> pd.DataFrame:
"""Process the input dataframe.
This is where your main transformation logic goes.
Args:
df: Input dataframe from context
Returns:
Processed dataframe
"""
logger.info(f"Processing {len(df)} rows")
# Example: Add a new column based on existing data
result_df = df.copy()
result_df['processed'] = True
result_df['example_value'] = result_df['id'].apply(lambda x: f"{self.param1}_{x}")
logger.info("Processing complete")
return result_df
def _store_results(self, con: sqlite3.Connection, df: pd.DataFrame):
"""Store results back to the database.
This is optional - only needed if you want to persist results.
Args:
con: Database connection
df: Processed dataframe to store
"""
if df.empty:
logger.info("No results to store")
return
logger.info(f"Storing {len(df)} results")
# Example: Store to database
# df[['post_id', 'result_value']].to_sql(
# 'example_results',
# con,
# if_exists='append',
# index=False
# )
con.commit()
logger.info("Results stored successfully")
def run(self, con: sqlite3.Connection, context: TransformContext) -> TransformContext:
"""Execute the transformation.
This is the main entry point called by the pipeline.
Args:
con: SQLite database connection
context: TransformContext containing input dataframe
Returns:
TransformContext with processed dataframe
"""
logger.info("Starting ExampleNode transformation")
# Get input dataframe from context
input_df = context.get_dataframe()
# Validate input
if input_df.empty:
logger.warning("Empty dataframe provided to ExampleNode")
return context
# Create any necessary tables
self._create_tables(con)
# Process the data
result_df = self._process_data(input_df)
# Store results (optional)
self._store_results(con, result_df)
logger.info("ExampleNode transformation complete")
# Return new context with results
return TransformContext(result_df)
def main():
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler("app.log"),
logging.StreamHandler(sys.stdout)
]
)
logger = logging.getLogger("knack-transform")
con = sqlite3.connect("/Users/linussilberstein/Documents/Knack-Scraper/data/knack.sqlite")
df = pd.read_sql('select * from posts;', con)
#node = TextEmbeddingNode(device='mps')
#context = TransformContext(df)
logger.info(df)
#new_context = node.run(con, context)
#logger.info(new_context.get_dataframe())
#umapNode = UmapNode()
#new_context = umapNode.run(con, new_context)
#logger.info(new_context.get_dataframe())
# Create 3D scatter plot of UMAP coordinates
result_df = df
fig = plt.figure(figsize=(12, 9))
ax = fig.add_subplot(111, projection='3d')
scatter = ax.scatter(
result_df['umap_x'],
result_df['umap_y'],
result_df['umap_z'],
c=result_df['id'],
cmap='viridis',
alpha=0.6,
s=50
)
ax.set_xlabel('UMAP X')
ax.set_ylabel('UMAP Y')
ax.set_zlabel('UMAP Z')
ax.set_title('3D UMAP Visualization of Post Embeddings')
plt.colorbar(scatter, ax=ax, label='Post Index')
plt.tight_layout()
plt.show()
logger.info("3D plot displayed")
if __name__ == '__main__':
main()