Makes transformer script executable via cli

This commit is contained in:
quorploop 2026-01-27 20:19:05 +01:00
parent 8fae350b34
commit 7c2e34906e
11 changed files with 648 additions and 37 deletions

View file

@ -11,7 +11,6 @@ RUN apt-get update && apt-get install -y --no-install-recommends \
liblapack-dev \
pkg-config \
curl \
jq \
&& rm -rf /var/lib/apt/lists/*
ENV GLINER_MODEL_ID=urchade/gliner_multi-v2.1
@ -40,7 +39,7 @@ COPY *.py .
# Create cron job that runs every weekend (Sunday at 3 AM) 0 3 * * 0
# Testing every 30 Minutes */30 * * * *
RUN echo "*/15 * * * * cd /app && /usr/local/bin/python main.py >> /proc/1/fd/1 2>&1" > /etc/cron.d/knack-transform
RUN echo "*/30 * * * * cd /app && /usr/local/bin/python main.py >> /proc/1/fd/1 2>&1" > /etc/cron.d/knack-transform
RUN chmod 0644 /etc/cron.d/knack-transform
RUN crontab /etc/cron.d/knack-transform

View file

@ -418,3 +418,52 @@ class FuzzyAuthorNode(TransformNode):
# Return new context with results
return TransformContext(input_df)
def main():
import sys
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[
logging.StreamHandler(sys.stdout)
]
)
logger = logging.getLogger("knack-transform")
# Connect to database
db_path = "/Users/linussilberstein/Documents/Knack-Scraper/data/knack.sqlite"
con = sqlite3.connect(db_path)
try:
# Read posts from database
df = pd.read_sql('SELECT * FROM posts;', con)
logger.info(f"Loaded {len(df)} posts from database")
# Create context
context = TransformContext(df)
# Run NerAuthorNode
logger.info("Running NerAuthorNode...")
ner_node = NerAuthorNode(device="mps")
context = ner_node.run(con, context)
logger.info("NerAuthorNode complete")
# Run FuzzyAuthorNode
logger.info("Running FuzzyAuthorNode...")
fuzzy_node = FuzzyAuthorNode(max_l_dist=1)
context = fuzzy_node.run(con, context)
logger.info("FuzzyAuthorNode complete")
logger.info("All author nodes completed successfully!")
except Exception as e:
logger.error(f"Error during transformation: {e}", exc_info=True)
raise
finally:
con.close()
if __name__ == '__main__':
main()

View file

