{ "cells": [ { "cell_type": "markdown", "id": "8495708c", "metadata": {}, "source": [ "# Knack Database Visualization\n", "\n", "This notebook explores and visualizes the findings from the `knack.sqlite` database using Altair for interactive data visualizations." ] }, { "cell_type": "markdown", "id": "75cdd349", "metadata": {}, "source": [ "## 1. Import Required Libraries\n", "\n", "Import necessary libraries for data manipulation and visualization." ] }, { "cell_type": "code", "execution_count": 1, "id": "c99dde85", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Libraries imported successfully!\n" ] } ], "source": [ "import sqlite3\n", "import pandas as pd\n", "import altair as alt\n", "from pathlib import Path\n", "\n", "# Configure Altair\n", "alt.data_transformers.disable_max_rows()\n", "alt.renderers.enable('default')\n", "\n", "print(\"Libraries imported successfully!\")" ] }, { "cell_type": "markdown", "id": "198121f5", "metadata": {}, "source": [ "## 2. Connect to SQLite Database\n", "\n", "Establish connection to the knack.sqlite database and explore its structure." ] }, { "cell_type": "code", "execution_count": 2, "id": "98ddc787", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Tables in the database:\n", " - posts\n", " - posttags\n", " - postcategories\n", " - tags\n", " - categories\n", " - authors\n", " - post_authors\n" ] } ], "source": [ "# Connect to the database\n", "db_path = Path('../data/knack.transformed.sqlite')\n", "conn = sqlite3.connect(db_path)\n", "cursor = conn.cursor()\n", "\n", "# Get all table names\n", "cursor.execute(\"SELECT name FROM sqlite_master WHERE type='table';\")\n", "tables = cursor.fetchall()\n", "\n", "print(\"Tables in the database:\")\n", "for table in tables:\n", " print(f\" - {table[0]}\")" ] }, { "cell_type": "markdown", "id": "4f216388", "metadata": {}, "source": [ "## 3. Explore Database Schema\n", "\n", "Examine the structure of each table to understand the data." ] }, { "cell_type": "code", "execution_count": 3, "id": "e51dd105", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "============================================================\n", "Table: posts\n", "============================================================\n", "\n", "Columns:\n", " index INTEGER \n", " id INTEGER \n", " title TEXT \n", " author TEXT \n", " date TIMESTAMP \n", " category TEXT \n", " url TEXT \n", " img_link TEXT \n", " tags TEXT \n", " text TEXT \n", " html TEXT \n", " scraped_at TIMESTAMP \n", " is_cleaned BOOLEAN \n", " embedding BLOB \n", " umap_x REAL \n", " umap_y REAL \n", "\n", "Total rows: 3678\n", "\n", "============================================================\n", "Table: posttags\n", "============================================================\n", "\n", "Columns:\n", " post_id INTEGER \n", " tag_id INTEGER \n", "\n", "Total rows: 14272\n", "\n", "============================================================\n", "Table: postcategories\n", "============================================================\n", "\n", "Columns:\n", " post_id INTEGER \n", " category_id INTEGER \n", "\n", "Total rows: 3691\n", "\n", "============================================================\n", "Table: tags\n", "============================================================\n", "\n", "Columns:\n", " id INTEGER \n", " tag TEXT \n", "\n", "Total rows: 64\n", "\n", "============================================================\n", "Table: categories\n", "============================================================\n", "\n", "Columns:\n", " id INTEGER \n", " category TEXT \n", "\n", "Total rows: 6\n", "\n", "============================================================\n", "Table: authors\n", "============================================================\n", "\n", "Columns:\n", " id INTEGER \n", " name TEXT \n", " type TEXT \n", " created_at TIMESTAMP \n", "\n", "Total rows: 1143\n", "\n", "============================================================\n", "Table: post_authors\n", "============================================================\n", "\n", "Columns:\n", " post_id INTEGER \n", " author_id INTEGER \n", "\n", "Total rows: 4934\n" ] } ], "source": [ "# Examine schema for each table\n", "for table in tables:\n", " table_name = table[0]\n", " print(f\"\\n{'='*60}\")\n", " print(f\"Table: {table_name}\")\n", " print('='*60)\n", " \n", " # Get column information\n", " cursor.execute(f\"PRAGMA table_info({table_name})\")\n", " columns = cursor.fetchall()\n", " \n", " print(\"\\nColumns:\")\n", " for col in columns:\n", " print(f\" {col[1]:20} {col[2]:15} {'NOT NULL' if col[3] else ''}\")\n", " \n", " # Get row count\n", " cursor.execute(f\"SELECT COUNT(*) FROM {table_name}\")\n", " count = cursor.fetchone()[0]\n", " print(f\"\\nTotal rows: {count}\")" ] }, { "cell_type": "markdown", "id": "25ffce32", "metadata": {}, "source": [ "## 4. Load Data from Database\n", "\n", "Load the data from tables into pandas DataFrames for analysis and visualization." ] }, { "cell_type": "code", "execution_count": 4, "id": "1459d68a", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Loaded posts: 3678 rows, 16 columns\n", "Loaded posttags: 14272 rows, 2 columns\n", "Loaded postcategories: 3691 rows, 2 columns\n", "Loaded tags: 64 rows, 2 columns\n", "Loaded categories: 6 rows, 2 columns\n", "Loaded authors: 1143 rows, 4 columns\n", "Loaded post_authors: 4934 rows, 2 columns\n", "\n", "Available dataframes: ['posts', 'posttags', 'postcategories', 'tags', 'categories', 'authors', 'post_authors']\n" ] } ], "source": [ "# Load all tables into DataFrames\n", "dataframes = {}\n", "\n", "for table in tables:\n", " table_name = table[0]\n", " query = f\"SELECT * FROM {table_name}\"\n", " df = pd.read_sql_query(query, conn)\n", " dataframes[table_name] = df\n", " print(f\"Loaded {table_name}: {df.shape[0]} rows, {df.shape[1]} columns\")\n", "\n", "# Display available dataframes\n", "print(f\"\\nAvailable dataframes: {list(dataframes.keys())}\")" ] }, { "cell_type": "markdown", "id": "c34b1bc5", "metadata": {}, "source": [ "## 5. Explore Data Structure\n", "\n", "Examine the first dataframe to understand the data better." ] }, { "cell_type": "code", "execution_count": 5, "id": "91616185", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Exploring: posts\n", "\n", "Shape: (3678, 16)\n", "\n", "Data types:\n", "index int64\n", "id int64\n", "title object\n", "author object\n", "date object\n", "category object\n", "url object\n", "img_link object\n", "tags object\n", "text object\n", "html object\n", "scraped_at object\n", "is_cleaned int64\n", "embedding object\n", "umap_x float64\n", "umap_y float64\n", "dtype: object\n", "\n", "Missing values:\n", "index 0\n", "id 0\n", "title 0\n", "author 3\n", "date 3\n", "category 3\n", "url 0\n", "img_link 148\n", "tags 4\n", "text 0\n", "html 0\n", "scraped_at 0\n", "is_cleaned 0\n", "embedding 0\n", "umap_x 0\n", "umap_y 0\n", "dtype: int64\n" ] } ], "source": [ "# Select the first table to explore (or specify a specific table)\n", "if dataframes:\n", " first_table = list(dataframes.keys())[0]\n", " df = dataframes[first_table]\n", " \n", " print(f\"Exploring: {first_table}\")\n", " print(f\"\\nShape: {df.shape}\")\n", " print(f\"\\nData types:\\n{df.dtypes}\")\n", " \n", " print(f\"\\nMissing values:\")\n", " print(df.isnull().sum())" ] }, { "cell_type": "markdown", "id": "f9b0e8d7", "metadata": {}, "source": [ "## 7. Create Time Series Visualizations\n", "\n", "If the data contains temporal information, create time series visualizations." ] }, { "cell_type": "code", "execution_count": 6, "id": "2190a06b", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Found potential date columns: ['date']\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/var/folders/j5/hpq7xq6x1p18cds26_lb_3gr0000gn/T/ipykernel_46007/4118830821.py:19: FutureWarning: 'M' is deprecated and will be removed in a future version, please use 'ME' instead.\n", " time_series = df.groupby(pd.Grouper(key=date_col, freq='M')).size().reset_index(name='count')\n" ] }, { "data": { "text/html": [ "\n", "\n", "
\n", "" ], "text/plain": [ "alt.Chart(...)" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# Check for date/time columns and create time series visualizations\n", "if dataframes:\n", " df = dataframes[list(dataframes.keys())[0]]\n", " \n", " # Look for columns that might contain dates (check column names)\n", " date_like_cols = [col for col in df.columns if any(\n", " keyword in col.lower() for keyword in ['date', 'time', 'created', 'updated', 'timestamp']\n", " )]\n", " \n", " if date_like_cols:\n", " print(f\"Found potential date columns: {date_like_cols}\")\n", " \n", " # Try to convert the first date-like column to datetime\n", " date_col = date_like_cols[0]\n", " try:\n", " df[date_col] = pd.to_datetime(df[date_col], errors='coerce')\n", " \n", " # Create a time series chart - count records over time\n", " time_series = df.groupby(pd.Grouper(key=date_col, freq='M')).size().reset_index(name='count')\n", " \n", " chart = alt.Chart(time_series).mark_line(point=True).encode(\n", " x=alt.X(f'{date_col}:T', title='Date'),\n", " y=alt.Y('count:Q', title='Count'),\n", " tooltip=[date_col, 'count']\n", " ).properties(\n", " title=f'Records Over Time',\n", " width=700,\n", " height=400\n", " ).interactive()\n", " \n", " display(chart)\n", " except Exception as e:\n", " print(f\"Could not create time series chart: {e}\")\n", " else:\n", " print(\"No date/time columns found\")" ] }, { "cell_type": "markdown", "id": "793026df", "metadata": {}, "source": [ "### Articles per Category\n", "\n", "Visualize the distribution of articles across different categories." ] }, { "cell_type": "code", "execution_count": 7, "id": "22c47b71", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "dict_keys(['posts', 'posttags', 'postcategories', 'tags', 'categories', 'authors', 'post_authors'])" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "dataframes.keys()" ] }, { "cell_type": "code", "execution_count": 8, "id": "1ac9fae5", "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", "\n", "
\n", "" ], "text/plain": [ "alt.Chart(...)" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "Total categories: 6\n", "Most articles in a category: 2098\n", "Average articles per category: 615.17\n" ] } ], "source": [ "# Check if categorisation data exists and create histogram\n", "if 'postcategories' in dataframes and 'categories' in dataframes:\n", " df_post_cat = dataframes['postcategories']\n", " df_categories = dataframes['categories']\n", " \n", " # Join postcategories with categories to get category names\n", " if 'category_id' in df_post_cat.columns and 'id' in df_categories.columns and 'category' in df_categories.columns:\n", " # Merge the two tables\n", " df_merged = df_post_cat.merge(\n", " df_categories[['id', 'category']], \n", " left_on='category_id', \n", " right_on='id',\n", " how='left'\n", " )\n", " \n", " # Count articles per category\n", " category_counts = df_merged['category'].value_counts().reset_index()\n", " category_counts.columns = ['category', 'article_count']\n", " \n", " # Sort by count descending\n", " category_counts = category_counts.sort_values('article_count', ascending=False)\n", " \n", " chart = alt.Chart(category_counts).mark_bar().encode(\n", " x=alt.X('category:N', sort='-y', title='Category', axis=alt.Axis(labelAngle=-45)),\n", " y=alt.Y('article_count:Q', title='Number of Articles'),\n", " color=alt.Color('article_count:Q', scale=alt.Scale(scheme='viridis'), legend=None),\n", " tooltip=['category', alt.Tooltip('article_count:Q', title='Articles')]\n", " ).properties(\n", " title='Distribution of Articles per Category',\n", " width=700,\n", " height=450\n", " ).interactive()\n", " \n", " display(chart)\n", " \n", " # Show summary statistics\n", " print(f\"\\nTotal categories: {len(category_counts)}\")\n", " print(f\"Most articles in a category: {category_counts['article_count'].max()}\")\n", " print(f\"Average articles per category: {category_counts['article_count'].mean():.2f}\")\n", " else:\n", " print(\"Could not find required columns for joining tables\")\n", "else:\n", " print(\"Need both 'postcategories' and 'categories' tables in database\")" ] }, { "cell_type": "markdown", "id": "56c89ec3", "metadata": {}, "source": [ "### Articles per Tag\n", "\n", "Visualize the distribution of articles across different tags." ] }, { "cell_type": "code", "execution_count": 9, "id": "95a28c5f", "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", "\n", "
\n", "" ], "text/plain": [ "alt.Chart(...)" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "Total tags: 64\n", "Most articles with a tag: 1954\n", "Average articles per tag: 223.00\n", "Median articles per tag: 101.50\n" ] } ], "source": [ "# Check if tag data exists and create histogram\n", "if 'posttags' in dataframes and 'tags' in dataframes:\n", " df_post_tags = dataframes['posttags']\n", " df_tags = dataframes['tags']\n", " \n", " # Join posttags with tags to get tag names\n", " if 'tag_id' in df_post_tags.columns and 'id' in df_tags.columns and 'tag' in df_tags.columns:\n", " # Merge the two tables\n", " df_merged = df_post_tags.merge(\n", " df_tags[['id', 'tag']], \n", " left_on='tag_id', \n", " right_on='id',\n", " how='left'\n", " )\n", " \n", " # Count articles per tag\n", " tag_counts = df_merged['tag'].value_counts().reset_index()\n", " tag_counts.columns = ['tag', 'article_count']\n", " \n", " # Show top 30 tags for readability\n", " tag_counts_top = tag_counts.head(30).sort_values('article_count', ascending=False)\n", " \n", " chart = alt.Chart(tag_counts_top).mark_bar().encode(\n", " x=alt.X('tag:N', sort='-y', title='Tag', axis=alt.Axis(labelAngle=-45)),\n", " y=alt.Y('article_count:Q', title='Number of Articles'),\n", " color=alt.Color('article_count:Q', scale=alt.Scale(scheme='oranges'), legend=None),\n", " tooltip=['tag', alt.Tooltip('article_count:Q', title='Articles')]\n", " ).properties(\n", " title='Distribution of Articles per Tag (Top 30)',\n", " width=700,\n", " height=450\n", " ).interactive()\n", " \n", " display(chart)\n", " \n", " # Show summary statistics\n", " print(f\"\\nTotal tags: {len(tag_counts)}\")\n", " print(f\"Most articles with a tag: {tag_counts['article_count'].max()}\")\n", " print(f\"Average articles per tag: {tag_counts['article_count'].mean():.2f}\")\n", " print(f\"Median articles per tag: {tag_counts['article_count'].median():.2f}\")\n", " else:\n", " print(\"Could not find required columns for joining tables\")\n", "else:\n", " print(\"Need both 'posttags' and 'tags' tables in database\")" ] }, { "cell_type": "markdown", "id": "549e6f38", "metadata": {}, "source": [ "### Articles per Author\n", "\n", "Visualize the distribution of articles across different authors." ] }, { "cell_type": "code", "execution_count": 10, "id": "a49be6f5", "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", "\n", "
\n", "" ], "text/plain": [ "alt.Chart(...)" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "Total authors: 1126\n", "Most articles with a author: 700\n", "Average articles per author: 4.38\n", "Median articles per author: 1.00\n" ] } ], "source": [ "# Check if author data exists and create histogram\n", "if 'post_authors' in dataframes and 'authors' in dataframes:\n", " df_post_tags = dataframes['post_authors']\n", " df_tags = dataframes['authors']\n", " \n", " # Join posttags with tags to get tag names\n", " if 'author_id' in df_post_tags.columns and 'id' in df_tags.columns and 'name' in df_tags.columns:\n", " # Merge the two tables\n", " df_merged = df_post_tags.merge(\n", " df_tags[['id', 'name']], \n", " left_on='author_id', \n", " right_on='id',\n", " how='left'\n", " )\n", " \n", " # Count articles per tag\n", " tag_counts = df_merged['name'].value_counts().reset_index()\n", " tag_counts.columns = ['author', 'article_count']\n", " \n", " # Show top 30 tags for readability\n", " tag_counts_top = tag_counts.head(30).sort_values('article_count', ascending=False)\n", " \n", " chart = alt.Chart(tag_counts_top).mark_bar().encode(\n", " x=alt.X('author:N', sort='-y', title='Author', axis=alt.Axis(labelAngle=-45)),\n", " y=alt.Y('article_count:Q', title='Number of Articles'),\n", " color=alt.Color('article_count:Q', scale=alt.Scale(scheme='oranges'), legend=None),\n", " tooltip=['author', alt.Tooltip('article_count:Q', title='Articles')]\n", " ).properties(\n", " title='Distribution of Articles per Author (Top 30)',\n", " width=700,\n", " height=450\n", " ).interactive()\n", " \n", " display(chart)\n", " \n", " # Show summary statistics\n", " print(f\"\\nTotal authors: {len(tag_counts)}\")\n", " print(f\"Most articles with a author: {tag_counts['article_count'].max()}\")\n", " print(f\"Average articles per author: {tag_counts['article_count'].mean():.2f}\")\n", " print(f\"Median articles per author: {tag_counts['article_count'].median():.2f}\")\n", " else:\n", " print(\"Could not find required columns for joining tables\")\n", "else:\n", " print(\"Need both 'post_authors' and 'authors' tables in database\")" ] }, { "cell_type": "markdown", "id": "7f6f1539", "metadata": {}, "source": [ "### UMAP Visualization\n", "\n", "Visualize the UMAP dimensionality reduction in 2D space." ] }, { "cell_type": "code", "execution_count": 11, "id": "196be503", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Found UMAP coordinates in table: posts\n" ] }, { "data": { "text/html": [ "\n", "\n", "
\n", "" ], "text/plain": [ "alt.Chart(...)" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "Total points: 5021\n", "Unique authors: 1127\n", "Top 15 authors shown in legend (others grouped as 'Other')\n" ] } ], "source": [ "# Check for UMAP coordinates and create scatter plot with author coloring\n", "umap_found = False\n", "\n", "# Look for tables with umap_x and umap_y columns\n", "for table_name, df in dataframes.items():\n", " if 'umap_x' in df.columns and 'umap_y' in df.columns:\n", " print(f\"Found UMAP coordinates in table: {table_name}\")\n", " umap_found = True\n", " \n", " # Check if we can join with authors\n", " if 'posts' in dataframes and 'post_authors' in dataframes and 'authors' in dataframes:\n", " df_posts = dataframes['posts']\n", " df_post_authors = dataframes['post_authors']\n", " df_authors = dataframes['authors']\n", " \n", " # Check if the current table has necessary columns for joining\n", " if 'id' in df.columns or 'post_id' in df.columns:\n", " post_id_col = 'id' if 'id' in df.columns else 'post_id'\n", " \n", " # Start with posts table that has UMAP coordinates\n", " df_umap = df[[post_id_col, 'umap_x', 'umap_y']].dropna(subset=['umap_x', 'umap_y'])\n", " \n", " # Join with post_authors to get author_id\n", " if 'post_id' in df_post_authors.columns and 'author_id' in df_post_authors.columns:\n", " df_umap = df_umap.merge(\n", " df_post_authors[['post_id', 'author_id']],\n", " left_on=post_id_col,\n", " right_on='post_id',\n", " how='left'\n", " )\n", " \n", " # Join with authors to get author name\n", " if 'id' in df_authors.columns and 'name' in df_authors.columns:\n", " df_umap = df_umap.merge(\n", " df_authors[['id', 'name']],\n", " left_on='author_id',\n", " right_on='id',\n", " how='left'\n", " )\n", " \n", " # Rename name column to author for clarity\n", " df_umap = df_umap.rename(columns={'name': 'author'})\n", " \n", " # Fill missing authors with 'Unknown'\n", " df_umap['author'] = df_umap['author'].fillna('Unknown')\n", " \n", " # Get top 15 authors by count for better visualization\n", " top_authors = df_umap['author'].value_counts().head(15).index.tolist()\n", " df_umap['author_group'] = df_umap['author'].apply(\n", " lambda x: x if x in top_authors else 'Other'\n", " )\n", " \n", " # Create scatter plot with author coloring\n", " scatter = alt.Chart(df_umap).mark_circle(size=40, opacity=0.7).encode(\n", " x=alt.X('umap_x:Q', title='UMAP Dimension 1'),\n", " y=alt.Y('umap_y:Q', title='UMAP Dimension 2'),\n", " color=alt.Color('author_group:N', title='Author', scale=alt.Scale(scheme='tableau20')),\n", " tooltip=['author', 'umap_x', 'umap_y']\n", " ).properties(\n", " title='UMAP 2D Projection by Author',\n", " width=800,\n", " height=600\n", " ).interactive()\n", " \n", " display(scatter)\n", " \n", " print(f\"\\nTotal points: {len(df_umap)}\")\n", " print(f\"Unique authors: {df_umap['author'].nunique()}\")\n", " print(f\"Top 15 authors shown in legend (others grouped as 'Other')\")\n", " else:\n", " print(\"Could not find required columns in authors table\")\n", " else:\n", " print(\"Could not find required columns in post_authors table\")\n", " else:\n", " print(f\"Could not find post_id column in {table_name} table\")\n", " else:\n", " # Fallback: create plot without author coloring\n", " df_umap = df[['umap_x', 'umap_y']].dropna()\n", " \n", " scatter = alt.Chart(df_umap).mark_circle(size=30, opacity=0.6).encode(\n", " x=alt.X('umap_x:Q', title='UMAP Dimension 1'),\n", " y=alt.Y('umap_y:Q', title='UMAP Dimension 2'),\n", " tooltip=['umap_x', 'umap_y']\n", " ).properties(\n", " title='UMAP 2D Projection',\n", " width=700,\n", " height=600\n", " ).interactive()\n", " \n", " display(scatter)\n", " \n", " print(f\"\\nTotal points: {len(df_umap)}\")\n", " print(\"Note: Author coloring not available (missing required tables)\")\n", " \n", " break\n", "\n", "if not umap_found:\n", " print(\"No UMAP coordinates (umap_x, umap_y) found in any table\")" ] }, { "cell_type": "markdown", "id": "c57a57fa", "metadata": {}, "source": [ "### 3D Embedding Visualization\n", "\n", "Visualize the high-dimensional embeddings in 3D space using PCA for dimensionality reduction.\n" ] }, { "cell_type": "code", "execution_count": 16, "id": "42352fef", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Found embedding column in posts table\n", "No valid embeddings found\n" ] } ], "source": [ "import numpy as np\n", "import plotly.graph_objects as go\n", "import json\n", "\n", "# Check if posts table has embedding column\n", "if 'posts' in dataframes:\n", " df_posts = dataframes['posts']\n", " \n", " if 'embedding' in df_posts.columns:\n", " print(\"Found embedding column in posts table\")\n", " \n", " # Extract embeddings and convert to array\n", " embeddings_3d = []\n", " valid_indices = []\n", " \n", " for idx, embedding in enumerate(df_posts['embedding']):\n", " try:\n", " # Handle different embedding formats (string, list, array, bytes)\n", " if isinstance(embedding, bytes):\n", " emb_array = np.array(json.loads(embedding.decode('utf-8')))\n", " elif isinstance(embedding, str):\n", " emb_array = np.array(json.loads(embedding))\n", " elif isinstance(embedding, (list, tuple)):\n", " emb_array = np.array(embedding)\n", " else:\n", " emb_array = embedding\n", " \n", " if emb_array is not None and len(emb_array) >= 3:\n", " # Take only the first 3 dimensions\n", " embeddings_3d.append(emb_array[:3])\n", " valid_indices.append(idx)\n", " except Exception as e:\n", " continue\n", " \n", " if embeddings_3d:\n", " # Convert to numpy array and ensure it's 2D (n_embeddings, 3)\n", " embeddings_3d = np.array(embeddings_3d)\n", " if embeddings_3d.ndim == 1:\n", " embeddings_3d = embeddings_3d.reshape(-1, 3)\n", " print(f\"Extracted {len(embeddings_3d)} embeddings with shape {embeddings_3d.shape}\")\n", " \n", " # Create a dataframe with 3D coordinates\n", " df_3d = pd.DataFrame({\n", " 'dim_1': embeddings_3d[:, 0],\n", " 'dim_2': embeddings_3d[:, 1],\n", " 'dim_3': embeddings_3d[:, 2]\n", " })\n", " \n", " # Try to add author information\n", " if 'post_authors' in dataframes and 'authors' in dataframes:\n", " try:\n", " df_post_authors = dataframes['post_authors']\n", " df_authors = dataframes['authors']\n", " \n", " # Get author names for valid indices\n", " authors = []\n", " for idx in valid_indices:\n", " post_id = df_posts.iloc[idx]['id'] if 'id' in df_posts.columns else None\n", " if post_id is not None:\n", " author_rows = df_post_authors[df_post_authors['post_id'] == post_id]\n", " if not author_rows.empty:\n", " author_id = author_rows.iloc[0]['author_id']\n", " author_name = df_authors[df_authors['id'] == author_id]['name'].values\n", " authors.append(author_name[0] if len(author_name) > 0 else 'Unknown')\n", " else:\n", " authors.append('Unknown')\n", " else:\n", " authors.append('Unknown')\n", " \n", " df_3d['author'] = authors\n", " \n", " # Get top 10 authors for coloring\n", " top_authors = df_3d['author'].value_counts().head(10).index.tolist()\n", " df_3d['author_group'] = df_3d['author'].apply(\n", " lambda x: x if x in top_authors else 'Other'\n", " )\n", " \n", " # Create 3D scatter plot with Plotly\n", " fig = go.Figure(data=[go.Scatter3d(\n", " x=df_3d['dim_1'],\n", " y=df_3d['dim_2'],\n", " z=df_3d['dim_3'],\n", " mode='markers',\n", " marker=dict(\n", " size=4,\n", " color=[top_authors.index(author) if author in top_authors else len(top_authors) \n", " for author in df_3d['author_group']],\n", " colorscale='Viridis',\n", " showscale=True,\n", " colorbar=dict(title=\"Author Group\"),\n", " opacity=0.7\n", " ),\n", " text=df_3d['author'],\n", " hovertemplate='%{text}
Dim 1: %{x:.3f}
Dim 2: %{y:.3f}
Dim 3: %{z:.3f}'\n", " )])\n", " except Exception as e:\n", " print(f\"Could not add author coloring: {e}\")\n", " # Fallback: create plot without author coloring\n", " fig = go.Figure(data=[go.Scatter3d(\n", " x=df_3d['dim_1'],\n", " y=df_3d['dim_2'],\n", " z=df_3d['dim_3'],\n", " mode='markers',\n", " marker=dict(size=4, opacity=0.7, color='blue'),\n", " hovertemplate='Dim 1: %{x:.3f}
Dim 2: %{y:.3f}
Dim 3: %{z:.3f}'\n", " )])\n", " else:\n", " # Create 3D scatter plot without author coloring\n", " fig = go.Figure(data=[go.Scatter3d(\n", " x=df_3d['dim_1'],\n", " y=df_3d['dim_2'],\n", " z=df_3d['dim_3'],\n", " mode='markers',\n", " marker=dict(size=4, opacity=0.7, color='blue'),\n", " hovertemplate='Dim 1: %{x:.3f}
Dim 2: %{y:.3f}
Dim 3: %{z:.3f}'\n", " )])\n", " \n", " fig.update_layout(\n", " title='3D Visualization of Post Embeddings (First 3 Dimensions)',\n", " scene=dict(\n", " xaxis_title='Embedding Dimension 1',\n", " yaxis_title='Embedding Dimension 2',\n", " zaxis_title='Embedding Dimension 3'\n", " ),\n", " width=900,\n", " height=700\n", " )\n", " \n", " fig.show()\n", " else:\n", " print(\"No valid embeddings found\")\n", " else:\n", " print(\"No 'embedding' column found in posts table\")\n", "else:\n", " print(\"No 'posts' table found in database\")\n" ] } ], "metadata": { "kernelspec": { "display_name": "knack-viz", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.14" } }, "nbformat": 4, "nbformat_minor": 5 }