Compare commits
4 commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
49239e7e25 | ||
|
|
72765532d3 | ||
|
|
64df8fb328 | ||
|
|
bcd210ce01 |
24 changed files with 1970 additions and 199 deletions
2
.gitignore
vendored
2
.gitignore
vendored
|
|
@ -1,3 +1,5 @@
|
||||||
data/
|
data/
|
||||||
venv/
|
venv/
|
||||||
|
experiment/
|
||||||
.DS_STORE
|
.DS_STORE
|
||||||
|
.env
|
||||||
15
Dockerfile
15
Dockerfile
|
|
@ -1,15 +0,0 @@
|
||||||
FROM python:slim
|
|
||||||
|
|
||||||
RUN mkdir /app
|
|
||||||
RUN mkdir /data
|
|
||||||
|
|
||||||
WORKDIR /app
|
|
||||||
COPY requirements.txt .
|
|
||||||
RUN pip install --no-cache-dir -r requirements.txt
|
|
||||||
|
|
||||||
RUN apt update -y
|
|
||||||
RUN apt install -y cron
|
|
||||||
COPY crontab .
|
|
||||||
RUN crontab crontab
|
|
||||||
|
|
||||||
COPY main.py .
|
|
||||||
14
Makefile
14
Makefile
|
|
@ -1,2 +1,12 @@
|
||||||
build:
|
volume:
|
||||||
docker build -t knack-scraper .
|
docker volume create knack_data
|
||||||
|
|
||||||
|
stop:
|
||||||
|
docker stop knack-scraper || true
|
||||||
|
docker rm knack-scraper || true
|
||||||
|
|
||||||
|
up:
|
||||||
|
docker compose up -d
|
||||||
|
|
||||||
|
down:
|
||||||
|
docker compose down
|
||||||
18
README.md
18
README.md
|
|
@ -0,0 +1,18 @@
|
||||||
|
Knack-Scraper does exacly what its name suggests it does.
|
||||||
|
Knack-Scraper scrapes knack.news and writes to an sqlite
|
||||||
|
database for later usage.
|
||||||
|
|
||||||
|
## Example for .env
|
||||||
|
|
||||||
|
```
|
||||||
|
NUM_THREADS=8
|
||||||
|
NUM_SCRAPES=100
|
||||||
|
DATABASE_LOCATION='./data/knack.sqlite'
|
||||||
|
```
|
||||||
|
|
||||||
|
## Run once
|
||||||
|
|
||||||
|
```
|
||||||
|
python main.py
|
||||||
|
```
|
||||||
|
|
||||||
1
crontab
1
crontab
|
|
@ -1 +0,0 @@
|
||||||
5 4 * * * python /app/main.py
|
|
||||||
43
docker-compose.yml
Normal file
43
docker-compose.yml
Normal file
|
|
@ -0,0 +1,43 @@
|
||||||
|
services:
|
||||||
|
scraper:
|
||||||
|
build:
|
||||||
|
context: ./scrape
|
||||||
|
dockerfile: Dockerfile
|
||||||
|
image: knack-scraper
|
||||||
|
container_name: knack-scraper
|
||||||
|
env_file:
|
||||||
|
- scrape/.env
|
||||||
|
volumes:
|
||||||
|
- knack_data:/data
|
||||||
|
restart: unless-stopped
|
||||||
|
|
||||||
|
transform:
|
||||||
|
build:
|
||||||
|
context: ./transform
|
||||||
|
dockerfile: Dockerfile
|
||||||
|
image: knack-transform
|
||||||
|
container_name: knack-transform
|
||||||
|
env_file:
|
||||||
|
- transform/.env
|
||||||
|
volumes:
|
||||||
|
- knack_data:/data
|
||||||
|
- models:/models
|
||||||
|
restart: unless-stopped
|
||||||
|
|
||||||
|
sqlitebrowser:
|
||||||
|
image: lscr.io/linuxserver/sqlitebrowser:latest
|
||||||
|
container_name: sqlitebrowser
|
||||||
|
environment:
|
||||||
|
- PUID=1000
|
||||||
|
- PGID=1000
|
||||||
|
- TZ=Etc/UTC
|
||||||
|
volumes:
|
||||||
|
- knack_data:/data
|
||||||
|
ports:
|
||||||
|
- "3000:3000" # noVNC web UI
|
||||||
|
- "3001:3001" # VNC server
|
||||||
|
restart: unless-stopped
|
||||||
|
|
||||||
|
volumes:
|
||||||
|
knack_data:
|
||||||
|
models:
|
||||||
167
main.py
167
main.py
|
|
@ -1,167 +0,0 @@
|
||||||
#! python3
|
|
||||||
import locale
|
|
||||||
import logging
|
|
||||||
import os
|
|
||||||
import sqlite3
|
|
||||||
import sys
|
|
||||||
import time
|
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
|
||||||
from datetime import datetime
|
|
||||||
|
|
||||||
import pandas as pd
|
|
||||||
import requests
|
|
||||||
import tqdm
|
|
||||||
from bs4 import BeautifulSoup
|
|
||||||
|
|
||||||
logger = logging.getLogger("knack-scraper")
|
|
||||||
# ch = logging.StreamHandler()
|
|
||||||
# formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
|
|
||||||
# ch.setFormatter(formatter)
|
|
||||||
# ch.setLevel(logging.INFO)
|
|
||||||
# logger.addHandler(ch)
|
|
||||||
|
|
||||||
|
|
||||||
def table_exists(tablename: str, con: sqlite3.Connection):
|
|
||||||
query = "SELECT 1 FROM sqlite_master WHERE type='table' AND name=? LIMIT 1"
|
|
||||||
return len(con.execute(query, [tablename]).fetchall()) > 0
|
|
||||||
|
|
||||||
|
|
||||||
def download(id: int):
|
|
||||||
if id == 0:
|
|
||||||
return
|
|
||||||
base_url = "https://knack.news/"
|
|
||||||
url = f"{base_url}{id}"
|
|
||||||
res = requests.get(url)
|
|
||||||
|
|
||||||
# make sure we don't dos knack
|
|
||||||
time.sleep(2)
|
|
||||||
|
|
||||||
if not (200 <= res.status_code <= 300):
|
|
||||||
return
|
|
||||||
|
|
||||||
logger.info("Found promising page with id %d!", id)
|
|
||||||
|
|
||||||
content = res.content
|
|
||||||
soup = BeautifulSoup(content, "html.parser")
|
|
||||||
date_format = "%d. %B %Y"
|
|
||||||
|
|
||||||
# TODO FIXME: this fails inside the docker container
|
|
||||||
locale.setlocale(locale.LC_TIME, "de_DE")
|
|
||||||
pC = soup.find("div", {"class": "postContent"})
|
|
||||||
|
|
||||||
if pC is None:
|
|
||||||
# not a normal post
|
|
||||||
logger.info(
|
|
||||||
"Page with id %d does not have a .pageContent-div. Skipping for now.", id
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
# every post has these fields
|
|
||||||
title = pC.find("h3", {"class": "postTitle"}).text
|
|
||||||
postText = pC.find("div", {"class": "postText"})
|
|
||||||
|
|
||||||
# these fields are possible but not required
|
|
||||||
# TODO: cleanup
|
|
||||||
try:
|
|
||||||
date_string = pC.find("span", {"class": "singledate"}).text
|
|
||||||
parsed_date = datetime.strptime(date_string, date_format)
|
|
||||||
except AttributeError:
|
|
||||||
parsed_date = None
|
|
||||||
|
|
||||||
try:
|
|
||||||
author = pC.find("span", {"class": "author"}).text
|
|
||||||
except AttributeError:
|
|
||||||
author = None
|
|
||||||
|
|
||||||
try:
|
|
||||||
category = pC.find("span", {"class": "categoryInfo"}).find_all()
|
|
||||||
category = [c.text for c in category]
|
|
||||||
category = ";".join(category)
|
|
||||||
except AttributeError:
|
|
||||||
category = None
|
|
||||||
|
|
||||||
try:
|
|
||||||
tags = [x.text for x in pC.find("div", {"class": "tagsInfo"}).find_all("a")]
|
|
||||||
tags = ";".join(tags)
|
|
||||||
except AttributeError:
|
|
||||||
tags = None
|
|
||||||
|
|
||||||
img = pC.find("img", {"class": "postImage"})
|
|
||||||
if img is not None:
|
|
||||||
img = img["src"]
|
|
||||||
|
|
||||||
res_dict = {
|
|
||||||
"id": id,
|
|
||||||
"title": title,
|
|
||||||
"author": author,
|
|
||||||
"date": parsed_date,
|
|
||||||
"category": category,
|
|
||||||
"url": url,
|
|
||||||
"img_link": img,
|
|
||||||
"tags": tags,
|
|
||||||
"text": postText.text,
|
|
||||||
"html": str(postText),
|
|
||||||
"scraped_at": datetime.now(),
|
|
||||||
}
|
|
||||||
|
|
||||||
return res_dict
|
|
||||||
|
|
||||||
|
|
||||||
def run_downloads(min_id: int, max_id: int, num_threads: int = 8):
|
|
||||||
res = []
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
"Started parallel scrape of posts from id %d to id %d using %d threads.",
|
|
||||||
min_id,
|
|
||||||
max_id - 1,
|
|
||||||
num_threads,
|
|
||||||
)
|
|
||||||
with ThreadPoolExecutor(max_workers=num_threads) as executor:
|
|
||||||
# Use a list comprehension to create a list of futures
|
|
||||||
futures = [executor.submit(download, i) for i in range(min_id, max_id)]
|
|
||||||
|
|
||||||
for future in tqdm.tqdm(
|
|
||||||
futures, total=max_id - min_id
|
|
||||||
): # tqdm to track progress
|
|
||||||
post = future.result()
|
|
||||||
if post is not None:
|
|
||||||
res.append(post)
|
|
||||||
|
|
||||||
# sqlite can't handle lists so let's convert them to a single row csv
|
|
||||||
# TODO: make sure our database is properly normalized
|
|
||||||
df = pd.DataFrame(res)
|
|
||||||
|
|
||||||
return df
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
num_threads = int(os.environ.get("NUM_THREADS", 8))
|
|
||||||
n_scrapes = int(os.environ.get("NUM_SCRAPES", 100))
|
|
||||||
database_location = os.environ.get("DATABASE_LOCATION", "/data/knack.sqlite")
|
|
||||||
|
|
||||||
con = sqlite3.connect(database_location)
|
|
||||||
with con:
|
|
||||||
post_table_exists = table_exists("posts", con)
|
|
||||||
|
|
||||||
if post_table_exists:
|
|
||||||
logger.info("found posts retrieved earlier")
|
|
||||||
# retrieve max post id from db so
|
|
||||||
# we can skip retrieving known posts
|
|
||||||
max_id_in_db = con.execute("SELECT MAX(id) FROM posts").fetchone()[0]
|
|
||||||
logger.info("Got max id %d!", max_id_in_db)
|
|
||||||
else:
|
|
||||||
logger.info("no posts scraped so far - starting from 0")
|
|
||||||
# retrieve from 0 onwards
|
|
||||||
max_id_in_db = -1
|
|
||||||
|
|
||||||
con = sqlite3.connect(database_location)
|
|
||||||
df = run_downloads(
|
|
||||||
min_id=max_id_in_db + 1,
|
|
||||||
max_id=max_id_in_db + n_scrapes,
|
|
||||||
num_threads=num_threads,
|
|
||||||
)
|
|
||||||
df.to_sql("posts", con, if_exists="append")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
||||||
|
|
@ -1,14 +0,0 @@
|
||||||
beautifulsoup4==4.12.2
|
|
||||||
certifi==2023.7.22
|
|
||||||
charset-normalizer==3.3.0
|
|
||||||
idna==3.4
|
|
||||||
numpy==1.26.1
|
|
||||||
pandas==2.1.1
|
|
||||||
python-dateutil==2.8.2
|
|
||||||
pytz==2023.3.post1
|
|
||||||
requests==2.31.0
|
|
||||||
six==1.16.0
|
|
||||||
soupsieve==2.5
|
|
||||||
tqdm==4.66.1
|
|
||||||
tzdata==2023.3
|
|
||||||
urllib3==2.0.7
|
|
||||||
29
scrape/Dockerfile
Normal file
29
scrape/Dockerfile
Normal file
|
|
@ -0,0 +1,29 @@
|
||||||
|
FROM python:slim
|
||||||
|
|
||||||
|
RUN mkdir /app
|
||||||
|
RUN mkdir /data
|
||||||
|
|
||||||
|
#COPY /data/knack.sqlite /data
|
||||||
|
|
||||||
|
WORKDIR /app
|
||||||
|
COPY requirements.txt .
|
||||||
|
RUN pip install --no-cache-dir -r requirements.txt
|
||||||
|
|
||||||
|
COPY .env .
|
||||||
|
|
||||||
|
RUN apt update -y
|
||||||
|
RUN apt install -y cron locales
|
||||||
|
|
||||||
|
COPY main.py .
|
||||||
|
|
||||||
|
ENV PYTHONUNBUFFERED=1
|
||||||
|
ENV LANG=de_DE.UTF-8
|
||||||
|
ENV LC_ALL=de_DE.UTF-8
|
||||||
|
|
||||||
|
# Create cron job that runs every 15 minutes with environment variables
|
||||||
|
RUN echo "5 4 * * * cd /app && /usr/local/bin/python main.py >> /proc/1/fd/1 2>&1" > /etc/cron.d/knack-scraper
|
||||||
|
RUN chmod 0644 /etc/cron.d/knack-scraper
|
||||||
|
RUN crontab /etc/cron.d/knack-scraper
|
||||||
|
|
||||||
|
# Start cron in foreground
|
||||||
|
CMD ["cron", "-f"]
|
||||||
262
scrape/main.py
Executable file
262
scrape/main.py
Executable file
|
|
@ -0,0 +1,262 @@
|
||||||
|
#! python3
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import sqlite3
|
||||||
|
import time
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
from datetime import datetime
|
||||||
|
import sys
|
||||||
|
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
import pandas as pd
|
||||||
|
import requests
|
||||||
|
from bs4 import BeautifulSoup
|
||||||
|
|
||||||
|
load_dotenv()
|
||||||
|
|
||||||
|
if (os.environ.get('LOGGING_LEVEL', 'INFO') == 'INFO'):
|
||||||
|
logging_level = logging.INFO
|
||||||
|
else:
|
||||||
|
logging_level = logging.DEBUG
|
||||||
|
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging_level,
|
||||||
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
||||||
|
handlers=[
|
||||||
|
logging.FileHandler("app.log"),
|
||||||
|
logging.StreamHandler(sys.stdout)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
logger = logging.getLogger("knack-scraper")
|
||||||
|
|
||||||
|
def table_exists(tablename: str, con: sqlite3.Connection):
|
||||||
|
query = "SELECT 1 FROM sqlite_master WHERE type='table' AND name=? LIMIT 1"
|
||||||
|
return len(con.execute(query, [tablename]).fetchall()) > 0
|
||||||
|
|
||||||
|
|
||||||
|
def split_semicolon_list(value: str):
|
||||||
|
if pd.isna(value):
|
||||||
|
return []
|
||||||
|
return [item.strip() for item in str(value).split(';') if item.strip()]
|
||||||
|
|
||||||
|
|
||||||
|
def build_dimension_and_mapping(postdf: pd.DataFrame, field_name: str, dim_col: str):
|
||||||
|
"""Extract unique dimension values and post-to-dimension mappings from a column."""
|
||||||
|
if postdf.empty or field_name not in postdf.columns:
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
values = set()
|
||||||
|
mapping_rows = []
|
||||||
|
|
||||||
|
for post_id, raw in zip(postdf['id'], postdf[field_name]):
|
||||||
|
items = split_semicolon_list(raw)
|
||||||
|
for item in items:
|
||||||
|
values.add(item)
|
||||||
|
mapping_rows.append({'post_id': post_id, dim_col: item})
|
||||||
|
|
||||||
|
if not values:
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
dim_df = pd.DataFrame({
|
||||||
|
'id': range(len(values)),
|
||||||
|
dim_col: sorted(values),
|
||||||
|
})
|
||||||
|
map_df = pd.DataFrame(mapping_rows)
|
||||||
|
return dim_df, map_df
|
||||||
|
|
||||||
|
|
||||||
|
def store_dimension_and_mapping(
|
||||||
|
con: sqlite3.Connection,
|
||||||
|
dim_df: pd.DataFrame | None,
|
||||||
|
map_df: pd.DataFrame | None,
|
||||||
|
table_name: str,
|
||||||
|
dim_col: str,
|
||||||
|
mapping_table: str,
|
||||||
|
mapping_id_col: str,
|
||||||
|
):
|
||||||
|
"""Persist a dimension table and its mapping table, merging with existing values."""
|
||||||
|
if dim_df is None or dim_df.empty:
|
||||||
|
return
|
||||||
|
|
||||||
|
if table_exists(table_name, con):
|
||||||
|
existing = pd.read_sql(f"SELECT id, {dim_col} FROM {table_name}", con)
|
||||||
|
merged = pd.concat([existing, dim_df], ignore_index=True)
|
||||||
|
merged = merged.drop_duplicates(subset=[dim_col], keep='first').reset_index(drop=True)
|
||||||
|
merged['id'] = range(len(merged))
|
||||||
|
else:
|
||||||
|
merged = dim_df.copy()
|
||||||
|
|
||||||
|
# Replace table with merged content
|
||||||
|
merged.to_sql(table_name, con, if_exists="replace", index=False)
|
||||||
|
|
||||||
|
if map_df is None or map_df.empty:
|
||||||
|
return
|
||||||
|
|
||||||
|
value_to_id = dict(zip(merged[dim_col], merged['id']))
|
||||||
|
map_df = map_df.copy()
|
||||||
|
map_df[mapping_id_col] = map_df[dim_col].map(value_to_id)
|
||||||
|
map_df = map_df[['post_id', mapping_id_col]].dropna()
|
||||||
|
map_df.to_sql(mapping_table, con, if_exists="append", index=False)
|
||||||
|
|
||||||
|
|
||||||
|
def download(id: int):
|
||||||
|
if id == 0:
|
||||||
|
return
|
||||||
|
base_url = "https://knack.news/"
|
||||||
|
url = f"{base_url}{id}"
|
||||||
|
res = requests.get(url)
|
||||||
|
|
||||||
|
# make sure we don't dos knack
|
||||||
|
time.sleep(2)
|
||||||
|
|
||||||
|
if not (200 <= res.status_code <= 300):
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.debug("Found promising page with id %d!", id)
|
||||||
|
|
||||||
|
content = res.content
|
||||||
|
soup = BeautifulSoup(content, "html.parser")
|
||||||
|
|
||||||
|
pC = soup.find("div", {"class": "postContent"})
|
||||||
|
|
||||||
|
if pC is None:
|
||||||
|
# not a normal post
|
||||||
|
logger.debug(
|
||||||
|
"Page with id %d does not have a .pageContent-div. Skipping for now.", id
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
# every post has these fields
|
||||||
|
title = pC.find("h3", {"class": "postTitle"}).text
|
||||||
|
postText = pC.find("div", {"class": "postText"})
|
||||||
|
|
||||||
|
# these fields are possible but not required
|
||||||
|
# TODO: cleanup
|
||||||
|
try:
|
||||||
|
date_parts = pC.find("span", {"class": "singledate"}).text.split(' ')
|
||||||
|
day = int(date_parts[0][:-1])
|
||||||
|
months = {'Januar': 1, 'Februar': 2, 'März': 3, 'April': 4, 'Mai': 5, 'Juni': 6, 'Juli': 7, 'August': 8, 'September': 9, 'Oktober': 10, 'November': 11, 'Dezember': 12}
|
||||||
|
month = months[date_parts[1]]
|
||||||
|
year = int(date_parts[2])
|
||||||
|
parsed_date = datetime(year, month, day)
|
||||||
|
except Exception:
|
||||||
|
parsed_date = None
|
||||||
|
|
||||||
|
try:
|
||||||
|
author = pC.find("span", {"class": "author"}).text
|
||||||
|
except AttributeError:
|
||||||
|
author = None
|
||||||
|
|
||||||
|
try:
|
||||||
|
category = pC.find("span", {"class": "categoryInfo"}).find_all()
|
||||||
|
category = [c.text for c in category if c.text != 'Alle Artikel']
|
||||||
|
category = ";".join(category)
|
||||||
|
except AttributeError:
|
||||||
|
category = None
|
||||||
|
|
||||||
|
try:
|
||||||
|
tags = [x.text for x in pC.find("div", {"class": "tagsInfo"}).find_all("a")]
|
||||||
|
tags = ";".join(tags)
|
||||||
|
except AttributeError:
|
||||||
|
tags = None
|
||||||
|
|
||||||
|
img = pC.find("img", {"class": "postImage"})
|
||||||
|
if img is not None:
|
||||||
|
img = img["src"]
|
||||||
|
|
||||||
|
res_dict = {
|
||||||
|
"id": id,
|
||||||
|
"title": title,
|
||||||
|
"author": author,
|
||||||
|
"date": parsed_date,
|
||||||
|
"category": category,
|
||||||
|
"url": url,
|
||||||
|
"img_link": img,
|
||||||
|
"tags": tags,
|
||||||
|
"text": postText.text,
|
||||||
|
"html": str(postText),
|
||||||
|
"scraped_at": datetime.now(),
|
||||||
|
"is_cleaned": False
|
||||||
|
}
|
||||||
|
|
||||||
|
return res_dict
|
||||||
|
|
||||||
|
|
||||||
|
def run_downloads(min_id: int, max_id: int, num_threads: int = 8):
|
||||||
|
res = []
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"Started parallel scrape of posts from id %d to id %d using %d threads.",
|
||||||
|
min_id,
|
||||||
|
max_id - 1,
|
||||||
|
num_threads,
|
||||||
|
)
|
||||||
|
with ThreadPoolExecutor(max_workers=num_threads) as executor:
|
||||||
|
# Use a list comprehension to create a list of futures
|
||||||
|
futures = [executor.submit(download, i) for i in range(min_id, max_id)]
|
||||||
|
|
||||||
|
for future in futures:
|
||||||
|
post = future.result()
|
||||||
|
if post is not None:
|
||||||
|
res.append(post)
|
||||||
|
|
||||||
|
postdf = pd.DataFrame(res)
|
||||||
|
return postdf
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
num_threads = int(os.environ.get("NUM_THREADS", 8))
|
||||||
|
n_scrapes = int(os.environ.get("NUM_SCRAPES", 100))
|
||||||
|
database_location = os.environ.get("DATABASE_LOCATION", "../data/knack.sqlite")
|
||||||
|
|
||||||
|
logger.debug(f"Started Knack Scraper: \nNUM_THREADS: {num_threads}\nN_SCRAPES: {n_scrapes}\nDATABASE_LOCATION: {database_location}")
|
||||||
|
|
||||||
|
con = sqlite3.connect(database_location)
|
||||||
|
with con:
|
||||||
|
if table_exists("posts", con):
|
||||||
|
logger.info("found posts retrieved earlier")
|
||||||
|
max_id_in_db = con.execute("SELECT MAX(id) FROM posts").fetchone()[0]
|
||||||
|
logger.info("Got max id %d!", max_id_in_db)
|
||||||
|
else:
|
||||||
|
logger.info("no posts scraped so far - starting from 0")
|
||||||
|
max_id_in_db = -1
|
||||||
|
|
||||||
|
postdf = run_downloads(
|
||||||
|
min_id=max_id_in_db + 1,
|
||||||
|
max_id=max_id_in_db + n_scrapes,
|
||||||
|
num_threads=num_threads,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Drop category and tags columns as they're stored in separate tables
|
||||||
|
postdf = postdf.drop(columns=['category', 'tags'])
|
||||||
|
postdf.to_sql("posts", con, if_exists="append", index=False)
|
||||||
|
|
||||||
|
# Tags
|
||||||
|
tag_dim, tag_map = build_dimension_and_mapping(postdf, 'tags', 'tag')
|
||||||
|
store_dimension_and_mapping(
|
||||||
|
con,
|
||||||
|
tag_dim,
|
||||||
|
tag_map,
|
||||||
|
table_name="tags",
|
||||||
|
dim_col="tag",
|
||||||
|
mapping_table="posttags",
|
||||||
|
mapping_id_col="tag_id",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Categories
|
||||||
|
category_dim, category_map = build_dimension_and_mapping(postdf, 'category', 'category')
|
||||||
|
store_dimension_and_mapping(
|
||||||
|
con,
|
||||||
|
category_dim,
|
||||||
|
category_map,
|
||||||
|
table_name="categories",
|
||||||
|
dim_col="category",
|
||||||
|
mapping_table="postcategories",
|
||||||
|
mapping_id_col="category_id",
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"scraped new entries. number of new posts: {len(postdf.index)}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
4
scrape/requirements.txt
Normal file
4
scrape/requirements.txt
Normal file
|
|
@ -0,0 +1,4 @@
|
||||||
|
pandas
|
||||||
|
requests
|
||||||
|
bs4
|
||||||
|
dotenv
|
||||||
4
transform/.env.example
Normal file
4
transform/.env.example
Normal file
|
|
@ -0,0 +1,4 @@
|
||||||
|
LOGGING_LEVEL=INFO
|
||||||
|
DB_PATH=/data/knack.sqlite
|
||||||
|
MAX_CLEANED_POSTS=1000
|
||||||
|
COMPUTE_DEVICE=mps
|
||||||
51
transform/Dockerfile
Normal file
51
transform/Dockerfile
Normal file
|
|
@ -0,0 +1,51 @@
|
||||||
|
FROM python:3.12-slim
|
||||||
|
|
||||||
|
RUN mkdir -p /app /data /models
|
||||||
|
|
||||||
|
# Install build dependencies
|
||||||
|
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||||
|
gcc \
|
||||||
|
g++ \
|
||||||
|
gfortran \
|
||||||
|
libopenblas-dev \
|
||||||
|
liblapack-dev \
|
||||||
|
pkg-config \
|
||||||
|
curl \
|
||||||
|
jq \
|
||||||
|
&& rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
|
ENV GLINER_MODEL_ID=urchade/gliner_multi-v2.1
|
||||||
|
ENV GLINER_MODEL_PATH=/models/gliner_multi-v2.1
|
||||||
|
|
||||||
|
ENV MINILM_MODEL_ID=sentence-transformers/all-MiniLM-L6-v2
|
||||||
|
ENV MINILM_MODEL_PATH=/models/all-MiniLM-L6-v2
|
||||||
|
|
||||||
|
WORKDIR /app
|
||||||
|
COPY requirements.txt .
|
||||||
|
RUN pip install --no-cache-dir -r requirements.txt
|
||||||
|
|
||||||
|
COPY .env .
|
||||||
|
|
||||||
|
RUN apt update -y
|
||||||
|
RUN apt install -y cron locales
|
||||||
|
|
||||||
|
# Ensure GLiNER helper scripts are available
|
||||||
|
COPY ensure_gliner_model.sh /usr/local/bin/ensure_gliner_model.sh
|
||||||
|
# Ensure MiniLM helper scripts are available
|
||||||
|
COPY ensure_minilm_model.sh /usr/local/bin/ensure_minilm_model.sh
|
||||||
|
COPY entrypoint.sh /usr/local/bin/entrypoint.sh
|
||||||
|
RUN chmod +x /usr/local/bin/ensure_gliner_model.sh /usr/local/bin/ensure_minilm_model.sh /usr/local/bin/entrypoint.sh
|
||||||
|
|
||||||
|
COPY *.py .
|
||||||
|
|
||||||
|
# Create cron job that runs every weekend (Sunday at 3 AM) 0 3 * * 0
|
||||||
|
# Testing every 30 Minutes */30 * * * *
|
||||||
|
RUN echo "*/30 * * * * cd /app && /usr/local/bin/python main.py >> /proc/1/fd/1 2>&1" > /etc/cron.d/knack-transform
|
||||||
|
RUN chmod 0644 /etc/cron.d/knack-transform
|
||||||
|
RUN crontab /etc/cron.d/knack-transform
|
||||||
|
|
||||||
|
# Persist models between container runs
|
||||||
|
VOLUME /models
|
||||||
|
|
||||||
|
CMD ["/usr/local/bin/entrypoint.sh"]
|
||||||
|
#CMD ["python", "main.py"]
|
||||||
67
transform/README.md
Normal file
67
transform/README.md
Normal file
|
|
@ -0,0 +1,67 @@
|
||||||
|
# Knack Transform
|
||||||
|
|
||||||
|
Data transformation pipeline for the Knack scraper project.
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
This folder contains the transformation logic that processes data from the SQLite database. It runs on a scheduled basis (every weekend) via cron.
|
||||||
|
|
||||||
|
The pipeline supports **parallel execution** of independent transform nodes, allowing you to leverage multi-core processors for faster data transformation.
|
||||||
|
|
||||||
|
## Structure
|
||||||
|
|
||||||
|
- `base.py` - Abstract base class for transform nodes
|
||||||
|
- `pipeline.py` - Parallel pipeline orchestration system
|
||||||
|
- `main.py` - Main entry point and pipeline execution
|
||||||
|
- `author_node.py` - NER-based author classification node
|
||||||
|
- `example_node.py` - Template for creating new nodes
|
||||||
|
- `Dockerfile` - Docker image configuration with cron setup
|
||||||
|
- `requirements.txt` - Python dependencies
|
||||||
|
|
||||||
|
## Transform Nodes
|
||||||
|
|
||||||
|
Transform nodes inherit from `TransformNode` and implement the `run` method:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from base import TransformNode, TransformContext
|
||||||
|
import sqlite3
|
||||||
|
|
||||||
|
class MyTransform(TransformNode):
|
||||||
|
def run(self, con: sqlite3.Connection, context: TransformContext) -> TransformContext:
|
||||||
|
df = context.get_dataframe()
|
||||||
|
|
||||||
|
# Transform logic here
|
||||||
|
transformed_df = df.copy()
|
||||||
|
# ... your transformations ...
|
||||||
|
|
||||||
|
# Optionally write back to database
|
||||||
|
transformed_df.to_sql("my_table", con, if_exists="replace", index=False)
|
||||||
|
|
||||||
|
return TransformContext(transformed_df)
|
||||||
|
```
|
||||||
|
|
||||||
|
## Configuration
|
||||||
|
|
||||||
|
Copy `.env.example` to `.env` and configure:
|
||||||
|
|
||||||
|
- `LOGGING_LEVEL` - Log level (INFO or DEBUG)
|
||||||
|
- `DB_PATH` - Path to SQLite database
|
||||||
|
|
||||||
|
## Running
|
||||||
|
|
||||||
|
### With Docker
|
||||||
|
|
||||||
|
```bash
|
||||||
|
docker build -t knack-transform .
|
||||||
|
docker run -v $(pwd)/data:/data knack-transform
|
||||||
|
```
|
||||||
|
|
||||||
|
### Locally
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python main.py
|
||||||
|
```
|
||||||
|
|
||||||
|
## Cron Schedule
|
||||||
|
|
||||||
|
The Docker container runs the transform pipeline every Sunday at 3 AM.
|
||||||
420
transform/author_node.py
Normal file
420
transform/author_node.py
Normal file
|
|
@ -0,0 +1,420 @@
|
||||||
|
"""Author classification transform node using NER."""
|
||||||
|
import os
|
||||||
|
import sqlite3
|
||||||
|
import pandas as pd
|
||||||
|
import logging
|
||||||
|
import fuzzysearch
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
|
||||||
|
from pipeline import TransformContext
|
||||||
|
from transform_node import TransformNode
|
||||||
|
|
||||||
|
logger = logging.getLogger("knack-transform")
|
||||||
|
|
||||||
|
try:
|
||||||
|
from gliner import GLiNER
|
||||||
|
import torch
|
||||||
|
GLINER_AVAILABLE = True
|
||||||
|
except ImportError:
|
||||||
|
GLINER_AVAILABLE = False
|
||||||
|
logging.warning("GLiNER not available. Install with: pip install gliner")
|
||||||
|
|
||||||
|
class NerAuthorNode(TransformNode):
|
||||||
|
"""Transform node that extracts and classifies authors using NER.
|
||||||
|
|
||||||
|
Creates two tables:
|
||||||
|
- authors: stores unique authors with their type (Person, Organisation, etc.)
|
||||||
|
- post_authors: maps posts to their authors
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, model_name: str = "urchade/gliner_multi-v2.1",
|
||||||
|
model_path: str = None,
|
||||||
|
threshold: float = 0.5,
|
||||||
|
max_workers: int = 64,
|
||||||
|
device: str = "cpu"):
|
||||||
|
"""Initialize the AuthorNode.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_name: GLiNER model to use
|
||||||
|
model_path: Optional local path to a downloaded GLiNER model
|
||||||
|
threshold: Confidence threshold for entity predictions
|
||||||
|
max_workers: Number of parallel workers for prediction
|
||||||
|
device: Device to run model on ('cpu', 'cuda', 'mps')
|
||||||
|
"""
|
||||||
|
self.model_name = model_name
|
||||||
|
self.model_path = model_path or os.environ.get('GLINER_MODEL_PATH')
|
||||||
|
self.threshold = threshold
|
||||||
|
self.max_workers = max_workers
|
||||||
|
self.device = device
|
||||||
|
self.model = None
|
||||||
|
self.labels = ["Person", "Organisation", "Email", "Newspaper", "NGO"]
|
||||||
|
|
||||||
|
def _setup_model(self):
|
||||||
|
"""Initialize the NER model."""
|
||||||
|
if not GLINER_AVAILABLE:
|
||||||
|
raise ImportError("GLiNER is required for AuthorNode. Install with: pip install gliner")
|
||||||
|
|
||||||
|
model_source = None
|
||||||
|
if self.model_path:
|
||||||
|
if os.path.exists(self.model_path):
|
||||||
|
model_source = self.model_path
|
||||||
|
logger.info(f"Loading GLiNER model from local path: {self.model_path}")
|
||||||
|
else:
|
||||||
|
logger.warning(f"GLINER_MODEL_PATH '{self.model_path}' not found; falling back to hub model {self.model_name}")
|
||||||
|
|
||||||
|
if model_source is None:
|
||||||
|
model_source = self.model_name
|
||||||
|
logger.info(f"Loading GLiNER model from hub: {self.model_name}")
|
||||||
|
|
||||||
|
if self.device == "cuda" and torch.cuda.is_available():
|
||||||
|
self.model = GLiNER.from_pretrained(
|
||||||
|
model_source,
|
||||||
|
max_length=255
|
||||||
|
).to('cuda', dtype=torch.float16)
|
||||||
|
elif self.device == "mps" and torch.backends.mps.is_available():
|
||||||
|
self.model = GLiNER.from_pretrained(
|
||||||
|
model_source,
|
||||||
|
max_length=255
|
||||||
|
).to('mps', dtype=torch.float16)
|
||||||
|
else:
|
||||||
|
self.model = GLiNER.from_pretrained(
|
||||||
|
model_source,
|
||||||
|
max_length=255
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info("Model loaded successfully")
|
||||||
|
|
||||||
|
def _predict(self, text_data: dict):
|
||||||
|
"""Predict entities for a single author text.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text_data: Dict with 'author' and 'id' keys
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (predictions, post_id) or None
|
||||||
|
"""
|
||||||
|
if text_data is None or text_data.get('author') is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
predictions = self.model.predict_entities(
|
||||||
|
text_data['author'],
|
||||||
|
self.labels,
|
||||||
|
threshold=self.threshold
|
||||||
|
)
|
||||||
|
return predictions, text_data['id']
|
||||||
|
|
||||||
|
def _classify_authors(self, posts_df: pd.DataFrame):
|
||||||
|
"""Classify all authors in the posts dataframe.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
posts_df: DataFrame with 'id' and 'author' columns
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of dicts with 'text', 'label', 'id' keys
|
||||||
|
"""
|
||||||
|
if self.model is None:
|
||||||
|
self._setup_model()
|
||||||
|
|
||||||
|
# Prepare input data
|
||||||
|
authors_data = []
|
||||||
|
for idx, row in posts_df.iterrows():
|
||||||
|
if pd.notna(row['author']):
|
||||||
|
authors_data.append({
|
||||||
|
'author': row['author'],
|
||||||
|
'id': row['id']
|
||||||
|
})
|
||||||
|
|
||||||
|
logger.info(f"Classifying {len(authors_data)} authors")
|
||||||
|
|
||||||
|
results = []
|
||||||
|
with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
|
||||||
|
futures = [executor.submit(self._predict, data) for data in authors_data]
|
||||||
|
|
||||||
|
for future in futures:
|
||||||
|
result = future.result()
|
||||||
|
if result is not None:
|
||||||
|
predictions, post_id = result
|
||||||
|
for pred in predictions:
|
||||||
|
results.append({
|
||||||
|
'text': pred['text'],
|
||||||
|
'label': pred['label'],
|
||||||
|
'id': post_id
|
||||||
|
})
|
||||||
|
|
||||||
|
logger.info(f"Classification complete. Found {len(results)} author entities")
|
||||||
|
return results
|
||||||
|
|
||||||
|
def _create_tables(self, con: sqlite3.Connection):
|
||||||
|
"""Create authors and post_authors tables if they don't exist."""
|
||||||
|
logger.info("Creating authors tables")
|
||||||
|
|
||||||
|
con.execute("""
|
||||||
|
CREATE TABLE IF NOT EXISTS authors (
|
||||||
|
id INTEGER PRIMARY KEY,
|
||||||
|
name TEXT,
|
||||||
|
type TEXT,
|
||||||
|
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
||||||
|
)
|
||||||
|
""")
|
||||||
|
|
||||||
|
con.execute("""
|
||||||
|
CREATE TABLE IF NOT EXISTS post_authors (
|
||||||
|
post_id INTEGER,
|
||||||
|
author_id INTEGER,
|
||||||
|
PRIMARY KEY (post_id, author_id),
|
||||||
|
FOREIGN KEY (post_id) REFERENCES posts(id),
|
||||||
|
FOREIGN KEY (author_id) REFERENCES authors(id)
|
||||||
|
)
|
||||||
|
""")
|
||||||
|
|
||||||
|
con.commit()
|
||||||
|
|
||||||
|
def _store_authors(self, con: sqlite3.Connection, results: list):
|
||||||
|
"""Store classified authors and their mappings.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
con: Database connection
|
||||||
|
results: List of classification results
|
||||||
|
"""
|
||||||
|
if not results:
|
||||||
|
logger.info("No authors to store")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Convert results to DataFrame
|
||||||
|
results_df = pd.DataFrame(results)
|
||||||
|
|
||||||
|
# Get unique authors with their types
|
||||||
|
unique_authors = results_df[['text', 'label']].drop_duplicates()
|
||||||
|
unique_authors.columns = ['name', 'type']
|
||||||
|
|
||||||
|
# Get existing authors
|
||||||
|
existing_authors = pd.read_sql("SELECT id, name FROM authors", con)
|
||||||
|
|
||||||
|
# Find new authors to insert
|
||||||
|
if not existing_authors.empty:
|
||||||
|
new_authors = unique_authors[~unique_authors['name'].isin(existing_authors['name'])]
|
||||||
|
else:
|
||||||
|
new_authors = unique_authors
|
||||||
|
|
||||||
|
if not new_authors.empty:
|
||||||
|
logger.info(f"Inserting {len(new_authors)} new authors")
|
||||||
|
new_authors.to_sql('authors', con, if_exists='append', index=False)
|
||||||
|
|
||||||
|
# Get all authors with their IDs
|
||||||
|
all_authors = pd.read_sql("SELECT id, name FROM authors", con)
|
||||||
|
name_to_id = dict(zip(all_authors['name'], all_authors['id']))
|
||||||
|
|
||||||
|
# Create post_authors mappings
|
||||||
|
mappings = []
|
||||||
|
for _, row in results_df.iterrows():
|
||||||
|
author_id = name_to_id.get(row['text'])
|
||||||
|
if author_id:
|
||||||
|
mappings.append({
|
||||||
|
'post_id': row['id'],
|
||||||
|
'author_id': author_id
|
||||||
|
})
|
||||||
|
|
||||||
|
if mappings:
|
||||||
|
mappings_df = pd.DataFrame(mappings).drop_duplicates()
|
||||||
|
|
||||||
|
# Clear existing mappings for these posts (optional, depends on your strategy)
|
||||||
|
# post_ids = tuple(mappings_df['post_id'].unique())
|
||||||
|
# con.execute(f"DELETE FROM post_authors WHERE post_id IN ({','.join('?' * len(post_ids))})", post_ids)
|
||||||
|
|
||||||
|
logger.info(f"Creating {len(mappings_df)} post-author mappings")
|
||||||
|
mappings_df.to_sql('post_authors', con, if_exists='append', index=False)
|
||||||
|
|
||||||
|
con.commit()
|
||||||
|
logger.info("Authors and mappings stored successfully")
|
||||||
|
|
||||||
|
def run(self, con: sqlite3.Connection, context: TransformContext) -> TransformContext:
|
||||||
|
"""Execute the author classification transformation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
con: SQLite database connection
|
||||||
|
context: TransformContext containing posts dataframe
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
TransformContext with classified authors dataframe
|
||||||
|
"""
|
||||||
|
logger.info("Starting AuthorNode transformation")
|
||||||
|
|
||||||
|
posts_df = context.get_dataframe()
|
||||||
|
|
||||||
|
# Ensure required columns exist
|
||||||
|
if 'author' not in posts_df.columns:
|
||||||
|
logger.warning("No 'author' column in dataframe. Skipping AuthorNode.")
|
||||||
|
return context
|
||||||
|
|
||||||
|
# Create tables
|
||||||
|
self._create_tables(con)
|
||||||
|
|
||||||
|
# Classify authors
|
||||||
|
results = self._classify_authors(posts_df)
|
||||||
|
|
||||||
|
# Store results
|
||||||
|
self._store_authors(con, results)
|
||||||
|
|
||||||
|
# Return context with results
|
||||||
|
logger.info("AuthorNode transformation complete")
|
||||||
|
|
||||||
|
return TransformContext(posts_df)
|
||||||
|
|
||||||
|
|
||||||
|
class FuzzyAuthorNode(TransformNode):
|
||||||
|
"""FuzzyAuthorNode
|
||||||
|
|
||||||
|
This Node takes in data and rules of authornames that have been classified already
|
||||||
|
and uses those 'rule' to find more similar fields.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
max_l_dist: int = 1,):
|
||||||
|
"""Initialize FuzzyAuthorNode.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
max_l_dist: The number of 'errors' that are allowed by the fuzzy search algorithm
|
||||||
|
"""
|
||||||
|
self.max_l_dist = max_l_dist
|
||||||
|
logger.info(f"Initialized FuzzyAuthorNode with max_l_dist={max_l_dist}")
|
||||||
|
|
||||||
|
def _process_data(self, con: sqlite3.Connection, df: pd.DataFrame) -> pd.DataFrame:
|
||||||
|
"""Process the input dataframe.
|
||||||
|
|
||||||
|
This is where your main transformation logic goes.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
con: Database connection
|
||||||
|
df: Input dataframe from context
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Processed dataframe
|
||||||
|
"""
|
||||||
|
logger.info(f"Processing {len(df)} rows")
|
||||||
|
|
||||||
|
# Retrieve all known authors from the authors table as 'rules'
|
||||||
|
authors_df = pd.read_sql("SELECT id, name FROM authors", con)
|
||||||
|
|
||||||
|
if authors_df.empty:
|
||||||
|
logger.warning("No authors found in database for fuzzy matching")
|
||||||
|
return pd.DataFrame(columns=['post_id', 'author_id'])
|
||||||
|
|
||||||
|
# Get existing post-author mappings to avoid duplicates
|
||||||
|
existing_mappings = pd.read_sql(
|
||||||
|
"SELECT post_id, author_id FROM post_authors", con
|
||||||
|
)
|
||||||
|
existing_post_ids = set(existing_mappings['post_id'].unique())
|
||||||
|
|
||||||
|
logger.info(f"Found {len(authors_df)} known authors for fuzzy matching")
|
||||||
|
logger.info(f"Found {len(existing_post_ids)} posts with existing author mappings")
|
||||||
|
|
||||||
|
# Filter to posts without author mappings and with non-null author field
|
||||||
|
if 'author' not in df.columns or 'id' not in df.columns:
|
||||||
|
logger.warning("Missing 'author' or 'id' column in input dataframe")
|
||||||
|
return pd.DataFrame(columns=['post_id', 'author_id'])
|
||||||
|
|
||||||
|
posts_to_process = df[
|
||||||
|
(df['id'].notna()) &
|
||||||
|
(df['author'].notna()) &
|
||||||
|
(~df['id'].isin(existing_post_ids))
|
||||||
|
]
|
||||||
|
|
||||||
|
logger.info(f"Processing {len(posts_to_process)} posts for fuzzy matching")
|
||||||
|
|
||||||
|
# Perform fuzzy matching
|
||||||
|
mappings = []
|
||||||
|
for _, post_row in posts_to_process.iterrows():
|
||||||
|
post_id = post_row['id']
|
||||||
|
post_author = str(post_row['author'])
|
||||||
|
|
||||||
|
# Try to find matches against all known author names
|
||||||
|
for _, author_row in authors_df.iterrows():
|
||||||
|
author_id = author_row['id']
|
||||||
|
author_name = str(author_row['name'])
|
||||||
|
# for author names < than 2 characters I want a fault tolerance of 0!
|
||||||
|
l_dist = self.max_l_dist if len(author_name) > 2 else 0
|
||||||
|
|
||||||
|
# Use fuzzysearch to find matches with allowed errors
|
||||||
|
matches = fuzzysearch.find_near_matches(
|
||||||
|
author_name,
|
||||||
|
post_author,
|
||||||
|
max_l_dist=l_dist,
|
||||||
|
)
|
||||||
|
|
||||||
|
if matches:
|
||||||
|
logger.debug(f"Found fuzzy match: '{author_name}' in '{post_author}' for post {post_id}")
|
||||||
|
mappings.append({
|
||||||
|
'post_id': post_id,
|
||||||
|
'author_id': author_id
|
||||||
|
})
|
||||||
|
# Only take the first match per post to avoid multiple mappings
|
||||||
|
break
|
||||||
|
|
||||||
|
# Create result dataframe
|
||||||
|
result_df = pd.DataFrame(mappings, columns=['post_id', 'author_id']) if mappings else pd.DataFrame(columns=['post_id', 'author_id'])
|
||||||
|
|
||||||
|
logger.info(f"Processing complete. Found {len(result_df)} fuzzy matches")
|
||||||
|
return result_df
|
||||||
|
|
||||||
|
def _store_results(self, con: sqlite3.Connection, df: pd.DataFrame):
|
||||||
|
"""Store results back to the database.
|
||||||
|
|
||||||
|
Uses INSERT OR IGNORE to avoid inserting duplicates.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
con: Database connection
|
||||||
|
df: Processed dataframe to store
|
||||||
|
"""
|
||||||
|
if df.empty:
|
||||||
|
logger.info("No results to store")
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.info(f"Storing {len(df)} results")
|
||||||
|
|
||||||
|
# Use INSERT OR IGNORE to handle duplicates (respects PRIMARY KEY constraint)
|
||||||
|
cursor = con.cursor()
|
||||||
|
inserted_count = 0
|
||||||
|
|
||||||
|
for _, row in df.iterrows():
|
||||||
|
cursor.execute(
|
||||||
|
"INSERT OR IGNORE INTO post_authors (post_id, author_id) VALUES (?, ?)",
|
||||||
|
(int(row['post_id']), int(row['author_id']))
|
||||||
|
)
|
||||||
|
if cursor.rowcount > 0:
|
||||||
|
inserted_count += 1
|
||||||
|
|
||||||
|
con.commit()
|
||||||
|
logger.info(f"Results stored successfully. Inserted {inserted_count} new mappings, skipped {len(df) - inserted_count} duplicates")
|
||||||
|
|
||||||
|
def run(self, con: sqlite3.Connection, context: TransformContext) -> TransformContext:
|
||||||
|
"""Execute the transformation.
|
||||||
|
|
||||||
|
This is the main entry point called by the pipeline.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
con: SQLite database connection
|
||||||
|
context: TransformContext containing input dataframe
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
TransformContext with processed dataframe
|
||||||
|
"""
|
||||||
|
logger.info("Starting FuzzyAuthorNode transformation")
|
||||||
|
|
||||||
|
# Get input dataframe from context
|
||||||
|
input_df = context.get_dataframe()
|
||||||
|
|
||||||
|
# Validate input
|
||||||
|
if input_df.empty:
|
||||||
|
logger.warning("Empty dataframe provided to FuzzyAuthorNode")
|
||||||
|
return context
|
||||||
|
|
||||||
|
# Process the data
|
||||||
|
result_df = self._process_data(con, input_df)
|
||||||
|
|
||||||
|
# Store results
|
||||||
|
self._store_results(con, result_df)
|
||||||
|
|
||||||
|
logger.info("FuzzyAuthorNode transformation complete")
|
||||||
|
|
||||||
|
# Return new context with results
|
||||||
|
return TransformContext(input_df)
|
||||||
445
transform/embeddings_node.py
Normal file
445
transform/embeddings_node.py
Normal file
|
|
@ -0,0 +1,445 @@
|
||||||
|
"""Classes of Transformernodes that have to do with
|
||||||
|
text processing.
|
||||||
|
|
||||||
|
- TextEmbeddingNode calculates text embeddings
|
||||||
|
- UmapNode calculates xy coordinates on those vector embeddings
|
||||||
|
- SimilarityNode calculates top n similar posts based on those embeddings
|
||||||
|
using the spectral distance.
|
||||||
|
"""
|
||||||
|
from pipeline import TransformContext
|
||||||
|
from transform_node import TransformNode
|
||||||
|
import sqlite3
|
||||||
|
import pandas as pd
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
logger = logging.getLogger("knack-transform")
|
||||||
|
|
||||||
|
try:
|
||||||
|
from sentence_transformers import SentenceTransformer
|
||||||
|
import torch
|
||||||
|
MINILM_AVAILABLE = True
|
||||||
|
except ImportError:
|
||||||
|
MINILM_AVAILABLE = False
|
||||||
|
logging.warning("MiniLM not available. Install with pip!")
|
||||||
|
|
||||||
|
try:
|
||||||
|
import umap
|
||||||
|
UMAP_AVAILABLE = True
|
||||||
|
except ImportError:
|
||||||
|
UMAP_AVAILABLE = False
|
||||||
|
logging.warning("UMAP not available. Install with pip install umap-learn!")
|
||||||
|
|
||||||
|
class TextEmbeddingNode(TransformNode):
|
||||||
|
"""Calculates vector embeddings based on a dataframe
|
||||||
|
of posts.
|
||||||
|
"""
|
||||||
|
def __init__(self,
|
||||||
|
model_name: str = "sentence-transformers/all-MiniLM-L6-v2",
|
||||||
|
model_path: str = None,
|
||||||
|
device: str = "cpu"):
|
||||||
|
"""Initialize the ExampleNode.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_name: Name of the ML Model to calculate text embeddings
|
||||||
|
model_path: Optional local path to a downloaded embedding model
|
||||||
|
device: Device to use for computations ('cpu', 'cuda', 'mps')
|
||||||
|
"""
|
||||||
|
self.model_name = model_name
|
||||||
|
self.model_path = model_path or os.environ.get('MINILM_MODEL_PATH')
|
||||||
|
self.device = device
|
||||||
|
self.model = None
|
||||||
|
logger.info(f"Initialized TextEmbeddingNode with model_name={model_name}, model_path={model_path}, device={device}")
|
||||||
|
|
||||||
|
def _setup_model(self):
|
||||||
|
"""Init the Text Embedding Model."""
|
||||||
|
if not MINILM_AVAILABLE:
|
||||||
|
raise ImportError("MiniLM is required for TextEmbeddingNode. Please install.")
|
||||||
|
|
||||||
|
model_source = None
|
||||||
|
if self.model_path:
|
||||||
|
if os.path.exists(self.model_path):
|
||||||
|
model_source = self.model_path
|
||||||
|
logger.info(f"Loading MiniLM model from local path: {self.model_path}")
|
||||||
|
else:
|
||||||
|
logger.warning(f"MiniLM_MODEL_PATH '{self.model_path}' not found; Falling back to hub model {self.model_name}")
|
||||||
|
|
||||||
|
if model_source is None:
|
||||||
|
model_source = self.model_name
|
||||||
|
logger.info(f"Loading MiniLM model from the hub: {self.model_name}")
|
||||||
|
|
||||||
|
if self.device == "cuda" and torch.cuda.is_available():
|
||||||
|
self.model = SentenceTransformer(model_source).to('cuda', dtype=torch.float16)
|
||||||
|
elif self.device == "mps" and torch.backends.mps.is_available():
|
||||||
|
self.model = SentenceTransformer(model_source).to('mps', dtype=torch.float16)
|
||||||
|
else:
|
||||||
|
self.model = SentenceTransformer(model_source)
|
||||||
|
|
||||||
|
def _process_data(self, df: pd.DataFrame) -> pd.DataFrame:
|
||||||
|
"""Process the input dataframe.
|
||||||
|
|
||||||
|
Calculates an embedding as a np.array.
|
||||||
|
Also pickles that array to prepare it to
|
||||||
|
storage in the database.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
df: Input dataframe from context
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Processed dataframe
|
||||||
|
"""
|
||||||
|
logger.info(f"Processing {len(df)} rows")
|
||||||
|
|
||||||
|
if self.model is None:
|
||||||
|
self._setup_model()
|
||||||
|
|
||||||
|
# Example: Add a new column based on existing data
|
||||||
|
result_df = df.copy()
|
||||||
|
|
||||||
|
df['embedding'] = df['text'].apply(lambda x: self.model.encode(x, convert_to_numpy=True))
|
||||||
|
|
||||||
|
logger.info("Processing complete")
|
||||||
|
return result_df
|
||||||
|
|
||||||
|
def _store_results(self, con: sqlite3.Connection, df: pd.DataFrame):
|
||||||
|
"""Store results back to the database using batch updates."""
|
||||||
|
if df.empty:
|
||||||
|
logger.info("No results to store")
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.info(f"Storing {len(df)} results")
|
||||||
|
|
||||||
|
# Convert numpy arrays to bytes for BLOB storage
|
||||||
|
# Use tobytes() to serialize numpy arrays efficiently
|
||||||
|
updates = [(row['embedding'].tobytes(), row['id']) for _, row in df.iterrows()]
|
||||||
|
con.executemany(
|
||||||
|
"UPDATE posts SET embedding = ? WHERE id = ?",
|
||||||
|
updates
|
||||||
|
)
|
||||||
|
|
||||||
|
con.commit()
|
||||||
|
logger.info("Results stored successfully")
|
||||||
|
|
||||||
|
def run(self, con: sqlite3.Connection, context: TransformContext) -> TransformContext:
|
||||||
|
"""Execute the transformation.
|
||||||
|
|
||||||
|
This is the main entry point called by the pipeline.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
con: SQLite database connection
|
||||||
|
context: TransformContext containing input dataframe
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
TransformContext with processed dataframe
|
||||||
|
"""
|
||||||
|
logger.info("Starting TextEmbeddingNode transformation")
|
||||||
|
|
||||||
|
# Get input dataframe from context
|
||||||
|
input_df = context.get_dataframe()
|
||||||
|
|
||||||
|
# Validate input
|
||||||
|
if input_df.empty:
|
||||||
|
logger.warning("Empty dataframe provided to TextEmbeddingNdode")
|
||||||
|
return context
|
||||||
|
|
||||||
|
if 'text' not in input_df.columns:
|
||||||
|
logger.warning("No 'text' column in context dataframe. Skipping TextEmbeddingNode")
|
||||||
|
return context
|
||||||
|
|
||||||
|
# Process the data
|
||||||
|
result_df = self._process_data(input_df)
|
||||||
|
|
||||||
|
# Store results (optional)
|
||||||
|
self._store_results(con, result_df)
|
||||||
|
|
||||||
|
logger.info("TextEmbeddingNode transformation complete")
|
||||||
|
|
||||||
|
# Return new context with results
|
||||||
|
return TransformContext(result_df)
|
||||||
|
|
||||||
|
|
||||||
|
class UmapNode(TransformNode):
|
||||||
|
"""Calculates 2D coordinates from embeddings using UMAP dimensionality reduction.
|
||||||
|
|
||||||
|
This node takes text embeddings and reduces them to 2D coordinates
|
||||||
|
for visualization purposes.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
n_neighbors: int = 15,
|
||||||
|
min_dist: float = 0.1,
|
||||||
|
n_components: int = 2,
|
||||||
|
metric: str = "cosine",
|
||||||
|
random_state: int = 42):
|
||||||
|
"""Initialize the UmapNode.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
n_neighbors: Number of neighbors to consider for UMAP (default: 15)
|
||||||
|
min_dist: Minimum distance between points in low-dimensional space (default: 0.1)
|
||||||
|
n_components: Number of dimensions to reduce to (default: 2)
|
||||||
|
metric: Distance metric to use (default: 'cosine')
|
||||||
|
random_state: Random seed for reproducibility (default: 42)
|
||||||
|
"""
|
||||||
|
self.n_neighbors = n_neighbors
|
||||||
|
self.min_dist = min_dist
|
||||||
|
self.n_components = n_components
|
||||||
|
self.metric = metric
|
||||||
|
self.random_state = random_state
|
||||||
|
self.reducer = None
|
||||||
|
logger.info(f"Initialized UmapNode with n_neighbors={n_neighbors}, min_dist={min_dist}, "
|
||||||
|
f"n_components={n_components}, metric={metric}, random_state={random_state}")
|
||||||
|
|
||||||
|
def _process_data(self, df: pd.DataFrame) -> pd.DataFrame:
|
||||||
|
"""Process the input dataframe.
|
||||||
|
|
||||||
|
Retrieves embeddings from BLOB storage, converts them back to numpy arrays,
|
||||||
|
and applies UMAP dimensionality reduction to create 2D coordinates.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
df: Input dataframe from context
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Processed dataframe with umap_x and umap_y columns
|
||||||
|
"""
|
||||||
|
logger.info(f"Processing {len(df)} rows")
|
||||||
|
|
||||||
|
if not UMAP_AVAILABLE:
|
||||||
|
raise ImportError("UMAP is required for UmapNode. Install with: pip install umap-learn")
|
||||||
|
|
||||||
|
result_df = df.copy()
|
||||||
|
|
||||||
|
# Convert BLOB embeddings back to numpy arrays
|
||||||
|
if 'embedding' not in result_df.columns:
|
||||||
|
logger.error("No 'embedding' column found in dataframe")
|
||||||
|
raise ValueError("Input dataframe must contain 'embedding' column")
|
||||||
|
|
||||||
|
logger.info("Converting embeddings from BLOB to numpy arrays")
|
||||||
|
result_df['embedding'] = result_df['embedding'].apply(
|
||||||
|
lambda x: np.frombuffer(x, dtype=np.float32) if x is not None else None
|
||||||
|
)
|
||||||
|
|
||||||
|
# Filter out rows with None embeddings
|
||||||
|
valid_rows = result_df['embedding'].notna()
|
||||||
|
if not valid_rows.any():
|
||||||
|
logger.error("No valid embeddings found in dataframe")
|
||||||
|
raise ValueError("No valid embeddings to process")
|
||||||
|
|
||||||
|
logger.info(f"Found {valid_rows.sum()} valid embeddings out of {len(result_df)} rows")
|
||||||
|
|
||||||
|
# Stack embeddings into a matrix
|
||||||
|
embeddings_matrix = np.vstack(result_df.loc[valid_rows, 'embedding'].values)
|
||||||
|
logger.info(f"Embeddings matrix shape: {embeddings_matrix.shape}")
|
||||||
|
|
||||||
|
# Apply UMAP
|
||||||
|
logger.info("Fitting UMAP reducer...")
|
||||||
|
self.reducer = umap.UMAP(
|
||||||
|
n_neighbors=self.n_neighbors,
|
||||||
|
min_dist=self.min_dist,
|
||||||
|
n_components=self.n_components,
|
||||||
|
metric=self.metric,
|
||||||
|
random_state=self.random_state
|
||||||
|
)
|
||||||
|
|
||||||
|
umap_coords = self.reducer.fit_transform(embeddings_matrix)
|
||||||
|
logger.info(f"UMAP transformation complete. Output shape: {umap_coords.shape}")
|
||||||
|
|
||||||
|
# Add UMAP coordinates to dataframe
|
||||||
|
result_df.loc[valid_rows, 'umap_x'] = umap_coords[:, 0]
|
||||||
|
result_df.loc[valid_rows, 'umap_y'] = umap_coords[:, 1]
|
||||||
|
|
||||||
|
# Fill NaN for invalid rows
|
||||||
|
result_df['umap_x'] = result_df['umap_x'].fillna(None)
|
||||||
|
result_df['umap_y'] = result_df['umap_y'].fillna(None)
|
||||||
|
|
||||||
|
logger.info("Processing complete")
|
||||||
|
return result_df
|
||||||
|
|
||||||
|
def _store_results(self, con: sqlite3.Connection, df: pd.DataFrame):
|
||||||
|
"""Store UMAP coordinates back to the database.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
con: Database connection
|
||||||
|
df: Processed dataframe with umap_x and umap_y columns
|
||||||
|
"""
|
||||||
|
if df.empty:
|
||||||
|
logger.info("No results to store")
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.info(f"Storing {len(df)} results")
|
||||||
|
|
||||||
|
# Batch update UMAP coordinates
|
||||||
|
updates = [
|
||||||
|
(row['umap_x'], row['umap_y'], row['id'])
|
||||||
|
for _, row in df.iterrows()
|
||||||
|
if pd.notna(row.get('umap_x')) and pd.notna(row.get('umap_y'))
|
||||||
|
]
|
||||||
|
|
||||||
|
if updates:
|
||||||
|
con.executemany(
|
||||||
|
"UPDATE posts SET umap_x = ?, umap_y = ? WHERE id = ?",
|
||||||
|
updates
|
||||||
|
)
|
||||||
|
con.commit()
|
||||||
|
logger.info(f"Stored {len(updates)} UMAP coordinate pairs successfully")
|
||||||
|
else:
|
||||||
|
logger.warning("No valid UMAP coordinates to store")
|
||||||
|
|
||||||
|
def run(self, con: sqlite3.Connection, context: TransformContext) -> TransformContext:
|
||||||
|
"""Execute the transformation.
|
||||||
|
|
||||||
|
This is the main entry point called by the pipeline.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
con: SQLite database connection
|
||||||
|
context: TransformContext containing input dataframe
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
TransformContext with processed dataframe
|
||||||
|
"""
|
||||||
|
logger.info("Starting ExampleNode transformation")
|
||||||
|
|
||||||
|
# Get input dataframe from context
|
||||||
|
input_df = context.get_dataframe()
|
||||||
|
|
||||||
|
# Validate input
|
||||||
|
if input_df.empty:
|
||||||
|
logger.warning("Empty dataframe provided to ExampleNode")
|
||||||
|
return context
|
||||||
|
|
||||||
|
# Process the data
|
||||||
|
result_df = self._process_data(input_df)
|
||||||
|
|
||||||
|
# Store results (optional)
|
||||||
|
self._store_results(con, result_df)
|
||||||
|
|
||||||
|
logger.info("ExampleNode transformation complete")
|
||||||
|
|
||||||
|
# Return new context with results
|
||||||
|
return TransformContext(result_df)
|
||||||
|
|
||||||
|
|
||||||
|
class SimilarityNode(TransformNode):
|
||||||
|
"""Example transform node template.
|
||||||
|
|
||||||
|
This node demonstrates the basic structure for creating
|
||||||
|
new transformation nodes in the pipeline.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
param1: str = "default_value",
|
||||||
|
param2: int = 42,
|
||||||
|
device: str = "cpu"):
|
||||||
|
"""Initialize the ExampleNode.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
param1: Example string parameter
|
||||||
|
param2: Example integer parameter
|
||||||
|
device: Device to use for computations ('cpu', 'cuda', 'mps')
|
||||||
|
"""
|
||||||
|
self.param1 = param1
|
||||||
|
self.param2 = param2
|
||||||
|
self.device = device
|
||||||
|
logger.info(f"Initialized ExampleNode with param1={param1}, param2={param2}")
|
||||||
|
|
||||||
|
def _create_tables(self, con: sqlite3.Connection):
|
||||||
|
"""Create any necessary tables in the database.
|
||||||
|
|
||||||
|
This is optional - only needed if your node creates new tables.
|
||||||
|
"""
|
||||||
|
logger.info("Creating example tables")
|
||||||
|
|
||||||
|
con.execute("""
|
||||||
|
CREATE TABLE IF NOT EXISTS example_results (
|
||||||
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
|
post_id INTEGER,
|
||||||
|
result_value TEXT,
|
||||||
|
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||||
|
FOREIGN KEY (post_id) REFERENCES posts(id)
|
||||||
|
)
|
||||||
|
""")
|
||||||
|
|
||||||
|
con.commit()
|
||||||
|
|
||||||
|
def _process_data(self, df: pd.DataFrame) -> pd.DataFrame:
|
||||||
|
"""Process the input dataframe.
|
||||||
|
|
||||||
|
This is where your main transformation logic goes.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
df: Input dataframe from context
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Processed dataframe
|
||||||
|
"""
|
||||||
|
logger.info(f"Processing {len(df)} rows")
|
||||||
|
|
||||||
|
# Example: Add a new column based on existing data
|
||||||
|
result_df = df.copy()
|
||||||
|
result_df['processed'] = True
|
||||||
|
result_df['example_value'] = result_df['id'].apply(lambda x: f"{self.param1}_{x}")
|
||||||
|
|
||||||
|
logger.info("Processing complete")
|
||||||
|
return result_df
|
||||||
|
|
||||||
|
def _store_results(self, con: sqlite3.Connection, df: pd.DataFrame):
|
||||||
|
"""Store results back to the database.
|
||||||
|
|
||||||
|
This is optional - only needed if you want to persist results.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
con: Database connection
|
||||||
|
df: Processed dataframe to store
|
||||||
|
"""
|
||||||
|
if df.empty:
|
||||||
|
logger.info("No results to store")
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.info(f"Storing {len(df)} results")
|
||||||
|
|
||||||
|
# Example: Store to database
|
||||||
|
# df[['post_id', 'result_value']].to_sql(
|
||||||
|
# 'example_results',
|
||||||
|
# con,
|
||||||
|
# if_exists='append',
|
||||||
|
# index=False
|
||||||
|
# )
|
||||||
|
|
||||||
|
con.commit()
|
||||||
|
logger.info("Results stored successfully")
|
||||||
|
|
||||||
|
def run(self, con: sqlite3.Connection, context: TransformContext) -> TransformContext:
|
||||||
|
"""Execute the transformation.
|
||||||
|
|
||||||
|
This is the main entry point called by the pipeline.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
con: SQLite database connection
|
||||||
|
context: TransformContext containing input dataframe
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
TransformContext with processed dataframe
|
||||||
|
"""
|
||||||
|
logger.info("Starting ExampleNode transformation")
|
||||||
|
|
||||||
|
# Get input dataframe from context
|
||||||
|
input_df = context.get_dataframe()
|
||||||
|
|
||||||
|
# Validate input
|
||||||
|
if input_df.empty:
|
||||||
|
logger.warning("Empty dataframe provided to ExampleNode")
|
||||||
|
return context
|
||||||
|
|
||||||
|
# Create any necessary tables
|
||||||
|
self._create_tables(con)
|
||||||
|
|
||||||
|
# Process the data
|
||||||
|
result_df = self._process_data(input_df)
|
||||||
|
|
||||||
|
# Store results (optional)
|
||||||
|
self._store_results(con, result_df)
|
||||||
|
|
||||||
|
logger.info("ExampleNode transformation complete")
|
||||||
|
|
||||||
|
# Return new context with results
|
||||||
|
return TransformContext(result_df)
|
||||||
16
transform/ensure_gliner_model.sh
Normal file
16
transform/ensure_gliner_model.sh
Normal file
|
|
@ -0,0 +1,16 @@
|
||||||
|
#!/usr/bin/env bash
|
||||||
|
set -euo pipefail
|
||||||
|
|
||||||
|
if [ -d "$GLINER_MODEL_PATH" ] && find "$GLINER_MODEL_PATH" -type f | grep -q .; then
|
||||||
|
echo "GLiNER model already present at $GLINER_MODEL_PATH"
|
||||||
|
exit 0
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo "Downloading GLiNER model to $GLINER_MODEL_PATH"
|
||||||
|
mkdir -p "$GLINER_MODEL_PATH"
|
||||||
|
curl -sL "https://huggingface.co/api/models/${GLINER_MODEL_ID}" | jq -r '.siblings[].rfilename' | while read -r file; do
|
||||||
|
target="${GLINER_MODEL_PATH}/${file}"
|
||||||
|
mkdir -p "$(dirname "$target")"
|
||||||
|
echo "Downloading ${file}"
|
||||||
|
curl -sL "https://huggingface.co/${GLINER_MODEL_ID}/resolve/main/${file}" -o "$target"
|
||||||
|
done
|
||||||
16
transform/ensure_minilm_model.sh
Normal file
16
transform/ensure_minilm_model.sh
Normal file
|
|
@ -0,0 +1,16 @@
|
||||||
|
#!/usr/bin/env bash
|
||||||
|
set -euo pipefail
|
||||||
|
|
||||||
|
if [ -d "$MINILM_MODEL_PATH" ] && find "$MINILM_MODEL_PATH" -type f | grep -q .; then
|
||||||
|
echo "MiniLM model already present at $MINILM_MODEL_PATH"
|
||||||
|
exit 0
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo "Downloading MiniLM model to $MINILM_MODEL_PATH"
|
||||||
|
mkdir -p "$MINILM_MODEL_PATH"
|
||||||
|
curl -sL "https://huggingface.co/api/models/${MINILM_MODEL_ID}" | jq -r '.siblings[].rfilename' | while read -r file; do
|
||||||
|
target="${MINILM_MODEL_PATH}/${file}"
|
||||||
|
mkdir -p "$(dirname "$target")"
|
||||||
|
echo "Downloading ${file}"
|
||||||
|
curl -sL "https://huggingface.co/${MINILM_MODEL_ID}/resolve/main/${file}" -o "$target"
|
||||||
|
done
|
||||||
10
transform/entrypoint.sh
Normal file
10
transform/entrypoint.sh
Normal file
|
|
@ -0,0 +1,10 @@
|
||||||
|
#!/usr/bin/env bash
|
||||||
|
set -euo pipefail
|
||||||
|
|
||||||
|
# Run model download with output to stdout/stderr
|
||||||
|
/usr/local/bin/ensure_minilm_model.sh 2>&1
|
||||||
|
/usr/local/bin/ensure_gliner_model.sh 2>&1
|
||||||
|
|
||||||
|
# Start cron in foreground with logging
|
||||||
|
exec cron -f -L 2
|
||||||
|
# cd /app && /usr/local/bin/python main.py >> /proc/1/fd/1 2>&1
|
||||||
170
transform/example_node.py
Normal file
170
transform/example_node.py
Normal file
|
|
@ -0,0 +1,170 @@
|
||||||
|
"""Example template node for the transform pipeline.
|
||||||
|
|
||||||
|
This is a template showing how to create new transform nodes.
|
||||||
|
Copy this file and modify it for your specific transformation needs.
|
||||||
|
"""
|
||||||
|
from pipeline import TransformContext
|
||||||
|
from transform_node import TransformNode
|
||||||
|
import sqlite3
|
||||||
|
import pandas as pd
|
||||||
|
import logging
|
||||||
|
|
||||||
|
logger = logging.getLogger("knack-transform")
|
||||||
|
|
||||||
|
|
||||||
|
class ExampleNode(TransformNode):
|
||||||
|
"""Example transform node template.
|
||||||
|
|
||||||
|
This node demonstrates the basic structure for creating
|
||||||
|
new transformation nodes in the pipeline.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
param1: str = "default_value",
|
||||||
|
param2: int = 42,
|
||||||
|
device: str = "cpu"):
|
||||||
|
"""Initialize the ExampleNode.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
param1: Example string parameter
|
||||||
|
param2: Example integer parameter
|
||||||
|
device: Device to use for computations ('cpu', 'cuda', 'mps')
|
||||||
|
"""
|
||||||
|
self.param1 = param1
|
||||||
|
self.param2 = param2
|
||||||
|
self.device = device
|
||||||
|
logger.info(f"Initialized ExampleNode with param1={param1}, param2={param2}")
|
||||||
|
|
||||||
|
def _create_tables(self, con: sqlite3.Connection):
|
||||||
|
"""Create any necessary tables in the database.
|
||||||
|
|
||||||
|
This is optional - only needed if your node creates new tables.
|
||||||
|
"""
|
||||||
|
logger.info("Creating example tables")
|
||||||
|
|
||||||
|
con.execute("""
|
||||||
|
CREATE TABLE IF NOT EXISTS example_results (
|
||||||
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
|
post_id INTEGER,
|
||||||
|
result_value TEXT,
|
||||||
|
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||||
|
FOREIGN KEY (post_id) REFERENCES posts(id)
|
||||||
|
)
|
||||||
|
""")
|
||||||
|
|
||||||
|
con.commit()
|
||||||
|
|
||||||
|
def _process_data(self, df: pd.DataFrame) -> pd.DataFrame:
|
||||||
|
"""Process the input dataframe.
|
||||||
|
|
||||||
|
This is where your main transformation logic goes.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
df: Input dataframe from context
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Processed dataframe
|
||||||
|
"""
|
||||||
|
logger.info(f"Processing {len(df)} rows")
|
||||||
|
|
||||||
|
# Example: Add a new column based on existing data
|
||||||
|
result_df = df.copy()
|
||||||
|
result_df['processed'] = True
|
||||||
|
result_df['example_value'] = result_df['id'].apply(lambda x: f"{self.param1}_{x}")
|
||||||
|
|
||||||
|
logger.info("Processing complete")
|
||||||
|
return result_df
|
||||||
|
|
||||||
|
def _store_results(self, con: sqlite3.Connection, df: pd.DataFrame):
|
||||||
|
"""Store results back to the database.
|
||||||
|
|
||||||
|
This is optional - only needed if you want to persist results.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
con: Database connection
|
||||||
|
df: Processed dataframe to store
|
||||||
|
"""
|
||||||
|
if df.empty:
|
||||||
|
logger.info("No results to store")
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.info(f"Storing {len(df)} results")
|
||||||
|
|
||||||
|
# Example: Store to database
|
||||||
|
# df[['post_id', 'result_value']].to_sql(
|
||||||
|
# 'example_results',
|
||||||
|
# con,
|
||||||
|
# if_exists='append',
|
||||||
|
# index=False
|
||||||
|
# )
|
||||||
|
|
||||||
|
con.commit()
|
||||||
|
logger.info("Results stored successfully")
|
||||||
|
|
||||||
|
def run(self, con: sqlite3.Connection, context: TransformContext) -> TransformContext:
|
||||||
|
"""Execute the transformation.
|
||||||
|
|
||||||
|
This is the main entry point called by the pipeline.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
con: SQLite database connection
|
||||||
|
context: TransformContext containing input dataframe
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
TransformContext with processed dataframe
|
||||||
|
"""
|
||||||
|
logger.info("Starting ExampleNode transformation")
|
||||||
|
|
||||||
|
# Get input dataframe from context
|
||||||
|
input_df = context.get_dataframe()
|
||||||
|
|
||||||
|
# Validate input
|
||||||
|
if input_df.empty:
|
||||||
|
logger.warning("Empty dataframe provided to ExampleNode")
|
||||||
|
return context
|
||||||
|
|
||||||
|
# Create any necessary tables
|
||||||
|
self._create_tables(con)
|
||||||
|
|
||||||
|
# Process the data
|
||||||
|
result_df = self._process_data(input_df)
|
||||||
|
|
||||||
|
# Store results (optional)
|
||||||
|
self._store_results(con, result_df)
|
||||||
|
|
||||||
|
logger.info("ExampleNode transformation complete")
|
||||||
|
|
||||||
|
# Return new context with results
|
||||||
|
return TransformContext(result_df)
|
||||||
|
|
||||||
|
|
||||||
|
# Example usage:
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# This allows you to test your node independently
|
||||||
|
import os
|
||||||
|
os.chdir('/Users/linussilberstein/Documents/Knack-Scraper/transform')
|
||||||
|
|
||||||
|
from pipeline import TransformContext
|
||||||
|
import sqlite3
|
||||||
|
|
||||||
|
# Create test data
|
||||||
|
test_df = pd.DataFrame({
|
||||||
|
'id': [1, 2, 3],
|
||||||
|
'author': ['Test Author 1', 'Test Author 2', 'Test Author 3']
|
||||||
|
})
|
||||||
|
|
||||||
|
# Create test database connection
|
||||||
|
test_con = sqlite3.connect(':memory:')
|
||||||
|
|
||||||
|
# Create and run node
|
||||||
|
node = ExampleNode(param1="test", param2=100)
|
||||||
|
context = TransformContext(test_df)
|
||||||
|
result_context = node.run(test_con, context)
|
||||||
|
|
||||||
|
# Check results
|
||||||
|
result_df = result_context.get_dataframe()
|
||||||
|
print("\nResult DataFrame:")
|
||||||
|
print(result_df)
|
||||||
|
|
||||||
|
test_con.close()
|
||||||
|
print("\n✓ ExampleNode test completed successfully!")
|
||||||
102
transform/main.py
Normal file
102
transform/main.py
Normal file
|
|
@ -0,0 +1,102 @@
|
||||||
|
#! python3
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import sqlite3
|
||||||
|
import sys
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
|
load_dotenv()
|
||||||
|
|
||||||
|
if (os.environ.get('LOGGING_LEVEL', 'INFO') == 'INFO'):
|
||||||
|
logging_level = logging.INFO
|
||||||
|
else:
|
||||||
|
logging_level = logging.DEBUG
|
||||||
|
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging_level,
|
||||||
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
||||||
|
handlers=[
|
||||||
|
logging.FileHandler("app.log"),
|
||||||
|
logging.StreamHandler(sys.stdout)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
logger = logging.getLogger("knack-transform")
|
||||||
|
|
||||||
|
|
||||||
|
def setup_database_connection():
|
||||||
|
"""Create connection to the SQLite database."""
|
||||||
|
db_path = os.environ.get('DB_PATH', '/data/knack.sqlite')
|
||||||
|
logger.info(f"Connecting to database: {db_path}")
|
||||||
|
return sqlite3.connect(db_path)
|
||||||
|
|
||||||
|
|
||||||
|
def table_exists(tablename: str, con: sqlite3.Connection):
|
||||||
|
"""Check if a table exists in the database."""
|
||||||
|
query = "SELECT 1 FROM sqlite_master WHERE type='table' AND name=? LIMIT 1"
|
||||||
|
return len(con.execute(query, [tablename]).fetchall()) > 0
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
"""Main entry point for the transform pipeline."""
|
||||||
|
logger.info("Starting transform pipeline")
|
||||||
|
|
||||||
|
try:
|
||||||
|
con = setup_database_connection()
|
||||||
|
logger.info("Database connection established")
|
||||||
|
|
||||||
|
# Check if posts table exists
|
||||||
|
if not table_exists('posts', con):
|
||||||
|
logger.warning("Posts table does not exist yet. Please run the scraper first to populate the database.")
|
||||||
|
logger.info("Transform pipeline skipped - no data available")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Import transform components
|
||||||
|
from pipeline import create_default_pipeline, TransformContext
|
||||||
|
import pandas as pd
|
||||||
|
|
||||||
|
# Load posts data
|
||||||
|
logger.info("Loading posts from database")
|
||||||
|
sql = "SELECT * FROM posts WHERE author IS NOT NULL AND (is_cleaned IS NULL OR is_cleaned = 0)"
|
||||||
|
# MAX_CLEANED_POSTS = os.environ.get("MAX_CLEANED_POSTS", 100)
|
||||||
|
df = pd.read_sql(sql, con)
|
||||||
|
logger.info(f"Loaded {len(df)} uncleaned posts with authors")
|
||||||
|
|
||||||
|
if df.empty:
|
||||||
|
logger.info("No uncleaned posts found. Transform pipeline skipped.")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Create initial context
|
||||||
|
context = TransformContext(df)
|
||||||
|
|
||||||
|
# Create and run parallel pipeline
|
||||||
|
device = os.environ.get('COMPUTE_DEVICE', 'cpu')
|
||||||
|
max_workers = int(os.environ.get('MAX_WORKERS', 4))
|
||||||
|
|
||||||
|
pipeline = create_default_pipeline(device=device, max_workers=max_workers)
|
||||||
|
results = pipeline.run(
|
||||||
|
db_path=os.environ.get('DB_PATH', '/data/knack.sqlite'),
|
||||||
|
initial_context=context,
|
||||||
|
fail_fast=False # Continue even if some nodes fail
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"Pipeline completed. Processed {len(results)} node(s)")
|
||||||
|
|
||||||
|
# Mark all processed posts as cleaned
|
||||||
|
post_ids = df['id'].tolist()
|
||||||
|
if post_ids:
|
||||||
|
placeholders = ','.join('?' * len(post_ids))
|
||||||
|
con.execute(f"UPDATE posts SET is_cleaned = 1 WHERE id IN ({placeholders})", post_ids)
|
||||||
|
con.commit()
|
||||||
|
logger.info(f"Marked {len(post_ids)} posts as cleaned")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in transform pipeline: {e}", exc_info=True)
|
||||||
|
sys.exit(1)
|
||||||
|
finally:
|
||||||
|
if 'con' in locals():
|
||||||
|
con.close()
|
||||||
|
logger.info("Database connection closed")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
266
transform/pipeline.py
Normal file
266
transform/pipeline.py
Normal file
|
|
@ -0,0 +1,266 @@
|
||||||
|
"""Parallel pipeline orchestration for transform nodes."""
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import sqlite3
|
||||||
|
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor, as_completed
|
||||||
|
from typing import List, Dict, Optional
|
||||||
|
|
||||||
|
import pandas as pd
|
||||||
|
import multiprocessing as mp
|
||||||
|
|
||||||
|
logger = logging.getLogger("knack-transform")
|
||||||
|
|
||||||
|
class TransformContext:
|
||||||
|
"""Context object containing the dataframe for transformation."""
|
||||||
|
# Possibly add a dict for the context to give more Information
|
||||||
|
|
||||||
|
def __init__(self, df: pd.DataFrame):
|
||||||
|
self.df = df
|
||||||
|
|
||||||
|
def get_dataframe(self) -> pd.DataFrame:
|
||||||
|
"""Get the pandas dataframe from the context."""
|
||||||
|
return self.df
|
||||||
|
|
||||||
|
class NodeConfig:
|
||||||
|
"""Configuration for a transform node."""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
node_class: type,
|
||||||
|
node_kwargs: Dict = None,
|
||||||
|
dependencies: List[str] = None,
|
||||||
|
name: str = None):
|
||||||
|
"""Initialize node configuration.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
node_class: The TransformNode class to instantiate
|
||||||
|
node_kwargs: Keyword arguments to pass to node constructor
|
||||||
|
dependencies: List of node names that must complete before this one
|
||||||
|
name: Optional name for the node (defaults to class name)
|
||||||
|
"""
|
||||||
|
self.node_class = node_class
|
||||||
|
self.node_kwargs = node_kwargs or {}
|
||||||
|
self.dependencies = dependencies or []
|
||||||
|
self.name = name or node_class.__name__
|
||||||
|
|
||||||
|
class ParallelPipeline:
|
||||||
|
"""Pipeline for executing transform nodes in parallel where possible.
|
||||||
|
|
||||||
|
The pipeline analyzes dependencies between nodes and executes
|
||||||
|
independent nodes concurrently using multiprocessing or threading.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
max_workers: Optional[int] = None,
|
||||||
|
use_processes: bool = False):
|
||||||
|
"""Initialize the parallel pipeline.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
max_workers: Maximum number of parallel workers (defaults to CPU count)
|
||||||
|
use_processes: If True, use ProcessPoolExecutor; if False, use ThreadPoolExecutor
|
||||||
|
"""
|
||||||
|
self.max_workers = max_workers or mp.cpu_count()
|
||||||
|
self.use_processes = use_processes
|
||||||
|
self.nodes: Dict[str, NodeConfig] = {}
|
||||||
|
logger.info(f"Initialized ParallelPipeline with {self.max_workers} workers "
|
||||||
|
f"({'processes' if use_processes else 'threads'})")
|
||||||
|
|
||||||
|
def add_node(self, config: NodeConfig):
|
||||||
|
"""Add a node to the pipeline.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config: NodeConfig with node details and dependencies
|
||||||
|
"""
|
||||||
|
self.nodes[config.name] = config
|
||||||
|
logger.info(f"Added node '{config.name}' with dependencies: {config.dependencies}")
|
||||||
|
|
||||||
|
def _get_execution_stages(self) -> List[List[str]]:
|
||||||
|
"""Determine execution stages based on dependencies.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of stages, where each stage contains node names that can run in parallel
|
||||||
|
"""
|
||||||
|
stages = []
|
||||||
|
completed = set()
|
||||||
|
remaining = set(self.nodes.keys())
|
||||||
|
|
||||||
|
while remaining:
|
||||||
|
# Find nodes whose dependencies are all completed
|
||||||
|
ready = []
|
||||||
|
for node_name in remaining:
|
||||||
|
config = self.nodes[node_name]
|
||||||
|
if all(dep in completed for dep in config.dependencies):
|
||||||
|
ready.append(node_name)
|
||||||
|
|
||||||
|
if not ready:
|
||||||
|
# Circular dependency or missing dependency
|
||||||
|
raise ValueError(f"Cannot resolve dependencies. Remaining nodes: {remaining}")
|
||||||
|
|
||||||
|
stages.append(ready)
|
||||||
|
completed.update(ready)
|
||||||
|
remaining -= set(ready)
|
||||||
|
|
||||||
|
return stages
|
||||||
|
|
||||||
|
def _execute_node(self,
|
||||||
|
node_name: str,
|
||||||
|
db_path: str,
|
||||||
|
context: TransformContext) -> tuple:
|
||||||
|
"""Execute a single node.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
node_name: Name of the node to execute
|
||||||
|
db_path: Path to the SQLite database
|
||||||
|
context: TransformContext for the node
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (node_name, result_context, error)
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Create fresh database connection (not shared across processes/threads)
|
||||||
|
con = sqlite3.connect(db_path)
|
||||||
|
|
||||||
|
config = self.nodes[node_name]
|
||||||
|
node = config.node_class(**config.node_kwargs)
|
||||||
|
|
||||||
|
logger.info(f"Executing node: {node_name}")
|
||||||
|
result_context = node.run(con, context)
|
||||||
|
|
||||||
|
con.close()
|
||||||
|
logger.info(f"Node '{node_name}' completed successfully")
|
||||||
|
|
||||||
|
return node_name, result_context, None
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error executing node '{node_name}': {e}", exc_info=True)
|
||||||
|
return node_name, None, str(e)
|
||||||
|
|
||||||
|
def run(self,
|
||||||
|
db_path: str,
|
||||||
|
initial_context: TransformContext,
|
||||||
|
fail_fast: bool = False) -> Dict[str, TransformContext]:
|
||||||
|
"""Execute the pipeline.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db_path: Path to the SQLite database
|
||||||
|
initial_context: Initial TransformContext for the pipeline
|
||||||
|
fail_fast: If True, stop execution on first error
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict mapping node names to their output TransformContext
|
||||||
|
"""
|
||||||
|
logger.info("Starting parallel pipeline execution")
|
||||||
|
|
||||||
|
stages = self._get_execution_stages()
|
||||||
|
logger.info(f"Pipeline has {len(stages)} execution stage(s)")
|
||||||
|
|
||||||
|
results = {}
|
||||||
|
errors = []
|
||||||
|
|
||||||
|
ExecutorClass = ProcessPoolExecutor if self.use_processes else ThreadPoolExecutor
|
||||||
|
|
||||||
|
for stage_num, stage_nodes in enumerate(stages, 1):
|
||||||
|
logger.info(f"Stage {stage_num}/{len(stages)}: Executing {len(stage_nodes)} node(s) in parallel: {stage_nodes}")
|
||||||
|
|
||||||
|
# For nodes in this stage, use the context from their dependencies
|
||||||
|
# If multiple dependencies, we'll use the most recent one (or could merge)
|
||||||
|
stage_futures = {}
|
||||||
|
|
||||||
|
with ExecutorClass(max_workers=min(self.max_workers, len(stage_nodes))) as executor:
|
||||||
|
for node_name in stage_nodes:
|
||||||
|
config = self.nodes[node_name]
|
||||||
|
|
||||||
|
# Get context from dependencies (use the last dependency's output)
|
||||||
|
if config.dependencies:
|
||||||
|
context = results.get(config.dependencies[-1], initial_context)
|
||||||
|
else:
|
||||||
|
context = initial_context
|
||||||
|
|
||||||
|
future = executor.submit(self._execute_node, node_name, db_path, context)
|
||||||
|
stage_futures[future] = node_name
|
||||||
|
|
||||||
|
# Wait for all nodes in this stage to complete
|
||||||
|
for future in as_completed(stage_futures):
|
||||||
|
node_name = stage_futures[future]
|
||||||
|
name, result_context, error = future.result()
|
||||||
|
|
||||||
|
if error:
|
||||||
|
errors.append((name, error))
|
||||||
|
if fail_fast:
|
||||||
|
logger.error(f"Pipeline failed at node '{name}': {error}")
|
||||||
|
raise RuntimeError(f"Node '{name}' failed: {error}")
|
||||||
|
else:
|
||||||
|
results[name] = result_context
|
||||||
|
|
||||||
|
if errors:
|
||||||
|
logger.warning(f"Pipeline completed with {len(errors)} error(s)")
|
||||||
|
for name, error in errors:
|
||||||
|
logger.error(f" - {name}: {error}")
|
||||||
|
else:
|
||||||
|
logger.info("Pipeline completed successfully")
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
def create_default_pipeline(device: str = "cpu",
|
||||||
|
max_workers: Optional[int] = None) -> ParallelPipeline:
|
||||||
|
"""Create a pipeline with default transform nodes.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
device: Device to use for compute-intensive nodes ('cpu', 'cuda', 'mps')
|
||||||
|
max_workers: Maximum number of parallel workers
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Configured ParallelPipeline
|
||||||
|
"""
|
||||||
|
from author_node import NerAuthorNode, FuzzyAuthorNode
|
||||||
|
from embeddings_node import TextEmbeddingNode, UmapNode
|
||||||
|
|
||||||
|
pipeline = ParallelPipeline(max_workers=max_workers, use_processes=False)
|
||||||
|
|
||||||
|
# Add AuthorNode (no dependencies)
|
||||||
|
pipeline.add_node(NodeConfig(
|
||||||
|
node_class=NerAuthorNode,
|
||||||
|
node_kwargs={
|
||||||
|
'device': device,
|
||||||
|
'model_path': os.environ.get('GLINER_MODEL_PATH')
|
||||||
|
},
|
||||||
|
dependencies=[],
|
||||||
|
name='AuthorNode'
|
||||||
|
))
|
||||||
|
|
||||||
|
pipeline.add_node(NodeConfig(
|
||||||
|
node_class=FuzzyAuthorNode,
|
||||||
|
node_kwargs={
|
||||||
|
'max_l_dist': 1
|
||||||
|
},
|
||||||
|
dependencies=['AuthorNode'],
|
||||||
|
name='FuzzyAuthorNode'
|
||||||
|
))
|
||||||
|
|
||||||
|
pipeline.add_node(NodeConfig(
|
||||||
|
node_class=TextEmbeddingNode,
|
||||||
|
node_kwargs={
|
||||||
|
'device': device,
|
||||||
|
'model_path': os.environ.get('MINILM_MODEL_PATH')
|
||||||
|
},
|
||||||
|
dependencies=[],
|
||||||
|
name='TextEmbeddingNode'
|
||||||
|
))
|
||||||
|
|
||||||
|
pipeline.add_node(NodeConfig(
|
||||||
|
node_class=UmapNode,
|
||||||
|
node_kwargs={},
|
||||||
|
dependencies=['TextEmbeddingNode'],
|
||||||
|
name='UmapNode'
|
||||||
|
))
|
||||||
|
|
||||||
|
# TODO: Create Node to compute Text Embeddings and UMAP.
|
||||||
|
|
||||||
|
# pipeline.add_node(NodeConfig(
|
||||||
|
# node_class=UMAPNode,
|
||||||
|
# node_kwargs={'device': device},
|
||||||
|
# dependencies=['EmbeddingNode'], # Runs after EmbeddingNode
|
||||||
|
# name='UMAPNode'
|
||||||
|
# ))
|
||||||
|
|
||||||
|
return pipeline
|
||||||
7
transform/requirements.txt
Normal file
7
transform/requirements.txt
Normal file
|
|
@ -0,0 +1,7 @@
|
||||||
|
pandas
|
||||||
|
python-dotenv
|
||||||
|
gliner
|
||||||
|
torch
|
||||||
|
fuzzysearch
|
||||||
|
sentence_transformers
|
||||||
|
umap-learn
|
||||||
26
transform/transform_node.py
Normal file
26
transform/transform_node.py
Normal file
|
|
@ -0,0 +1,26 @@
|
||||||
|
"""Base transform node for data pipeline."""
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
import sqlite3
|
||||||
|
|
||||||
|
from pipeline import TransformContext
|
||||||
|
|
||||||
|
class TransformNode(ABC):
|
||||||
|
"""Abstract base class for transformation nodes.
|
||||||
|
|
||||||
|
Each transform node implements a single transformation step
|
||||||
|
that takes data from the database, transforms it, and
|
||||||
|
potentially writes results back.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def run(self, con: sqlite3.Connection, context: TransformContext) -> TransformContext:
|
||||||
|
"""Execute the transformation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
con: SQLite database connection
|
||||||
|
context: TransformContext containing the input dataframe
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
TransformContext with the transformed dataframe
|
||||||
|
"""
|
||||||
|
pass
|
||||||
Loading…
Add table
Add a link
Reference in a new issue