@ -40,7 +40,7 @@ class TextEmbeddingNode(TransformNode):
of posts.
"""
def __init__(self,
model_name: str = "thenlper/gte-small",
model_name: str = "thenlper/gte-large",
model_path: str = None,
device: str = "cpu"):
"""Initialize the ExampleNode.
@ -64,8 +64,12 @@ class TextEmbeddingNode(TransformNode):
model_source = None
if self.model_path:
if os.path.exists(self.model_path):
model_source = self.model_path
logger.info(f"Loading GTE model from local path: {self.model_path}")
# Check if it's a valid model directory
if os.path.exists(os.path.join(self.model_path, 'config.json')):
model_source = self.model_path
logger.info(f"Loading GTE model from local path: {self.model_path}")
else:
logger.warning(f"GTE_MODEL_PATH '{self.model_path}' found but missing config.json; Falling back to hub model {self.model_name}")
else:
logger.warning(f"GTE_MODEL_PATH '{self.model_path}' not found; Falling back to hub model {self.model_name}")
@ -73,12 +77,17 @@ class TextEmbeddingNode(TransformNode):
model_source = self.model_name
logger.info(f"Loading GTE model from the hub: {self.model_name}")
if self.device == "cuda" and torch.cuda.is_available():
self.model = SentenceTransformer(model_source).to('cuda', dtype=torch.float16)
elif self.device == "mps" and torch.backends.mps.is_available():
self.model = SentenceTransformer(model_source).to('mps', dtype=torch.float16)
else:
self.model = SentenceTransformer(model_source)
try:
if self.device == "cuda" and torch.cuda.is_available():
self.model = SentenceTransformer(model_source).to('cuda', dtype=torch.float16)
elif self.device == "mps" and torch.backends.mps.is_available():
self.model = SentenceTransformer(model_source).to('mps', dtype=torch.float16)
else:
self.model = SentenceTransformer(model_source)
logger.info(f"Successfully loaded GTE model from: {model_source}")
except Exception as e:
logger.error(f"Failed to load GTE model from {model_source}: {e}")
raise
def _process_data(self, df: pd.DataFrame) -> pd.DataFrame:
"""Process the input dataframe.

View file

@ -1,16 +1,35 @@
#!/usr/bin/env bash
set -euo pipefail
if [ -d "$GLINER_MODEL_PATH" ] && find "$GLINER_MODEL_PATH" -type f | grep -q .; then
if [ -d "$GLINER_MODEL_PATH" ] && [ -f "$GLINER_MODEL_PATH/config.json" ]; then
echo "GLiNER model already present at $GLINER_MODEL_PATH"
exit 0
fi
echo "Downloading GLiNER model to $GLINER_MODEL_PATH"
echo "Downloading GLiNER model $GLINER_MODEL_ID to $GLINER_MODEL_PATH"
mkdir -p "$GLINER_MODEL_PATH"
curl -sL "https://huggingface.co/api/models/${GLINER_MODEL_ID}" | jq -r '.siblings[].rfilename' | while read -r file; do
target="${GLINER_MODEL_PATH}/${file}"
mkdir -p "$(dirname "$target")"
echo "Downloading ${file}"
curl -sL "https://huggingface.co/${GLINER_MODEL_ID}/resolve/main/${file}" -o "$target"
done
# Use Python with huggingface_hub for reliable model downloading
python3 << 'EOF'
import os
from huggingface_hub import snapshot_download
model_id = os.environ.get('GLINER_MODEL_ID')
model_path = os.environ.get('GLINER_MODEL_PATH')
if not model_id or not model_path:
raise ValueError(f"GLINER_MODEL_ID and GLINER_MODEL_PATH environment variables must be set")
try:
print(f"Downloading model {model_id} to {model_path}")
snapshot_download(
repo_id=model_id,
cache_dir=None, # Don't use cache, download directly
local_dir=model_path,
local_dir_use_symlinks=False # Don't use symlinks, copy files
)
print(f"Successfully downloaded GLiNER model to {model_path}")
except Exception as e:
print(f"Error downloading GLiNER model: {e}")
exit(1)
EOF

View file

@ -1,16 +1,35 @@
#!/usr/bin/env bash
set -euo pipefail
if [ -d "$GTE_MODEL_PATH" ] && find "$GTE_MODEL_PATH" -type f | grep -q .; then
if [ -d "$GTE_MODEL_PATH" ] && [ -f "$GTE_MODEL_PATH/config.json" ]; then
echo "GTE model already present at $GTE_MODEL_PATH"
exit 0
fi
echo "Downloading GTE model to $GTE_MODEL_PATH"
echo "Downloading GTE model $GTE_MODEL_ID to $GTE_MODEL_PATH"
mkdir -p "$GTE_MODEL_PATH"
curl -sL "https://huggingface.co/api/models/${GTE_MODEL_ID}" | jq -r '.siblings[].rfilename' | while read -r file; do
target="${GTE_MODEL_PATH}/${file}"
mkdir -p "$(dirname "$target")"
echo "Downloading ${file}"
curl -sL "https://huggingface.co/${GTE_MODEL_ID}/resolve/main/${file}" -o "$target"
done
# Use Python with huggingface_hub for reliable model downloading
python3 << 'EOF'
import os
from huggingface_hub import snapshot_download
model_id = os.environ.get('GTE_MODEL_ID')
model_path = os.environ.get('GTE_MODEL_PATH')
if not model_id or not model_path:
raise ValueError(f"GTE_MODEL_ID and GTE_MODEL_PATH environment variables must be set")
try:
print(f"Downloading model {model_id} to {model_path}")
snapshot_download(
repo_id=model_id,
cache_dir=None, # Don't use cache, download directly
local_dir=model_path,
local_dir_use_symlinks=False # Don't use symlinks, copy files
)
print(f"Successfully downloaded GTE model to {model_path}")
except Exception as e:
print(f"Error downloading GTE model: {e}")
exit(1)
EOF

View file

@ -1,4 +1,5 @@
#! python3
import argparse
import logging
import os
import sqlite3
@ -23,9 +24,10 @@ logging.basicConfig(
logger = logging.getLogger("knack-transform")
def setup_database_connection():
def setup_database_connection(db_path=None):
"""Create connection to the SQLite database."""
db_path = os.environ.get('DB_PATH', '/data/knack.sqlite')
if db_path is None:
db_path = os.environ.get('DB_PATH', '/data/knack.sqlite')
logger.info(f"Connecting to database: {db_path}")
return sqlite3.connect(db_path)
@ -35,13 +37,12 @@ def table_exists(tablename: str, con: sqlite3.Connection):
query = "SELECT 1 FROM sqlite_master WHERE type='table' AND name=? LIMIT 1"
return len(con.execute(query, [tablename]).fetchall()) > 0
def main():
"""Main entry point for the transform pipeline."""
logger.info("Starting transform pipeline")
def run_from_database(db_path=None):
"""Run the pipeline using database as input and output."""
logger.info("Starting transform pipeline (database mode)")
try:
con = setup_database_connection()
con = setup_database_connection(db_path)
logger.info("Database connection established")
# Check if posts table exists
@ -73,8 +74,9 @@ def main():
max_workers = int(os.environ.get('MAX_WORKERS', 4))
pipeline = create_default_pipeline(device=device, max_workers=max_workers)
effective_db_path = db_path or os.environ.get('DB_PATH', '/data/knack.sqlite')
results = pipeline.run(
db_path=os.environ.get('DB_PATH', '/data/knack.sqlite'),
db_path=effective_db_path,
initial_context=context,
fail_fast=False # Continue even if some nodes fail
)
@ -97,6 +99,49 @@ def main():
con.close()
logger.info("Database connection closed")
def main():
"""Main entry point with command-line argument support."""
parser = argparse.ArgumentParser(
description='Transform pipeline for Knack scraper data',
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Examples:
# Run with database (Docker mode)
python main.py
# Run with custom device and workers
python main.py --database /path/to/knack.sqlite --device mps --workers 8
# Run with specific database file
python main.py --database /path/to/knack.sqlite
"""
)
parser.add_argument(
'--database',
help='Path to SQLite database (for database mode). Defaults to DB_PATH env var or /data/knack.sqlite'
)
parser.add_argument(
'--device',
default=os.environ.get('COMPUTE_DEVICE', 'cpu'),
choices=['cpu', 'cuda', 'mps'],
help='Device to use for compute-intensive operations (default: cpu)'
)
parser.add_argument(
'--workers',
type=int,
default=int(os.environ.get('MAX_WORKERS', 4)),
help='Maximum number of parallel workers (default: 4)'
)
args = parser.parse_args()
# Determine mode based on arguments
if args.database:
# Database mode (original behavior)
run_from_database(db_path=args.database)
logger.info("Database connection closed")
if __name__ == "__main__":
main()

