|
import os |
|
|
|
import streamlit as st |
|
from yasem import SpladeEmbedder |
|
|
|
if os.getenv("SPACE_ID"): |
|
USE_HF_SPACE = True |
|
os.environ["HF_HOME"] = "/data/.huggingface" |
|
os.environ["HF_DATASETS_CACHE"] = "/data/.huggingface" |
|
else: |
|
USE_HF_SPACE = False |
|
|
|
MODEL_NAME = os.environ.get("MODEL_NAME", "hotchpotch/japanese-splade-base-v1") |
|
|
|
|
|
@st.cache_resource |
|
def get_embedder(model_name: str = MODEL_NAME) -> SpladeEmbedder: |
|
embedder = SpladeEmbedder( |
|
model_name, |
|
) |
|
return embedder |
|
|
|
|
|
def get_token_values_sorted(input_text: str) -> list[tuple[float, str]]: |
|
embedder = get_embedder() |
|
embeddings = embedder.encode([input_text]) |
|
token_values = embedder.get_token_values(embeddings[0]) |
|
sorted_tokens = sorted(token_values.items(), key=lambda item: item[1], reverse=True) |
|
return [(value, key) for key, value in sorted_tokens] |
|
|
|
|
|
def main(): |
|
st.set_page_config( |
|
page_title="SPLADE 日本語 demo", |
|
layout="centered", |
|
initial_sidebar_state="auto", |
|
) |
|
|
|
st.title("SPLADE 日本語 demo") |
|
|
|
get_embedder() |
|
|
|
st.markdown(""" |
|
[hotchpotch/japanese-splade-base-v1](https://huggingface.co/hotchpotch/japanese-splade-base-v1)を使って、テキストからSPLADEのスパースベクトルに変換するデモです。 |
|
|
|
""") |
|
|
|
input_text = st.text_area("テキスト", height=200) |
|
|
|
if st.button("変換"): |
|
if input_text.strip(): |
|
with st.spinner("変換中..."): |
|
sorted_tokens = get_token_values_sorted(input_text) |
|
|
|
total_tokens = len(sorted_tokens) |
|
st.markdown(f"### 結果 (トークン数: {total_tokens})") |
|
if sorted_tokens: |
|
formatted_data = [ |
|
{"スコア": freq, "単語(vocab)": word} |
|
for freq, word in sorted_tokens |
|
] |
|
st.table(formatted_data) |
|
else: |
|
st.warning("入力テキストから有効な単語が見つかりませんでした。") |
|
else: |
|
st.warning("テキストを入力してください。") |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|