diff --git a/.gitignore b/.gitignore index 2e7a5cf..8e2fba4 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,5 @@ data/ venv/ +experiment/ .DS_STORE +.env \ No newline at end of file diff --git a/Dockerfile b/Dockerfile deleted file mode 100644 index 9c94fd6..0000000 --- a/Dockerfile +++ /dev/null @@ -1,15 +0,0 @@ -FROM python:slim - -RUN mkdir /app -RUN mkdir /data - -WORKDIR /app -COPY requirements.txt . -RUN pip install --no-cache-dir -r requirements.txt - -RUN apt update -y -RUN apt install -y cron -COPY crontab . -RUN crontab crontab - -COPY main.py . \ No newline at end of file diff --git a/Makefile b/Makefile index a669090..47a8063 100644 --- a/Makefile +++ b/Makefile @@ -1,2 +1,12 @@ -build: - docker build -t knack-scraper . \ No newline at end of file +volume: + docker volume create knack_data + +stop: + docker stop knack-scraper || true + docker rm knack-scraper || true + +up: + docker compose up -d + +down: + docker compose down \ No newline at end of file diff --git a/README.md b/README.md index e69de29..ab971fc 100644 --- a/README.md +++ b/README.md @@ -0,0 +1,18 @@ +Knack-Scraper does exacly what its name suggests it does. +Knack-Scraper scrapes knack.news and writes to an sqlite +database for later usage. + +## Example for .env + +``` +NUM_THREADS=8 +NUM_SCRAPES=100 +DATABASE_LOCATION='./data/knack.sqlite' +``` + +## Run once + +``` +python main.py +``` + diff --git a/crontab b/crontab deleted file mode 100644 index 6b6ae11..0000000 --- a/crontab +++ /dev/null @@ -1 +0,0 @@ -5 4 * * * python /app/main.py diff --git a/docker-compose.yml b/docker-compose.yml new file mode 100644 index 0000000..4ab3b8c --- /dev/null +++ b/docker-compose.yml @@ -0,0 +1,43 @@ +services: + scraper: + build: + context: ./scrape + dockerfile: Dockerfile + image: knack-scraper + container_name: knack-scraper + env_file: + - scrape/.env + volumes: + - knack_data:/data + restart: unless-stopped + + transform: + build: + context: ./transform + dockerfile: Dockerfile + image: knack-transform + container_name: knack-transform + env_file: + - transform/.env + volumes: + - knack_data:/data + - models:/models + restart: unless-stopped + + sqlitebrowser: + image: lscr.io/linuxserver/sqlitebrowser:latest + container_name: sqlitebrowser + environment: + - PUID=1000 + - PGID=1000 + - TZ=Etc/UTC + volumes: + - knack_data:/data + ports: + - "3000:3000" # noVNC web UI + - "3001:3001" # VNC server + restart: unless-stopped + +volumes: + knack_data: + models: diff --git a/main.py b/main.py deleted file mode 100755 index 850ba3c..0000000 --- a/main.py +++ /dev/null @@ -1,167 +0,0 @@ -#! python3 -import locale -import logging -import os -import sqlite3 -import sys -import time -from concurrent.futures import ThreadPoolExecutor -from datetime import datetime - -import pandas as pd -import requests -import tqdm -from bs4 import BeautifulSoup - -logger = logging.getLogger("knack-scraper") -# ch = logging.StreamHandler() -# formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") -# ch.setFormatter(formatter) -# ch.setLevel(logging.INFO) -# logger.addHandler(ch) - - -def table_exists(tablename: str, con: sqlite3.Connection): - query = "SELECT 1 FROM sqlite_master WHERE type='table' AND name=? LIMIT 1" - return len(con.execute(query, [tablename]).fetchall()) > 0 - - -def download(id: int): - if id == 0: - return - base_url = "https://knack.news/" - url = f"{base_url}{id}" - res = requests.get(url) - - # make sure we don't dos knack - time.sleep(2) - - if not (200 <= res.status_code <= 300): - return - - logger.info("Found promising page with id %d!", id) - - content = res.content - soup = BeautifulSoup(content, "html.parser") - date_format = "%d. %B %Y" - - # TODO FIXME: this fails inside the docker container - locale.setlocale(locale.LC_TIME, "de_DE") - pC = soup.find("div", {"class": "postContent"}) - - if pC is None: - # not a normal post - logger.info( - "Page with id %d does not have a .pageContent-div. Skipping for now.", id - ) - return - - # every post has these fields - title = pC.find("h3", {"class": "postTitle"}).text - postText = pC.find("div", {"class": "postText"}) - - # these fields are possible but not required - # TODO: cleanup - try: - date_string = pC.find("span", {"class": "singledate"}).text - parsed_date = datetime.strptime(date_string, date_format) - except AttributeError: - parsed_date = None - - try: - author = pC.find("span", {"class": "author"}).text - except AttributeError: - author = None - - try: - category = pC.find("span", {"class": "categoryInfo"}).find_all() - category = [c.text for c in category] - category = ";".join(category) - except AttributeError: - category = None - - try: - tags = [x.text for x in pC.find("div", {"class": "tagsInfo"}).find_all("a")] - tags = ";".join(tags) - except AttributeError: - tags = None - - img = pC.find("img", {"class": "postImage"}) - if img is not None: - img = img["src"] - - res_dict = { - "id": id, - "title": title, - "author": author, - "date": parsed_date, - "category": category, - "url": url, - "img_link": img, - "tags": tags, - "text": postText.text, - "html": str(postText), - "scraped_at": datetime.now(), - } - - return res_dict - - -def run_downloads(min_id: int, max_id: int, num_threads: int = 8): - res = [] - - logger.info( - "Started parallel scrape of posts from id %d to id %d using %d threads.", - min_id, - max_id - 1, - num_threads, - ) - with ThreadPoolExecutor(max_workers=num_threads) as executor: - # Use a list comprehension to create a list of futures - futures = [executor.submit(download, i) for i in range(min_id, max_id)] - - for future in tqdm.tqdm( - futures, total=max_id - min_id - ): # tqdm to track progress - post = future.result() - if post is not None: - res.append(post) - - # sqlite can't handle lists so let's convert them to a single row csv - # TODO: make sure our database is properly normalized - df = pd.DataFrame(res) - - return df - - -def main(): - num_threads = int(os.environ.get("NUM_THREADS", 8)) - n_scrapes = int(os.environ.get("NUM_SCRAPES", 100)) - database_location = os.environ.get("DATABASE_LOCATION", "/data/knack.sqlite") - - con = sqlite3.connect(database_location) - with con: - post_table_exists = table_exists("posts", con) - - if post_table_exists: - logger.info("found posts retrieved earlier") - # retrieve max post id from db so - # we can skip retrieving known posts - max_id_in_db = con.execute("SELECT MAX(id) FROM posts").fetchone()[0] - logger.info("Got max id %d!", max_id_in_db) - else: - logger.info("no posts scraped so far - starting from 0") - # retrieve from 0 onwards - max_id_in_db = -1 - - con = sqlite3.connect(database_location) - df = run_downloads( - min_id=max_id_in_db + 1, - max_id=max_id_in_db + n_scrapes, - num_threads=num_threads, - ) - df.to_sql("posts", con, if_exists="append") - - -if __name__ == "__main__": - main() diff --git a/requirements.txt b/requirements.txt deleted file mode 100644 index 7792d83..0000000 --- a/requirements.txt +++ /dev/null @@ -1,14 +0,0 @@ -beautifulsoup4==4.12.2 -certifi==2023.7.22 -charset-normalizer==3.3.0 -idna==3.4 -numpy==1.26.1 -pandas==2.1.1 -python-dateutil==2.8.2 -pytz==2023.3.post1 -requests==2.31.0 -six==1.16.0 -soupsieve==2.5 -tqdm==4.66.1 -tzdata==2023.3 -urllib3==2.0.7 diff --git a/scrape/Dockerfile b/scrape/Dockerfile new file mode 100644 index 0000000..a2fbe2e --- /dev/null +++ b/scrape/Dockerfile @@ -0,0 +1,29 @@ +FROM python:slim + +RUN mkdir /app +RUN mkdir /data + +#COPY /data/knack.sqlite /data + +WORKDIR /app +COPY requirements.txt . +RUN pip install --no-cache-dir -r requirements.txt + +COPY .env . + +RUN apt update -y +RUN apt install -y cron locales + +COPY main.py . + +ENV PYTHONUNBUFFERED=1 +ENV LANG=de_DE.UTF-8 +ENV LC_ALL=de_DE.UTF-8 + +# Create cron job that runs every 15 minutes with environment variables +RUN echo "5 4 * * * cd /app && /usr/local/bin/python main.py >> /proc/1/fd/1 2>&1" > /etc/cron.d/knack-scraper +RUN chmod 0644 /etc/cron.d/knack-scraper +RUN crontab /etc/cron.d/knack-scraper + +# Start cron in foreground +CMD ["cron", "-f"] \ No newline at end of file diff --git a/scrape/main.py b/scrape/main.py new file mode 100755 index 0000000..10b66dd --- /dev/null +++ b/scrape/main.py @@ -0,0 +1,262 @@ +#! python3 +import logging +import os +import sqlite3 +import time +from concurrent.futures import ThreadPoolExecutor +from datetime import datetime +import sys + +from dotenv import load_dotenv +import pandas as pd +import requests +from bs4 import BeautifulSoup + +load_dotenv() + +if (os.environ.get('LOGGING_LEVEL', 'INFO') == 'INFO'): + logging_level = logging.INFO +else: + logging_level = logging.DEBUG + +logging.basicConfig( + level=logging_level, + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', + handlers=[ + logging.FileHandler("app.log"), + logging.StreamHandler(sys.stdout) + ] +) +logger = logging.getLogger("knack-scraper") + +def table_exists(tablename: str, con: sqlite3.Connection): + query = "SELECT 1 FROM sqlite_master WHERE type='table' AND name=? LIMIT 1" + return len(con.execute(query, [tablename]).fetchall()) > 0 + + +def split_semicolon_list(value: str): + if pd.isna(value): + return [] + return [item.strip() for item in str(value).split(';') if item.strip()] + + +def build_dimension_and_mapping(postdf: pd.DataFrame, field_name: str, dim_col: str): + """Extract unique dimension values and post-to-dimension mappings from a column.""" + if postdf.empty or field_name not in postdf.columns: + return None, None + + values = set() + mapping_rows = [] + + for post_id, raw in zip(postdf['id'], postdf[field_name]): + items = split_semicolon_list(raw) + for item in items: + values.add(item) + mapping_rows.append({'post_id': post_id, dim_col: item}) + + if not values: + return None, None + + dim_df = pd.DataFrame({ + 'id': range(len(values)), + dim_col: sorted(values), + }) + map_df = pd.DataFrame(mapping_rows) + return dim_df, map_df + + +def store_dimension_and_mapping( + con: sqlite3.Connection, + dim_df: pd.DataFrame | None, + map_df: pd.DataFrame | None, + table_name: str, + dim_col: str, + mapping_table: str, + mapping_id_col: str, +): + """Persist a dimension table and its mapping table, merging with existing values.""" + if dim_df is None or dim_df.empty: + return + + if table_exists(table_name, con): + existing = pd.read_sql(f"SELECT id, {dim_col} FROM {table_name}", con) + merged = pd.concat([existing, dim_df], ignore_index=True) + merged = merged.drop_duplicates(subset=[dim_col], keep='first').reset_index(drop=True) + merged['id'] = range(len(merged)) + else: + merged = dim_df.copy() + + # Replace table with merged content + merged.to_sql(table_name, con, if_exists="replace", index=False) + + if map_df is None or map_df.empty: + return + + value_to_id = dict(zip(merged[dim_col], merged['id'])) + map_df = map_df.copy() + map_df[mapping_id_col] = map_df[dim_col].map(value_to_id) + map_df = map_df[['post_id', mapping_id_col]].dropna() + map_df.to_sql(mapping_table, con, if_exists="append", index=False) + + +def download(id: int): + if id == 0: + return + base_url = "https://knack.news/" + url = f"{base_url}{id}" + res = requests.get(url) + + # make sure we don't dos knack + time.sleep(2) + + if not (200 <= res.status_code <= 300): + return + + logger.debug("Found promising page with id %d!", id) + + content = res.content + soup = BeautifulSoup(content, "html.parser") + + pC = soup.find("div", {"class": "postContent"}) + + if pC is None: + # not a normal post + logger.debug( + "Page with id %d does not have a .pageContent-div. Skipping for now.", id + ) + return + + # every post has these fields + title = pC.find("h3", {"class": "postTitle"}).text + postText = pC.find("div", {"class": "postText"}) + + # these fields are possible but not required + # TODO: cleanup + try: + date_parts = pC.find("span", {"class": "singledate"}).text.split(' ') + day = int(date_parts[0][:-1]) + months = {'Januar': 1, 'Februar': 2, 'März': 3, 'April': 4, 'Mai': 5, 'Juni': 6, 'Juli': 7, 'August': 8, 'September': 9, 'Oktober': 10, 'November': 11, 'Dezember': 12} + month = months[date_parts[1]] + year = int(date_parts[2]) + parsed_date = datetime(year, month, day) + except Exception: + parsed_date = None + + try: + author = pC.find("span", {"class": "author"}).text + except AttributeError: + author = None + + try: + category = pC.find("span", {"class": "categoryInfo"}).find_all() + category = [c.text for c in category if c.text != 'Alle Artikel'] + category = ";".join(category) + except AttributeError: + category = None + + try: + tags = [x.text for x in pC.find("div", {"class": "tagsInfo"}).find_all("a")] + tags = ";".join(tags) + except AttributeError: + tags = None + + img = pC.find("img", {"class": "postImage"}) + if img is not None: + img = img["src"] + + res_dict = { + "id": id, + "title": title, + "author": author, + "date": parsed_date, + "category": category, + "url": url, + "img_link": img, + "tags": tags, + "text": postText.text, + "html": str(postText), + "scraped_at": datetime.now(), + "is_cleaned": False + } + + return res_dict + + +def run_downloads(min_id: int, max_id: int, num_threads: int = 8): + res = [] + + logger.info( + "Started parallel scrape of posts from id %d to id %d using %d threads.", + min_id, + max_id - 1, + num_threads, + ) + with ThreadPoolExecutor(max_workers=num_threads) as executor: + # Use a list comprehension to create a list of futures + futures = [executor.submit(download, i) for i in range(min_id, max_id)] + + for future in futures: + post = future.result() + if post is not None: + res.append(post) + + postdf = pd.DataFrame(res) + return postdf + + +def main(): + num_threads = int(os.environ.get("NUM_THREADS", 8)) + n_scrapes = int(os.environ.get("NUM_SCRAPES", 100)) + database_location = os.environ.get("DATABASE_LOCATION", "../data/knack.sqlite") + + logger.debug(f"Started Knack Scraper: \nNUM_THREADS: {num_threads}\nN_SCRAPES: {n_scrapes}\nDATABASE_LOCATION: {database_location}") + + con = sqlite3.connect(database_location) + with con: + if table_exists("posts", con): + logger.info("found posts retrieved earlier") + max_id_in_db = con.execute("SELECT MAX(id) FROM posts").fetchone()[0] + logger.info("Got max id %d!", max_id_in_db) + else: + logger.info("no posts scraped so far - starting from 0") + max_id_in_db = -1 + + postdf = run_downloads( + min_id=max_id_in_db + 1, + max_id=max_id_in_db + n_scrapes, + num_threads=num_threads, + ) + + # Drop category and tags columns as they're stored in separate tables + postdf = postdf.drop(columns=['category', 'tags']) + postdf.to_sql("posts", con, if_exists="append", index=False) + + # Tags + tag_dim, tag_map = build_dimension_and_mapping(postdf, 'tags', 'tag') + store_dimension_and_mapping( + con, + tag_dim, + tag_map, + table_name="tags", + dim_col="tag", + mapping_table="posttags", + mapping_id_col="tag_id", + ) + + # Categories + category_dim, category_map = build_dimension_and_mapping(postdf, 'category', 'category') + store_dimension_and_mapping( + con, + category_dim, + category_map, + table_name="categories", + dim_col="category", + mapping_table="postcategories", + mapping_id_col="category_id", + ) + + logger.info(f"scraped new entries. number of new posts: {len(postdf.index)}") + + +if __name__ == "__main__": + main() diff --git a/scrape/requirements.txt b/scrape/requirements.txt new file mode 100644 index 0000000..32d5df2 --- /dev/null +++ b/scrape/requirements.txt @@ -0,0 +1,4 @@ +pandas +requests +bs4 +dotenv \ No newline at end of file diff --git a/transform/.env.example b/transform/.env.example new file mode 100644 index 0000000..34cd0e0 --- /dev/null +++ b/transform/.env.example @@ -0,0 +1,4 @@ +LOGGING_LEVEL=INFO +DB_PATH=/data/knack.sqlite +MAX_CLEANED_POSTS=1000 +COMPUTE_DEVICE=mps \ No newline at end of file diff --git a/transform/Dockerfile b/transform/Dockerfile new file mode 100644 index 0000000..6f148bd --- /dev/null +++ b/transform/Dockerfile @@ -0,0 +1,51 @@ +FROM python:3.12-slim + +RUN mkdir -p /app /data /models + +# Install build dependencies +RUN apt-get update && apt-get install -y --no-install-recommends \ + gcc \ + g++ \ + gfortran \ + libopenblas-dev \ + liblapack-dev \ + pkg-config \ + curl \ + jq \ + && rm -rf /var/lib/apt/lists/* + +ENV GLINER_MODEL_ID=urchade/gliner_multi-v2.1 +ENV GLINER_MODEL_PATH=/models/gliner_multi-v2.1 + +ENV MINILM_MODEL_ID=sentence-transformers/all-MiniLM-L6-v2 +ENV MINILM_MODEL_PATH=/models/all-MiniLM-L6-v2 + +WORKDIR /app +COPY requirements.txt . +RUN pip install --no-cache-dir -r requirements.txt + +COPY .env . + +RUN apt update -y +RUN apt install -y cron locales + +# Ensure GLiNER helper scripts are available +COPY ensure_gliner_model.sh /usr/local/bin/ensure_gliner_model.sh +# Ensure MiniLM helper scripts are available +COPY ensure_minilm_model.sh /usr/local/bin/ensure_minilm_model.sh +COPY entrypoint.sh /usr/local/bin/entrypoint.sh +RUN chmod +x /usr/local/bin/ensure_gliner_model.sh /usr/local/bin/ensure_minilm_model.sh /usr/local/bin/entrypoint.sh + +COPY *.py . + +# Create cron job that runs every weekend (Sunday at 3 AM) 0 3 * * 0 +# Testing every 30 Minutes */30 * * * * +RUN echo "*/30 * * * * cd /app && /usr/local/bin/python main.py >> /proc/1/fd/1 2>&1" > /etc/cron.d/knack-transform +RUN chmod 0644 /etc/cron.d/knack-transform +RUN crontab /etc/cron.d/knack-transform + +# Persist models between container runs +VOLUME /models + +CMD ["/usr/local/bin/entrypoint.sh"] +#CMD ["python", "main.py"] diff --git a/transform/README.md b/transform/README.md new file mode 100644 index 0000000..9e3665a --- /dev/null +++ b/transform/README.md @@ -0,0 +1,67 @@ +# Knack Transform + +Data transformation pipeline for the Knack scraper project. + +## Overview + +This folder contains the transformation logic that processes data from the SQLite database. It runs on a scheduled basis (every weekend) via cron. + +The pipeline supports **parallel execution** of independent transform nodes, allowing you to leverage multi-core processors for faster data transformation. + +## Structure + +- `base.py` - Abstract base class for transform nodes +- `pipeline.py` - Parallel pipeline orchestration system +- `main.py` - Main entry point and pipeline execution +- `author_node.py` - NER-based author classification node +- `example_node.py` - Template for creating new nodes +- `Dockerfile` - Docker image configuration with cron setup +- `requirements.txt` - Python dependencies + +## Transform Nodes + +Transform nodes inherit from `TransformNode` and implement the `run` method: + +```python +from base import TransformNode, TransformContext +import sqlite3 + +class MyTransform(TransformNode): + def run(self, con: sqlite3.Connection, context: TransformContext) -> TransformContext: + df = context.get_dataframe() + + # Transform logic here + transformed_df = df.copy() + # ... your transformations ... + + # Optionally write back to database + transformed_df.to_sql("my_table", con, if_exists="replace", index=False) + + return TransformContext(transformed_df) +``` + +## Configuration + +Copy `.env.example` to `.env` and configure: + +- `LOGGING_LEVEL` - Log level (INFO or DEBUG) +- `DB_PATH` - Path to SQLite database + +## Running + +### With Docker + +```bash +docker build -t knack-transform . +docker run -v $(pwd)/data:/data knack-transform +``` + +### Locally + +```bash +python main.py +``` + +## Cron Schedule + +The Docker container runs the transform pipeline every Sunday at 3 AM. diff --git a/transform/author_node.py b/transform/author_node.py new file mode 100644 index 0000000..845e87a --- /dev/null +++ b/transform/author_node.py @@ -0,0 +1,420 @@ +"""Author classification transform node using NER.""" +import os +import sqlite3 +import pandas as pd +import logging +import fuzzysearch +from concurrent.futures import ThreadPoolExecutor + +from pipeline import TransformContext +from transform_node import TransformNode + +logger = logging.getLogger("knack-transform") + +try: + from gliner import GLiNER + import torch + GLINER_AVAILABLE = True +except ImportError: + GLINER_AVAILABLE = False + logging.warning("GLiNER not available. Install with: pip install gliner") + +class NerAuthorNode(TransformNode): + """Transform node that extracts and classifies authors using NER. + + Creates two tables: + - authors: stores unique authors with their type (Person, Organisation, etc.) + - post_authors: maps posts to their authors + """ + + def __init__(self, model_name: str = "urchade/gliner_multi-v2.1", + model_path: str = None, + threshold: float = 0.5, + max_workers: int = 64, + device: str = "cpu"): + """Initialize the AuthorNode. + + Args: + model_name: GLiNER model to use + model_path: Optional local path to a downloaded GLiNER model + threshold: Confidence threshold for entity predictions + max_workers: Number of parallel workers for prediction + device: Device to run model on ('cpu', 'cuda', 'mps') + """ + self.model_name = model_name + self.model_path = model_path or os.environ.get('GLINER_MODEL_PATH') + self.threshold = threshold + self.max_workers = max_workers + self.device = device + self.model = None + self.labels = ["Person", "Organisation", "Email", "Newspaper", "NGO"] + + def _setup_model(self): + """Initialize the NER model.""" + if not GLINER_AVAILABLE: + raise ImportError("GLiNER is required for AuthorNode. Install with: pip install gliner") + + model_source = None + if self.model_path: + if os.path.exists(self.model_path): + model_source = self.model_path + logger.info(f"Loading GLiNER model from local path: {self.model_path}") + else: + logger.warning(f"GLINER_MODEL_PATH '{self.model_path}' not found; falling back to hub model {self.model_name}") + + if model_source is None: + model_source = self.model_name + logger.info(f"Loading GLiNER model from hub: {self.model_name}") + + if self.device == "cuda" and torch.cuda.is_available(): + self.model = GLiNER.from_pretrained( + model_source, + max_length=255 + ).to('cuda', dtype=torch.float16) + elif self.device == "mps" and torch.backends.mps.is_available(): + self.model = GLiNER.from_pretrained( + model_source, + max_length=255 + ).to('mps', dtype=torch.float16) + else: + self.model = GLiNER.from_pretrained( + model_source, + max_length=255 + ) + + logger.info("Model loaded successfully") + + def _predict(self, text_data: dict): + """Predict entities for a single author text. + + Args: + text_data: Dict with 'author' and 'id' keys + + Returns: + Tuple of (predictions, post_id) or None + """ + if text_data is None or text_data.get('author') is None: + return None + + predictions = self.model.predict_entities( + text_data['author'], + self.labels, + threshold=self.threshold + ) + return predictions, text_data['id'] + + def _classify_authors(self, posts_df: pd.DataFrame): + """Classify all authors in the posts dataframe. + + Args: + posts_df: DataFrame with 'id' and 'author' columns + + Returns: + List of dicts with 'text', 'label', 'id' keys + """ + if self.model is None: + self._setup_model() + + # Prepare input data + authors_data = [] + for idx, row in posts_df.iterrows(): + if pd.notna(row['author']): + authors_data.append({ + 'author': row['author'], + 'id': row['id'] + }) + + logger.info(f"Classifying {len(authors_data)} authors") + + results = [] + with ThreadPoolExecutor(max_workers=self.max_workers) as executor: + futures = [executor.submit(self._predict, data) for data in authors_data] + + for future in futures: + result = future.result() + if result is not None: + predictions, post_id = result + for pred in predictions: + results.append({ + 'text': pred['text'], + 'label': pred['label'], + 'id': post_id + }) + + logger.info(f"Classification complete. Found {len(results)} author entities") + return results + + def _create_tables(self, con: sqlite3.Connection): + """Create authors and post_authors tables if they don't exist.""" + logger.info("Creating authors tables") + + con.execute(""" + CREATE TABLE IF NOT EXISTS authors ( + id INTEGER PRIMARY KEY, + name TEXT, + type TEXT, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP + ) + """) + + con.execute(""" + CREATE TABLE IF NOT EXISTS post_authors ( + post_id INTEGER, + author_id INTEGER, + PRIMARY KEY (post_id, author_id), + FOREIGN KEY (post_id) REFERENCES posts(id), + FOREIGN KEY (author_id) REFERENCES authors(id) + ) + """) + + con.commit() + + def _store_authors(self, con: sqlite3.Connection, results: list): + """Store classified authors and their mappings. + + Args: + con: Database connection + results: List of classification results + """ + if not results: + logger.info("No authors to store") + return + + # Convert results to DataFrame + results_df = pd.DataFrame(results) + + # Get unique authors with their types + unique_authors = results_df[['text', 'label']].drop_duplicates() + unique_authors.columns = ['name', 'type'] + + # Get existing authors + existing_authors = pd.read_sql("SELECT id, name FROM authors", con) + + # Find new authors to insert + if not existing_authors.empty: + new_authors = unique_authors[~unique_authors['name'].isin(existing_authors['name'])] + else: + new_authors = unique_authors + + if not new_authors.empty: + logger.info(f"Inserting {len(new_authors)} new authors") + new_authors.to_sql('authors', con, if_exists='append', index=False) + + # Get all authors with their IDs + all_authors = pd.read_sql("SELECT id, name FROM authors", con) + name_to_id = dict(zip(all_authors['name'], all_authors['id'])) + + # Create post_authors mappings + mappings = [] + for _, row in results_df.iterrows(): + author_id = name_to_id.get(row['text']) + if author_id: + mappings.append({ + 'post_id': row['id'], + 'author_id': author_id + }) + + if mappings: + mappings_df = pd.DataFrame(mappings).drop_duplicates() + + # Clear existing mappings for these posts (optional, depends on your strategy) + # post_ids = tuple(mappings_df['post_id'].unique()) + # con.execute(f"DELETE FROM post_authors WHERE post_id IN ({','.join('?' * len(post_ids))})", post_ids) + + logger.info(f"Creating {len(mappings_df)} post-author mappings") + mappings_df.to_sql('post_authors', con, if_exists='append', index=False) + + con.commit() + logger.info("Authors and mappings stored successfully") + + def run(self, con: sqlite3.Connection, context: TransformContext) -> TransformContext: + """Execute the author classification transformation. + + Args: + con: SQLite database connection + context: TransformContext containing posts dataframe + + Returns: + TransformContext with classified authors dataframe + """ + logger.info("Starting AuthorNode transformation") + + posts_df = context.get_dataframe() + + # Ensure required columns exist + if 'author' not in posts_df.columns: + logger.warning("No 'author' column in dataframe. Skipping AuthorNode.") + return context + + # Create tables + self._create_tables(con) + + # Classify authors + results = self._classify_authors(posts_df) + + # Store results + self._store_authors(con, results) + + # Return context with results + logger.info("AuthorNode transformation complete") + + return TransformContext(posts_df) + + +class FuzzyAuthorNode(TransformNode): + """FuzzyAuthorNode + + This Node takes in data and rules of authornames that have been classified already + and uses those 'rule' to find more similar fields. + """ + + def __init__(self, + max_l_dist: int = 1,): + """Initialize FuzzyAuthorNode. + + Args: + max_l_dist: The number of 'errors' that are allowed by the fuzzy search algorithm + """ + self.max_l_dist = max_l_dist + logger.info(f"Initialized FuzzyAuthorNode with max_l_dist={max_l_dist}") + + def _process_data(self, con: sqlite3.Connection, df: pd.DataFrame) -> pd.DataFrame: + """Process the input dataframe. + + This is where your main transformation logic goes. + + Args: + con: Database connection + df: Input dataframe from context + + Returns: + Processed dataframe + """ + logger.info(f"Processing {len(df)} rows") + + # Retrieve all known authors from the authors table as 'rules' + authors_df = pd.read_sql("SELECT id, name FROM authors", con) + + if authors_df.empty: + logger.warning("No authors found in database for fuzzy matching") + return pd.DataFrame(columns=['post_id', 'author_id']) + + # Get existing post-author mappings to avoid duplicates + existing_mappings = pd.read_sql( + "SELECT post_id, author_id FROM post_authors", con + ) + existing_post_ids = set(existing_mappings['post_id'].unique()) + + logger.info(f"Found {len(authors_df)} known authors for fuzzy matching") + logger.info(f"Found {len(existing_post_ids)} posts with existing author mappings") + + # Filter to posts without author mappings and with non-null author field + if 'author' not in df.columns or 'id' not in df.columns: + logger.warning("Missing 'author' or 'id' column in input dataframe") + return pd.DataFrame(columns=['post_id', 'author_id']) + + posts_to_process = df[ + (df['id'].notna()) & + (df['author'].notna()) & + (~df['id'].isin(existing_post_ids)) + ] + + logger.info(f"Processing {len(posts_to_process)} posts for fuzzy matching") + + # Perform fuzzy matching + mappings = [] + for _, post_row in posts_to_process.iterrows(): + post_id = post_row['id'] + post_author = str(post_row['author']) + + # Try to find matches against all known author names + for _, author_row in authors_df.iterrows(): + author_id = author_row['id'] + author_name = str(author_row['name']) + # for author names < than 2 characters I want a fault tolerance of 0! + l_dist = self.max_l_dist if len(author_name) > 2 else 0 + + # Use fuzzysearch to find matches with allowed errors + matches = fuzzysearch.find_near_matches( + author_name, + post_author, + max_l_dist=l_dist, + ) + + if matches: + logger.debug(f"Found fuzzy match: '{author_name}' in '{post_author}' for post {post_id}") + mappings.append({ + 'post_id': post_id, + 'author_id': author_id + }) + # Only take the first match per post to avoid multiple mappings + break + + # Create result dataframe + result_df = pd.DataFrame(mappings, columns=['post_id', 'author_id']) if mappings else pd.DataFrame(columns=['post_id', 'author_id']) + + logger.info(f"Processing complete. Found {len(result_df)} fuzzy matches") + return result_df + + def _store_results(self, con: sqlite3.Connection, df: pd.DataFrame): + """Store results back to the database. + + Uses INSERT OR IGNORE to avoid inserting duplicates. + + Args: + con: Database connection + df: Processed dataframe to store + """ + if df.empty: + logger.info("No results to store") + return + + logger.info(f"Storing {len(df)} results") + + # Use INSERT OR IGNORE to handle duplicates (respects PRIMARY KEY constraint) + cursor = con.cursor() + inserted_count = 0 + + for _, row in df.iterrows(): + cursor.execute( + "INSERT OR IGNORE INTO post_authors (post_id, author_id) VALUES (?, ?)", + (int(row['post_id']), int(row['author_id'])) + ) + if cursor.rowcount > 0: + inserted_count += 1 + + con.commit() + logger.info(f"Results stored successfully. Inserted {inserted_count} new mappings, skipped {len(df) - inserted_count} duplicates") + + def run(self, con: sqlite3.Connection, context: TransformContext) -> TransformContext: + """Execute the transformation. + + This is the main entry point called by the pipeline. + + Args: + con: SQLite database connection + context: TransformContext containing input dataframe + + Returns: + TransformContext with processed dataframe + """ + logger.info("Starting FuzzyAuthorNode transformation") + + # Get input dataframe from context + input_df = context.get_dataframe() + + # Validate input + if input_df.empty: + logger.warning("Empty dataframe provided to FuzzyAuthorNode") + return context + + # Process the data + result_df = self._process_data(con, input_df) + + # Store results + self._store_results(con, result_df) + + logger.info("FuzzyAuthorNode transformation complete") + + # Return new context with results + return TransformContext(input_df) diff --git a/transform/embeddings_node.py b/transform/embeddings_node.py new file mode 100644 index 0000000..9821eca --- /dev/null +++ b/transform/embeddings_node.py @@ -0,0 +1,445 @@ +"""Classes of Transformernodes that have to do with +text processing. + +- TextEmbeddingNode calculates text embeddings +- UmapNode calculates xy coordinates on those vector embeddings +- SimilarityNode calculates top n similar posts based on those embeddings + using the spectral distance. +""" +from pipeline import TransformContext +from transform_node import TransformNode +import sqlite3 +import pandas as pd +import logging +import os +import numpy as np + +logger = logging.getLogger("knack-transform") + +try: + from sentence_transformers import SentenceTransformer + import torch + MINILM_AVAILABLE = True +except ImportError: + MINILM_AVAILABLE = False + logging.warning("MiniLM not available. Install with pip!") + +try: + import umap + UMAP_AVAILABLE = True +except ImportError: + UMAP_AVAILABLE = False + logging.warning("UMAP not available. Install with pip install umap-learn!") + +class TextEmbeddingNode(TransformNode): + """Calculates vector embeddings based on a dataframe + of posts. + """ + def __init__(self, + model_name: str = "sentence-transformers/all-MiniLM-L6-v2", + model_path: str = None, + device: str = "cpu"): + """Initialize the ExampleNode. + + Args: + model_name: Name of the ML Model to calculate text embeddings + model_path: Optional local path to a downloaded embedding model + device: Device to use for computations ('cpu', 'cuda', 'mps') + """ + self.model_name = model_name + self.model_path = model_path or os.environ.get('MINILM_MODEL_PATH') + self.device = device + self.model = None + logger.info(f"Initialized TextEmbeddingNode with model_name={model_name}, model_path={model_path}, device={device}") + + def _setup_model(self): + """Init the Text Embedding Model.""" + if not MINILM_AVAILABLE: + raise ImportError("MiniLM is required for TextEmbeddingNode. Please install.") + + model_source = None + if self.model_path: + if os.path.exists(self.model_path): + model_source = self.model_path + logger.info(f"Loading MiniLM model from local path: {self.model_path}") + else: + logger.warning(f"MiniLM_MODEL_PATH '{self.model_path}' not found; Falling back to hub model {self.model_name}") + + if model_source is None: + model_source = self.model_name + logger.info(f"Loading MiniLM model from the hub: {self.model_name}") + + if self.device == "cuda" and torch.cuda.is_available(): + self.model = SentenceTransformer(model_source).to('cuda', dtype=torch.float16) + elif self.device == "mps" and torch.backends.mps.is_available(): + self.model = SentenceTransformer(model_source).to('mps', dtype=torch.float16) + else: + self.model = SentenceTransformer(model_source) + + def _process_data(self, df: pd.DataFrame) -> pd.DataFrame: + """Process the input dataframe. + + Calculates an embedding as a np.array. + Also pickles that array to prepare it to + storage in the database. + + Args: + df: Input dataframe from context + + Returns: + Processed dataframe + """ + logger.info(f"Processing {len(df)} rows") + + if self.model is None: + self._setup_model() + + # Example: Add a new column based on existing data + result_df = df.copy() + + df['embedding'] = df['text'].apply(lambda x: self.model.encode(x, convert_to_numpy=True)) + + logger.info("Processing complete") + return result_df + + def _store_results(self, con: sqlite3.Connection, df: pd.DataFrame): + """Store results back to the database using batch updates.""" + if df.empty: + logger.info("No results to store") + return + + logger.info(f"Storing {len(df)} results") + + # Convert numpy arrays to bytes for BLOB storage + # Use tobytes() to serialize numpy arrays efficiently + updates = [(row['embedding'].tobytes(), row['id']) for _, row in df.iterrows()] + con.executemany( + "UPDATE posts SET embedding = ? WHERE id = ?", + updates + ) + + con.commit() + logger.info("Results stored successfully") + + def run(self, con: sqlite3.Connection, context: TransformContext) -> TransformContext: + """Execute the transformation. + + This is the main entry point called by the pipeline. + + Args: + con: SQLite database connection + context: TransformContext containing input dataframe + + Returns: + TransformContext with processed dataframe + """ + logger.info("Starting TextEmbeddingNode transformation") + + # Get input dataframe from context + input_df = context.get_dataframe() + + # Validate input + if input_df.empty: + logger.warning("Empty dataframe provided to TextEmbeddingNdode") + return context + + if 'text' not in input_df.columns: + logger.warning("No 'text' column in context dataframe. Skipping TextEmbeddingNode") + return context + + # Process the data + result_df = self._process_data(input_df) + + # Store results (optional) + self._store_results(con, result_df) + + logger.info("TextEmbeddingNode transformation complete") + + # Return new context with results + return TransformContext(result_df) + + +class UmapNode(TransformNode): + """Calculates 2D coordinates from embeddings using UMAP dimensionality reduction. + + This node takes text embeddings and reduces them to 2D coordinates + for visualization purposes. + """ + + def __init__(self, + n_neighbors: int = 15, + min_dist: float = 0.1, + n_components: int = 2, + metric: str = "cosine", + random_state: int = 42): + """Initialize the UmapNode. + + Args: + n_neighbors: Number of neighbors to consider for UMAP (default: 15) + min_dist: Minimum distance between points in low-dimensional space (default: 0.1) + n_components: Number of dimensions to reduce to (default: 2) + metric: Distance metric to use (default: 'cosine') + random_state: Random seed for reproducibility (default: 42) + """ + self.n_neighbors = n_neighbors + self.min_dist = min_dist + self.n_components = n_components + self.metric = metric + self.random_state = random_state + self.reducer = None + logger.info(f"Initialized UmapNode with n_neighbors={n_neighbors}, min_dist={min_dist}, " + f"n_components={n_components}, metric={metric}, random_state={random_state}") + + def _process_data(self, df: pd.DataFrame) -> pd.DataFrame: + """Process the input dataframe. + + Retrieves embeddings from BLOB storage, converts them back to numpy arrays, + and applies UMAP dimensionality reduction to create 2D coordinates. + + Args: + df: Input dataframe from context + + Returns: + Processed dataframe with umap_x and umap_y columns + """ + logger.info(f"Processing {len(df)} rows") + + if not UMAP_AVAILABLE: + raise ImportError("UMAP is required for UmapNode. Install with: pip install umap-learn") + + result_df = df.copy() + + # Convert BLOB embeddings back to numpy arrays + if 'embedding' not in result_df.columns: + logger.error("No 'embedding' column found in dataframe") + raise ValueError("Input dataframe must contain 'embedding' column") + + logger.info("Converting embeddings from BLOB to numpy arrays") + result_df['embedding'] = result_df['embedding'].apply( + lambda x: np.frombuffer(x, dtype=np.float32) if x is not None else None + ) + + # Filter out rows with None embeddings + valid_rows = result_df['embedding'].notna() + if not valid_rows.any(): + logger.error("No valid embeddings found in dataframe") + raise ValueError("No valid embeddings to process") + + logger.info(f"Found {valid_rows.sum()} valid embeddings out of {len(result_df)} rows") + + # Stack embeddings into a matrix + embeddings_matrix = np.vstack(result_df.loc[valid_rows, 'embedding'].values) + logger.info(f"Embeddings matrix shape: {embeddings_matrix.shape}") + + # Apply UMAP + logger.info("Fitting UMAP reducer...") + self.reducer = umap.UMAP( + n_neighbors=self.n_neighbors, + min_dist=self.min_dist, + n_components=self.n_components, + metric=self.metric, + random_state=self.random_state + ) + + umap_coords = self.reducer.fit_transform(embeddings_matrix) + logger.info(f"UMAP transformation complete. Output shape: {umap_coords.shape}") + + # Add UMAP coordinates to dataframe + result_df.loc[valid_rows, 'umap_x'] = umap_coords[:, 0] + result_df.loc[valid_rows, 'umap_y'] = umap_coords[:, 1] + + # Fill NaN for invalid rows + result_df['umap_x'] = result_df['umap_x'].fillna(None) + result_df['umap_y'] = result_df['umap_y'].fillna(None) + + logger.info("Processing complete") + return result_df + + def _store_results(self, con: sqlite3.Connection, df: pd.DataFrame): + """Store UMAP coordinates back to the database. + + Args: + con: Database connection + df: Processed dataframe with umap_x and umap_y columns + """ + if df.empty: + logger.info("No results to store") + return + + logger.info(f"Storing {len(df)} results") + + # Batch update UMAP coordinates + updates = [ + (row['umap_x'], row['umap_y'], row['id']) + for _, row in df.iterrows() + if pd.notna(row.get('umap_x')) and pd.notna(row.get('umap_y')) + ] + + if updates: + con.executemany( + "UPDATE posts SET umap_x = ?, umap_y = ? WHERE id = ?", + updates + ) + con.commit() + logger.info(f"Stored {len(updates)} UMAP coordinate pairs successfully") + else: + logger.warning("No valid UMAP coordinates to store") + + def run(self, con: sqlite3.Connection, context: TransformContext) -> TransformContext: + """Execute the transformation. + + This is the main entry point called by the pipeline. + + Args: + con: SQLite database connection + context: TransformContext containing input dataframe + + Returns: + TransformContext with processed dataframe + """ + logger.info("Starting ExampleNode transformation") + + # Get input dataframe from context + input_df = context.get_dataframe() + + # Validate input + if input_df.empty: + logger.warning("Empty dataframe provided to ExampleNode") + return context + + # Process the data + result_df = self._process_data(input_df) + + # Store results (optional) + self._store_results(con, result_df) + + logger.info("ExampleNode transformation complete") + + # Return new context with results + return TransformContext(result_df) + + +class SimilarityNode(TransformNode): + """Example transform node template. + + This node demonstrates the basic structure for creating + new transformation nodes in the pipeline. + """ + + def __init__(self, + param1: str = "default_value", + param2: int = 42, + device: str = "cpu"): + """Initialize the ExampleNode. + + Args: + param1: Example string parameter + param2: Example integer parameter + device: Device to use for computations ('cpu', 'cuda', 'mps') + """ + self.param1 = param1 + self.param2 = param2 + self.device = device + logger.info(f"Initialized ExampleNode with param1={param1}, param2={param2}") + + def _create_tables(self, con: sqlite3.Connection): + """Create any necessary tables in the database. + + This is optional - only needed if your node creates new tables. + """ + logger.info("Creating example tables") + + con.execute(""" + CREATE TABLE IF NOT EXISTS example_results ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + post_id INTEGER, + result_value TEXT, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + FOREIGN KEY (post_id) REFERENCES posts(id) + ) + """) + + con.commit() + + def _process_data(self, df: pd.DataFrame) -> pd.DataFrame: + """Process the input dataframe. + + This is where your main transformation logic goes. + + Args: + df: Input dataframe from context + + Returns: + Processed dataframe + """ + logger.info(f"Processing {len(df)} rows") + + # Example: Add a new column based on existing data + result_df = df.copy() + result_df['processed'] = True + result_df['example_value'] = result_df['id'].apply(lambda x: f"{self.param1}_{x}") + + logger.info("Processing complete") + return result_df + + def _store_results(self, con: sqlite3.Connection, df: pd.DataFrame): + """Store results back to the database. + + This is optional - only needed if you want to persist results. + + Args: + con: Database connection + df: Processed dataframe to store + """ + if df.empty: + logger.info("No results to store") + return + + logger.info(f"Storing {len(df)} results") + + # Example: Store to database + # df[['post_id', 'result_value']].to_sql( + # 'example_results', + # con, + # if_exists='append', + # index=False + # ) + + con.commit() + logger.info("Results stored successfully") + + def run(self, con: sqlite3.Connection, context: TransformContext) -> TransformContext: + """Execute the transformation. + + This is the main entry point called by the pipeline. + + Args: + con: SQLite database connection + context: TransformContext containing input dataframe + + Returns: + TransformContext with processed dataframe + """ + logger.info("Starting ExampleNode transformation") + + # Get input dataframe from context + input_df = context.get_dataframe() + + # Validate input + if input_df.empty: + logger.warning("Empty dataframe provided to ExampleNode") + return context + + # Create any necessary tables + self._create_tables(con) + + # Process the data + result_df = self._process_data(input_df) + + # Store results (optional) + self._store_results(con, result_df) + + logger.info("ExampleNode transformation complete") + + # Return new context with results + return TransformContext(result_df) diff --git a/transform/ensure_gliner_model.sh b/transform/ensure_gliner_model.sh new file mode 100644 index 0000000..4df8215 --- /dev/null +++ b/transform/ensure_gliner_model.sh @@ -0,0 +1,16 @@ +#!/usr/bin/env bash +set -euo pipefail + +if [ -d "$GLINER_MODEL_PATH" ] && find "$GLINER_MODEL_PATH" -type f | grep -q .; then + echo "GLiNER model already present at $GLINER_MODEL_PATH" + exit 0 +fi + +echo "Downloading GLiNER model to $GLINER_MODEL_PATH" +mkdir -p "$GLINER_MODEL_PATH" +curl -sL "https://huggingface.co/api/models/${GLINER_MODEL_ID}" | jq -r '.siblings[].rfilename' | while read -r file; do + target="${GLINER_MODEL_PATH}/${file}" + mkdir -p "$(dirname "$target")" + echo "Downloading ${file}" + curl -sL "https://huggingface.co/${GLINER_MODEL_ID}/resolve/main/${file}" -o "$target" +done diff --git a/transform/ensure_minilm_model.sh b/transform/ensure_minilm_model.sh new file mode 100644 index 0000000..2d58f24 --- /dev/null +++ b/transform/ensure_minilm_model.sh @@ -0,0 +1,16 @@ +#!/usr/bin/env bash +set -euo pipefail + +if [ -d "$MINILM_MODEL_PATH" ] && find "$MINILM_MODEL_PATH" -type f | grep -q .; then + echo "MiniLM model already present at $MINILM_MODEL_PATH" + exit 0 +fi + +echo "Downloading MiniLM model to $MINILM_MODEL_PATH" +mkdir -p "$MINILM_MODEL_PATH" +curl -sL "https://huggingface.co/api/models/${MINILM_MODEL_ID}" | jq -r '.siblings[].rfilename' | while read -r file; do + target="${MINILM_MODEL_PATH}/${file}" + mkdir -p "$(dirname "$target")" + echo "Downloading ${file}" + curl -sL "https://huggingface.co/${MINILM_MODEL_ID}/resolve/main/${file}" -o "$target" +done diff --git a/transform/entrypoint.sh b/transform/entrypoint.sh new file mode 100644 index 0000000..96f5932 --- /dev/null +++ b/transform/entrypoint.sh @@ -0,0 +1,10 @@ +#!/usr/bin/env bash +set -euo pipefail + +# Run model download with output to stdout/stderr +/usr/local/bin/ensure_minilm_model.sh 2>&1 +/usr/local/bin/ensure_gliner_model.sh 2>&1 + +# Start cron in foreground with logging +exec cron -f -L 2 +# cd /app && /usr/local/bin/python main.py >> /proc/1/fd/1 2>&1 \ No newline at end of file diff --git a/transform/example_node.py b/transform/example_node.py new file mode 100644 index 0000000..69900d1 --- /dev/null +++ b/transform/example_node.py @@ -0,0 +1,170 @@ +"""Example template node for the transform pipeline. + +This is a template showing how to create new transform nodes. +Copy this file and modify it for your specific transformation needs. +""" +from pipeline import TransformContext +from transform_node import TransformNode +import sqlite3 +import pandas as pd +import logging + +logger = logging.getLogger("knack-transform") + + +class ExampleNode(TransformNode): + """Example transform node template. + + This node demonstrates the basic structure for creating + new transformation nodes in the pipeline. + """ + + def __init__(self, + param1: str = "default_value", + param2: int = 42, + device: str = "cpu"): + """Initialize the ExampleNode. + + Args: + param1: Example string parameter + param2: Example integer parameter + device: Device to use for computations ('cpu', 'cuda', 'mps') + """ + self.param1 = param1 + self.param2 = param2 + self.device = device + logger.info(f"Initialized ExampleNode with param1={param1}, param2={param2}") + + def _create_tables(self, con: sqlite3.Connection): + """Create any necessary tables in the database. + + This is optional - only needed if your node creates new tables. + """ + logger.info("Creating example tables") + + con.execute(""" + CREATE TABLE IF NOT EXISTS example_results ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + post_id INTEGER, + result_value TEXT, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + FOREIGN KEY (post_id) REFERENCES posts(id) + ) + """) + + con.commit() + + def _process_data(self, df: pd.DataFrame) -> pd.DataFrame: + """Process the input dataframe. + + This is where your main transformation logic goes. + + Args: + df: Input dataframe from context + + Returns: + Processed dataframe + """ + logger.info(f"Processing {len(df)} rows") + + # Example: Add a new column based on existing data + result_df = df.copy() + result_df['processed'] = True + result_df['example_value'] = result_df['id'].apply(lambda x: f"{self.param1}_{x}") + + logger.info("Processing complete") + return result_df + + def _store_results(self, con: sqlite3.Connection, df: pd.DataFrame): + """Store results back to the database. + + This is optional - only needed if you want to persist results. + + Args: + con: Database connection + df: Processed dataframe to store + """ + if df.empty: + logger.info("No results to store") + return + + logger.info(f"Storing {len(df)} results") + + # Example: Store to database + # df[['post_id', 'result_value']].to_sql( + # 'example_results', + # con, + # if_exists='append', + # index=False + # ) + + con.commit() + logger.info("Results stored successfully") + + def run(self, con: sqlite3.Connection, context: TransformContext) -> TransformContext: + """Execute the transformation. + + This is the main entry point called by the pipeline. + + Args: + con: SQLite database connection + context: TransformContext containing input dataframe + + Returns: + TransformContext with processed dataframe + """ + logger.info("Starting ExampleNode transformation") + + # Get input dataframe from context + input_df = context.get_dataframe() + + # Validate input + if input_df.empty: + logger.warning("Empty dataframe provided to ExampleNode") + return context + + # Create any necessary tables + self._create_tables(con) + + # Process the data + result_df = self._process_data(input_df) + + # Store results (optional) + self._store_results(con, result_df) + + logger.info("ExampleNode transformation complete") + + # Return new context with results + return TransformContext(result_df) + + +# Example usage: +if __name__ == "__main__": + # This allows you to test your node independently + import os + os.chdir('/Users/linussilberstein/Documents/Knack-Scraper/transform') + + from pipeline import TransformContext + import sqlite3 + + # Create test data + test_df = pd.DataFrame({ + 'id': [1, 2, 3], + 'author': ['Test Author 1', 'Test Author 2', 'Test Author 3'] + }) + + # Create test database connection + test_con = sqlite3.connect(':memory:') + + # Create and run node + node = ExampleNode(param1="test", param2=100) + context = TransformContext(test_df) + result_context = node.run(test_con, context) + + # Check results + result_df = result_context.get_dataframe() + print("\nResult DataFrame:") + print(result_df) + + test_con.close() + print("\n✓ ExampleNode test completed successfully!") diff --git a/transform/main.py b/transform/main.py new file mode 100644 index 0000000..9922eed --- /dev/null +++ b/transform/main.py @@ -0,0 +1,102 @@ +#! python3 +import logging +import os +import sqlite3 +import sys +from dotenv import load_dotenv + +load_dotenv() + +if (os.environ.get('LOGGING_LEVEL', 'INFO') == 'INFO'): + logging_level = logging.INFO +else: + logging_level = logging.DEBUG + +logging.basicConfig( + level=logging_level, + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', + handlers=[ + logging.FileHandler("app.log"), + logging.StreamHandler(sys.stdout) + ] +) +logger = logging.getLogger("knack-transform") + + +def setup_database_connection(): + """Create connection to the SQLite database.""" + db_path = os.environ.get('DB_PATH', '/data/knack.sqlite') + logger.info(f"Connecting to database: {db_path}") + return sqlite3.connect(db_path) + + +def table_exists(tablename: str, con: sqlite3.Connection): + """Check if a table exists in the database.""" + query = "SELECT 1 FROM sqlite_master WHERE type='table' AND name=? LIMIT 1" + return len(con.execute(query, [tablename]).fetchall()) > 0 + + +def main(): + """Main entry point for the transform pipeline.""" + logger.info("Starting transform pipeline") + + try: + con = setup_database_connection() + logger.info("Database connection established") + + # Check if posts table exists + if not table_exists('posts', con): + logger.warning("Posts table does not exist yet. Please run the scraper first to populate the database.") + logger.info("Transform pipeline skipped - no data available") + return + + # Import transform components + from pipeline import create_default_pipeline, TransformContext + import pandas as pd + + # Load posts data + logger.info("Loading posts from database") + sql = "SELECT * FROM posts WHERE author IS NOT NULL AND (is_cleaned IS NULL OR is_cleaned = 0)" + # MAX_CLEANED_POSTS = os.environ.get("MAX_CLEANED_POSTS", 100) + df = pd.read_sql(sql, con) + logger.info(f"Loaded {len(df)} uncleaned posts with authors") + + if df.empty: + logger.info("No uncleaned posts found. Transform pipeline skipped.") + return + + # Create initial context + context = TransformContext(df) + + # Create and run parallel pipeline + device = os.environ.get('COMPUTE_DEVICE', 'cpu') + max_workers = int(os.environ.get('MAX_WORKERS', 4)) + + pipeline = create_default_pipeline(device=device, max_workers=max_workers) + results = pipeline.run( + db_path=os.environ.get('DB_PATH', '/data/knack.sqlite'), + initial_context=context, + fail_fast=False # Continue even if some nodes fail + ) + + logger.info(f"Pipeline completed. Processed {len(results)} node(s)") + + # Mark all processed posts as cleaned + post_ids = df['id'].tolist() + if post_ids: + placeholders = ','.join('?' * len(post_ids)) + con.execute(f"UPDATE posts SET is_cleaned = 1 WHERE id IN ({placeholders})", post_ids) + con.commit() + logger.info(f"Marked {len(post_ids)} posts as cleaned") + + except Exception as e: + logger.error(f"Error in transform pipeline: {e}", exc_info=True) + sys.exit(1) + finally: + if 'con' in locals(): + con.close() + logger.info("Database connection closed") + + +if __name__ == "__main__": + main() diff --git a/transform/pipeline.py b/transform/pipeline.py new file mode 100644 index 0000000..e1f4e9c --- /dev/null +++ b/transform/pipeline.py @@ -0,0 +1,266 @@ +"""Parallel pipeline orchestration for transform nodes.""" +import logging +import os +import sqlite3 +from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor, as_completed +from typing import List, Dict, Optional + +import pandas as pd +import multiprocessing as mp + +logger = logging.getLogger("knack-transform") + +class TransformContext: + """Context object containing the dataframe for transformation.""" + # Possibly add a dict for the context to give more Information + + def __init__(self, df: pd.DataFrame): + self.df = df + + def get_dataframe(self) -> pd.DataFrame: + """Get the pandas dataframe from the context.""" + return self.df + +class NodeConfig: + """Configuration for a transform node.""" + + def __init__(self, + node_class: type, + node_kwargs: Dict = None, + dependencies: List[str] = None, + name: str = None): + """Initialize node configuration. + + Args: + node_class: The TransformNode class to instantiate + node_kwargs: Keyword arguments to pass to node constructor + dependencies: List of node names that must complete before this one + name: Optional name for the node (defaults to class name) + """ + self.node_class = node_class + self.node_kwargs = node_kwargs or {} + self.dependencies = dependencies or [] + self.name = name or node_class.__name__ + +class ParallelPipeline: + """Pipeline for executing transform nodes in parallel where possible. + + The pipeline analyzes dependencies between nodes and executes + independent nodes concurrently using multiprocessing or threading. + """ + + def __init__(self, + max_workers: Optional[int] = None, + use_processes: bool = False): + """Initialize the parallel pipeline. + + Args: + max_workers: Maximum number of parallel workers (defaults to CPU count) + use_processes: If True, use ProcessPoolExecutor; if False, use ThreadPoolExecutor + """ + self.max_workers = max_workers or mp.cpu_count() + self.use_processes = use_processes + self.nodes: Dict[str, NodeConfig] = {} + logger.info(f"Initialized ParallelPipeline with {self.max_workers} workers " + f"({'processes' if use_processes else 'threads'})") + + def add_node(self, config: NodeConfig): + """Add a node to the pipeline. + + Args: + config: NodeConfig with node details and dependencies + """ + self.nodes[config.name] = config + logger.info(f"Added node '{config.name}' with dependencies: {config.dependencies}") + + def _get_execution_stages(self) -> List[List[str]]: + """Determine execution stages based on dependencies. + + Returns: + List of stages, where each stage contains node names that can run in parallel + """ + stages = [] + completed = set() + remaining = set(self.nodes.keys()) + + while remaining: + # Find nodes whose dependencies are all completed + ready = [] + for node_name in remaining: + config = self.nodes[node_name] + if all(dep in completed for dep in config.dependencies): + ready.append(node_name) + + if not ready: + # Circular dependency or missing dependency + raise ValueError(f"Cannot resolve dependencies. Remaining nodes: {remaining}") + + stages.append(ready) + completed.update(ready) + remaining -= set(ready) + + return stages + + def _execute_node(self, + node_name: str, + db_path: str, + context: TransformContext) -> tuple: + """Execute a single node. + + Args: + node_name: Name of the node to execute + db_path: Path to the SQLite database + context: TransformContext for the node + + Returns: + Tuple of (node_name, result_context, error) + """ + try: + # Create fresh database connection (not shared across processes/threads) + con = sqlite3.connect(db_path) + + config = self.nodes[node_name] + node = config.node_class(**config.node_kwargs) + + logger.info(f"Executing node: {node_name}") + result_context = node.run(con, context) + + con.close() + logger.info(f"Node '{node_name}' completed successfully") + + return node_name, result_context, None + + except Exception as e: + logger.error(f"Error executing node '{node_name}': {e}", exc_info=True) + return node_name, None, str(e) + + def run(self, + db_path: str, + initial_context: TransformContext, + fail_fast: bool = False) -> Dict[str, TransformContext]: + """Execute the pipeline. + + Args: + db_path: Path to the SQLite database + initial_context: Initial TransformContext for the pipeline + fail_fast: If True, stop execution on first error + + Returns: + Dict mapping node names to their output TransformContext + """ + logger.info("Starting parallel pipeline execution") + + stages = self._get_execution_stages() + logger.info(f"Pipeline has {len(stages)} execution stage(s)") + + results = {} + errors = [] + + ExecutorClass = ProcessPoolExecutor if self.use_processes else ThreadPoolExecutor + + for stage_num, stage_nodes in enumerate(stages, 1): + logger.info(f"Stage {stage_num}/{len(stages)}: Executing {len(stage_nodes)} node(s) in parallel: {stage_nodes}") + + # For nodes in this stage, use the context from their dependencies + # If multiple dependencies, we'll use the most recent one (or could merge) + stage_futures = {} + + with ExecutorClass(max_workers=min(self.max_workers, len(stage_nodes))) as executor: + for node_name in stage_nodes: + config = self.nodes[node_name] + + # Get context from dependencies (use the last dependency's output) + if config.dependencies: + context = results.get(config.dependencies[-1], initial_context) + else: + context = initial_context + + future = executor.submit(self._execute_node, node_name, db_path, context) + stage_futures[future] = node_name + + # Wait for all nodes in this stage to complete + for future in as_completed(stage_futures): + node_name = stage_futures[future] + name, result_context, error = future.result() + + if error: + errors.append((name, error)) + if fail_fast: + logger.error(f"Pipeline failed at node '{name}': {error}") + raise RuntimeError(f"Node '{name}' failed: {error}") + else: + results[name] = result_context + + if errors: + logger.warning(f"Pipeline completed with {len(errors)} error(s)") + for name, error in errors: + logger.error(f" - {name}: {error}") + else: + logger.info("Pipeline completed successfully") + + return results + + +def create_default_pipeline(device: str = "cpu", + max_workers: Optional[int] = None) -> ParallelPipeline: + """Create a pipeline with default transform nodes. + + Args: + device: Device to use for compute-intensive nodes ('cpu', 'cuda', 'mps') + max_workers: Maximum number of parallel workers + + Returns: + Configured ParallelPipeline + """ + from author_node import NerAuthorNode, FuzzyAuthorNode + from embeddings_node import TextEmbeddingNode, UmapNode + + pipeline = ParallelPipeline(max_workers=max_workers, use_processes=False) + + # Add AuthorNode (no dependencies) + pipeline.add_node(NodeConfig( + node_class=NerAuthorNode, + node_kwargs={ + 'device': device, + 'model_path': os.environ.get('GLINER_MODEL_PATH') + }, + dependencies=[], + name='AuthorNode' + )) + + pipeline.add_node(NodeConfig( + node_class=FuzzyAuthorNode, + node_kwargs={ + 'max_l_dist': 1 + }, + dependencies=['AuthorNode'], + name='FuzzyAuthorNode' + )) + + pipeline.add_node(NodeConfig( + node_class=TextEmbeddingNode, + node_kwargs={ + 'device': device, + 'model_path': os.environ.get('MINILM_MODEL_PATH') + }, + dependencies=[], + name='TextEmbeddingNode' + )) + + pipeline.add_node(NodeConfig( + node_class=UmapNode, + node_kwargs={}, + dependencies=['TextEmbeddingNode'], + name='UmapNode' + )) + + # TODO: Create Node to compute Text Embeddings and UMAP. + + # pipeline.add_node(NodeConfig( + # node_class=UMAPNode, + # node_kwargs={'device': device}, + # dependencies=['EmbeddingNode'], # Runs after EmbeddingNode + # name='UMAPNode' + # )) + + return pipeline diff --git a/transform/requirements.txt b/transform/requirements.txt new file mode 100644 index 0000000..023d14f --- /dev/null +++ b/transform/requirements.txt @@ -0,0 +1,7 @@ +pandas +python-dotenv +gliner +torch +fuzzysearch +sentence_transformers +umap-learn \ No newline at end of file diff --git a/transform/transform_node.py b/transform/transform_node.py new file mode 100644 index 0000000..54e6bed --- /dev/null +++ b/transform/transform_node.py @@ -0,0 +1,26 @@ +"""Base transform node for data pipeline.""" +from abc import ABC, abstractmethod +import sqlite3 + +from pipeline import TransformContext + +class TransformNode(ABC): + """Abstract base class for transformation nodes. + + Each transform node implements a single transformation step + that takes data from the database, transforms it, and + potentially writes results back. + """ + + @abstractmethod + def run(self, con: sqlite3.Connection, context: TransformContext) -> TransformContext: + """Execute the transformation. + + Args: + con: SQLite database connection + context: TransformContext containing the input dataframe + + Returns: + TransformContext with the transformed dataframe + """ + pass