diff --git a/.gitignore b/.gitignore index 2e7a5cf..8e2fba4 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,5 @@ data/ venv/ +experiment/ .DS_STORE +.env \ No newline at end of file diff --git a/Makefile b/Makefile index a669090..47a8063 100644 --- a/Makefile +++ b/Makefile @@ -1,2 +1,12 @@ -build: - docker build -t knack-scraper . \ No newline at end of file +volume: + docker volume create knack_data + +stop: + docker stop knack-scraper || true + docker rm knack-scraper || true + +up: + docker compose up -d + +down: + docker compose down \ No newline at end of file diff --git a/docker-compose.yml b/docker-compose.yml new file mode 100644 index 0000000..5c5c4e7 --- /dev/null +++ b/docker-compose.yml @@ -0,0 +1,27 @@ +services: + scraper: + build: + context: ./scrape + dockerfile: Dockerfile + image: knack-scraper + container_name: knack-scraper + env_file: + - scrape/.env + volumes: + - knack_data:/data + restart: unless-stopped + + transform: + build: + context: ./transform + dockerfile: Dockerfile + image: knack-transform + container_name: knack-transform + env_file: + - transform/.env + volumes: + - knack_data:/data + restart: unless-stopped + +volumes: + knack_data: diff --git a/main.py b/main.py deleted file mode 100755 index f5a0b7a..0000000 --- a/main.py +++ /dev/null @@ -1,306 +0,0 @@ -#! python3 -import logging -import os -import sqlite3 -import time -from concurrent.futures import ThreadPoolExecutor -from datetime import datetime -import sys - -from dotenv import load_dotenv -import pandas as pd -import requests -import tqdm -from bs4 import BeautifulSoup - -load_dotenv() - -if (os.environ.get('LOGGING_LEVEL', 'INFO') == 'INFO'): - logging_level = logging.INFO -else: - logging_level = logging.DEBUG - -logging.basicConfig( - level=logging_level, - format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', - handlers=[ - logging.FileHandler("app.log"), - logging.StreamHandler(sys.stdout) - ] -) -logger = logging.getLogger("knack-scraper") - -def table_exists(tablename: str, con: sqlite3.Connection): - query = "SELECT 1 FROM sqlite_master WHERE type='table' AND name=? LIMIT 1" - return len(con.execute(query, [tablename]).fetchall()) > 0 - - -def download(id: int): - if id == 0: - return - base_url = "https://knack.news/" - url = f"{base_url}{id}" - res = requests.get(url) - - # make sure we don't dos knack - time.sleep(2) - - if not (200 <= res.status_code <= 300): - return - - logger.debug("Found promising page with id %d!", id) - - content = res.content - soup = BeautifulSoup(content, "html.parser") - - pC = soup.find("div", {"class": "postContent"}) - - if pC is None: - # not a normal post - logger.debug( - "Page with id %d does not have a .pageContent-div. Skipping for now.", id - ) - return - - # every post has these fields - title = pC.find("h3", {"class": "postTitle"}).text - postText = pC.find("div", {"class": "postText"}) - - # these fields are possible but not required - # TODO: cleanup - try: - date_parts = pC.find("span", {"class": "singledate"}).text.split(' ') - day = int(date_parts[0][:-1]) - months = {'Januar': 1, 'Februar': 2, 'März': 3, 'April': 4, 'Mai': 5, 'Juni': 6, 'Juli': 7, 'August': 8, 'September': 9, 'Oktober': 10, 'November': 11, 'Dezember': 12} - month = months[date_parts[1]] - year = int(date_parts[2]) - parsed_date = datetime(year, month, day) - except Exception: - parsed_date = None - - try: - author = pC.find("span", {"class": "author"}).text - except AttributeError: - author = None - - try: - category = pC.find("span", {"class": "categoryInfo"}).find_all() - category = [c.text for c in category if c.text != 'Alle Artikel'] - category = ";".join(category) - except AttributeError: - category = None - - try: - tags = [x.text for x in pC.find("div", {"class": "tagsInfo"}).find_all("a")] - tags = ";".join(tags) - except AttributeError: - tags = None - - img = pC.find("img", {"class": "postImage"}) - if img is not None: - img = img["src"] - - res_dict = { - "id": id, - "title": title, - "author": author, - "date": parsed_date, - "category": category, - "url": url, - "img_link": img, - "tags": tags, - "text": postText.text, - "html": str(postText), - "scraped_at": datetime.now(), - } - - return res_dict - - -def run_downloads(min_id: int, max_id: int, num_threads: int = 8): - res = [] - - logger.info( - "Started parallel scrape of posts from id %d to id %d using %d threads.", - min_id, - max_id - 1, - num_threads, - ) - with ThreadPoolExecutor(max_workers=num_threads) as executor: - # Use a list comprehension to create a list of futures - futures = [executor.submit(download, i) for i in range(min_id, max_id)] - - for future in tqdm.tqdm( - futures, total=max_id - min_id - ): # tqdm to track progress - post = future.result() - if post is not None: - res.append(post) - - # sqlite can't handle lists so let's convert them to a single row csv - # TODO: make sure our database is properly normalized - postdf = pd.DataFrame(res) - tagdf = None - posttotagdf = None - categorydf = None - postcategorydf = None - - # Extract and create tags dataframe - if not postdf.empty and 'tags' in postdf.columns: - # Collect all unique tags - all_tags = set() - for tags_str in postdf['tags']: - if pd.notna(tags_str): - tags_list = [tag.strip() for tag in tags_str.split(';')] - all_tags.update(tags_list) - - # Create tagdf with id and text columns - if all_tags: - all_tags = sorted(list(all_tags)) - tagdf = pd.DataFrame({ - 'id': range(len(all_tags)), - 'tag': all_tags - }) - - # Create posttotagdf mapping table - rows = [] - for post_id, tags_str in zip(postdf['id'], postdf['tags']): - if pd.notna(tags_str): - tags_list = [tag.strip() for tag in tags_str.split(';')] - for tag_text in tags_list: - tag_id = tagdf[tagdf['tag'] == tag_text]['id'].values[0] - rows.append({'post_id': post_id, 'tag_id': tag_id}) - - if rows: - posttotagdf = pd.DataFrame(rows) - - # Extract and create categories dataframe - if not postdf.empty and 'category' in postdf.columns: - # Collect all unique categories - all_categories = set() - for category_str in postdf['category']: - if pd.notna(category_str): - category_list = [cat.strip() for cat in category_str.split(';')] - all_categories.update(category_list) - - # Create categorydf with id and category columns - if all_categories: - all_categories = sorted(list(all_categories)) - categorydf = pd.DataFrame({ - 'id': range(len(all_categories)), - 'category': all_categories - }) - - # Create postcategorydf mapping table - rows = [] - for post_id, category_str in zip(postdf['id'], postdf['category']): - if pd.notna(category_str): - category_list = [cat.strip() for cat in category_str.split(';')] - for category_text in category_list: - category_id = categorydf[categorydf['category'] == category_text]['id'].values[0] - rows.append({'post_id': post_id, 'category_id': category_id}) - - if rows: - postcategorydf = pd.DataFrame(rows) - - return postdf, tagdf, posttotagdf, categorydf, postcategorydf - - -def main(): - num_threads = int(os.environ.get("NUM_THREADS", 8)) - n_scrapes = int(os.environ.get("NUM_SCRAPES", 100)) - database_location = os.environ.get("DATABASE_LOCATION", "../data/knack.sqlite") - - logger.debug(f"Started Knack Scraper: \nNUM_THREADS: {num_threads}\nN_SCRAPES: {n_scrapes}\nDATABASE_LOCATION: {database_location}") - - con = sqlite3.connect(database_location) - with con: - post_table_exists = table_exists("posts", con) - - if post_table_exists: - logger.info("found posts retrieved earlier") - # retrieve max post id from db so - # we can skip retrieving known posts - max_id_in_db = con.execute("SELECT MAX(id) FROM posts").fetchone()[0] - logger.info("Got max id %d!", max_id_in_db) - else: - logger.info("no posts scraped so far - starting from 0") - # retrieve from 0 onwards - max_id_in_db = -1 - - con = sqlite3.connect(database_location) - postdf, tagdf, posttotagdf, categorydf, postcategorydf = run_downloads( - min_id=max_id_in_db + 1, - max_id=max_id_in_db + n_scrapes, - num_threads=num_threads, - ) - postdf.to_sql("posts", con, if_exists="append") - - # Handle tags dataframe merging and storage - if tagdf is not None and not tagdf.empty: - # Check if tags table already exists - if table_exists("tags", con): - # Read existing tags from database - existing_tagdf = pd.read_sql("SELECT id, tag FROM tags", con) - - # Merge new tags with existing tags, avoiding duplicates - merged_tagdf = pd.concat([existing_tagdf, tagdf], ignore_index=False) - merged_tagdf = merged_tagdf.drop_duplicates(subset=['tag'], keep='first') - merged_tagdf = merged_tagdf.reset_index(drop=True) - merged_tagdf['id'] = range(len(merged_tagdf)) - - # Drop the old table and insert the merged data - con.execute("DROP TABLE tags") - con.commit() - merged_tagdf.to_sql("tags", con, if_exists="append", index=False) - - # Update tag_id references in posttotagdf - if posttotagdf is not None and not posttotagdf.empty: - #tag_mapping = dict(zip(tagdf['tag'], tagdf['id'])) - posttotagdf['tag_id'] = posttotagdf['tag_id'].map( - lambda old_id: merged_tagdf[merged_tagdf['tag'] == tagdf.loc[old_id, 'tag']]['id'].values[0] - ) - else: - # First time creating tags table - tagdf.to_sql("tags", con, if_exists="append", index=False) - - # Store posttags (post to tags mapping) - if posttotagdf is not None and not posttotagdf.empty: - posttotagdf.to_sql("posttags", con, if_exists="append", index=False) - - # Handle categories dataframe merging and storage - if categorydf is not None and not categorydf.empty: - # Check if categories table already exists - if table_exists("categories", con): - # Read existing categories from database - existing_categorydf = pd.read_sql("SELECT id, category FROM categories", con) - - # Merge new categories with existing categories, avoiding duplicates - merged_categorydf = pd.concat([existing_categorydf, categorydf], ignore_index=False) - merged_categorydf = merged_categorydf.drop_duplicates(subset=['category'], keep='first') - merged_categorydf = merged_categorydf.reset_index(drop=True) - merged_categorydf['id'] = range(len(merged_categorydf)) - - # Drop the old table and insert the merged data - con.execute("DROP TABLE categories") - con.commit() - merged_categorydf.to_sql("categories", con, if_exists="append", index=False) - - # Update category_id references in postcategorydf - if postcategorydf is not None and not postcategorydf.empty: - postcategorydf['category_id'] = postcategorydf['category_id'].map( - lambda old_id: merged_categorydf[merged_categorydf['category'] == categorydf.loc[old_id, 'category']]['id'].values[0] - ) - else: - # First time creating categories table - categorydf.to_sql("categories", con, if_exists="append", index=False) - - # Store postcategories (post to categories mapping) - if postcategorydf is not None and not postcategorydf.empty: - postcategorydf.to_sql("postcategories", con, if_exists="append", index=False) - - logger.info(f"scraped new entries. number of new posts: {len(postdf.index)}") - - -if __name__ == "__main__": - main() diff --git a/Dockerfile b/scrape/Dockerfile similarity index 77% rename from Dockerfile rename to scrape/Dockerfile index d9752ef..a2fbe2e 100644 --- a/Dockerfile +++ b/scrape/Dockerfile @@ -3,6 +3,8 @@ FROM python:slim RUN mkdir /app RUN mkdir /data +#COPY /data/knack.sqlite /data + WORKDIR /app COPY requirements.txt . RUN pip install --no-cache-dir -r requirements.txt @@ -19,7 +21,7 @@ ENV LANG=de_DE.UTF-8 ENV LC_ALL=de_DE.UTF-8 # Create cron job that runs every 15 minutes with environment variables -RUN echo "*/10 * * * * cd /app && /usr/local/bin/python main.py >> /proc/1/fd/1 2>&1" > /etc/cron.d/knack-scraper +RUN echo "5 4 * * * cd /app && /usr/local/bin/python main.py >> /proc/1/fd/1 2>&1" > /etc/cron.d/knack-scraper RUN chmod 0644 /etc/cron.d/knack-scraper RUN crontab /etc/cron.d/knack-scraper diff --git a/scrape/main.py b/scrape/main.py new file mode 100755 index 0000000..15f0e72 --- /dev/null +++ b/scrape/main.py @@ -0,0 +1,260 @@ +#! python3 +import logging +import os +import sqlite3 +import time +from concurrent.futures import ThreadPoolExecutor +from datetime import datetime +import sys + +from dotenv import load_dotenv +import pandas as pd +import requests +from bs4 import BeautifulSoup + +load_dotenv() + +if (os.environ.get('LOGGING_LEVEL', 'INFO') == 'INFO'): + logging_level = logging.INFO +else: + logging_level = logging.DEBUG + +logging.basicConfig( + level=logging_level, + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', + handlers=[ + logging.FileHandler("app.log"), + logging.StreamHandler(sys.stdout) + ] +) +logger = logging.getLogger("knack-scraper") + +def table_exists(tablename: str, con: sqlite3.Connection): + query = "SELECT 1 FROM sqlite_master WHERE type='table' AND name=? LIMIT 1" + return len(con.execute(query, [tablename]).fetchall()) > 0 + + +def split_semicolon_list(value: str): + if pd.isna(value): + return [] + return [item.strip() for item in str(value).split(';') if item.strip()] + + +def build_dimension_and_mapping(postdf: pd.DataFrame, field_name: str, dim_col: str): + """Extract unique dimension values and post-to-dimension mappings from a column.""" + if postdf.empty or field_name not in postdf.columns: + return None, None + + values = set() + mapping_rows = [] + + for post_id, raw in zip(postdf['id'], postdf[field_name]): + items = split_semicolon_list(raw) + for item in items: + values.add(item) + mapping_rows.append({'post_id': post_id, dim_col: item}) + + if not values: + return None, None + + dim_df = pd.DataFrame({ + 'id': range(len(values)), + dim_col: sorted(values), + }) + map_df = pd.DataFrame(mapping_rows) + return dim_df, map_df + + +def store_dimension_and_mapping( + con: sqlite3.Connection, + dim_df: pd.DataFrame | None, + map_df: pd.DataFrame | None, + table_name: str, + dim_col: str, + mapping_table: str, + mapping_id_col: str, +): + """Persist a dimension table and its mapping table, merging with existing values.""" + if dim_df is None or dim_df.empty: + return + + if table_exists(table_name, con): + existing = pd.read_sql(f"SELECT id, {dim_col} FROM {table_name}", con) + merged = pd.concat([existing, dim_df], ignore_index=True) + merged = merged.drop_duplicates(subset=[dim_col], keep='first').reset_index(drop=True) + merged['id'] = range(len(merged)) + else: + merged = dim_df.copy() + + # Replace table with merged content + merged.to_sql(table_name, con, if_exists="replace", index=False) + + if map_df is None or map_df.empty: + return + + value_to_id = dict(zip(merged[dim_col], merged['id'])) + map_df = map_df.copy() + map_df[mapping_id_col] = map_df[dim_col].map(value_to_id) + map_df = map_df[['post_id', mapping_id_col]].dropna() + map_df.to_sql(mapping_table, con, if_exists="append", index=False) + + +def download(id: int): + if id == 0: + return + base_url = "https://knack.news/" + url = f"{base_url}{id}" + res = requests.get(url) + + # make sure we don't dos knack + time.sleep(2) + + if not (200 <= res.status_code <= 300): + return + + logger.debug("Found promising page with id %d!", id) + + content = res.content + soup = BeautifulSoup(content, "html.parser") + + pC = soup.find("div", {"class": "postContent"}) + + if pC is None: + # not a normal post + logger.debug( + "Page with id %d does not have a .pageContent-div. Skipping for now.", id + ) + return + + # every post has these fields + title = pC.find("h3", {"class": "postTitle"}).text + postText = pC.find("div", {"class": "postText"}) + + # these fields are possible but not required + # TODO: cleanup + try: + date_parts = pC.find("span", {"class": "singledate"}).text.split(' ') + day = int(date_parts[0][:-1]) + months = {'Januar': 1, 'Februar': 2, 'März': 3, 'April': 4, 'Mai': 5, 'Juni': 6, 'Juli': 7, 'August': 8, 'September': 9, 'Oktober': 10, 'November': 11, 'Dezember': 12} + month = months[date_parts[1]] + year = int(date_parts[2]) + parsed_date = datetime(year, month, day) + except Exception: + parsed_date = None + + try: + author = pC.find("span", {"class": "author"}).text + except AttributeError: + author = None + + try: + category = pC.find("span", {"class": "categoryInfo"}).find_all() + category = [c.text for c in category if c.text != 'Alle Artikel'] + category = ";".join(category) + except AttributeError: + category = None + + try: + tags = [x.text for x in pC.find("div", {"class": "tagsInfo"}).find_all("a")] + tags = ";".join(tags) + except AttributeError: + tags = None + + img = pC.find("img", {"class": "postImage"}) + if img is not None: + img = img["src"] + + res_dict = { + "id": id, + "title": title, + "author": author, + "date": parsed_date, + "category": category, + "url": url, + "img_link": img, + "tags": tags, + "text": postText.text, + "html": str(postText), + "scraped_at": datetime.now(), + "is_cleaned": False + } + + return res_dict + + +def run_downloads(min_id: int, max_id: int, num_threads: int = 8): + res = [] + + logger.info( + "Started parallel scrape of posts from id %d to id %d using %d threads.", + min_id, + max_id - 1, + num_threads, + ) + with ThreadPoolExecutor(max_workers=num_threads) as executor: + # Use a list comprehension to create a list of futures + futures = [executor.submit(download, i) for i in range(min_id, max_id)] + + for future in futures: + post = future.result() + if post is not None: + res.append(post) + + postdf = pd.DataFrame(res) + return postdf + + +def main(): + num_threads = int(os.environ.get("NUM_THREADS", 8)) + n_scrapes = int(os.environ.get("NUM_SCRAPES", 100)) + database_location = os.environ.get("DATABASE_LOCATION", "../data/knack.sqlite") + + logger.debug(f"Started Knack Scraper: \nNUM_THREADS: {num_threads}\nN_SCRAPES: {n_scrapes}\nDATABASE_LOCATION: {database_location}") + + con = sqlite3.connect(database_location) + with con: + if table_exists("posts", con): + logger.info("found posts retrieved earlier") + max_id_in_db = con.execute("SELECT MAX(id) FROM posts").fetchone()[0] + logger.info("Got max id %d!", max_id_in_db) + else: + logger.info("no posts scraped so far - starting from 0") + max_id_in_db = -1 + + postdf = run_downloads( + min_id=max_id_in_db + 1, + max_id=max_id_in_db + n_scrapes, + num_threads=num_threads, + ) + + postdf.to_sql("posts", con, if_exists="append") + + # Tags + tag_dim, tag_map = build_dimension_and_mapping(postdf, 'tags', 'tag') + store_dimension_and_mapping( + con, + tag_dim, + tag_map, + table_name="tags", + dim_col="tag", + mapping_table="posttags", + mapping_id_col="tag_id", + ) + + # Categories + category_dim, category_map = build_dimension_and_mapping(postdf, 'category', 'category') + store_dimension_and_mapping( + con, + category_dim, + category_map, + table_name="categories", + dim_col="category", + mapping_table="postcategories", + mapping_id_col="category_id", + ) + + logger.info(f"scraped new entries. number of new posts: {len(postdf.index)}") + + +if __name__ == "__main__": + main() diff --git a/requirements.txt b/scrape/requirements.txt similarity index 83% rename from requirements.txt rename to scrape/requirements.txt index 3c59d8d..32d5df2 100644 --- a/requirements.txt +++ b/scrape/requirements.txt @@ -1,5 +1,4 @@ pandas requests -tqdm bs4 dotenv \ No newline at end of file diff --git a/transform/.env.example b/transform/.env.example new file mode 100644 index 0000000..34cd0e0 --- /dev/null +++ b/transform/.env.example @@ -0,0 +1,4 @@ +LOGGING_LEVEL=INFO +DB_PATH=/data/knack.sqlite +MAX_CLEANED_POSTS=1000 +COMPUTE_DEVICE=mps \ No newline at end of file diff --git a/transform/Dockerfile b/transform/Dockerfile new file mode 100644 index 0000000..4c72480 --- /dev/null +++ b/transform/Dockerfile @@ -0,0 +1,41 @@ +FROM python:3.12-slim + +RUN mkdir /app +RUN mkdir /data + +# Install build dependencies +RUN apt-get update && apt-get install -y --no-install-recommends \ + gcc \ + g++ \ + gfortran \ + libopenblas-dev \ + liblapack-dev \ + pkg-config \ + && rm -rf /var/lib/apt/lists/* + +#COPY /data/knack.sqlite /data + +WORKDIR /app +COPY requirements.txt . +RUN pip install --no-cache-dir -r requirements.txt + +COPY .env . + +RUN apt update -y +RUN apt install -y cron locales + +COPY *.py . + +ENV PYTHONUNBUFFERED=1 +ENV LANG=de_DE.UTF-8 +ENV LC_ALL=de_DE.UTF-8 + +# Create cron job that runs every weekend (Sunday at 3 AM) 0 3 * * 0 +# Testing every 30 Minutes */30 * * * * +RUN echo "0 3 * * 0 cd /app && /usr/local/bin/python main.py >> /proc/1/fd/1 2>&1" > /etc/cron.d/knack-transform +RUN chmod 0644 /etc/cron.d/knack-transform +RUN crontab /etc/cron.d/knack-transform + +# Start cron in foreground +CMD ["cron", "-f"] +#CMD ["python", "main.py"] diff --git a/transform/README.md b/transform/README.md new file mode 100644 index 0000000..44ddeb1 --- /dev/null +++ b/transform/README.md @@ -0,0 +1,62 @@ +# Knack Transform + +Data transformation pipeline for the Knack scraper project. + +## Overview + +This folder contains the transformation logic that processes data from the SQLite database. It runs on a scheduled basis (every weekend) via cron. + +## Structure + +- `base.py` - Abstract base class for transform nodes +- `main.py` - Main entry point and pipeline orchestration +- `Dockerfile` - Docker image configuration with cron setup +- `requirements.txt` - Python dependencies + +## Transform Nodes + +Transform nodes inherit from `TransformNode` and implement the `run` method: + +```python +from base import TransformNode, TransformContext +import sqlite3 + +class MyTransform(TransformNode): + def run(self, con: sqlite3.Connection, context: TransformContext) -> TransformContext: + df = context.get_dataframe() + + # Transform logic here + transformed_df = df.copy() + # ... your transformations ... + + # Optionally write back to database + transformed_df.to_sql("my_table", con, if_exists="replace", index=False) + + return TransformContext(transformed_df) +``` + +## Configuration + +Copy `.env.example` to `.env` and configure: + +- `LOGGING_LEVEL` - Log level (INFO or DEBUG) +- `DB_PATH` - Path to SQLite database + +## Running + +### With Docker + +```bash +docker build -t knack-transform . +docker run -v $(pwd)/data:/data knack-transform +``` + +### Locally + +```bash +python main.py +``` + +## Cron Schedule + +The Docker container runs the transform pipeline every Sunday at 3 AM. diff --git a/transform/author_node.py b/transform/author_node.py new file mode 100644 index 0000000..23d3365 --- /dev/null +++ b/transform/author_node.py @@ -0,0 +1,263 @@ +"""Author classification transform node using NER.""" +from base import TransformNode, TransformContext +import sqlite3 +import pandas as pd +import logging +from concurrent.futures import ThreadPoolExecutor +from datetime import datetime + +try: + from gliner import GLiNER + import torch + GLINER_AVAILABLE = True +except ImportError: + GLINER_AVAILABLE = False + logging.warning("GLiNER not available. Install with: pip install gliner") + +logger = logging.getLogger("knack-transform") + + +class AuthorNode(TransformNode): + """Transform node that extracts and classifies authors using NER. + + Creates two tables: + - authors: stores unique authors with their type (Person, Organisation, etc.) + - post_authors: maps posts to their authors + """ + + def __init__(self, model_name: str = "urchade/gliner_medium-v2.1", + threshold: float = 0.5, + max_workers: int = 64, + device: str = "cpu"): + """Initialize the AuthorNode. + + Args: + model_name: GLiNER model to use + threshold: Confidence threshold for entity predictions + max_workers: Number of parallel workers for prediction + device: Device to run model on ('cpu', 'cuda', 'mps') + """ + self.model_name = model_name + self.threshold = threshold + self.max_workers = max_workers + self.device = device + self.model = None + self.labels = ["Person", "Organisation", "Email", "Newspaper", "NGO"] + + def _setup_model(self): + """Initialize the NER model.""" + if not GLINER_AVAILABLE: + raise ImportError("GLiNER is required for AuthorNode. Install with: pip install gliner") + + logger.info(f"Loading GLiNER model: {self.model_name}") + + if self.device == "cuda" and torch.cuda.is_available(): + self.model = GLiNER.from_pretrained( + self.model_name, + max_length=255 + ).to('cuda', dtype=torch.float16) + elif self.device == "mps" and torch.backends.mps.is_available(): + self.model = GLiNER.from_pretrained( + self.model_name, + max_length=255 + ).to('mps', dtype=torch.float16) + else: + self.model = GLiNER.from_pretrained( + self.model_name, + max_length=255 + ) + + logger.info("Model loaded successfully") + + def _predict(self, text_data: dict): + """Predict entities for a single author text. + + Args: + text_data: Dict with 'author' and 'id' keys + + Returns: + Tuple of (predictions, post_id) or None + """ + if text_data is None or text_data.get('author') is None: + return None + + predictions = self.model.predict_entities( + text_data['author'], + self.labels, + threshold=self.threshold + ) + return predictions, text_data['id'] + + def _classify_authors(self, posts_df: pd.DataFrame): + """Classify all authors in the posts dataframe. + + Args: + posts_df: DataFrame with 'id' and 'author' columns + + Returns: + List of dicts with 'text', 'label', 'id' keys + """ + if self.model is None: + self._setup_model() + + # Prepare input data + authors_data = [] + for idx, row in posts_df.iterrows(): + if pd.notna(row['author']): + authors_data.append({ + 'author': row['author'], + 'id': row['id'] + }) + + logger.info(f"Classifying {len(authors_data)} authors") + + results = [] + with ThreadPoolExecutor(max_workers=self.max_workers) as executor: + futures = [executor.submit(self._predict, data) for data in authors_data] + + for future in futures: + result = future.result() + if result is not None: + predictions, post_id = result + for pred in predictions: + results.append({ + 'text': pred['text'], + 'label': pred['label'], + 'id': post_id + }) + + logger.info(f"Classification complete. Found {len(results)} author entities") + return results + + def _create_tables(self, con: sqlite3.Connection): + """Create authors and post_authors tables if they don't exist.""" + logger.info("Creating authors tables") + + con.execute(""" + CREATE TABLE IF NOT EXISTS authors ( + id INTEGER PRIMARY KEY, + name TEXT, + type TEXT, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP + ) + """) + + con.execute(""" + CREATE TABLE IF NOT EXISTS post_authors ( + post_id INTEGER, + author_id INTEGER, + PRIMARY KEY (post_id, author_id), + FOREIGN KEY (post_id) REFERENCES posts(id), + FOREIGN KEY (author_id) REFERENCES authors(id) + ) + """) + + con.commit() + + def _store_authors(self, con: sqlite3.Connection, results: list): + """Store classified authors and their mappings. + + Args: + con: Database connection + results: List of classification results + """ + if not results: + logger.info("No authors to store") + return + + # Convert results to DataFrame + results_df = pd.DataFrame(results) + + # Get unique authors with their types + unique_authors = results_df[['text', 'label']].drop_duplicates() + unique_authors.columns = ['name', 'type'] + + # Get existing authors + existing_authors = pd.read_sql("SELECT id, name FROM authors", con) + + # Find new authors to insert + if not existing_authors.empty: + new_authors = unique_authors[~unique_authors['name'].isin(existing_authors['name'])] + else: + new_authors = unique_authors + + if not new_authors.empty: + logger.info(f"Inserting {len(new_authors)} new authors") + new_authors.to_sql('authors', con, if_exists='append', index=False) + + # Get all authors with their IDs + all_authors = pd.read_sql("SELECT id, name FROM authors", con) + name_to_id = dict(zip(all_authors['name'], all_authors['id'])) + + # Create post_authors mappings + mappings = [] + for _, row in results_df.iterrows(): + author_id = name_to_id.get(row['text']) + if author_id: + mappings.append({ + 'post_id': row['id'], + 'author_id': author_id + }) + + if mappings: + mappings_df = pd.DataFrame(mappings).drop_duplicates() + + # Clear existing mappings for these posts (optional, depends on your strategy) + # post_ids = tuple(mappings_df['post_id'].unique()) + # con.execute(f"DELETE FROM post_authors WHERE post_id IN ({','.join('?' * len(post_ids))})", post_ids) + + logger.info(f"Creating {len(mappings_df)} post-author mappings") + mappings_df.to_sql('post_authors', con, if_exists='append', index=False) + + # Mark posts as cleaned + processed_post_ids = mappings_df['post_id'].unique().tolist() + if processed_post_ids: + placeholders = ','.join('?' * len(processed_post_ids)) + con.execute(f"UPDATE posts SET is_cleaned = 1 WHERE id IN ({placeholders})", processed_post_ids) + logger.info(f"Marked {len(processed_post_ids)} posts as cleaned") + + con.commit() + logger.info("Authors and mappings stored successfully") + + def run(self, con: sqlite3.Connection, context: TransformContext) -> TransformContext: + """Execute the author classification transformation. + + Args: + con: SQLite database connection + context: TransformContext containing posts dataframe + + Returns: + TransformContext with classified authors dataframe + """ + logger.info("Starting AuthorNode transformation") + + posts_df = context.get_dataframe() + + # Ensure required columns exist + if 'author' not in posts_df.columns: + logger.warning("No 'author' column in dataframe. Skipping AuthorNode.") + return context + + # Create tables + self._create_tables(con) + + # Classify authors + results = self._classify_authors(posts_df) + + # Store results + self._store_authors(con, results) + + # Mark posts without author entities as cleaned too (no authors found) + processed_ids = set([r['id'] for r in results]) if results else set() + unprocessed_ids = [pid for pid in posts_df['id'].tolist() if pid not in processed_ids] + if unprocessed_ids: + placeholders = ','.join('?' * len(unprocessed_ids)) + con.execute(f"UPDATE posts SET is_cleaned = 1 WHERE id IN ({placeholders})", unprocessed_ids) + con.commit() + logger.info(f"Marked {len(unprocessed_ids)} posts without author entities as cleaned") + + # Return context with results + results_df = pd.DataFrame(results) if results else pd.DataFrame() + logger.info("AuthorNode transformation complete") + + return TransformContext(results_df) diff --git a/transform/base.py b/transform/base.py new file mode 100644 index 0000000..59a4f31 --- /dev/null +++ b/transform/base.py @@ -0,0 +1,37 @@ +"""Base transform node for data pipeline.""" +from abc import ABC, abstractmethod +import sqlite3 +import pandas as pd + + +class TransformContext: + """Context object containing the dataframe for transformation.""" + + def __init__(self, df: pd.DataFrame): + self.df = df + + def get_dataframe(self) -> pd.DataFrame: + """Get the pandas dataframe from the context.""" + return self.df + + +class TransformNode(ABC): + """Abstract base class for transformation nodes. + + Each transform node implements a single transformation step + that takes data from the database, transforms it, and + potentially writes results back. + """ + + @abstractmethod + def run(self, con: sqlite3.Connection, context: TransformContext) -> TransformContext: + """Execute the transformation. + + Args: + con: SQLite database connection + context: TransformContext containing the input dataframe + + Returns: + TransformContext with the transformed dataframe + """ + pass diff --git a/transform/main.py b/transform/main.py new file mode 100644 index 0000000..29b9a38 --- /dev/null +++ b/transform/main.py @@ -0,0 +1,89 @@ +#! python3 +import logging +import os +import sqlite3 +import sys +from dotenv import load_dotenv + +load_dotenv() + +if (os.environ.get('LOGGING_LEVEL', 'INFO') == 'INFO'): + logging_level = logging.INFO +else: + logging_level = logging.DEBUG + +logging.basicConfig( + level=logging_level, + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', + handlers=[ + logging.FileHandler("app.log"), + logging.StreamHandler(sys.stdout) + ] +) +logger = logging.getLogger("knack-transform") + + +def setup_database_connection(): + """Create connection to the SQLite database.""" + db_path = os.environ.get('DB_PATH', '/data/knack.sqlite') + logger.info(f"Connecting to database: {db_path}") + return sqlite3.connect(db_path) + + +def table_exists(tablename: str, con: sqlite3.Connection): + """Check if a table exists in the database.""" + query = "SELECT 1 FROM sqlite_master WHERE type='table' AND name=? LIMIT 1" + return len(con.execute(query, [tablename]).fetchall()) > 0 + + +def main(): + """Main entry point for the transform pipeline.""" + logger.info("Starting transform pipeline") + + try: + con = setup_database_connection() + logger.info("Database connection established") + + # Check if posts table exists + if not table_exists('posts', con): + logger.warning("Posts table does not exist yet. Please run the scraper first to populate the database.") + logger.info("Transform pipeline skipped - no data available") + return + + # Import transform nodes + from author_node import AuthorNode + from base import TransformContext + import pandas as pd + + # Load posts data + logger.info("Loading posts from database") + sql = "SELECT id, author FROM posts WHERE author IS NOT NULL AND (is_cleaned IS NULL OR is_cleaned = 0) LIMIT ?" + MAX_CLEANED_POSTS = os.environ.get("MAX_CLEANED_POSTS", 500) + df = pd.read_sql(sql, con, params=[MAX_CLEANED_POSTS]) + logger.info(f"Loaded {len(df)} uncleaned posts with authors") + + if df.empty: + logger.info("No uncleaned posts found. Transform pipeline skipped.") + return + + # Create context and run author classification + context = TransformContext(df) + author_transform = AuthorNode(device=os.environ.get('COMPUTE_DEVICE', 'cpu')) # Change to "cuda" or "mps" if available + result_context = author_transform.run(con, context) + + # TODO: Create Node to compute Text Embeddings and UMAP. + # TODO: Create Node to pre-compute data based on visuals to reduce load time. + + logger.info("Transform pipeline completed successfully") + + except Exception as e: + logger.error(f"Error in transform pipeline: {e}", exc_info=True) + sys.exit(1) + finally: + if 'con' in locals(): + con.close() + logger.info("Database connection closed") + + +if __name__ == "__main__": + main() diff --git a/transform/requirements.txt b/transform/requirements.txt new file mode 100644 index 0000000..c95bd6d --- /dev/null +++ b/transform/requirements.txt @@ -0,0 +1,4 @@ +pandas +python-dotenv +gliner +torch