104 lines
5 KiB
Python
104 lines
5 KiB
Python
from shiny import module, ui, render, Inputs, Outputs, Session
|
||
from sklearn.feature_extraction.text import TfidfVectorizer
|
||
from sklearn.metrics.pairwise import cosine_similarity
|
||
import pandas as pd
|
||
import numpy as np
|
||
import pickle
|
||
|
||
tfidf_matrix_path = "data/tfidf_matrix.pckl"
|
||
tfidf_vectorizer_path = "data/tfidf_vectorizer.pckl"
|
||
relevance_score_path = "data/tweet_relevance.json"
|
||
tweets_path = "data/tweets_all_combined.csv"
|
||
|
||
reply_html_svg = '<svg width="18px" height="18px" viewBox="0 0 24 24" aria-hidden="true"><g><path d="M1.751 10c0-4.42 3.584-8 8.005-8h4.366c4.49 0 8.129 3.64 8.129 8.13 0 2.96-1.607 5.68-4.196 7.11l-8.054 4.46v-3.69h-.067c-4.49.1-8.183-3.51-8.183-8.01zm8.005-6c-3.317 0-6.005 2.69-6.005 6 0 3.37 2.77 6.08 6.138 6.01l.351-.01h1.761v2.3l5.087-2.81c1.951-1.08 3.163-3.13 3.163-5.36 0-3.39-2.744-6.13-6.129-6.13H9.756z"></path></g></svg>'
|
||
retweet_html_svg = '<svg width="18px" height="18px" viewBox="0 0 24 24" aria-hidden="true"><g><path d="M4.5 3.88l4.432 4.14-1.364 1.46L5.5 7.55V16c0 1.1.896 2 2 2H13v2H7.5c-2.209 0-4-1.79-4-4V7.55L1.432 9.48.068 8.02 4.5 3.88zM16.5 6H11V4h5.5c2.209 0 4 1.79 4 4v8.45l2.068-1.93 1.364 1.46-4.432 4.14-4.432-4.14 1.364-1.46 2.068 1.93V8c0-1.1-.896-2-2-2z"></path></g></svg>'
|
||
like_html_svg = '<svg width="18px" height="18px" viewBox="0 0 24 24" aria-hidden="true"><g><path d="M16.697 5.5c-1.222-.06-2.679.51-3.89 2.16l-.805 1.09-.806-1.09C9.984 6.01 8.526 5.44 7.304 5.5c-1.243.07-2.349.78-2.91 1.91-.552 1.12-.633 2.78.479 4.82 1.074 1.97 3.257 4.27 7.129 6.61 3.87-2.34 6.052-4.64 7.126-6.61 1.111-2.04 1.03-3.7.477-4.82-.561-1.13-1.666-1.84-2.908-1.91zm4.187 7.69c-1.351 2.48-4.001 5.12-8.379 7.67l-.503.3-.504-.3c-4.379-2.55-7.029-5.19-8.382-7.67-1.36-2.5-1.41-4.86-.514-6.67.887-1.79 2.647-2.91 4.601-3.01 1.651-.09 3.368.56 4.798 2.01 1.429-1.45 3.146-2.1 4.796-2.01 1.954.1 3.714 1.22 4.601 3.01.896 1.81.846 4.17-.514 6.67z"></path></g></svg>'
|
||
|
||
|
||
print("Loading data from storage")
|
||
tweets = pd.read_csv(tweets_path)
|
||
relevance_score = pd.read_csv(relevance_score_path)
|
||
|
||
tfidf_matrix = None
|
||
with open(tfidf_matrix_path, "rb") as f:
|
||
tfidf_matrix = pickle.load(f)
|
||
|
||
tfidf_vectorizer: TfidfVectorizer = None
|
||
with open(tfidf_vectorizer_path, "rb") as f:
|
||
tfidf_vectorizer = pickle.load(f)
|
||
|
||
|
||
tweets["relevance_score"] = relevance_score["relevance_score"]
|
||
tweets = tweets.drop(["user_id", "measured_at", "tweet_id"], axis=1)
|
||
|
||
|
||
def search_query(query: str, limit: int = 5) -> pd.DataFrame:
|
||
query_vec = tfidf_vectorizer.transform([query])
|
||
similarity = cosine_similarity(query_vec, tfidf_matrix).flatten()
|
||
|
||
filtered = np.where(similarity != 0)[0]
|
||
indices = np.argsort(-similarity[filtered])
|
||
correct_indices = filtered[indices]
|
||
result = tweets.iloc[correct_indices]
|
||
|
||
if not len(result):
|
||
return None
|
||
|
||
overall = result['relevance_score'] * similarity[correct_indices]
|
||
return result.loc[overall.sort_values(ascending=False).index].head(limit)
|
||
|
||
|
||
@module.ui
|
||
def searchable_ui():
|
||
return ui.div(
|
||
ui.h2("Tweet Suchmaschine"),
|
||
ui.input_text("search_input", "Suche:", placeholder="Gebe Suchterm ein", value="Leipzig"),
|
||
ui.HTML("<br>"),
|
||
ui.output_ui(id="searchable_tweet_ui"),
|
||
)
|
||
|
||
|
||
@ module.server
|
||
def searchable_server(input: Inputs, output: Outputs, session: Session):
|
||
@output
|
||
@render.ui
|
||
def searchable_tweet_ui():
|
||
|
||
query = input.search_input()
|
||
|
||
result_pd = search_query(query, 15)
|
||
|
||
style = "text-align: center; padding-top: 0.5em;"
|
||
tweet_ui = ui.page_fluid()
|
||
|
||
if result_pd is None:
|
||
return ui.div(
|
||
ui.h5("Keine Ergebnisse gefunden!")
|
||
)
|
||
|
||
# iterating over dataframe is bad but needed
|
||
for idx, row in result_pd.iterrows():
|
||
tweet_ui.append(
|
||
ui.div(
|
||
ui.row(
|
||
ui.column(9, ui.markdown(
|
||
f"**{row['user_name']}** *@{row['handle']}*"), style=style),
|
||
ui.column(3, ui.p(f"{row['created_at']}"), style=style),
|
||
),
|
||
ui.row(
|
||
ui.column(12, ui.HTML(str(row["tweet_text"]).replace(
|
||
"\\n", "<br>")), style=style + "font-size: 20px; padding:1em;"),
|
||
),
|
||
ui.row(
|
||
ui.column(3, ui.HTML(reply_html_svg), ui.p(
|
||
f"{row['reply_count']}"), style=style),
|
||
ui.column(3, ui.HTML(retweet_html_svg), ui.p(
|
||
f"{row['retweet_count']}"), style=style),
|
||
ui.column(3, ui.HTML(like_html_svg), ui.p(
|
||
f"{row['like_count']}"), style=style),
|
||
# quote_count: . Indicates approximately how many times this Tweet has been quoted by Twitter users. Example:
|
||
# TODO: use a nice svg for quote_count
|
||
ui.column(3, ui.p(f"Quote Count: {row['quote_count']}"), style=style),
|
||
), style="border: 1px solid #954; margin-bottom: 1em;"))
|
||
|
||
return tweet_ui
|