diff --git a/app.py b/app.py index 4c0fba7..3546e3e 100644 --- a/app.py +++ b/app.py @@ -2,9 +2,13 @@ from pathlib import Path from typing import List from shiny import App, ui, Inputs, Outputs, Session from shiny.types import NavSetArg -from src import mod_welcome, mod_searchable +from src import mod_welcome, mod_searchable, mod_topics from src.util import load_html_str_from_file + +# by importing this module, the tweets are loaded into the tweet_store variable at program start +import src.data_loader + import os @@ -15,7 +19,9 @@ def nav_controls() -> List[NavSetArg]: return [ ui.nav(ui.h5("Intro"), mod_welcome.welcome_ui("intro"), value="intro"), ui.nav(ui.h5("Analyse"), "Analyse"), - ui.nav(ui.h5("Suchmaschine"), mod_searchable.searchable_ui("search_engine"), value="search_engine"), + ui.nav(ui.h5("Suchmaschine"), mod_searchable.searchable_ui( + "search_engine"), value="search_engine"), + ui.nav(ui.h5("Topics"), mod_topics.topics_ui("topics"), value="topics"), ui.nav_control( ui.a( ui.h5("AG-Link"), @@ -60,6 +66,7 @@ app_ui = ui.page_navbar( def server(input: Inputs, output: Outputs, session: Session): mod_welcome.welcome_server("intro") mod_searchable.searchable_server("search_engine") + mod_topics.topics_server("topics") static_dir = Path(__file__).parent / "www" diff --git a/src/data_loader.py b/src/data_loader.py new file mode 100644 index 0000000..1a19cc3 --- /dev/null +++ b/src/data_loader.py @@ -0,0 +1,26 @@ +import pandas as pd +from sklearn.feature_extraction.text import TfidfVectorizer +import pickle + + +class TweetStore(): + + tweets_path: str = "data/tweets_all_combined.csv" + tfidf_matrix_path = "data/tfidf_matrix.pckl" + tfidf_vectorizer_path = "data/tfidf_vectorizer.pckl" + + def __init__(self): + print("Loading tweets from dataframe") + self.tweets = pd.read_csv(self.tweets_path) + + print("Loading tfidf from file") + self.tfidf_matrix = None + with open(self.tfidf_matrix_path, "rb") as f: + self.tfidf_matrix = pickle.load(f) + + self.tfidf_vectorizer: TfidfVectorizer = None + with open(self.tfidf_vectorizer_path, "rb") as f: + self.tfidf_vectorizer = pickle.load(f) + + +tweet_store = TweetStore() diff --git a/src/mod_searchable.py b/src/mod_searchable.py index 3ba4c34..318ba1f 100644 --- a/src/mod_searchable.py +++ b/src/mod_searchable.py @@ -1,15 +1,12 @@ from shiny import module, ui, render, Inputs, Outputs, Session -from sklearn.feature_extraction.text import TfidfVectorizer from sklearn.metrics.pairwise import cosine_similarity import pandas as pd import numpy as np -import pickle import re -tfidf_matrix_path = "data/tfidf_matrix.pckl" -tfidf_vectorizer_path = "data/tfidf_vectorizer.pckl" +from src.data_loader import tweet_store + relevance_score_path = "data/tweet_relevance.json" -tweets_path = "data/tweets_all_combined.csv" reply_html_svg = '' retweet_html_svg = '' @@ -31,31 +28,21 @@ def replace_hastag(match): return f'{hashtag}' -print("Loading data from storage") -tweets = pd.read_csv(tweets_path) relevance_score = pd.read_csv(relevance_score_path) -tfidf_matrix = None -with open(tfidf_matrix_path, "rb") as f: - tfidf_matrix = pickle.load(f) -tfidf_vectorizer: TfidfVectorizer = None -with open(tfidf_vectorizer_path, "rb") as f: - tfidf_vectorizer = pickle.load(f) - - -tweets["relevance_score"] = relevance_score["relevance_score"] -tweets = tweets.drop(["user_id", "measured_at"], axis=1) +tweet_store.tweets["relevance_score"] = relevance_score["relevance_score"] +tweet_store.tweets = tweet_store.tweets.drop(["user_id", "measured_at"], axis=1) def search_query(query: str, limit: int = 5, sorting_method: str = "score") -> (pd.DataFrame, int): - query_vec = tfidf_vectorizer.transform([query]) - similarity = cosine_similarity(query_vec, tfidf_matrix).flatten() + query_vec = tweet_store.tfidf_vectorizer.transform([query]) + similarity = cosine_similarity(query_vec, tweet_store.tfidf_matrix).flatten() filtered = np.where(similarity != 0)[0] indices = np.argsort(-similarity[filtered]) correct_indices = filtered[indices] - result = tweets.iloc[correct_indices] + result = tweet_store.tweets.iloc[correct_indices] if not len(result): return None, 0 @@ -81,10 +68,13 @@ def searchable_ui(): ui.h2("Tweet Suchmaschine"), ui.HTML("