View file

@ -214,8 +214,15 @@ def create_default_pipeline(device: str = "cpu",
"""
from author_node import NerAuthorNode, FuzzyAuthorNode
from embeddings_node import TextEmbeddingNode, UmapNode
from url_node import URLNode
pipeline = ParallelPipeline(max_workers=max_workers, use_processes=False)
pipeline.add_node(NodeConfig(
node_class=URLNode,
dependencies=[],
name='URLNode'
))
# Add AuthorNode (no dependencies)
pipeline.add_node(NodeConfig(
@ -243,7 +250,7 @@ def create_default_pipeline(device: str = "cpu",
'device': device,
'model_path': os.environ.get('GTE_MODEL_PATH')
},
dependencies=[],
dependencies=['AuthorNode'],
name='TextEmbeddingNode'
))

View file

@ -4,4 +4,6 @@ gliner
torch
fuzzysearch
sentence_transformers
umap-learn
umap-learn
matplotlib
huggingface_hub

160
transform/url_node.py Normal file
View file

@ -0,0 +1,160 @@
"""Nodes to extract URL in text using regex patterns."""
import sqlite3
import pandas as pd
import logging
import re
from urllib.parse import urlparse
from pipeline import TransformContext
from transform_node import TransformNode
logger = logging.getLogger("knack-transform")
class URLNode(TransformNode):
"""Node that looks for URLs in the text-column in posts.
Stores data in a new table urls:
- id, post_id, url_raw, tld, host
"""
def __init__(self):
super().__init__()
logger.info("Init URL Node")
def _create_tables(self, con: sqlite3.Connection):
"""Create urls table if they don't exist."""
con.execute("""
CREATE TABLE IF NOT EXISTS urls (
id INTEGER PRIMARY KEY AUTOINCREMENT,
post_id INTEGER,
url_raw TEXT,
tld TEXT,
host TEXT,
FOREIGN KEY (post_id) REFERENCES posts(id)
)
""")
con.commit()
def _process_data(self, input_df: pd.DataFrame) -> pd.DataFrame:
logger.info(f"Processing {len(input_df)} rows")
mappings = []
for _, post_row in input_df.iterrows():
post_id = post_row['id']
post_text = post_row['text']
pattern = r"https?://(?:www\.)?[-a-zA-Z0-9@:%._\+~#=]{1,256}\.[a-zA-Z0-9()]{1,6}\b[-a-zA-Z0-9@:%_\+.~#?&/=]*"
urls = re.findall(pattern, post_text)
logger.debug(f"Post {post_id}, text preview: {post_text[:50]}, URLs found: {len(urls)}")
for url in urls:
try:
parsed = urlparse(url)
hostname = parsed.netloc
# If the hostname starts with www. remove that part.
if hostname[:4] == 'www.':
hostname = hostname[4:]
# Extract TLD (last part after the last dot)
tld = ""
if hostname:
parts = hostname.split('.')
if len(parts) > 0:
tld = parts[-1]
mappings.append({
'post_id': post_id,
'url_raw': url,
'host': hostname,
'tld': tld
})
logger.debug(f" URL: {url} -> Host: {hostname}, TLD: {tld}")
except Exception as e:
logger.warning(f"Failed to parse URL {url}: {e}")
result_df = pd.DataFrame(mappings)
logger.info(f"Extracted {len(result_df)} URLs from {len(input_df)} posts")
return result_df
def _store_results(self, con: sqlite3.Connection, result_df: pd.DataFrame):
if result_df.empty:
logger.info("No URLs to store")
return
result_df.to_sql('urls', con, if_exists='append', index=False)
logger.info(f"Stored {len(result_df)} URLs to database")
def run(self, con: sqlite3.Connection, context: TransformContext):
"""Executes the URL Node.
Writes to a new table urls and creates said table if it does not
exist currently.
Args:
con (sqlite3.Connection): SQLite database connection
context (TransformContext): Transformcontext,
containing the input dataframe of all posts
Returns:
TransformContext with processed dataframe.
"""
logger.info("Starting URLNode transformation")
input_df = context.get_dataframe()
if input_df.empty:
logger.warning("Empty dataframe. Skipping URLNode")
return context
self._create_tables(con)
result_df = self._process_data(input_df)
self._store_results(con, result_df)
logger.info("Node transformation complete")
return TransformContext(input_df)
def main():
import sys
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[
logging.StreamHandler(sys.stdout)
]
)
logger = logging.getLogger("knack-transform")
# Connect to database
db_path = "/Users/linussilberstein/Documents/Knack-Scraper/data/knack.sqlite"
con = sqlite3.connect(db_path)
try:
# Read posts from database
df = pd.read_sql('SELECT * FROM posts;', con)
logger.info(f"Loaded {len(df)} posts from database")
# Create context
context = TransformContext(df)
# Run NerAuthorNode
logger.info("Running NerAuthorNode...")
node = URLNode()
context = node.run(con, context)
logger.info("NerAuthorNode complete")
logger.info("All author nodes completed successfully!")
except Exception as e:
logger.error(f"Error during transformation: {e}", exc_info=True)
raise
finally:
con.close()
if __name__ == '__main__':
main()