From bcd210ce0194b73a4f2db913e29082081f53d183 Mon Sep 17 00:00:00 2001 From: quorploop <> Date: Sat, 20 Dec 2025 20:55:04 +0100 Subject: [PATCH 1/4] Dockerized Scraper - Implements Dockerized Version of Scraper - Atomized tags and categories columns --- Dockerfile | 22 ++++-- README.md | 18 +++++ crontab | 1 - main.py | 183 +++++++++++++++++++++++++++++++++++++++++------ requirements.txt | 19 ++--- 5 files changed, 201 insertions(+), 42 deletions(-) delete mode 100644 crontab diff --git a/Dockerfile b/Dockerfile index 9c94fd6..d9752ef 100644 --- a/Dockerfile +++ b/Dockerfile @@ -7,9 +7,21 @@ WORKDIR /app COPY requirements.txt . RUN pip install --no-cache-dir -r requirements.txt -RUN apt update -y -RUN apt install -y cron -COPY crontab . -RUN crontab crontab +COPY .env . -COPY main.py . \ No newline at end of file +RUN apt update -y +RUN apt install -y cron locales + +COPY main.py . + +ENV PYTHONUNBUFFERED=1 +ENV LANG=de_DE.UTF-8 +ENV LC_ALL=de_DE.UTF-8 + +# Create cron job that runs every 15 minutes with environment variables +RUN echo "*/10 * * * * cd /app && /usr/local/bin/python main.py >> /proc/1/fd/1 2>&1" > /etc/cron.d/knack-scraper +RUN chmod 0644 /etc/cron.d/knack-scraper +RUN crontab /etc/cron.d/knack-scraper + +# Start cron in foreground +CMD ["cron", "-f"] \ No newline at end of file diff --git a/README.md b/README.md index e69de29..ab971fc 100644 --- a/README.md +++ b/README.md @@ -0,0 +1,18 @@ +Knack-Scraper does exacly what its name suggests it does. +Knack-Scraper scrapes knack.news and writes to an sqlite +database for later usage. + +## Example for .env + +``` +NUM_THREADS=8 +NUM_SCRAPES=100 +DATABASE_LOCATION='./data/knack.sqlite' +``` + +## Run once + +``` +python main.py +``` + diff --git a/crontab b/crontab deleted file mode 100644 index 6b6ae11..0000000 --- a/crontab +++ /dev/null @@ -1 +0,0 @@ -5 4 * * * python /app/main.py diff --git a/main.py b/main.py index 850ba3c..f5a0b7a 100755 --- a/main.py +++ b/main.py @@ -1,25 +1,34 @@ #! python3 -import locale import logging import os import sqlite3 -import sys import time from concurrent.futures import ThreadPoolExecutor from datetime import datetime +import sys +from dotenv import load_dotenv import pandas as pd import requests import tqdm from bs4 import BeautifulSoup -logger = logging.getLogger("knack-scraper") -# ch = logging.StreamHandler() -# formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") -# ch.setFormatter(formatter) -# ch.setLevel(logging.INFO) -# logger.addHandler(ch) +load_dotenv() +if (os.environ.get('LOGGING_LEVEL', 'INFO') == 'INFO'): + logging_level = logging.INFO +else: + logging_level = logging.DEBUG + +logging.basicConfig( + level=logging_level, + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', + handlers=[ + logging.FileHandler("app.log"), + logging.StreamHandler(sys.stdout) + ] +) +logger = logging.getLogger("knack-scraper") def table_exists(tablename: str, con: sqlite3.Connection): query = "SELECT 1 FROM sqlite_master WHERE type='table' AND name=? LIMIT 1" @@ -39,19 +48,16 @@ def download(id: int): if not (200 <= res.status_code <= 300): return - logger.info("Found promising page with id %d!", id) + logger.debug("Found promising page with id %d!", id) content = res.content soup = BeautifulSoup(content, "html.parser") - date_format = "%d. %B %Y" - # TODO FIXME: this fails inside the docker container - locale.setlocale(locale.LC_TIME, "de_DE") pC = soup.find("div", {"class": "postContent"}) if pC is None: # not a normal post - logger.info( + logger.debug( "Page with id %d does not have a .pageContent-div. Skipping for now.", id ) return @@ -63,9 +69,13 @@ def download(id: int): # these fields are possible but not required # TODO: cleanup try: - date_string = pC.find("span", {"class": "singledate"}).text - parsed_date = datetime.strptime(date_string, date_format) - except AttributeError: + date_parts = pC.find("span", {"class": "singledate"}).text.split(' ') + day = int(date_parts[0][:-1]) + months = {'Januar': 1, 'Februar': 2, 'März': 3, 'April': 4, 'Mai': 5, 'Juni': 6, 'Juli': 7, 'August': 8, 'September': 9, 'Oktober': 10, 'November': 11, 'Dezember': 12} + month = months[date_parts[1]] + year = int(date_parts[2]) + parsed_date = datetime(year, month, day) + except Exception: parsed_date = None try: @@ -75,7 +85,7 @@ def download(id: int): try: category = pC.find("span", {"class": "categoryInfo"}).find_all() - category = [c.text for c in category] + category = [c.text for c in category if c.text != 'Alle Artikel'] category = ";".join(category) except AttributeError: category = None @@ -129,15 +139,79 @@ def run_downloads(min_id: int, max_id: int, num_threads: int = 8): # sqlite can't handle lists so let's convert them to a single row csv # TODO: make sure our database is properly normalized - df = pd.DataFrame(res) + postdf = pd.DataFrame(res) + tagdf = None + posttotagdf = None + categorydf = None + postcategorydf = None - return df + # Extract and create tags dataframe + if not postdf.empty and 'tags' in postdf.columns: + # Collect all unique tags + all_tags = set() + for tags_str in postdf['tags']: + if pd.notna(tags_str): + tags_list = [tag.strip() for tag in tags_str.split(';')] + all_tags.update(tags_list) + + # Create tagdf with id and text columns + if all_tags: + all_tags = sorted(list(all_tags)) + tagdf = pd.DataFrame({ + 'id': range(len(all_tags)), + 'tag': all_tags + }) + + # Create posttotagdf mapping table + rows = [] + for post_id, tags_str in zip(postdf['id'], postdf['tags']): + if pd.notna(tags_str): + tags_list = [tag.strip() for tag in tags_str.split(';')] + for tag_text in tags_list: + tag_id = tagdf[tagdf['tag'] == tag_text]['id'].values[0] + rows.append({'post_id': post_id, 'tag_id': tag_id}) + + if rows: + posttotagdf = pd.DataFrame(rows) + + # Extract and create categories dataframe + if not postdf.empty and 'category' in postdf.columns: + # Collect all unique categories + all_categories = set() + for category_str in postdf['category']: + if pd.notna(category_str): + category_list = [cat.strip() for cat in category_str.split(';')] + all_categories.update(category_list) + + # Create categorydf with id and category columns + if all_categories: + all_categories = sorted(list(all_categories)) + categorydf = pd.DataFrame({ + 'id': range(len(all_categories)), + 'category': all_categories + }) + + # Create postcategorydf mapping table + rows = [] + for post_id, category_str in zip(postdf['id'], postdf['category']): + if pd.notna(category_str): + category_list = [cat.strip() for cat in category_str.split(';')] + for category_text in category_list: + category_id = categorydf[categorydf['category'] == category_text]['id'].values[0] + rows.append({'post_id': post_id, 'category_id': category_id}) + + if rows: + postcategorydf = pd.DataFrame(rows) + + return postdf, tagdf, posttotagdf, categorydf, postcategorydf def main(): num_threads = int(os.environ.get("NUM_THREADS", 8)) n_scrapes = int(os.environ.get("NUM_SCRAPES", 100)) - database_location = os.environ.get("DATABASE_LOCATION", "/data/knack.sqlite") + database_location = os.environ.get("DATABASE_LOCATION", "../data/knack.sqlite") + + logger.debug(f"Started Knack Scraper: \nNUM_THREADS: {num_threads}\nN_SCRAPES: {n_scrapes}\nDATABASE_LOCATION: {database_location}") con = sqlite3.connect(database_location) with con: @@ -155,12 +229,77 @@ def main(): max_id_in_db = -1 con = sqlite3.connect(database_location) - df = run_downloads( + postdf, tagdf, posttotagdf, categorydf, postcategorydf = run_downloads( min_id=max_id_in_db + 1, max_id=max_id_in_db + n_scrapes, num_threads=num_threads, ) - df.to_sql("posts", con, if_exists="append") + postdf.to_sql("posts", con, if_exists="append") + + # Handle tags dataframe merging and storage + if tagdf is not None and not tagdf.empty: + # Check if tags table already exists + if table_exists("tags", con): + # Read existing tags from database + existing_tagdf = pd.read_sql("SELECT id, tag FROM tags", con) + + # Merge new tags with existing tags, avoiding duplicates + merged_tagdf = pd.concat([existing_tagdf, tagdf], ignore_index=False) + merged_tagdf = merged_tagdf.drop_duplicates(subset=['tag'], keep='first') + merged_tagdf = merged_tagdf.reset_index(drop=True) + merged_tagdf['id'] = range(len(merged_tagdf)) + + # Drop the old table and insert the merged data + con.execute("DROP TABLE tags") + con.commit() + merged_tagdf.to_sql("tags", con, if_exists="append", index=False) + + # Update tag_id references in posttotagdf + if posttotagdf is not None and not posttotagdf.empty: + #tag_mapping = dict(zip(tagdf['tag'], tagdf['id'])) + posttotagdf['tag_id'] = posttotagdf['tag_id'].map( + lambda old_id: merged_tagdf[merged_tagdf['tag'] == tagdf.loc[old_id, 'tag']]['id'].values[0] + ) + else: + # First time creating tags table + tagdf.to_sql("tags", con, if_exists="append", index=False) + + # Store posttags (post to tags mapping) + if posttotagdf is not None and not posttotagdf.empty: + posttotagdf.to_sql("posttags", con, if_exists="append", index=False) + + # Handle categories dataframe merging and storage + if categorydf is not None and not categorydf.empty: + # Check if categories table already exists + if table_exists("categories", con): + # Read existing categories from database + existing_categorydf = pd.read_sql("SELECT id, category FROM categories", con) + + # Merge new categories with existing categories, avoiding duplicates + merged_categorydf = pd.concat([existing_categorydf, categorydf], ignore_index=False) + merged_categorydf = merged_categorydf.drop_duplicates(subset=['category'], keep='first') + merged_categorydf = merged_categorydf.reset_index(drop=True) + merged_categorydf['id'] = range(len(merged_categorydf)) + + # Drop the old table and insert the merged data + con.execute("DROP TABLE categories") + con.commit() + merged_categorydf.to_sql("categories", con, if_exists="append", index=False) + + # Update category_id references in postcategorydf + if postcategorydf is not None and not postcategorydf.empty: + postcategorydf['category_id'] = postcategorydf['category_id'].map( + lambda old_id: merged_categorydf[merged_categorydf['category'] == categorydf.loc[old_id, 'category']]['id'].values[0] + ) + else: + # First time creating categories table + categorydf.to_sql("categories", con, if_exists="append", index=False) + + # Store postcategories (post to categories mapping) + if postcategorydf is not None and not postcategorydf.empty: + postcategorydf.to_sql("postcategories", con, if_exists="append", index=False) + + logger.info(f"scraped new entries. number of new posts: {len(postdf.index)}") if __name__ == "__main__": diff --git a/requirements.txt b/requirements.txt index 7792d83..3c59d8d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,14 +1,5 @@ -beautifulsoup4==4.12.2 -certifi==2023.7.22 -charset-normalizer==3.3.0 -idna==3.4 -numpy==1.26.1 -pandas==2.1.1 -python-dateutil==2.8.2 -pytz==2023.3.post1 -requests==2.31.0 -six==1.16.0 -soupsieve==2.5 -tqdm==4.66.1 -tzdata==2023.3 -urllib3==2.0.7 +pandas +requests +tqdm +bs4 +dotenv \ No newline at end of file From 64df8fb3280a380aa5a1b8e136da1d6ae7e8e978 Mon Sep 17 00:00:00 2001 From: quorploop <> Date: Sun, 21 Dec 2025 21:18:05 +0100 Subject: [PATCH 2/4] Implements Feature to cleanup authors freetext field --- .gitignore | 2 + Makefile | 14 +- docker-compose.yml | 27 ++ main.py | 306 -------------------- Dockerfile => scrape/Dockerfile | 4 +- scrape/main.py | 260 +++++++++++++++++ requirements.txt => scrape/requirements.txt | 1 - transform/.env.example | 4 + transform/Dockerfile | 41 +++ transform/README.md | 62 ++++ transform/author_node.py | 263 +++++++++++++++++ transform/base.py | 37 +++ transform/main.py | 89 ++++++ transform/requirements.txt | 4 + 14 files changed, 804 insertions(+), 310 deletions(-) create mode 100644 docker-compose.yml delete mode 100755 main.py rename Dockerfile => scrape/Dockerfile (77%) create mode 100755 scrape/main.py rename requirements.txt => scrape/requirements.txt (83%) create mode 100644 transform/.env.example create mode 100644 transform/Dockerfile create mode 100644 transform/README.md create mode 100644 transform/author_node.py create mode 100644 transform/base.py create mode 100644 transform/main.py create mode 100644 transform/requirements.txt diff --git a/.gitignore b/.gitignore index 2e7a5cf..8e2fba4 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,5 @@ data/ venv/ +experiment/ .DS_STORE +.env \ No newline at end of file diff --git a/Makefile b/Makefile index a669090..47a8063 100644 --- a/Makefile +++ b/Makefile @@ -1,2 +1,12 @@ -build: - docker build -t knack-scraper . \ No newline at end of file +volume: + docker volume create knack_data + +stop: + docker stop knack-scraper || true + docker rm knack-scraper || true + +up: + docker compose up -d + +down: + docker compose down \ No newline at end of file diff --git a/docker-compose.yml b/docker-compose.yml new file mode 100644 index 0000000..5c5c4e7 --- /dev/null +++ b/docker-compose.yml @@ -0,0 +1,27 @@ +services: + scraper: + build: + context: ./scrape + dockerfile: Dockerfile + image: knack-scraper + container_name: knack-scraper + env_file: + - scrape/.env + volumes: + - knack_data:/data + restart: unless-stopped + + transform: + build: + context: ./transform + dockerfile: Dockerfile + image: knack-transform + container_name: knack-transform + env_file: + - transform/.env + volumes: + - knack_data:/data + restart: unless-stopped + +volumes: + knack_data: diff --git a/main.py b/main.py deleted file mode 100755 index f5a0b7a..0000000 --- a/main.py +++ /dev/null @@ -1,306 +0,0 @@ -#! python3 -import logging -import os -import sqlite3 -import time -from concurrent.futures import ThreadPoolExecutor -from datetime import datetime -import sys - -from dotenv import load_dotenv -import pandas as pd -import requests -import tqdm -from bs4 import BeautifulSoup - -load_dotenv() - -if (os.environ.get('LOGGING_LEVEL', 'INFO') == 'INFO'): - logging_level = logging.INFO -else: - logging_level = logging.DEBUG - -logging.basicConfig( - level=logging_level, - format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', - handlers=[ - logging.FileHandler("app.log"), - logging.StreamHandler(sys.stdout) - ] -) -logger = logging.getLogger("knack-scraper") - -def table_exists(tablename: str, con: sqlite3.Connection): - query = "SELECT 1 FROM sqlite_master WHERE type='table' AND name=? LIMIT 1" - return len(con.execute(query, [tablename]).fetchall()) > 0 - - -def download(id: int): - if id == 0: - return - base_url = "https://knack.news/" - url = f"{base_url}{id}" - res = requests.get(url) - - # make sure we don't dos knack - time.sleep(2) - - if not (200 <= res.status_code <= 300): - return - - logger.debug("Found promising page with id %d!", id) - - content = res.content - soup = BeautifulSoup(content, "html.parser") - - pC = soup.find("div", {"class": "postContent"}) - - if pC is None: - # not a normal post - logger.debug( - "Page with id %d does not have a .pageContent-div. Skipping for now.", id - ) - return - - # every post has these fields - title = pC.find("h3", {"class": "postTitle"}).text - postText = pC.find("div", {"class": "postText"}) - - # these fields are possible but not required - # TODO: cleanup - try: - date_parts = pC.find("span", {"class": "singledate"}).text.split(' ') - day = int(date_parts[0][:-1]) - months = {'Januar': 1, 'Februar': 2, 'März': 3, 'April': 4, 'Mai': 5, 'Juni': 6, 'Juli': 7, 'August': 8, 'September': 9, 'Oktober': 10, 'November': 11, 'Dezember': 12} - month = months[date_parts[1]] - year = int(date_parts[2]) - parsed_date = datetime(year, month, day) - except Exception: - parsed_date = None - - try: - author = pC.find("span", {"class": "author"}).text - except AttributeError: - author = None - - try: - category = pC.find("span", {"class": "categoryInfo"}).find_all() - category = [c.text for c in category if c.text != 'Alle Artikel'] - category = ";".join(category) - except AttributeError: - category = None - - try: - tags = [x.text for x in pC.find("div", {"class": "tagsInfo"}).find_all("a")] - tags = ";".join(tags) - except AttributeError: - tags = None - - img = pC.find("img", {"class": "postImage"}) - if img is not None: - img = img["src"] - - res_dict = { - "id": id, - "title": title, - "author": author, - "date": parsed_date, - "category": category, - "url": url, - "img_link": img, - "tags": tags, - "text": postText.text, - "html": str(postText), - "scraped_at": datetime.now(), - } - - return res_dict - - -def run_downloads(min_id: int, max_id: int, num_threads: int = 8): - res = [] - - logger.info( - "Started parallel scrape of posts from id %d to id %d using %d threads.", - min_id, - max_id - 1, - num_threads, - ) - with ThreadPoolExecutor(max_workers=num_threads) as executor: - # Use a list comprehension to create a list of futures - futures = [executor.submit(download, i) for i in range(min_id, max_id)] - - for future in tqdm.tqdm( - futures, total=max_id - min_id - ): # tqdm to track progress - post = future.result() - if post is not None: - res.append(post) - - # sqlite can't handle lists so let's convert them to a single row csv - # TODO: make sure our database is properly normalized - postdf = pd.DataFrame(res) - tagdf = None - posttotagdf = None - categorydf = None - postcategorydf = None - - # Extract and create tags dataframe - if not postdf.empty and 'tags' in postdf.columns: - # Collect all unique tags - all_tags = set() - for tags_str in postdf['tags']: - if pd.notna(tags_str): - tags_list = [tag.strip() for tag in tags_str.split(';')] - all_tags.update(tags_list) - - # Create tagdf with id and text columns - if all_tags: - all_tags = sorted(list(all_tags)) - tagdf = pd.DataFrame({ - 'id': range(len(all_tags)), - 'tag': all_tags - }) - - # Create posttotagdf mapping table - rows = [] - for post_id, tags_str in zip(postdf['id'], postdf['tags']): - if pd.notna(tags_str): - tags_list = [tag.strip() for tag in tags_str.split(';')] - for tag_text in tags_list: - tag_id = tagdf[tagdf['tag'] == tag_text]['id'].values[0] - rows.append({'post_id': post_id, 'tag_id': tag_id}) - - if rows: - posttotagdf = pd.DataFrame(rows) - - # Extract and create categories dataframe - if not postdf.empty and 'category' in postdf.columns: - # Collect all unique categories - all_categories = set() - for category_str in postdf['category']: - if pd.notna(category_str): - category_list = [cat.strip() for cat in category_str.split(';')] - all_categories.update(category_list) - - # Create categorydf with id and category columns - if all_categories: - all_categories = sorted(list(all_categories)) - categorydf = pd.DataFrame({ - 'id': range(len(all_categories)), - 'category': all_categories - }) - - # Create postcategorydf mapping table - rows = [] - for post_id, category_str in zip(postdf['id'], postdf['category']): - if pd.notna(category_str): - category_list = [cat.strip() for cat in category_str.split(';')] - for category_text in category_list: - category_id = categorydf[categorydf['category'] == category_text]['id'].values[0] - rows.append({'post_id': post_id, 'category_id': category_id}) - - if rows: - postcategorydf = pd.DataFrame(rows) - - return postdf, tagdf, posttotagdf, categorydf, postcategorydf - - -def main(): - num_threads = int(os.environ.get("NUM_THREADS", 8)) - n_scrapes = int(os.environ.get("NUM_SCRAPES", 100)) - database_location = os.environ.get("DATABASE_LOCATION", "../data/knack.sqlite") - - logger.debug(f"Started Knack Scraper: \nNUM_THREADS: {num_threads}\nN_SCRAPES: {n_scrapes}\nDATABASE_LOCATION: {database_location}") - - con = sqlite3.connect(database_location) - with con: - post_table_exists = table_exists("posts", con) - - if post_table_exists: - logger.info("found posts retrieved earlier") - # retrieve max post id from db so - # we can skip retrieving known posts - max_id_in_db = con.execute("SELECT MAX(id) FROM posts").fetchone()[0] - logger.info("Got max id %d!", max_id_in_db) - else: - logger.info("no posts scraped so far - starting from 0") - # retrieve from 0 onwards - max_id_in_db = -1 - - con = sqlite3.connect(database_location) - postdf, tagdf, posttotagdf, categorydf, postcategorydf = run_downloads( - min_id=max_id_in_db + 1, - max_id=max_id_in_db + n_scrapes, - num_threads=num_threads, - ) - postdf.to_sql("posts", con, if_exists="append") - - # Handle tags dataframe merging and storage - if tagdf is not None and not tagdf.empty: - # Check if tags table already exists - if table_exists("tags", con): - # Read existing tags from database - existing_tagdf = pd.read_sql("SELECT id, tag FROM tags", con) - - # Merge new tags with existing tags, avoiding duplicates - merged_tagdf = pd.concat([existing_tagdf, tagdf], ignore_index=False) - merged_tagdf = merged_tagdf.drop_duplicates(subset=['tag'], keep='first') - merged_tagdf = merged_tagdf.reset_index(drop=True) - merged_tagdf['id'] = range(len(merged_tagdf)) - - # Drop the old table and insert the merged data - con.execute("DROP TABLE tags") - con.commit() - merged_tagdf.to_sql("tags", con, if_exists="append", index=False) - - # Update tag_id references in posttotagdf - if posttotagdf is not None and not posttotagdf.empty: - #tag_mapping = dict(zip(tagdf['tag'], tagdf['id'])) - posttotagdf['tag_id'] = posttotagdf['tag_id'].map( - lambda old_id: merged_tagdf[merged_tagdf['tag'] == tagdf.loc[old_id, 'tag']]['id'].values[0] - ) - else: - # First time creating tags table - tagdf.to_sql("tags", con, if_exists="append", index=False) - - # Store posttags (post to tags mapping) - if posttotagdf is not None and not posttotagdf.empty: - posttotagdf.to_sql("posttags", con, if_exists="append", index=False) - - # Handle categories dataframe merging and storage - if categorydf is not None and not categorydf.empty: - # Check if categories table already exists - if table_exists("categories", con): - # Read existing categories from database - existing_categorydf = pd.read_sql("SELECT id, category FROM categories", con) - - # Merge new categories with existing categories, avoiding duplicates - merged_categorydf = pd.concat([existing_categorydf, categorydf], ignore_index=False) - merged_categorydf = merged_categorydf.drop_duplicates(subset=['category'], keep='first') - merged_categorydf = merged_categorydf.reset_index(drop=True) - merged_categorydf['id'] = range(len(merged_categorydf)) - - # Drop the old table and insert the merged data - con.execute("DROP TABLE categories") - con.commit() - merged_categorydf.to_sql("categories", con, if_exists="append", index=False) - - # Update category_id references in postcategorydf - if postcategorydf is not None and not postcategorydf.empty: - postcategorydf['category_id'] = postcategorydf['category_id'].map( - lambda old_id: merged_categorydf[merged_categorydf['category'] == categorydf.loc[old_id, 'category']]['id'].values[0] - ) - else: - # First time creating categories table - categorydf.to_sql("categories", con, if_exists="append", index=False) - - # Store postcategories (post to categories mapping) - if postcategorydf is not None and not postcategorydf.empty: - postcategorydf.to_sql("postcategories", con, if_exists="append", index=False) - - logger.info(f"scraped new entries. number of new posts: {len(postdf.index)}") - - -if __name__ == "__main__": - main() diff --git a/Dockerfile b/scrape/Dockerfile similarity index 77% rename from Dockerfile rename to scrape/Dockerfile index d9752ef..a2fbe2e 100644 --- a/Dockerfile +++ b/scrape/Dockerfile @@ -3,6 +3,8 @@ FROM python:slim RUN mkdir /app RUN mkdir /data +#COPY /data/knack.sqlite /data + WORKDIR /app COPY requirements.txt . RUN pip install --no-cache-dir -r requirements.txt @@ -19,7 +21,7 @@ ENV LANG=de_DE.UTF-8 ENV LC_ALL=de_DE.UTF-8 # Create cron job that runs every 15 minutes with environment variables -RUN echo "*/10 * * * * cd /app && /usr/local/bin/python main.py >> /proc/1/fd/1 2>&1" > /etc/cron.d/knack-scraper +RUN echo "5 4 * * * cd /app && /usr/local/bin/python main.py >> /proc/1/fd/1 2>&1" > /etc/cron.d/knack-scraper RUN chmod 0644 /etc/cron.d/knack-scraper RUN crontab /etc/cron.d/knack-scraper diff --git a/scrape/main.py b/scrape/main.py new file mode 100755 index 0000000..15f0e72 --- /dev/null +++ b/scrape/main.py @@ -0,0 +1,260 @@ +#! python3 +import logging +import os +import sqlite3 +import time +from concurrent.futures import ThreadPoolExecutor +from datetime import datetime +import sys + +from dotenv import load_dotenv +import pandas as pd +import requests +from bs4 import BeautifulSoup + +load_dotenv() + +if (os.environ.get('LOGGING_LEVEL', 'INFO') == 'INFO'): + logging_level = logging.INFO +else: + logging_level = logging.DEBUG + +logging.basicConfig( + level=logging_level, + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', + handlers=[ + logging.FileHandler("app.log"), + logging.StreamHandler(sys.stdout) + ] +) +logger = logging.getLogger("knack-scraper") + +def table_exists(tablename: str, con: sqlite3.Connection): + query = "SELECT 1 FROM sqlite_master WHERE type='table' AND name=? LIMIT 1" + return len(con.execute(query, [tablename]).fetchall()) > 0 + + +def split_semicolon_list(value: str): + if pd.isna(value): + return [] + return [item.strip() for item in str(value).split(';') if item.strip()] + + +def build_dimension_and_mapping(postdf: pd.DataFrame, field_name: str, dim_col: str): + """Extract unique dimension values and post-to-dimension mappings from a column.""" + if postdf.empty or field_name not in postdf.columns: + return None, None + + values = set() + mapping_rows = [] + + for post_id, raw in zip(postdf['id'], postdf[field_name]): + items = split_semicolon_list(raw) + for item in items: + values.add(item) + mapping_rows.append({'post_id': post_id, dim_col: item}) + + if not values: + return None, None + + dim_df = pd.DataFrame({ + 'id': range(len(values)), + dim_col: sorted(values), + }) + map_df = pd.DataFrame(mapping_rows) + return dim_df, map_df + + +def store_dimension_and_mapping( + con: sqlite3.Connection, + dim_df: pd.DataFrame | None, + map_df: pd.DataFrame | None, + table_name: str, + dim_col: str, + mapping_table: str, + mapping_id_col: str, +): + """Persist a dimension table and its mapping table, merging with existing values.""" + if dim_df is None or dim_df.empty: + return + + if table_exists(table_name, con): + existing = pd.read_sql(f"SELECT id, {dim_col} FROM {table_name}", con) + merged = pd.concat([existing, dim_df], ignore_index=True) + merged = merged.drop_duplicates(subset=[dim_col], keep='first').reset_index(drop=True) + merged['id'] = range(len(merged)) + else: + merged = dim_df.copy() + + # Replace table with merged content + merged.to_sql(table_name, con, if_exists="replace", index=False) + + if map_df is None or map_df.empty: + return + + value_to_id = dict(zip(merged[dim_col], merged['id'])) + map_df = map_df.copy() + map_df[mapping_id_col] = map_df[dim_col].map(value_to_id) + map_df = map_df[['post_id', mapping_id_col]].dropna() + map_df.to_sql(mapping_table, con, if_exists="append", index=False) + + +def download(id: int): + if id == 0: + return + base_url = "https://knack.news/" + url = f"{base_url}{id}" + res = requests.get(url) + + # make sure we don't dos knack + time.sleep(2) + + if not (200 <= res.status_code <= 300): + return + + logger.debug("Found promising page with id %d!", id) + + content = res.content + soup = BeautifulSoup(content, "html.parser") + + pC = soup.find("div", {"class": "postContent"}) + + if pC is None: + # not a normal post + logger.debug( + "Page with id %d does not have a .pageContent-div. Skipping for now.", id + ) + return + + # every post has these fields + title = pC.find("h3", {"class": "postTitle"}).text + postText = pC.find("div", {"class": "postText"}) + + # these fields are possible but not required + # TODO: cleanup + try: + date_parts = pC.find("span", {"class": "singledate"}).text.split(' ') + day = int(date_parts[0][:-1]) + months = {'Januar': 1, 'Februar': 2, 'März': 3, 'April': 4, 'Mai': 5, 'Juni': 6, 'Juli': 7, 'August': 8, 'September': 9, 'Oktober': 10, 'November': 11, 'Dezember': 12} + month = months[date_parts[1]] + year = int(date_parts[2]) + parsed_date = datetime(year, month, day) + except Exception: + parsed_date = None + + try: + author = pC.find("span", {"class": "author"}).text + except AttributeError: + author = None + + try: + category = pC.find("span", {"class": "categoryInfo"}).find_all() + category = [c.text for c in category if c.text != 'Alle Artikel'] + category = ";".join(category) + except AttributeError: + category = None + + try: + tags = [x.text for x in pC.find("div", {"class": "tagsInfo"}).find_all("a")] + tags = ";".join(tags) + except AttributeError: + tags = None + + img = pC.find("img", {"class": "postImage"}) + if img is not None: + img = img["src"] + + res_dict = { + "id": id, + "title": title, + "author": author, + "date": parsed_date, + "category": category, + "url": url, + "img_link": img, + "tags": tags, + "text": postText.text, + "html": str(postText), + "scraped_at": datetime.now(), + "is_cleaned": False + } + + return res_dict + + +def run_downloads(min_id: int, max_id: int, num_threads: int = 8): + res = [] + + logger.info( + "Started parallel scrape of posts from id %d to id %d using %d threads.", + min_id, + max_id - 1, + num_threads, + ) + with ThreadPoolExecutor(max_workers=num_threads) as executor: + # Use a list comprehension to create a list of futures + futures = [executor.submit(download, i) for i in range(min_id, max_id)] + + for future in futures: + post = future.result() + if post is not None: + res.append(post) + + postdf = pd.DataFrame(res) + return postdf + + +def main(): + num_threads = int(os.environ.get("NUM_THREADS", 8)) + n_scrapes = int(os.environ.get("NUM_SCRAPES", 100)) + database_location = os.environ.get("DATABASE_LOCATION", "../data/knack.sqlite") + + logger.debug(f"Started Knack Scraper: \nNUM_THREADS: {num_threads}\nN_SCRAPES: {n_scrapes}\nDATABASE_LOCATION: {database_location}") + + con = sqlite3.connect(database_location) + with con: + if table_exists("posts", con): + logger.info("found posts retrieved earlier") + max_id_in_db = con.execute("SELECT MAX(id) FROM posts").fetchone()[0] + logger.info("Got max id %d!", max_id_in_db) + else: + logger.info("no posts scraped so far - starting from 0") + max_id_in_db = -1 + + postdf = run_downloads( + min_id=max_id_in_db + 1, + max_id=max_id_in_db + n_scrapes, + num_threads=num_threads, + ) + + postdf.to_sql("posts", con, if_exists="append") + + # Tags + tag_dim, tag_map = build_dimension_and_mapping(postdf, 'tags', 'tag') + store_dimension_and_mapping( + con, + tag_dim, + tag_map, + table_name="tags", + dim_col="tag", + mapping_table="posttags", + mapping_id_col="tag_id", + ) + + # Categories + category_dim, category_map = build_dimension_and_mapping(postdf, 'category', 'category') + store_dimension_and_mapping( + con, + category_dim, + category_map, + table_name="categories", + dim_col="category", + mapping_table="postcategories", + mapping_id_col="category_id", + ) + + logger.info(f"scraped new entries. number of new posts: {len(postdf.index)}") + + +if __name__ == "__main__": + main() diff --git a/requirements.txt b/scrape/requirements.txt similarity index 83% rename from requirements.txt rename to scrape/requirements.txt index 3c59d8d..32d5df2 100644 --- a/requirements.txt +++ b/scrape/requirements.txt @@ -1,5 +1,4 @@ pandas requests -tqdm bs4 dotenv \ No newline at end of file diff --git a/transform/.env.example b/transform/.env.example new file mode 100644 index 0000000..34cd0e0 --- /dev/null +++ b/transform/.env.example @@ -0,0 +1,4 @@ +LOGGING_LEVEL=INFO +DB_PATH=/data/knack.sqlite +MAX_CLEANED_POSTS=1000 +COMPUTE_DEVICE=mps \ No newline at end of file diff --git a/transform/Dockerfile b/transform/Dockerfile new file mode 100644 index 0000000..4c72480 --- /dev/null +++ b/transform/Dockerfile @@ -0,0 +1,41 @@ +FROM python:3.12-slim + +RUN mkdir /app +RUN mkdir /data + +# Install build dependencies +RUN apt-get update && apt-get install -y --no-install-recommends \ + gcc \ + g++ \ + gfortran \ + libopenblas-dev \ + liblapack-dev \ + pkg-config \ + && rm -rf /var/lib/apt/lists/* + +#COPY /data/knack.sqlite /data + +WORKDIR /app +COPY requirements.txt . +RUN pip install --no-cache-dir -r requirements.txt + +COPY .env . + +RUN apt update -y +RUN apt install -y cron locales + +COPY *.py . + +ENV PYTHONUNBUFFERED=1 +ENV LANG=de_DE.UTF-8 +ENV LC_ALL=de_DE.UTF-8 + +# Create cron job that runs every weekend (Sunday at 3 AM) 0 3 * * 0 +# Testing every 30 Minutes */30 * * * * +RUN echo "0 3 * * 0 cd /app && /usr/local/bin/python main.py >> /proc/1/fd/1 2>&1" > /etc/cron.d/knack-transform +RUN chmod 0644 /etc/cron.d/knack-transform +RUN crontab /etc/cron.d/knack-transform + +# Start cron in foreground +CMD ["cron", "-f"] +#CMD ["python", "main.py"] diff --git a/transform/README.md b/transform/README.md new file mode 100644 index 0000000..44ddeb1 --- /dev/null +++ b/transform/README.md @@ -0,0 +1,62 @@ +# Knack Transform + +Data transformation pipeline for the Knack scraper project. + +## Overview + +This folder contains the transformation logic that processes data from the SQLite database. It runs on a scheduled basis (every weekend) via cron. + +## Structure + +- `base.py` - Abstract base class for transform nodes +- `main.py` - Main entry point and pipeline orchestration +- `Dockerfile` - Docker image configuration with cron setup +- `requirements.txt` - Python dependencies + +## Transform Nodes + +Transform nodes inherit from `TransformNode` and implement the `run` method: + +```python +from base import TransformNode, TransformContext +import sqlite3 + +class MyTransform(TransformNode): + def run(self, con: sqlite3.Connection, context: TransformContext) -> TransformContext: + df = context.get_dataframe() + + # Transform logic here + transformed_df = df.copy() + # ... your transformations ... + + # Optionally write back to database + transformed_df.to_sql("my_table", con, if_exists="replace", index=False) + + return TransformContext(transformed_df) +``` + +## Configuration + +Copy `.env.example` to `.env` and configure: + +- `LOGGING_LEVEL` - Log level (INFO or DEBUG) +- `DB_PATH` - Path to SQLite database + +## Running + +### With Docker + +```bash +docker build -t knack-transform . +docker run -v $(pwd)/data:/data knack-transform +``` + +### Locally + +```bash +python main.py +``` + +## Cron Schedule + +The Docker container runs the transform pipeline every Sunday at 3 AM. diff --git a/transform/author_node.py b/transform/author_node.py new file mode 100644 index 0000000..23d3365 --- /dev/null +++ b/transform/author_node.py @@ -0,0 +1,263 @@ +"""Author classification transform node using NER.""" +from base import TransformNode, TransformContext +import sqlite3 +import pandas as pd +import logging +from concurrent.futures import ThreadPoolExecutor +from datetime import datetime + +try: + from gliner import GLiNER + import torch + GLINER_AVAILABLE = True +except ImportError: + GLINER_AVAILABLE = False + logging.warning("GLiNER not available. Install with: pip install gliner") + +logger = logging.getLogger("knack-transform") + + +class AuthorNode(TransformNode): + """Transform node that extracts and classifies authors using NER. + + Creates two tables: + - authors: stores unique authors with their type (Person, Organisation, etc.) + - post_authors: maps posts to their authors + """ + + def __init__(self, model_name: str = "urchade/gliner_medium-v2.1", + threshold: float = 0.5, + max_workers: int = 64, + device: str = "cpu"): + """Initialize the AuthorNode. + + Args: + model_name: GLiNER model to use + threshold: Confidence threshold for entity predictions + max_workers: Number of parallel workers for prediction + device: Device to run model on ('cpu', 'cuda', 'mps') + """ + self.model_name = model_name + self.threshold = threshold + self.max_workers = max_workers + self.device = device + self.model = None + self.labels = ["Person", "Organisation", "Email", "Newspaper", "NGO"] + + def _setup_model(self): + """Initialize the NER model.""" + if not GLINER_AVAILABLE: + raise ImportError("GLiNER is required for AuthorNode. Install with: pip install gliner") + + logger.info(f"Loading GLiNER model: {self.model_name}") + + if self.device == "cuda" and torch.cuda.is_available(): + self.model = GLiNER.from_pretrained( + self.model_name, + max_length=255 + ).to('cuda', dtype=torch.float16) + elif self.device == "mps" and torch.backends.mps.is_available(): + self.model = GLiNER.from_pretrained( + self.model_name, + max_length=255 + ).to('mps', dtype=torch.float16) + else: + self.model = GLiNER.from_pretrained( + self.model_name, + max_length=255 + ) + + logger.info("Model loaded successfully") + + def _predict(self, text_data: dict): + """Predict entities for a single author text. + + Args: + text_data: Dict with 'author' and 'id' keys + + Returns: + Tuple of (predictions, post_id) or None + """ + if text_data is None or text_data.get('author') is None: + return None + + predictions = self.model.predict_entities( + text_data['author'], + self.labels, + threshold=self.threshold + ) + return predictions, text_data['id'] + + def _classify_authors(self, posts_df: pd.DataFrame): + """Classify all authors in the posts dataframe. + + Args: + posts_df: DataFrame with 'id' and 'author' columns + + Returns: + List of dicts with 'text', 'label', 'id' keys + """ + if self.model is None: + self._setup_model() + + # Prepare input data + authors_data = [] + for idx, row in posts_df.iterrows(): + if pd.notna(row['author']): + authors_data.append({ + 'author': row['author'], + 'id': row['id'] + }) + + logger.info(f"Classifying {len(authors_data)} authors") + + results = [] + with ThreadPoolExecutor(max_workers=self.max_workers) as executor: + futures = [executor.submit(self._predict, data) for data in authors_data] + + for future in futures: + result = future.result() + if result is not None: + predictions, post_id = result + for pred in predictions: + results.append({ + 'text': pred['text'], + 'label': pred['label'], + 'id': post_id + }) + + logger.info(f"Classification complete. Found {len(results)} author entities") + return results + + def _create_tables(self, con: sqlite3.Connection): + """Create authors and post_authors tables if they don't exist.""" + logger.info("Creating authors tables") + + con.execute(""" + CREATE TABLE IF NOT EXISTS authors ( + id INTEGER PRIMARY KEY, + name TEXT, + type TEXT, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP + ) + """) + + con.execute(""" + CREATE TABLE IF NOT EXISTS post_authors ( + post_id INTEGER, + author_id INTEGER, + PRIMARY KEY (post_id, author_id), + FOREIGN KEY (post_id) REFERENCES posts(id), + FOREIGN KEY (author_id) REFERENCES authors(id) + ) + """) + + con.commit() + + def _store_authors(self, con: sqlite3.Connection, results: list): + """Store classified authors and their mappings. + + Args: + con: Database connection + results: List of classification results + """ + if not results: + logger.info("No authors to store") + return + + # Convert results to DataFrame + results_df = pd.DataFrame(results) + + # Get unique authors with their types + unique_authors = results_df[['text', 'label']].drop_duplicates() + unique_authors.columns = ['name', 'type'] + + # Get existing authors + existing_authors = pd.read_sql("SELECT id, name FROM authors", con) + + # Find new authors to insert + if not existing_authors.empty: + new_authors = unique_authors[~unique_authors['name'].isin(existing_authors['name'])] + else: + new_authors = unique_authors + + if not new_authors.empty: + logger.info(f"Inserting {len(new_authors)} new authors") + new_authors.to_sql('authors', con, if_exists='append', index=False) + + # Get all authors with their IDs + all_authors = pd.read_sql("SELECT id, name FROM authors", con) + name_to_id = dict(zip(all_authors['name'], all_authors['id'])) + + # Create post_authors mappings + mappings = [] + for _, row in results_df.iterrows(): + author_id = name_to_id.get(row['text']) + if author_id: + mappings.append({ + 'post_id': row['id'], + 'author_id': author_id + }) + + if mappings: + mappings_df = pd.DataFrame(mappings).drop_duplicates() + + # Clear existing mappings for these posts (optional, depends on your strategy) + # post_ids = tuple(mappings_df['post_id'].unique()) + # con.execute(f"DELETE FROM post_authors WHERE post_id IN ({','.join('?' * len(post_ids))})", post_ids) + + logger.info(f"Creating {len(mappings_df)} post-author mappings") + mappings_df.to_sql('post_authors', con, if_exists='append', index=False) + + # Mark posts as cleaned + processed_post_ids = mappings_df['post_id'].unique().tolist() + if processed_post_ids: + placeholders = ','.join('?' * len(processed_post_ids)) + con.execute(f"UPDATE posts SET is_cleaned = 1 WHERE id IN ({placeholders})", processed_post_ids) + logger.info(f"Marked {len(processed_post_ids)} posts as cleaned") + + con.commit() + logger.info("Authors and mappings stored successfully") + + def run(self, con: sqlite3.Connection, context: TransformContext) -> TransformContext: + """Execute the author classification transformation. + + Args: + con: SQLite database connection + context: TransformContext containing posts dataframe + + Returns: + TransformContext with classified authors dataframe + """ + logger.info("Starting AuthorNode transformation") + + posts_df = context.get_dataframe() + + # Ensure required columns exist + if 'author' not in posts_df.columns: + logger.warning("No 'author' column in dataframe. Skipping AuthorNode.") + return context + + # Create tables + self._create_tables(con) + + # Classify authors + results = self._classify_authors(posts_df) + + # Store results + self._store_authors(con, results) + + # Mark posts without author entities as cleaned too (no authors found) + processed_ids = set([r['id'] for r in results]) if results else set() + unprocessed_ids = [pid for pid in posts_df['id'].tolist() if pid not in processed_ids] + if unprocessed_ids: + placeholders = ','.join('?' * len(unprocessed_ids)) + con.execute(f"UPDATE posts SET is_cleaned = 1 WHERE id IN ({placeholders})", unprocessed_ids) + con.commit() + logger.info(f"Marked {len(unprocessed_ids)} posts without author entities as cleaned") + + # Return context with results + results_df = pd.DataFrame(results) if results else pd.DataFrame() + logger.info("AuthorNode transformation complete") + + return TransformContext(results_df) diff --git a/transform/base.py b/transform/base.py new file mode 100644 index 0000000..59a4f31 --- /dev/null +++ b/transform/base.py @@ -0,0 +1,37 @@ +"""Base transform node for data pipeline.""" +from abc import ABC, abstractmethod +import sqlite3 +import pandas as pd + + +class TransformContext: + """Context object containing the dataframe for transformation.""" + + def __init__(self, df: pd.DataFrame): + self.df = df + + def get_dataframe(self) -> pd.DataFrame: + """Get the pandas dataframe from the context.""" + return self.df + + +class TransformNode(ABC): + """Abstract base class for transformation nodes. + + Each transform node implements a single transformation step + that takes data from the database, transforms it, and + potentially writes results back. + """ + + @abstractmethod + def run(self, con: sqlite3.Connection, context: TransformContext) -> TransformContext: + """Execute the transformation. + + Args: + con: SQLite database connection + context: TransformContext containing the input dataframe + + Returns: + TransformContext with the transformed dataframe + """ + pass diff --git a/transform/main.py b/transform/main.py new file mode 100644 index 0000000..29b9a38 --- /dev/null +++ b/transform/main.py @@ -0,0 +1,89 @@ +#! python3 +import logging +import os +import sqlite3 +import sys +from dotenv import load_dotenv + +load_dotenv() + +if (os.environ.get('LOGGING_LEVEL', 'INFO') == 'INFO'): + logging_level = logging.INFO +else: + logging_level = logging.DEBUG + +logging.basicConfig( + level=logging_level, + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', + handlers=[ + logging.FileHandler("app.log"), + logging.StreamHandler(sys.stdout) + ] +) +logger = logging.getLogger("knack-transform") + + +def setup_database_connection(): + """Create connection to the SQLite database.""" + db_path = os.environ.get('DB_PATH', '/data/knack.sqlite') + logger.info(f"Connecting to database: {db_path}") + return sqlite3.connect(db_path) + + +def table_exists(tablename: str, con: sqlite3.Connection): + """Check if a table exists in the database.""" + query = "SELECT 1 FROM sqlite_master WHERE type='table' AND name=? LIMIT 1" + return len(con.execute(query, [tablename]).fetchall()) > 0 + + +def main(): + """Main entry point for the transform pipeline.""" + logger.info("Starting transform pipeline") + + try: + con = setup_database_connection() + logger.info("Database connection established") + + # Check if posts table exists + if not table_exists('posts', con): + logger.warning("Posts table does not exist yet. Please run the scraper first to populate the database.") + logger.info("Transform pipeline skipped - no data available") + return + + # Import transform nodes + from author_node import AuthorNode + from base import TransformContext + import pandas as pd + + # Load posts data + logger.info("Loading posts from database") + sql = "SELECT id, author FROM posts WHERE author IS NOT NULL AND (is_cleaned IS NULL OR is_cleaned = 0) LIMIT ?" + MAX_CLEANED_POSTS = os.environ.get("MAX_CLEANED_POSTS", 500) + df = pd.read_sql(sql, con, params=[MAX_CLEANED_POSTS]) + logger.info(f"Loaded {len(df)} uncleaned posts with authors") + + if df.empty: + logger.info("No uncleaned posts found. Transform pipeline skipped.") + return + + # Create context and run author classification + context = TransformContext(df) + author_transform = AuthorNode(device=os.environ.get('COMPUTE_DEVICE', 'cpu')) # Change to "cuda" or "mps" if available + result_context = author_transform.run(con, context) + + # TODO: Create Node to compute Text Embeddings and UMAP. + # TODO: Create Node to pre-compute data based on visuals to reduce load time. + + logger.info("Transform pipeline completed successfully") + + except Exception as e: + logger.error(f"Error in transform pipeline: {e}", exc_info=True) + sys.exit(1) + finally: + if 'con' in locals(): + con.close() + logger.info("Database connection closed") + + +if __name__ == "__main__": + main() diff --git a/transform/requirements.txt b/transform/requirements.txt new file mode 100644 index 0000000..c95bd6d --- /dev/null +++ b/transform/requirements.txt @@ -0,0 +1,4 @@ +pandas +python-dotenv +gliner +torch From 72765532d39613b65cfec72441b381b8329c594b Mon Sep 17 00:00:00 2001 From: quorploop <> Date: Tue, 23 Dec 2025 17:53:37 +0100 Subject: [PATCH 3/4] Adds TransformNode to FuzzyFind Author Names --- docker-compose.yml | 16 ++ transform/Dockerfile | 25 ++- transform/README.md | 7 +- transform/author_node.py | 205 +++++++++++++++--- transform/ensure_gliner_model.sh | 16 ++ transform/entrypoint.sh | 8 + transform/example_node.py | 170 +++++++++++++++ transform/main.py | 35 ++- transform/pipeline.py | 258 +++++++++++++++++++++++ transform/requirements.txt | 1 + transform/{base.py => transform_node.py} | 13 +- 11 files changed, 696 insertions(+), 58 deletions(-) create mode 100644 transform/ensure_gliner_model.sh create mode 100644 transform/entrypoint.sh create mode 100644 transform/example_node.py create mode 100644 transform/pipeline.py rename transform/{base.py => transform_node.py} (70%) diff --git a/docker-compose.yml b/docker-compose.yml index 5c5c4e7..4ab3b8c 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -21,7 +21,23 @@ services: - transform/.env volumes: - knack_data:/data + - models:/models + restart: unless-stopped + + sqlitebrowser: + image: lscr.io/linuxserver/sqlitebrowser:latest + container_name: sqlitebrowser + environment: + - PUID=1000 + - PGID=1000 + - TZ=Etc/UTC + volumes: + - knack_data:/data + ports: + - "3000:3000" # noVNC web UI + - "3001:3001" # VNC server restart: unless-stopped volumes: knack_data: + models: diff --git a/transform/Dockerfile b/transform/Dockerfile index 4c72480..682af4f 100644 --- a/transform/Dockerfile +++ b/transform/Dockerfile @@ -1,7 +1,6 @@ FROM python:3.12-slim -RUN mkdir /app -RUN mkdir /data +RUN mkdir -p /app /data /models # Install build dependencies RUN apt-get update && apt-get install -y --no-install-recommends \ @@ -11,9 +10,12 @@ RUN apt-get update && apt-get install -y --no-install-recommends \ libopenblas-dev \ liblapack-dev \ pkg-config \ + curl \ + jq \ && rm -rf /var/lib/apt/lists/* -#COPY /data/knack.sqlite /data +ENV GLINER_MODEL_ID=urchade/gliner_multi-v2.1 +ENV GLINER_MODEL_PATH=/models/gliner_multi-v2.1 WORKDIR /app COPY requirements.txt . @@ -24,18 +26,21 @@ COPY .env . RUN apt update -y RUN apt install -y cron locales -COPY *.py . +# Ensure GLiNER helper scripts are available +COPY ensure_gliner_model.sh /usr/local/bin/ensure_gliner_model.sh +COPY entrypoint.sh /usr/local/bin/entrypoint.sh +RUN chmod +x /usr/local/bin/ensure_gliner_model.sh /usr/local/bin/entrypoint.sh -ENV PYTHONUNBUFFERED=1 -ENV LANG=de_DE.UTF-8 -ENV LC_ALL=de_DE.UTF-8 +COPY *.py . # Create cron job that runs every weekend (Sunday at 3 AM) 0 3 * * 0 # Testing every 30 Minutes */30 * * * * -RUN echo "0 3 * * 0 cd /app && /usr/local/bin/python main.py >> /proc/1/fd/1 2>&1" > /etc/cron.d/knack-transform +RUN echo "*/30 * * * * cd /app && /usr/local/bin/python main.py >> /proc/1/fd/1 2>&1" > /etc/cron.d/knack-transform RUN chmod 0644 /etc/cron.d/knack-transform RUN crontab /etc/cron.d/knack-transform -# Start cron in foreground -CMD ["cron", "-f"] +# Persist models between container runs +VOLUME /models + +CMD ["/usr/local/bin/entrypoint.sh"] #CMD ["python", "main.py"] diff --git a/transform/README.md b/transform/README.md index 44ddeb1..9e3665a 100644 --- a/transform/README.md +++ b/transform/README.md @@ -6,10 +6,15 @@ Data transformation pipeline for the Knack scraper project. This folder contains the transformation logic that processes data from the SQLite database. It runs on a scheduled basis (every weekend) via cron. +The pipeline supports **parallel execution** of independent transform nodes, allowing you to leverage multi-core processors for faster data transformation. + ## Structure - `base.py` - Abstract base class for transform nodes -- `main.py` - Main entry point and pipeline orchestration +- `pipeline.py` - Parallel pipeline orchestration system +- `main.py` - Main entry point and pipeline execution +- `author_node.py` - NER-based author classification node +- `example_node.py` - Template for creating new nodes - `Dockerfile` - Docker image configuration with cron setup - `requirements.txt` - Python dependencies diff --git a/transform/author_node.py b/transform/author_node.py index 23d3365..719a191 100644 --- a/transform/author_node.py +++ b/transform/author_node.py @@ -1,10 +1,13 @@ """Author classification transform node using NER.""" -from base import TransformNode, TransformContext +import os import sqlite3 import pandas as pd import logging +import fuzzysearch from concurrent.futures import ThreadPoolExecutor -from datetime import datetime + +from pipeline import TransformContext +from transform_node import TransformNode try: from gliner import GLiNER @@ -17,7 +20,7 @@ except ImportError: logger = logging.getLogger("knack-transform") -class AuthorNode(TransformNode): +class NerAuthorNode(TransformNode): """Transform node that extracts and classifies authors using NER. Creates two tables: @@ -25,7 +28,8 @@ class AuthorNode(TransformNode): - post_authors: maps posts to their authors """ - def __init__(self, model_name: str = "urchade/gliner_medium-v2.1", + def __init__(self, model_name: str = "urchade/gliner_multi-v2.1", + model_path: str = None, threshold: float = 0.5, max_workers: int = 64, device: str = "cpu"): @@ -33,11 +37,13 @@ class AuthorNode(TransformNode): Args: model_name: GLiNER model to use + model_path: Optional local path to a downloaded GLiNER model threshold: Confidence threshold for entity predictions max_workers: Number of parallel workers for prediction device: Device to run model on ('cpu', 'cuda', 'mps') """ self.model_name = model_name + self.model_path = model_path or os.environ.get('GLINER_MODEL_PATH') self.threshold = threshold self.max_workers = max_workers self.device = device @@ -49,21 +55,31 @@ class AuthorNode(TransformNode): if not GLINER_AVAILABLE: raise ImportError("GLiNER is required for AuthorNode. Install with: pip install gliner") - logger.info(f"Loading GLiNER model: {self.model_name}") + model_source = None + if self.model_path: + if os.path.exists(self.model_path): + model_source = self.model_path + logger.info(f"Loading GLiNER model from local path: {self.model_path}") + else: + logger.warning(f"GLINER_MODEL_PATH '{self.model_path}' not found; falling back to hub model {self.model_name}") + + if model_source is None: + model_source = self.model_name + logger.info(f"Loading GLiNER model from hub: {self.model_name}") if self.device == "cuda" and torch.cuda.is_available(): self.model = GLiNER.from_pretrained( - self.model_name, + model_source, max_length=255 ).to('cuda', dtype=torch.float16) elif self.device == "mps" and torch.backends.mps.is_available(): self.model = GLiNER.from_pretrained( - self.model_name, + model_source, max_length=255 ).to('mps', dtype=torch.float16) else: self.model = GLiNER.from_pretrained( - self.model_name, + model_source, max_length=255 ) @@ -208,13 +224,6 @@ class AuthorNode(TransformNode): logger.info(f"Creating {len(mappings_df)} post-author mappings") mappings_df.to_sql('post_authors', con, if_exists='append', index=False) - - # Mark posts as cleaned - processed_post_ids = mappings_df['post_id'].unique().tolist() - if processed_post_ids: - placeholders = ','.join('?' * len(processed_post_ids)) - con.execute(f"UPDATE posts SET is_cleaned = 1 WHERE id IN ({placeholders})", processed_post_ids) - logger.info(f"Marked {len(processed_post_ids)} posts as cleaned") con.commit() logger.info("Authors and mappings stored successfully") @@ -247,17 +256,165 @@ class AuthorNode(TransformNode): # Store results self._store_authors(con, results) - # Mark posts without author entities as cleaned too (no authors found) - processed_ids = set([r['id'] for r in results]) if results else set() - unprocessed_ids = [pid for pid in posts_df['id'].tolist() if pid not in processed_ids] - if unprocessed_ids: - placeholders = ','.join('?' * len(unprocessed_ids)) - con.execute(f"UPDATE posts SET is_cleaned = 1 WHERE id IN ({placeholders})", unprocessed_ids) - con.commit() - logger.info(f"Marked {len(unprocessed_ids)} posts without author entities as cleaned") - # Return context with results results_df = pd.DataFrame(results) if results else pd.DataFrame() logger.info("AuthorNode transformation complete") return TransformContext(results_df) + + +class FuzzyAuthorNode(TransformNode): + """FuzzyAuthorNode + + This Node takes in data and rules of authornames that have been classified already + and uses those 'rule' to find more similar fields. + """ + + def __init__(self, + max_l_dist: int = 1,): + """Initialize FuzzyAuthorNode. + + Args: + max_l_dist: The number of 'errors' that are allowed by the fuzzy search algorithm + """ + self.max_l_dist = max_l_dist + logger.info(f"Initialized FuzzyAuthorNode with max_l_dist={max_l_dist}") + + def _process_data(self, con: sqlite3.Connection, df: pd.DataFrame) -> pd.DataFrame: + """Process the input dataframe. + + This is where your main transformation logic goes. + + Args: + con: Database connection + df: Input dataframe from context + + Returns: + Processed dataframe + """ + logger.info(f"Processing {len(df)} rows") + + # Retrieve all known authors from the authors table as 'rules' + authors_df = pd.read_sql("SELECT id, name FROM authors", con) + + if authors_df.empty: + logger.warning("No authors found in database for fuzzy matching") + return pd.DataFrame(columns=['post_id', 'author_id']) + + # Get existing post-author mappings to avoid duplicates + existing_mappings = pd.read_sql( + "SELECT post_id, author_id FROM post_authors", con + ) + existing_post_ids = set(existing_mappings['post_id'].unique()) + + logger.info(f"Found {len(authors_df)} known authors for fuzzy matching") + logger.info(f"Found {len(existing_post_ids)} posts with existing author mappings") + + # Filter to posts without author mappings and with non-null author field + if 'author' not in df.columns or 'id' not in df.columns: + logger.warning("Missing 'author' or 'id' column in input dataframe") + return pd.DataFrame(columns=['post_id', 'author_id']) + + posts_to_process = df[ + (df['id'].notna()) & + (df['author'].notna()) & + (~df['id'].isin(existing_post_ids)) + ] + + logger.info(f"Processing {len(posts_to_process)} posts for fuzzy matching") + + # Perform fuzzy matching + mappings = [] + for _, post_row in posts_to_process.iterrows(): + post_id = post_row['id'] + post_author = str(post_row['author']) + + # Try to find matches against all known author names + for _, author_row in authors_df.iterrows(): + author_id = author_row['id'] + author_name = str(author_row['name']) + + # Use fuzzysearch to find matches with allowed errors + matches = fuzzysearch.find_near_matches( + author_name, + post_author, + max_l_dist=self.max_l_dist + ) + + if matches: + logger.debug(f"Found fuzzy match: '{author_name}' in '{post_author}' for post {post_id}") + mappings.append({ + 'post_id': post_id, + 'author_id': author_id + }) + # Only take the first match per post to avoid multiple mappings + break + + # Create result dataframe + result_df = pd.DataFrame(mappings, columns=['post_id', 'author_id']) if mappings else pd.DataFrame(columns=['post_id', 'author_id']) + + logger.info(f"Processing complete. Found {len(result_df)} fuzzy matches") + return result_df + + def _store_results(self, con: sqlite3.Connection, df: pd.DataFrame): + """Store results back to the database. + + Uses INSERT OR IGNORE to avoid inserting duplicates. + + Args: + con: Database connection + df: Processed dataframe to store + """ + if df.empty: + logger.info("No results to store") + return + + logger.info(f"Storing {len(df)} results") + + # Use INSERT OR IGNORE to handle duplicates (respects PRIMARY KEY constraint) + cursor = con.cursor() + inserted_count = 0 + + for _, row in df.iterrows(): + cursor.execute( + "INSERT OR IGNORE INTO post_authors (post_id, author_id) VALUES (?, ?)", + (int(row['post_id']), int(row['author_id'])) + ) + if cursor.rowcount > 0: + inserted_count += 1 + + con.commit() + logger.info(f"Results stored successfully. Inserted {inserted_count} new mappings, skipped {len(df) - inserted_count} duplicates") + + def run(self, con: sqlite3.Connection, context: TransformContext) -> TransformContext: + """Execute the transformation. + + This is the main entry point called by the pipeline. + + Args: + con: SQLite database connection + context: TransformContext containing input dataframe + + Returns: + TransformContext with processed dataframe + """ + logger.info("Starting FuzzyAuthorNode transformation") + + # Get input dataframe from context + input_df = context.get_dataframe() + + # Validate input + if input_df.empty: + logger.warning("Empty dataframe provided to FuzzyAuthorNode") + return context + + # Process the data + result_df = self._process_data(con, input_df) + + # Store results + self._store_results(con, result_df) + + logger.info("FuzzyAuthorNode transformation complete") + + # Return new context with results + return TransformContext(result_df) diff --git a/transform/ensure_gliner_model.sh b/transform/ensure_gliner_model.sh new file mode 100644 index 0000000..4df8215 --- /dev/null +++ b/transform/ensure_gliner_model.sh @@ -0,0 +1,16 @@ +#!/usr/bin/env bash +set -euo pipefail + +if [ -d "$GLINER_MODEL_PATH" ] && find "$GLINER_MODEL_PATH" -type f | grep -q .; then + echo "GLiNER model already present at $GLINER_MODEL_PATH" + exit 0 +fi + +echo "Downloading GLiNER model to $GLINER_MODEL_PATH" +mkdir -p "$GLINER_MODEL_PATH" +curl -sL "https://huggingface.co/api/models/${GLINER_MODEL_ID}" | jq -r '.siblings[].rfilename' | while read -r file; do + target="${GLINER_MODEL_PATH}/${file}" + mkdir -p "$(dirname "$target")" + echo "Downloading ${file}" + curl -sL "https://huggingface.co/${GLINER_MODEL_ID}/resolve/main/${file}" -o "$target" +done diff --git a/transform/entrypoint.sh b/transform/entrypoint.sh new file mode 100644 index 0000000..8beab84 --- /dev/null +++ b/transform/entrypoint.sh @@ -0,0 +1,8 @@ +#!/usr/bin/env bash +set -euo pipefail + +# Run model download with output to stdout/stderr +/usr/local/bin/ensure_gliner_model.sh 2>&1 + +# Start cron in foreground with logging +exec cron -f -L 2 diff --git a/transform/example_node.py b/transform/example_node.py new file mode 100644 index 0000000..69900d1 --- /dev/null +++ b/transform/example_node.py @@ -0,0 +1,170 @@ +"""Example template node for the transform pipeline. + +This is a template showing how to create new transform nodes. +Copy this file and modify it for your specific transformation needs. +""" +from pipeline import TransformContext +from transform_node import TransformNode +import sqlite3 +import pandas as pd +import logging + +logger = logging.getLogger("knack-transform") + + +class ExampleNode(TransformNode): + """Example transform node template. + + This node demonstrates the basic structure for creating + new transformation nodes in the pipeline. + """ + + def __init__(self, + param1: str = "default_value", + param2: int = 42, + device: str = "cpu"): + """Initialize the ExampleNode. + + Args: + param1: Example string parameter + param2: Example integer parameter + device: Device to use for computations ('cpu', 'cuda', 'mps') + """ + self.param1 = param1 + self.param2 = param2 + self.device = device + logger.info(f"Initialized ExampleNode with param1={param1}, param2={param2}") + + def _create_tables(self, con: sqlite3.Connection): + """Create any necessary tables in the database. + + This is optional - only needed if your node creates new tables. + """ + logger.info("Creating example tables") + + con.execute(""" + CREATE TABLE IF NOT EXISTS example_results ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + post_id INTEGER, + result_value TEXT, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + FOREIGN KEY (post_id) REFERENCES posts(id) + ) + """) + + con.commit() + + def _process_data(self, df: pd.DataFrame) -> pd.DataFrame: + """Process the input dataframe. + + This is where your main transformation logic goes. + + Args: + df: Input dataframe from context + + Returns: + Processed dataframe + """ + logger.info(f"Processing {len(df)} rows") + + # Example: Add a new column based on existing data + result_df = df.copy() + result_df['processed'] = True + result_df['example_value'] = result_df['id'].apply(lambda x: f"{self.param1}_{x}") + + logger.info("Processing complete") + return result_df + + def _store_results(self, con: sqlite3.Connection, df: pd.DataFrame): + """Store results back to the database. + + This is optional - only needed if you want to persist results. + + Args: + con: Database connection + df: Processed dataframe to store + """ + if df.empty: + logger.info("No results to store") + return + + logger.info(f"Storing {len(df)} results") + + # Example: Store to database + # df[['post_id', 'result_value']].to_sql( + # 'example_results', + # con, + # if_exists='append', + # index=False + # ) + + con.commit() + logger.info("Results stored successfully") + + def run(self, con: sqlite3.Connection, context: TransformContext) -> TransformContext: + """Execute the transformation. + + This is the main entry point called by the pipeline. + + Args: + con: SQLite database connection + context: TransformContext containing input dataframe + + Returns: + TransformContext with processed dataframe + """ + logger.info("Starting ExampleNode transformation") + + # Get input dataframe from context + input_df = context.get_dataframe() + + # Validate input + if input_df.empty: + logger.warning("Empty dataframe provided to ExampleNode") + return context + + # Create any necessary tables + self._create_tables(con) + + # Process the data + result_df = self._process_data(input_df) + + # Store results (optional) + self._store_results(con, result_df) + + logger.info("ExampleNode transformation complete") + + # Return new context with results + return TransformContext(result_df) + + +# Example usage: +if __name__ == "__main__": + # This allows you to test your node independently + import os + os.chdir('/Users/linussilberstein/Documents/Knack-Scraper/transform') + + from pipeline import TransformContext + import sqlite3 + + # Create test data + test_df = pd.DataFrame({ + 'id': [1, 2, 3], + 'author': ['Test Author 1', 'Test Author 2', 'Test Author 3'] + }) + + # Create test database connection + test_con = sqlite3.connect(':memory:') + + # Create and run node + node = ExampleNode(param1="test", param2=100) + context = TransformContext(test_df) + result_context = node.run(test_con, context) + + # Check results + result_df = result_context.get_dataframe() + print("\nResult DataFrame:") + print(result_df) + + test_con.close() + print("\n✓ ExampleNode test completed successfully!") diff --git a/transform/main.py b/transform/main.py index 29b9a38..d07d905 100644 --- a/transform/main.py +++ b/transform/main.py @@ -50,15 +50,14 @@ def main(): logger.info("Transform pipeline skipped - no data available") return - # Import transform nodes - from author_node import AuthorNode - from base import TransformContext + # Import transform components + from pipeline import create_default_pipeline, TransformContext import pandas as pd # Load posts data logger.info("Loading posts from database") sql = "SELECT id, author FROM posts WHERE author IS NOT NULL AND (is_cleaned IS NULL OR is_cleaned = 0) LIMIT ?" - MAX_CLEANED_POSTS = os.environ.get("MAX_CLEANED_POSTS", 500) + MAX_CLEANED_POSTS = os.environ.get("MAX_CLEANED_POSTS", 100) df = pd.read_sql(sql, con, params=[MAX_CLEANED_POSTS]) logger.info(f"Loaded {len(df)} uncleaned posts with authors") @@ -66,15 +65,29 @@ def main(): logger.info("No uncleaned posts found. Transform pipeline skipped.") return - # Create context and run author classification + # Create initial context context = TransformContext(df) - author_transform = AuthorNode(device=os.environ.get('COMPUTE_DEVICE', 'cpu')) # Change to "cuda" or "mps" if available - result_context = author_transform.run(con, context) - - # TODO: Create Node to compute Text Embeddings and UMAP. - # TODO: Create Node to pre-compute data based on visuals to reduce load time. - logger.info("Transform pipeline completed successfully") + # Create and run parallel pipeline + device = os.environ.get('COMPUTE_DEVICE', 'cpu') + max_workers = int(os.environ.get('MAX_WORKERS', 4)) + + pipeline = create_default_pipeline(device=device, max_workers=max_workers) + results = pipeline.run( + db_path=os.environ.get('DB_PATH', '/data/knack.sqlite'), + initial_context=context, + fail_fast=False # Continue even if some nodes fail + ) + + logger.info(f"Pipeline completed. Processed {len(results)} node(s)") + + # Mark all processed posts as cleaned + post_ids = df['id'].tolist() + if post_ids: + placeholders = ','.join('?' * len(post_ids)) + con.execute(f"UPDATE posts SET is_cleaned = 1 WHERE id IN ({placeholders})", post_ids) + con.commit() + logger.info(f"Marked {len(post_ids)} posts as cleaned") except Exception as e: logger.error(f"Error in transform pipeline: {e}", exc_info=True) diff --git a/transform/pipeline.py b/transform/pipeline.py new file mode 100644 index 0000000..1a97f1f --- /dev/null +++ b/transform/pipeline.py @@ -0,0 +1,258 @@ +"""Parallel pipeline orchestration for transform nodes.""" +import logging +import os +import sqlite3 +from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor, as_completed +from typing import List, Dict, Optional + +import pandas as pd +import multiprocessing as mp + +logger = logging.getLogger("knack-transform") + +class TransformContext: + """Context object containing the dataframe for transformation.""" + + def __init__(self, df: pd.DataFrame): + self.df = df + + def get_dataframe(self) -> pd.DataFrame: + """Get the pandas dataframe from the context.""" + return self.df + +class NodeConfig: + """Configuration for a transform node.""" + + def __init__(self, + node_class: type, + node_kwargs: Dict = None, + dependencies: List[str] = None, + name: str = None): + """Initialize node configuration. + + Args: + node_class: The TransformNode class to instantiate + node_kwargs: Keyword arguments to pass to node constructor + dependencies: List of node names that must complete before this one + name: Optional name for the node (defaults to class name) + """ + self.node_class = node_class + self.node_kwargs = node_kwargs or {} + self.dependencies = dependencies or [] + self.name = name or node_class.__name__ + +class ParallelPipeline: + """Pipeline for executing transform nodes in parallel where possible. + + The pipeline analyzes dependencies between nodes and executes + independent nodes concurrently using multiprocessing or threading. + """ + + def __init__(self, + max_workers: Optional[int] = None, + use_processes: bool = False): + """Initialize the parallel pipeline. + + Args: + max_workers: Maximum number of parallel workers (defaults to CPU count) + use_processes: If True, use ProcessPoolExecutor; if False, use ThreadPoolExecutor + """ + self.max_workers = max_workers or mp.cpu_count() + self.use_processes = use_processes + self.nodes: Dict[str, NodeConfig] = {} + logger.info(f"Initialized ParallelPipeline with {self.max_workers} workers " + f"({'processes' if use_processes else 'threads'})") + + def add_node(self, config: NodeConfig): + """Add a node to the pipeline. + + Args: + config: NodeConfig with node details and dependencies + """ + self.nodes[config.name] = config + logger.info(f"Added node '{config.name}' with dependencies: {config.dependencies}") + + def _get_execution_stages(self) -> List[List[str]]: + """Determine execution stages based on dependencies. + + Returns: + List of stages, where each stage contains node names that can run in parallel + """ + stages = [] + completed = set() + remaining = set(self.nodes.keys()) + + while remaining: + # Find nodes whose dependencies are all completed + ready = [] + for node_name in remaining: + config = self.nodes[node_name] + if all(dep in completed for dep in config.dependencies): + ready.append(node_name) + + if not ready: + # Circular dependency or missing dependency + raise ValueError(f"Cannot resolve dependencies. Remaining nodes: {remaining}") + + stages.append(ready) + completed.update(ready) + remaining -= set(ready) + + return stages + + def _execute_node(self, + node_name: str, + db_path: str, + context: TransformContext) -> tuple: + """Execute a single node. + + Args: + node_name: Name of the node to execute + db_path: Path to the SQLite database + context: TransformContext for the node + + Returns: + Tuple of (node_name, result_context, error) + """ + try: + # Create fresh database connection (not shared across processes/threads) + con = sqlite3.connect(db_path) + + config = self.nodes[node_name] + node = config.node_class(**config.node_kwargs) + + logger.info(f"Executing node: {node_name}") + result_context = node.run(con, context) + + con.close() + logger.info(f"Node '{node_name}' completed successfully") + + return node_name, result_context, None + + except Exception as e: + logger.error(f"Error executing node '{node_name}': {e}", exc_info=True) + return node_name, None, str(e) + + def run(self, + db_path: str, + initial_context: TransformContext, + fail_fast: bool = False) -> Dict[str, TransformContext]: + """Execute the pipeline. + + Args: + db_path: Path to the SQLite database + initial_context: Initial TransformContext for the pipeline + fail_fast: If True, stop execution on first error + + Returns: + Dict mapping node names to their output TransformContext + """ + logger.info("Starting parallel pipeline execution") + + stages = self._get_execution_stages() + logger.info(f"Pipeline has {len(stages)} execution stage(s)") + + results = {} + contexts = {None: initial_context} # Track contexts from each node + errors = [] + + ExecutorClass = ProcessPoolExecutor if self.use_processes else ThreadPoolExecutor + + for stage_num, stage_nodes in enumerate(stages, 1): + logger.info(f"Stage {stage_num}/{len(stages)}: Executing {len(stage_nodes)} node(s) in parallel: {stage_nodes}") + + # For nodes in this stage, use the context from their dependencies + # If multiple dependencies, we'll use the most recent one (or could merge) + stage_futures = {} + + with ExecutorClass(max_workers=min(self.max_workers, len(stage_nodes))) as executor: + for node_name in stage_nodes: + config = self.nodes[node_name] + + # Get context from dependencies (use the last dependency's output) + if config.dependencies: + context = results.get(config.dependencies[-1], initial_context) + else: + context = initial_context + + future = executor.submit(self._execute_node, node_name, db_path, context) + stage_futures[future] = node_name + + # Wait for all nodes in this stage to complete + for future in as_completed(stage_futures): + node_name = stage_futures[future] + name, result_context, error = future.result() + + if error: + errors.append((name, error)) + if fail_fast: + logger.error(f"Pipeline failed at node '{name}': {error}") + raise RuntimeError(f"Node '{name}' failed: {error}") + else: + results[name] = result_context + + if errors: + logger.warning(f"Pipeline completed with {len(errors)} error(s)") + for name, error in errors: + logger.error(f" - {name}: {error}") + else: + logger.info("Pipeline completed successfully") + + return results + + +def create_default_pipeline(device: str = "cpu", + max_workers: Optional[int] = None) -> ParallelPipeline: + """Create a pipeline with default transform nodes. + + Args: + device: Device to use for compute-intensive nodes ('cpu', 'cuda', 'mps') + max_workers: Maximum number of parallel workers + + Returns: + Configured ParallelPipeline + """ + from author_node import NerAuthorNode, FuzzyAuthorNode + + pipeline = ParallelPipeline(max_workers=max_workers, use_processes=False) + + # Add AuthorNode (no dependencies) + pipeline.add_node(NodeConfig( + node_class=NerAuthorNode, + node_kwargs={ + 'device': device, + 'model_path': os.environ.get('GLINER_MODEL_PATH') + }, + dependencies=[], + name='AuthorNode' + )) + + pipeline.add_node(NodeConfig( + node_class=FuzzyAuthorNode, + node_kwargs={ + 'max_l_dist': 1 + }, + dependencies=['AuthorNode'], + name='FuzzyAuthorNode' + )) + + # TODO: Create Node to compute Text Embeddings and UMAP. + # TODO: Create Node to pre-compute data based on visuals to reduce load time. + + # TODO: Add more nodes here as they are implemented + # Example: + # pipeline.add_node(NodeConfig( + # node_class=EmbeddingNode, + # node_kwargs={'device': device}, + # dependencies=[], # Runs after AuthorNode + # name='EmbeddingNode' + # )) + + # pipeline.add_node(NodeConfig( + # node_class=UMAPNode, + # node_kwargs={'device': device}, + # dependencies=['EmbeddingNode'], # Runs after EmbeddingNode + # name='UMAPNode' + # )) + + return pipeline diff --git a/transform/requirements.txt b/transform/requirements.txt index c95bd6d..e210d05 100644 --- a/transform/requirements.txt +++ b/transform/requirements.txt @@ -2,3 +2,4 @@ pandas python-dotenv gliner torch +fuzzysearch \ No newline at end of file diff --git a/transform/base.py b/transform/transform_node.py similarity index 70% rename from transform/base.py rename to transform/transform_node.py index 59a4f31..54e6bed 100644 --- a/transform/base.py +++ b/transform/transform_node.py @@ -1,19 +1,8 @@ """Base transform node for data pipeline.""" from abc import ABC, abstractmethod import sqlite3 -import pandas as pd - - -class TransformContext: - """Context object containing the dataframe for transformation.""" - - def __init__(self, df: pd.DataFrame): - self.df = df - - def get_dataframe(self) -> pd.DataFrame: - """Get the pandas dataframe from the context.""" - return self.df +from pipeline import TransformContext class TransformNode(ABC): """Abstract base class for transformation nodes. From 49239e7e25dbf0b409c8dc964bf536e28ba7af95 Mon Sep 17 00:00:00 2001 From: quorploop <> Date: Wed, 24 Dec 2025 17:58:23 +0100 Subject: [PATCH 4/4] Implement Nodes to compute text embeddings --- scrape/main.py | 4 +- transform/Dockerfile | 7 +- transform/author_node.py | 16 +- transform/embeddings_node.py | 445 +++++++++++++++++++++++++++++++ transform/ensure_minilm_model.sh | 16 ++ transform/entrypoint.sh | 2 + transform/main.py | 6 +- transform/pipeline.py | 30 ++- transform/requirements.txt | 4 +- 9 files changed, 505 insertions(+), 25 deletions(-) create mode 100644 transform/embeddings_node.py create mode 100644 transform/ensure_minilm_model.sh diff --git a/scrape/main.py b/scrape/main.py index 15f0e72..10b66dd 100755 --- a/scrape/main.py +++ b/scrape/main.py @@ -227,7 +227,9 @@ def main(): num_threads=num_threads, ) - postdf.to_sql("posts", con, if_exists="append") + # Drop category and tags columns as they're stored in separate tables + postdf = postdf.drop(columns=['category', 'tags']) + postdf.to_sql("posts", con, if_exists="append", index=False) # Tags tag_dim, tag_map = build_dimension_and_mapping(postdf, 'tags', 'tag') diff --git a/transform/Dockerfile b/transform/Dockerfile index 682af4f..6f148bd 100644 --- a/transform/Dockerfile +++ b/transform/Dockerfile @@ -17,6 +17,9 @@ RUN apt-get update && apt-get install -y --no-install-recommends \ ENV GLINER_MODEL_ID=urchade/gliner_multi-v2.1 ENV GLINER_MODEL_PATH=/models/gliner_multi-v2.1 +ENV MINILM_MODEL_ID=sentence-transformers/all-MiniLM-L6-v2 +ENV MINILM_MODEL_PATH=/models/all-MiniLM-L6-v2 + WORKDIR /app COPY requirements.txt . RUN pip install --no-cache-dir -r requirements.txt @@ -28,8 +31,10 @@ RUN apt install -y cron locales # Ensure GLiNER helper scripts are available COPY ensure_gliner_model.sh /usr/local/bin/ensure_gliner_model.sh +# Ensure MiniLM helper scripts are available +COPY ensure_minilm_model.sh /usr/local/bin/ensure_minilm_model.sh COPY entrypoint.sh /usr/local/bin/entrypoint.sh -RUN chmod +x /usr/local/bin/ensure_gliner_model.sh /usr/local/bin/entrypoint.sh +RUN chmod +x /usr/local/bin/ensure_gliner_model.sh /usr/local/bin/ensure_minilm_model.sh /usr/local/bin/entrypoint.sh COPY *.py . diff --git a/transform/author_node.py b/transform/author_node.py index 719a191..845e87a 100644 --- a/transform/author_node.py +++ b/transform/author_node.py @@ -9,6 +9,8 @@ from concurrent.futures import ThreadPoolExecutor from pipeline import TransformContext from transform_node import TransformNode +logger = logging.getLogger("knack-transform") + try: from gliner import GLiNER import torch @@ -17,9 +19,6 @@ except ImportError: GLINER_AVAILABLE = False logging.warning("GLiNER not available. Install with: pip install gliner") -logger = logging.getLogger("knack-transform") - - class NerAuthorNode(TransformNode): """Transform node that extracts and classifies authors using NER. @@ -257,10 +256,9 @@ class NerAuthorNode(TransformNode): self._store_authors(con, results) # Return context with results - results_df = pd.DataFrame(results) if results else pd.DataFrame() logger.info("AuthorNode transformation complete") - return TransformContext(results_df) + return TransformContext(posts_df) class FuzzyAuthorNode(TransformNode): @@ -309,7 +307,7 @@ class FuzzyAuthorNode(TransformNode): logger.info(f"Found {len(authors_df)} known authors for fuzzy matching") logger.info(f"Found {len(existing_post_ids)} posts with existing author mappings") - + # Filter to posts without author mappings and with non-null author field if 'author' not in df.columns or 'id' not in df.columns: logger.warning("Missing 'author' or 'id' column in input dataframe") @@ -333,12 +331,14 @@ class FuzzyAuthorNode(TransformNode): for _, author_row in authors_df.iterrows(): author_id = author_row['id'] author_name = str(author_row['name']) + # for author names < than 2 characters I want a fault tolerance of 0! + l_dist = self.max_l_dist if len(author_name) > 2 else 0 # Use fuzzysearch to find matches with allowed errors matches = fuzzysearch.find_near_matches( author_name, post_author, - max_l_dist=self.max_l_dist + max_l_dist=l_dist, ) if matches: @@ -417,4 +417,4 @@ class FuzzyAuthorNode(TransformNode): logger.info("FuzzyAuthorNode transformation complete") # Return new context with results - return TransformContext(result_df) + return TransformContext(input_df) diff --git a/transform/embeddings_node.py b/transform/embeddings_node.py new file mode 100644 index 0000000..9821eca --- /dev/null +++ b/transform/embeddings_node.py @@ -0,0 +1,445 @@ +"""Classes of Transformernodes that have to do with +text processing. + +- TextEmbeddingNode calculates text embeddings +- UmapNode calculates xy coordinates on those vector embeddings +- SimilarityNode calculates top n similar posts based on those embeddings + using the spectral distance. +""" +from pipeline import TransformContext +from transform_node import TransformNode +import sqlite3 +import pandas as pd +import logging +import os +import numpy as np + +logger = logging.getLogger("knack-transform") + +try: + from sentence_transformers import SentenceTransformer + import torch + MINILM_AVAILABLE = True +except ImportError: + MINILM_AVAILABLE = False + logging.warning("MiniLM not available. Install with pip!") + +try: + import umap + UMAP_AVAILABLE = True +except ImportError: + UMAP_AVAILABLE = False + logging.warning("UMAP not available. Install with pip install umap-learn!") + +class TextEmbeddingNode(TransformNode): + """Calculates vector embeddings based on a dataframe + of posts. + """ + def __init__(self, + model_name: str = "sentence-transformers/all-MiniLM-L6-v2", + model_path: str = None, + device: str = "cpu"): + """Initialize the ExampleNode. + + Args: + model_name: Name of the ML Model to calculate text embeddings + model_path: Optional local path to a downloaded embedding model + device: Device to use for computations ('cpu', 'cuda', 'mps') + """ + self.model_name = model_name + self.model_path = model_path or os.environ.get('MINILM_MODEL_PATH') + self.device = device + self.model = None + logger.info(f"Initialized TextEmbeddingNode with model_name={model_name}, model_path={model_path}, device={device}") + + def _setup_model(self): + """Init the Text Embedding Model.""" + if not MINILM_AVAILABLE: + raise ImportError("MiniLM is required for TextEmbeddingNode. Please install.") + + model_source = None + if self.model_path: + if os.path.exists(self.model_path): + model_source = self.model_path + logger.info(f"Loading MiniLM model from local path: {self.model_path}") + else: + logger.warning(f"MiniLM_MODEL_PATH '{self.model_path}' not found; Falling back to hub model {self.model_name}") + + if model_source is None: + model_source = self.model_name + logger.info(f"Loading MiniLM model from the hub: {self.model_name}") + + if self.device == "cuda" and torch.cuda.is_available(): + self.model = SentenceTransformer(model_source).to('cuda', dtype=torch.float16) + elif self.device == "mps" and torch.backends.mps.is_available(): + self.model = SentenceTransformer(model_source).to('mps', dtype=torch.float16) + else: + self.model = SentenceTransformer(model_source) + + def _process_data(self, df: pd.DataFrame) -> pd.DataFrame: + """Process the input dataframe. + + Calculates an embedding as a np.array. + Also pickles that array to prepare it to + storage in the database. + + Args: + df: Input dataframe from context + + Returns: + Processed dataframe + """ + logger.info(f"Processing {len(df)} rows") + + if self.model is None: + self._setup_model() + + # Example: Add a new column based on existing data + result_df = df.copy() + + df['embedding'] = df['text'].apply(lambda x: self.model.encode(x, convert_to_numpy=True)) + + logger.info("Processing complete") + return result_df + + def _store_results(self, con: sqlite3.Connection, df: pd.DataFrame): + """Store results back to the database using batch updates.""" + if df.empty: + logger.info("No results to store") + return + + logger.info(f"Storing {len(df)} results") + + # Convert numpy arrays to bytes for BLOB storage + # Use tobytes() to serialize numpy arrays efficiently + updates = [(row['embedding'].tobytes(), row['id']) for _, row in df.iterrows()] + con.executemany( + "UPDATE posts SET embedding = ? WHERE id = ?", + updates + ) + + con.commit() + logger.info("Results stored successfully") + + def run(self, con: sqlite3.Connection, context: TransformContext) -> TransformContext: + """Execute the transformation. + + This is the main entry point called by the pipeline. + + Args: + con: SQLite database connection + context: TransformContext containing input dataframe + + Returns: + TransformContext with processed dataframe + """ + logger.info("Starting TextEmbeddingNode transformation") + + # Get input dataframe from context + input_df = context.get_dataframe() + + # Validate input + if input_df.empty: + logger.warning("Empty dataframe provided to TextEmbeddingNdode") + return context + + if 'text' not in input_df.columns: + logger.warning("No 'text' column in context dataframe. Skipping TextEmbeddingNode") + return context + + # Process the data + result_df = self._process_data(input_df) + + # Store results (optional) + self._store_results(con, result_df) + + logger.info("TextEmbeddingNode transformation complete") + + # Return new context with results + return TransformContext(result_df) + + +class UmapNode(TransformNode): + """Calculates 2D coordinates from embeddings using UMAP dimensionality reduction. + + This node takes text embeddings and reduces them to 2D coordinates + for visualization purposes. + """ + + def __init__(self, + n_neighbors: int = 15, + min_dist: float = 0.1, + n_components: int = 2, + metric: str = "cosine", + random_state: int = 42): + """Initialize the UmapNode. + + Args: + n_neighbors: Number of neighbors to consider for UMAP (default: 15) + min_dist: Minimum distance between points in low-dimensional space (default: 0.1) + n_components: Number of dimensions to reduce to (default: 2) + metric: Distance metric to use (default: 'cosine') + random_state: Random seed for reproducibility (default: 42) + """ + self.n_neighbors = n_neighbors + self.min_dist = min_dist + self.n_components = n_components + self.metric = metric + self.random_state = random_state + self.reducer = None + logger.info(f"Initialized UmapNode with n_neighbors={n_neighbors}, min_dist={min_dist}, " + f"n_components={n_components}, metric={metric}, random_state={random_state}") + + def _process_data(self, df: pd.DataFrame) -> pd.DataFrame: + """Process the input dataframe. + + Retrieves embeddings from BLOB storage, converts them back to numpy arrays, + and applies UMAP dimensionality reduction to create 2D coordinates. + + Args: + df: Input dataframe from context + + Returns: + Processed dataframe with umap_x and umap_y columns + """ + logger.info(f"Processing {len(df)} rows") + + if not UMAP_AVAILABLE: + raise ImportError("UMAP is required for UmapNode. Install with: pip install umap-learn") + + result_df = df.copy() + + # Convert BLOB embeddings back to numpy arrays + if 'embedding' not in result_df.columns: + logger.error("No 'embedding' column found in dataframe") + raise ValueError("Input dataframe must contain 'embedding' column") + + logger.info("Converting embeddings from BLOB to numpy arrays") + result_df['embedding'] = result_df['embedding'].apply( + lambda x: np.frombuffer(x, dtype=np.float32) if x is not None else None + ) + + # Filter out rows with None embeddings + valid_rows = result_df['embedding'].notna() + if not valid_rows.any(): + logger.error("No valid embeddings found in dataframe") + raise ValueError("No valid embeddings to process") + + logger.info(f"Found {valid_rows.sum()} valid embeddings out of {len(result_df)} rows") + + # Stack embeddings into a matrix + embeddings_matrix = np.vstack(result_df.loc[valid_rows, 'embedding'].values) + logger.info(f"Embeddings matrix shape: {embeddings_matrix.shape}") + + # Apply UMAP + logger.info("Fitting UMAP reducer...") + self.reducer = umap.UMAP( + n_neighbors=self.n_neighbors, + min_dist=self.min_dist, + n_components=self.n_components, + metric=self.metric, + random_state=self.random_state + ) + + umap_coords = self.reducer.fit_transform(embeddings_matrix) + logger.info(f"UMAP transformation complete. Output shape: {umap_coords.shape}") + + # Add UMAP coordinates to dataframe + result_df.loc[valid_rows, 'umap_x'] = umap_coords[:, 0] + result_df.loc[valid_rows, 'umap_y'] = umap_coords[:, 1] + + # Fill NaN for invalid rows + result_df['umap_x'] = result_df['umap_x'].fillna(None) + result_df['umap_y'] = result_df['umap_y'].fillna(None) + + logger.info("Processing complete") + return result_df + + def _store_results(self, con: sqlite3.Connection, df: pd.DataFrame): + """Store UMAP coordinates back to the database. + + Args: + con: Database connection + df: Processed dataframe with umap_x and umap_y columns + """ + if df.empty: + logger.info("No results to store") + return + + logger.info(f"Storing {len(df)} results") + + # Batch update UMAP coordinates + updates = [ + (row['umap_x'], row['umap_y'], row['id']) + for _, row in df.iterrows() + if pd.notna(row.get('umap_x')) and pd.notna(row.get('umap_y')) + ] + + if updates: + con.executemany( + "UPDATE posts SET umap_x = ?, umap_y = ? WHERE id = ?", + updates + ) + con.commit() + logger.info(f"Stored {len(updates)} UMAP coordinate pairs successfully") + else: + logger.warning("No valid UMAP coordinates to store") + + def run(self, con: sqlite3.Connection, context: TransformContext) -> TransformContext: + """Execute the transformation. + + This is the main entry point called by the pipeline. + + Args: + con: SQLite database connection + context: TransformContext containing input dataframe + + Returns: + TransformContext with processed dataframe + """ + logger.info("Starting ExampleNode transformation") + + # Get input dataframe from context + input_df = context.get_dataframe() + + # Validate input + if input_df.empty: + logger.warning("Empty dataframe provided to ExampleNode") + return context + + # Process the data + result_df = self._process_data(input_df) + + # Store results (optional) + self._store_results(con, result_df) + + logger.info("ExampleNode transformation complete") + + # Return new context with results + return TransformContext(result_df) + + +class SimilarityNode(TransformNode): + """Example transform node template. + + This node demonstrates the basic structure for creating + new transformation nodes in the pipeline. + """ + + def __init__(self, + param1: str = "default_value", + param2: int = 42, + device: str = "cpu"): + """Initialize the ExampleNode. + + Args: + param1: Example string parameter + param2: Example integer parameter + device: Device to use for computations ('cpu', 'cuda', 'mps') + """ + self.param1 = param1 + self.param2 = param2 + self.device = device + logger.info(f"Initialized ExampleNode with param1={param1}, param2={param2}") + + def _create_tables(self, con: sqlite3.Connection): + """Create any necessary tables in the database. + + This is optional - only needed if your node creates new tables. + """ + logger.info("Creating example tables") + + con.execute(""" + CREATE TABLE IF NOT EXISTS example_results ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + post_id INTEGER, + result_value TEXT, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + FOREIGN KEY (post_id) REFERENCES posts(id) + ) + """) + + con.commit() + + def _process_data(self, df: pd.DataFrame) -> pd.DataFrame: + """Process the input dataframe. + + This is where your main transformation logic goes. + + Args: + df: Input dataframe from context + + Returns: + Processed dataframe + """ + logger.info(f"Processing {len(df)} rows") + + # Example: Add a new column based on existing data + result_df = df.copy() + result_df['processed'] = True + result_df['example_value'] = result_df['id'].apply(lambda x: f"{self.param1}_{x}") + + logger.info("Processing complete") + return result_df + + def _store_results(self, con: sqlite3.Connection, df: pd.DataFrame): + """Store results back to the database. + + This is optional - only needed if you want to persist results. + + Args: + con: Database connection + df: Processed dataframe to store + """ + if df.empty: + logger.info("No results to store") + return + + logger.info(f"Storing {len(df)} results") + + # Example: Store to database + # df[['post_id', 'result_value']].to_sql( + # 'example_results', + # con, + # if_exists='append', + # index=False + # ) + + con.commit() + logger.info("Results stored successfully") + + def run(self, con: sqlite3.Connection, context: TransformContext) -> TransformContext: + """Execute the transformation. + + This is the main entry point called by the pipeline. + + Args: + con: SQLite database connection + context: TransformContext containing input dataframe + + Returns: + TransformContext with processed dataframe + """ + logger.info("Starting ExampleNode transformation") + + # Get input dataframe from context + input_df = context.get_dataframe() + + # Validate input + if input_df.empty: + logger.warning("Empty dataframe provided to ExampleNode") + return context + + # Create any necessary tables + self._create_tables(con) + + # Process the data + result_df = self._process_data(input_df) + + # Store results (optional) + self._store_results(con, result_df) + + logger.info("ExampleNode transformation complete") + + # Return new context with results + return TransformContext(result_df) diff --git a/transform/ensure_minilm_model.sh b/transform/ensure_minilm_model.sh new file mode 100644 index 0000000..2d58f24 --- /dev/null +++ b/transform/ensure_minilm_model.sh @@ -0,0 +1,16 @@ +#!/usr/bin/env bash +set -euo pipefail + +if [ -d "$MINILM_MODEL_PATH" ] && find "$MINILM_MODEL_PATH" -type f | grep -q .; then + echo "MiniLM model already present at $MINILM_MODEL_PATH" + exit 0 +fi + +echo "Downloading MiniLM model to $MINILM_MODEL_PATH" +mkdir -p "$MINILM_MODEL_PATH" +curl -sL "https://huggingface.co/api/models/${MINILM_MODEL_ID}" | jq -r '.siblings[].rfilename' | while read -r file; do + target="${MINILM_MODEL_PATH}/${file}" + mkdir -p "$(dirname "$target")" + echo "Downloading ${file}" + curl -sL "https://huggingface.co/${MINILM_MODEL_ID}/resolve/main/${file}" -o "$target" +done diff --git a/transform/entrypoint.sh b/transform/entrypoint.sh index 8beab84..96f5932 100644 --- a/transform/entrypoint.sh +++ b/transform/entrypoint.sh @@ -2,7 +2,9 @@ set -euo pipefail # Run model download with output to stdout/stderr +/usr/local/bin/ensure_minilm_model.sh 2>&1 /usr/local/bin/ensure_gliner_model.sh 2>&1 # Start cron in foreground with logging exec cron -f -L 2 +# cd /app && /usr/local/bin/python main.py >> /proc/1/fd/1 2>&1 \ No newline at end of file diff --git a/transform/main.py b/transform/main.py index d07d905..9922eed 100644 --- a/transform/main.py +++ b/transform/main.py @@ -56,9 +56,9 @@ def main(): # Load posts data logger.info("Loading posts from database") - sql = "SELECT id, author FROM posts WHERE author IS NOT NULL AND (is_cleaned IS NULL OR is_cleaned = 0) LIMIT ?" - MAX_CLEANED_POSTS = os.environ.get("MAX_CLEANED_POSTS", 100) - df = pd.read_sql(sql, con, params=[MAX_CLEANED_POSTS]) + sql = "SELECT * FROM posts WHERE author IS NOT NULL AND (is_cleaned IS NULL OR is_cleaned = 0)" + # MAX_CLEANED_POSTS = os.environ.get("MAX_CLEANED_POSTS", 100) + df = pd.read_sql(sql, con) logger.info(f"Loaded {len(df)} uncleaned posts with authors") if df.empty: diff --git a/transform/pipeline.py b/transform/pipeline.py index 1a97f1f..e1f4e9c 100644 --- a/transform/pipeline.py +++ b/transform/pipeline.py @@ -12,6 +12,7 @@ logger = logging.getLogger("knack-transform") class TransformContext: """Context object containing the dataframe for transformation.""" + # Possibly add a dict for the context to give more Information def __init__(self, df: pd.DataFrame): self.df = df @@ -153,7 +154,6 @@ class ParallelPipeline: logger.info(f"Pipeline has {len(stages)} execution stage(s)") results = {} - contexts = {None: initial_context} # Track contexts from each node errors = [] ExecutorClass = ProcessPoolExecutor if self.use_processes else ThreadPoolExecutor @@ -213,6 +213,7 @@ def create_default_pipeline(device: str = "cpu", Configured ParallelPipeline """ from author_node import NerAuthorNode, FuzzyAuthorNode + from embeddings_node import TextEmbeddingNode, UmapNode pipeline = ParallelPipeline(max_workers=max_workers, use_processes=False) @@ -236,17 +237,24 @@ def create_default_pipeline(device: str = "cpu", name='FuzzyAuthorNode' )) + pipeline.add_node(NodeConfig( + node_class=TextEmbeddingNode, + node_kwargs={ + 'device': device, + 'model_path': os.environ.get('MINILM_MODEL_PATH') + }, + dependencies=[], + name='TextEmbeddingNode' + )) + + pipeline.add_node(NodeConfig( + node_class=UmapNode, + node_kwargs={}, + dependencies=['TextEmbeddingNode'], + name='UmapNode' + )) + # TODO: Create Node to compute Text Embeddings and UMAP. - # TODO: Create Node to pre-compute data based on visuals to reduce load time. - - # TODO: Add more nodes here as they are implemented - # Example: - # pipeline.add_node(NodeConfig( - # node_class=EmbeddingNode, - # node_kwargs={'device': device}, - # dependencies=[], # Runs after AuthorNode - # name='EmbeddingNode' - # )) # pipeline.add_node(NodeConfig( # node_class=UMAPNode, diff --git a/transform/requirements.txt b/transform/requirements.txt index e210d05..023d14f 100644 --- a/transform/requirements.txt +++ b/transform/requirements.txt @@ -2,4 +2,6 @@ pandas python-dotenv gliner torch -fuzzysearch \ No newline at end of file +fuzzysearch +sentence_transformers +umap-learn \ No newline at end of file