From 7a8e01f9d185b9a96b99d4b33562dc4b97928fd2 Mon Sep 17 00:00:00 2001 From: procrastimax Date: Wed, 26 Jul 2023 20:56:27 +0200 Subject: [PATCH] WIP: adds site for displaying topics --- app.py | 11 +++++++++-- src/data_loader.py | 26 ++++++++++++++++++++++++++ src/mod_searchable.py | 39 +++++++++++++++------------------------ src/mod_topics.py | 37 +++++++++++++++++++++++++++++++++++++ 4 files changed, 87 insertions(+), 26 deletions(-) create mode 100644 src/data_loader.py create mode 100644 src/mod_topics.py diff --git a/app.py b/app.py index 4c0fba7..3546e3e 100644 --- a/app.py +++ b/app.py @@ -2,9 +2,13 @@ from pathlib import Path from typing import List from shiny import App, ui, Inputs, Outputs, Session from shiny.types import NavSetArg -from src import mod_welcome, mod_searchable +from src import mod_welcome, mod_searchable, mod_topics from src.util import load_html_str_from_file + +# by importing this module, the tweets are loaded into the tweet_store variable at program start +import src.data_loader + import os @@ -15,7 +19,9 @@ def nav_controls() -> List[NavSetArg]: return [ ui.nav(ui.h5("Intro"), mod_welcome.welcome_ui("intro"), value="intro"), ui.nav(ui.h5("Analyse"), "Analyse"), - ui.nav(ui.h5("Suchmaschine"), mod_searchable.searchable_ui("search_engine"), value="search_engine"), + ui.nav(ui.h5("Suchmaschine"), mod_searchable.searchable_ui( + "search_engine"), value="search_engine"), + ui.nav(ui.h5("Topics"), mod_topics.topics_ui("topics"), value="topics"), ui.nav_control( ui.a( ui.h5("AG-Link"), @@ -60,6 +66,7 @@ app_ui = ui.page_navbar( def server(input: Inputs, output: Outputs, session: Session): mod_welcome.welcome_server("intro") mod_searchable.searchable_server("search_engine") + mod_topics.topics_server("topics") static_dir = Path(__file__).parent / "www" diff --git a/src/data_loader.py b/src/data_loader.py new file mode 100644 index 0000000..1a19cc3 --- /dev/null +++ b/src/data_loader.py @@ -0,0 +1,26 @@ +import pandas as pd +from sklearn.feature_extraction.text import TfidfVectorizer +import pickle + + +class TweetStore(): + + tweets_path: str = "data/tweets_all_combined.csv" + tfidf_matrix_path = "data/tfidf_matrix.pckl" + tfidf_vectorizer_path = "data/tfidf_vectorizer.pckl" + + def __init__(self): + print("Loading tweets from dataframe") + self.tweets = pd.read_csv(self.tweets_path) + + print("Loading tfidf from file") + self.tfidf_matrix = None + with open(self.tfidf_matrix_path, "rb") as f: + self.tfidf_matrix = pickle.load(f) + + self.tfidf_vectorizer: TfidfVectorizer = None + with open(self.tfidf_vectorizer_path, "rb") as f: + self.tfidf_vectorizer = pickle.load(f) + + +tweet_store = TweetStore() diff --git a/src/mod_searchable.py b/src/mod_searchable.py index 3ba4c34..318ba1f 100644 --- a/src/mod_searchable.py +++ b/src/mod_searchable.py @@ -1,15 +1,12 @@ from shiny import module, ui, render, Inputs, Outputs, Session -from sklearn.feature_extraction.text import TfidfVectorizer from sklearn.metrics.pairwise import cosine_similarity import pandas as pd import numpy as np -import pickle import re -tfidf_matrix_path = "data/tfidf_matrix.pckl" -tfidf_vectorizer_path = "data/tfidf_vectorizer.pckl" +from src.data_loader import tweet_store + relevance_score_path = "data/tweet_relevance.json" -tweets_path = "data/tweets_all_combined.csv" reply_html_svg = '' retweet_html_svg = '' @@ -31,31 +28,21 @@ def replace_hastag(match): return f'{hashtag}' -print("Loading data from storage") -tweets = pd.read_csv(tweets_path) relevance_score = pd.read_csv(relevance_score_path) -tfidf_matrix = None -with open(tfidf_matrix_path, "rb") as f: - tfidf_matrix = pickle.load(f) -tfidf_vectorizer: TfidfVectorizer = None -with open(tfidf_vectorizer_path, "rb") as f: - tfidf_vectorizer = pickle.load(f) - - -tweets["relevance_score"] = relevance_score["relevance_score"] -tweets = tweets.drop(["user_id", "measured_at"], axis=1) +tweet_store.tweets["relevance_score"] = relevance_score["relevance_score"] +tweet_store.tweets = tweet_store.tweets.drop(["user_id", "measured_at"], axis=1) def search_query(query: str, limit: int = 5, sorting_method: str = "score") -> (pd.DataFrame, int): - query_vec = tfidf_vectorizer.transform([query]) - similarity = cosine_similarity(query_vec, tfidf_matrix).flatten() + query_vec = tweet_store.tfidf_vectorizer.transform([query]) + similarity = cosine_similarity(query_vec, tweet_store.tfidf_matrix).flatten() filtered = np.where(similarity != 0)[0] indices = np.argsort(-similarity[filtered]) correct_indices = filtered[indices] - result = tweets.iloc[correct_indices] + result = tweet_store.tweets.iloc[correct_indices] if not len(result): return None, 0 @@ -81,10 +68,13 @@ def searchable_ui(): ui.h2("Tweet Suchmaschine"), ui.HTML("
"), ui.row( - ui.column(6, ui.input_text("search_input", "Suche", placeholder="Gib Suchterm ein", value="Leipzig", width="100%")), + ui.column(6, ui.input_text("search_input", "Suche", + placeholder="Gib Suchterm ein", value="Leipzig", width="100%")), ui.column(3, - ui.input_select("sorting_method", "Sortierung", {"score": "Relevanz", "date_new": "Neuste Zuerst", "date_old": "Älteste Zuerst"}, selected="score", selectize=True, width="12em"), - ui.input_select("tweet_count", "Ergebnisse", {"5": "5", "20": "20", "50": "50", "all": "alle"}, selected="5", selectize=True, width="12em"), + ui.input_select("sorting_method", "Sortierung", { + "score": "Relevanz", "date_new": "Neuste Zuerst", "date_old": "Älteste Zuerst"}, selected="score", selectize=True, width="12em"), + ui.input_select("tweet_count", "Ergebnisse", { + "5": "5", "20": "20", "50": "50", "all": "alle"}, selected="5", selectize=True, width="12em"), style="display: flex; flex-direction: column; align-items: center; justify-content: center;"), style="justify-content:space-between;" @@ -138,7 +128,8 @@ def searchable_server(input: Inputs, output: Outputs, session: Session): ui.row( ui.column(6, ui.HTML( f"{user_name}@{user_handle}"), style=style + "padding-top: 1.5em; "), - ui.column(6, ui.p(f"{row['created_at']}"), style=style + "padding-top: 1.5em;"), + ui.column(6, ui.p(f"{row['created_at']}"), + style=style + "padding-top: 1.5em;"), ), ui.row( ui.column(12, ui.HTML("
"), diff --git a/src/mod_topics.py b/src/mod_topics.py new file mode 100644 index 0000000..eee6223 --- /dev/null +++ b/src/mod_topics.py @@ -0,0 +1,37 @@ +from shiny import module, ui, render, Inputs, Outputs, Session + +from sklearn.decomposition import NMF +from src.data_loader import tweet_store + +classes = 10 + +# Fit the NMF model +nmf = NMF( + n_components=classes, + random_state=42, + init=None, + beta_loss="frobenius", + alpha_W=0.0, + alpha_H="same", + l1_ratio=0.0, +).fit(tweet_store.tfidf_matrix) + + +# TODO: dont do this live -> load the feature_names and values from a pre-calculated list for each day +tfidf_feature_names = tweet_store.tfidf_vectorizer.get_feature_names_out() +print(tfidf_feature_names) + + +@ module.ui +def topics_ui(): + return ui.div( + ui.h2("Tweet Topics"), + ) + + +@ module.server +def topics_server(input: Inputs, output: Outputs, session: Session): + @ output + @ render.ui + def searchable_tweet_ui(): + pass