"""Classes of Transformernodes that have to do with text processing. - TextEmbeddingNode calculates text embeddings - UmapNode calculates xy coordinates on those vector embeddings - SimilarityNode calculates top n similar posts based on those embeddings using the spectral distance. """ from pipeline import TransformContext from transform_node import TransformNode import sqlite3 import pandas as pd import logging import os import numpy as np import sys import pickle import matplotlib.pyplot as plt from mpl_toolkits.mplot3d import Axes3D logger = logging.getLogger("knack-transform") try: from sentence_transformers import SentenceTransformer import torch GTE_AVAILABLE = True except ImportError: GTE_AVAILABLE = False logging.warning("GTE not available. Install with pip!") try: import umap UMAP_AVAILABLE = True except ImportError: UMAP_AVAILABLE = False logging.warning("UMAP not available. Install with pip install umap-learn!") class TextEmbeddingNode(TransformNode): """Calculates vector embeddings based on a dataframe of posts. """ def __init__(self, model_name: str = "thenlper/gte-small", model_path: str = None, device: str = "cpu"): """Initialize the ExampleNode. Args: model_name: Name of the ML Model to calculate text embeddings model_path: Optional local path to a downloaded embedding model device: Device to use for computations ('cpu', 'cuda', 'mps') """ self.model_name = model_name self.model_path = model_path or os.environ.get('GTE_MODEL_PATH') self.device = device self.model = None logger.info(f"Initialized TextEmbeddingNode with model_name={model_name}, model_path={model_path}, device={device}") def _setup_model(self): """Init the Text Embedding Model.""" if not GTE_AVAILABLE: raise ImportError("GTE is required for TextEmbeddingNode. Please install.") model_source = None if self.model_path: if os.path.exists(self.model_path): model_source = self.model_path logger.info(f"Loading GTE model from local path: {self.model_path}") else: logger.warning(f"GTE_MODEL_PATH '{self.model_path}' not found; Falling back to hub model {self.model_name}") if model_source is None: model_source = self.model_name logger.info(f"Loading GTE model from the hub: {self.model_name}") if self.device == "cuda" and torch.cuda.is_available(): self.model = SentenceTransformer(model_source).to('cuda', dtype=torch.float16) elif self.device == "mps" and torch.backends.mps.is_available(): self.model = SentenceTransformer(model_source).to('mps', dtype=torch.float16) else: self.model = SentenceTransformer(model_source) def _process_data(self, df: pd.DataFrame) -> pd.DataFrame: """Process the input dataframe. Calculates an embedding as a np.array. Also pickles that array to prepare it to storage in the database. Args: df: Input dataframe from context Returns: Processed dataframe """ logger.info(f"Processing {len(df)} rows") if self.model is None: self._setup_model() # Example: Add a new column based on existing data result_df = df.copy() result_df['embedding'] = df['text'].apply(lambda x: self.model.encode(x, convert_to_numpy=True)) logger.info("Processing complete") return result_df def _store_results(self, con: sqlite3.Connection, df: pd.DataFrame): """Store results back to the database using batch updates.""" if df.empty: logger.info("No results to store") return logger.info(f"Storing {len(df)} results") # Convert numpy arrays to bytes for BLOB storage updates = [(row['embedding'], row['id']) for _, row in df.iterrows()] con.executemany( "UPDATE posts SET embedding = ? WHERE id = ?", updates ) con.commit() logger.info("Results stored successfully") def run(self, con: sqlite3.Connection, context: TransformContext) -> TransformContext: """Execute the transformation. This is the main entry point called by the pipeline. Args: con: SQLite database connection context: TransformContext containing input dataframe Returns: TransformContext with processed dataframe """ logger.info("Starting TextEmbeddingNode transformation") # Get input dataframe from context input_df = context.get_dataframe() # Validate input if input_df.empty: logger.warning("Empty dataframe provided to TextEmbeddingNdode") return context if 'text' not in input_df.columns: logger.warning("No 'text' column in context dataframe. Skipping TextEmbeddingNode") return context # Process the data result_df = self._process_data(input_df) # Store results (optional) self._store_results(con, result_df) logger.info("TextEmbeddingNode transformation complete") # Return new context with results return TransformContext(result_df) class UmapNode(TransformNode): """Calculates 2D coordinates from embeddings using UMAP dimensionality reduction. This node takes text embeddings and reduces them to 2D coordinates for visualization purposes. """ def __init__(self, n_neighbors: int = 10, min_dist: float = 0.1, n_components: int = 3, metric: str = "cosine", random_state: int = 42, model_path: str = None): """Initialize the UmapNode. Args: n_neighbors: Number of neighbors to consider for UMAP (default: 15) min_dist: Minimum distance between points in low-dimensional space (default: 0.1) n_components: Number of dimensions to reduce to (default: 2) metric: Distance metric to use (default: 'cosine') random_state: Random seed for reproducibility (default: 42) model_path: Path to save/load the fitted UMAP model (default: None, uses 'umap_model.pkl') """ self.n_neighbors = n_neighbors self.min_dist = min_dist self.n_components = n_components self.metric = metric self.random_state = random_state self.model_path = model_path or os.environ.get('UMAP_MODEL_PATH') self.reducer = None logger.info(f"Initialized UmapNode with n_neighbors={n_neighbors}, min_dist={min_dist}, " f"n_components={n_components}, metric={metric}, random_state={random_state}, " f"model_path={self.model_path}") def _process_data(self, df: pd.DataFrame) -> pd.DataFrame: """Process the input dataframe. Retrieves embeddings from BLOB storage, converts them back to numpy arrays, and applies UMAP dimensionality reduction to create 2D coordinates. Args: df: Input dataframe from context Returns: Processed dataframe with umap_x and umap_y columns """ logger.info(f"Processing {len(df)} rows") if not UMAP_AVAILABLE: raise ImportError("UMAP is required for UmapNode. Install with: pip install umap-learn") result_df = df.copy() # Convert BLOB embeddings back to numpy arrays if 'embedding' not in result_df.columns: logger.error("No 'embedding' column found in dataframe") raise ValueError("Input dataframe must contain 'embedding' column") logger.info("Converting embeddings from BLOB to numpy arrays") result_df['embedding'] = result_df['embedding'].apply( lambda x: np.frombuffer(x, dtype=np.float32) if x is not None else None ) # Filter out rows with None embeddings valid_rows = result_df['embedding'].notna() if not valid_rows.any(): logger.error("No valid embeddings found in dataframe") raise ValueError("No valid embeddings to process") logger.info(f"Found {valid_rows.sum()} valid embeddings out of {len(result_df)} rows") # Stack embeddings into a matrix embeddings_matrix = np.vstack(result_df.loc[valid_rows, 'embedding'].values) logger.info(f"Embeddings matrix shape: {embeddings_matrix.shape}") # Check if a saved UMAP model exists if self.model_path and os.path.exists(self.model_path): logger.info(f"Loading existing UMAP model from {self.model_path}") try: with open(self.model_path, 'rb') as f: self.reducer = pickle.load(f) logger.info("UMAP model loaded successfully") umap_coords = self.reducer.transform(embeddings_matrix) logger.info(f"UMAP transformation complete using existing model. Output shape: {umap_coords.shape}") except Exception as e: logger.warning(f"Failed to load UMAP model from {self.model_path}: {e}") logger.info("Falling back to fitting a new model") self.reducer = None # If no saved model or loading failed, fit a new model if self.reducer is None: logger.info("Fitting new UMAP reducer...") self.reducer = umap.UMAP( n_neighbors=self.n_neighbors, min_dist=self.min_dist, n_components=self.n_components, metric=self.metric, random_state=self.random_state ) umap_coords = self.reducer.fit_transform(embeddings_matrix) logger.info(f"UMAP transformation complete. Output shape: {umap_coords.shape}") # Save the fitted model try: umap_folder = '/'.join(self.model_path.split('/')[:1]) os.mkdir(umap_folder) with open(self.model_path, 'wb') as f: pickle.dump(self.reducer, f) logger.info(f"UMAP model saved to {self.model_path}") except Exception as e: logger.error(f"Failed to save UMAP model to {self.model_path}: {e}") # Add UMAP coordinates to dataframe result_df.loc[valid_rows, 'umap_x'] = umap_coords[:, 0] result_df.loc[valid_rows, 'umap_y'] = umap_coords[:, 1] result_df.loc[valid_rows, 'umap_z'] = umap_coords[:, 2] # Fill NaN for invalid rows result_df['umap_x'] = result_df['umap_x'].fillna(value=0) result_df['umap_y'] = result_df['umap_y'].fillna(value=0) result_df['umap_z'] = result_df['umap_z'].fillna(value=0) logger.info("Processing complete") return result_df def _store_results(self, con: sqlite3.Connection, df: pd.DataFrame): """Store UMAP coordinates back to the database. Args: con: Database connection df: Processed dataframe with umap_x and umap_y columns """ if df.empty: logger.info("No results to store") return logger.info(f"Storing {len(df)} results") # Batch update UMAP coordinates updates = [ (row['umap_x'], row['umap_y'], row['umap_z'], row['id']) for _, row in df.iterrows() if pd.notna(row.get('umap_x')) and pd.notna(row.get('umap_y')) and pd.notna(row.get('umap_z')) ] if updates: con.executemany( "UPDATE posts SET umap_x = ?, umap_y = ?, umap_z = ? WHERE id = ?", updates ) con.commit() logger.info(f"Stored {len(updates)} UMAP coordinate pairs successfully") else: logger.warning("No valid UMAP coordinates to store") def run(self, con: sqlite3.Connection, context: TransformContext) -> TransformContext: """Execute the transformation. This is the main entry point called by the pipeline. Args: con: SQLite database connection context: TransformContext containing input dataframe Returns: TransformContext with processed dataframe """ logger.info("Starting ExampleNode transformation") # Get input dataframe from context input_df = context.get_dataframe() # Validate input if input_df.empty: logger.warning("Empty dataframe provided to ExampleNode") return context # Process the data result_df = self._process_data(input_df) # Store results (optional) self._store_results(con, result_df) logger.info("ExampleNode transformation complete") # Return new context with results return TransformContext(result_df) class SimilarityNode(TransformNode): """Example transform node template. This node demonstrates the basic structure for creating new transformation nodes in the pipeline. """ def __init__(self, param1: str = "default_value", param2: int = 42, device: str = "cpu"): """Initialize the ExampleNode. Args: param1: Example string parameter param2: Example integer parameter device: Device to use for computations ('cpu', 'cuda', 'mps') """ self.param1 = param1 self.param2 = param2 self.device = device logger.info(f"Initialized ExampleNode with param1={param1}, param2={param2}") def _create_tables(self, con: sqlite3.Connection): """Create any necessary tables in the database. This is optional - only needed if your node creates new tables. """ logger.info("Creating example tables") con.execute(""" CREATE TABLE IF NOT EXISTS example_results ( id INTEGER PRIMARY KEY AUTOINCREMENT, post_id INTEGER, result_value TEXT, created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, FOREIGN KEY (post_id) REFERENCES posts(id) ) """) con.commit() def _process_data(self, df: pd.DataFrame) -> pd.DataFrame: """Process the input dataframe. This is where your main transformation logic goes. Args: df: Input dataframe from context Returns: Processed dataframe """ logger.info(f"Processing {len(df)} rows") # Example: Add a new column based on existing data result_df = df.copy() result_df['processed'] = True result_df['example_value'] = result_df['id'].apply(lambda x: f"{self.param1}_{x}") logger.info("Processing complete") return result_df def _store_results(self, con: sqlite3.Connection, df: pd.DataFrame): """Store results back to the database. This is optional - only needed if you want to persist results. Args: con: Database connection df: Processed dataframe to store """ if df.empty: logger.info("No results to store") return logger.info(f"Storing {len(df)} results") # Example: Store to database # df[['post_id', 'result_value']].to_sql( # 'example_results', # con, # if_exists='append', # index=False # ) con.commit() logger.info("Results stored successfully") def run(self, con: sqlite3.Connection, context: TransformContext) -> TransformContext: """Execute the transformation. This is the main entry point called by the pipeline. Args: con: SQLite database connection context: TransformContext containing input dataframe Returns: TransformContext with processed dataframe """ logger.info("Starting ExampleNode transformation") # Get input dataframe from context input_df = context.get_dataframe() # Validate input if input_df.empty: logger.warning("Empty dataframe provided to ExampleNode") return context # Create any necessary tables self._create_tables(con) # Process the data result_df = self._process_data(input_df) # Store results (optional) self._store_results(con, result_df) logger.info("ExampleNode transformation complete") # Return new context with results return TransformContext(result_df) def main(): logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', handlers=[ logging.FileHandler("app.log"), logging.StreamHandler(sys.stdout) ] ) logger = logging.getLogger("knack-transform") con = sqlite3.connect("/Users/linussilberstein/Documents/Knack-Scraper/data/knack.sqlite") df = pd.read_sql('select * from posts;', con) #node = TextEmbeddingNode(device='mps') #context = TransformContext(df) logger.info(df) #new_context = node.run(con, context) #logger.info(new_context.get_dataframe()) #umapNode = UmapNode() #new_context = umapNode.run(con, new_context) #logger.info(new_context.get_dataframe()) # Create 3D scatter plot of UMAP coordinates result_df = df fig = plt.figure(figsize=(12, 9)) ax = fig.add_subplot(111, projection='3d') scatter = ax.scatter( result_df['umap_x'], result_df['umap_y'], result_df['umap_z'], c=result_df['id'], cmap='viridis', alpha=0.6, s=50 ) ax.set_xlabel('UMAP X') ax.set_ylabel('UMAP Y') ax.set_zlabel('UMAP Z') ax.set_title('3D UMAP Visualization of Post Embeddings') plt.colorbar(scatter, ax=ax, label='Post Index') plt.tight_layout() plt.show() logger.info("3D plot displayed") if __name__ == '__main__': main()