xxxxxx commited on
Commit
eb36871
·
1 Parent(s): dd50be3
Files changed (1) hide show
  1. app.py +39 -50
app.py CHANGED
@@ -1,8 +1,8 @@
1
  import streamlit as st
2
- from transformers import pipeline
3
  import json
4
  from onnxruntime import InferenceSession
5
- from transformers import AutoTokenizer
6
 
7
  # 设置页面配置
8
  st.set_page_config(page_title="中文垃圾信息分类器", page_icon="🚫", layout="wide")
@@ -13,9 +13,15 @@ def load_classifiers():
13
  hf_classifier = pipeline("text-classification", model="app-x/chinese_spam_classifier")
14
  onnx_session = InferenceSession("app-x/chinese_spam_classifier_onnx/model_optimized.onnx")
15
  tokenizer = AutoTokenizer.from_pretrained("app-x/chinese_spam_classifier_onnx")
16
- return hf_classifier, onnx_session, tokenizer
 
 
 
 
 
 
17
 
18
- hf_classifier, onnx_session, tokenizer = load_classifiers()
19
 
20
  st.title("🚫 中文垃圾信息分类器")
21
  st.write("使用两个模型进行中文文本的垃圾信息分类。")
@@ -23,6 +29,24 @@ st.write("使用两个模型进行中文文本的垃圾信息分类。")
23
  # 创建两列布局
24
  col1, col2 = st.columns([2, 1])
25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  with col1:
27
  # 创建文本输入框
28
  text_input = st.text_area("请输入中文文本:", height=200)
@@ -30,52 +54,17 @@ with col1:
30
  if st.button("分类", key="classify_button"):
31
  if text_input:
32
  with st.spinner("正在分析..."):
33
- # HuggingFace模型分类
34
- hf_result = hf_classifier(text_input)[0]
35
- hf_label = "垃圾信息" if hf_result["label"] == "spam" else "正常信息"
36
- hf_confidence = hf_result["score"]
37
-
38
- # ONNX模型分类
39
- inputs = tokenizer(text_input, return_tensors="np", padding=True, truncation=True)
40
- onnx_result = onnx_session.run(None, dict(inputs))
41
- onnx_label = "垃圾信息" if onnx_result[0][0][1] > onnx_result[0][0][0] else "正常信息"
42
- onnx_confidence = max(onnx_result[0][0])
43
-
44
- # 创建JSON格式的结果
45
- json_result = {
46
- "input_text": text_input,
47
- "huggingface_model": {
48
- "classification": hf_label,
49
- "confidence": hf_confidence,
50
- "raw_output": hf_result
51
- },
52
- "onnx_model": {
53
- "classification": onnx_label,
54
- "confidence": float(onnx_confidence),
55
- "raw_output": onnx_result[0].tolist()
56
- }
57
- }
58
-
59
- # 显示结果
60
- st.subheader("HuggingFace模型分类结果:")
61
- if hf_label == "垃圾信息":
62
- st.error(f"⚠️ {hf_label}")
63
- else:
64
- st.success(f"✅ {hf_label}")
65
- st.write(f"概率: {hf_confidence:.2f}")
66
- st.progress(hf_confidence)
67
-
68
- st.subheader("ONNX模型分类结果:")
69
- if onnx_label == "垃圾信息":
70
- st.error(f"⚠️ {onnx_label}")
71
- else:
72
- st.success(f"✅ {onnx_label}")
73
- st.write(f"概率: {onnx_confidence:.2f}")
74
- st.progress(float(onnx_confidence))
75
 
76
- # 显示JSON格式的结果
77
- st.subheader("JSON 格式的详细结果:")
78
- st.json(json_result)
 
 
 
 
 
 
79
  else:
80
  st.warning("请输入文本后再进行分类。")
81
 
@@ -104,4 +93,4 @@ with col2:
104
 
105
  # 添加页脚
106
  st.markdown("---")
107
- st.markdown("由 Streamlit 和 Hugging Face 提供支持 | 作者:[app-x]")
 
1
  import streamlit as st
2
+ from transformers import pipeline, AutoTokenizer
3
  import json
4
  from onnxruntime import InferenceSession
5
+ import numpy as np
6
 
7
  # 设置页面配置
8
  st.set_page_config(page_title="中文垃圾信息分类器", page_icon="🚫", layout="wide")
 
13
  hf_classifier = pipeline("text-classification", model="app-x/chinese_spam_classifier")
14
  onnx_session = InferenceSession("app-x/chinese_spam_classifier_onnx/model_optimized.onnx")
15
  tokenizer = AutoTokenizer.from_pretrained("app-x/chinese_spam_classifier_onnx")
16
+
17
+ # 加载配置文件
18
+ with open("app-x/chinese_spam_classifier_onnx/config.json", "r") as f:
19
+ config = json.load(f)
20
+
21
+ id2label = config["id2label"]
22
+ return hf_classifier, onnx_session, tokenizer, id2label
23
 
24
+ hf_classifier, onnx_session, tokenizer, id2label = load_classifiers()
25
 
26
  st.title("🚫 中文垃圾信息分类器")
27
  st.write("使用两个模型进行中文文本的垃圾信息分类。")
 
29
  # 创建两列布局
30
  col1, col2 = st.columns([2, 1])
31
 
32
+ def classify_text(text):
33
+ # HuggingFace模型分类
34
+ hf_result = hf_classifier(text)[0]
35
+ hf_label = id2label[str(int(hf_result["label"].split("_")[-1]))]
36
+ hf_confidence = hf_result["score"]
37
+
38
+ # ONNX模型分类
39
+ inputs = tokenizer(text, return_tensors="np", padding=True, truncation=True)
40
+ onnx_result = onnx_session.run(None, dict(inputs))
41
+ onnx_probs = onnx_result[0][0]
42
+ onnx_label = id2label[str(np.argmax(onnx_probs))]
43
+ onnx_confidence = np.max(onnx_probs)
44
+
45
+ return {
46
+ "hf": {"label": hf_label, "confidence": hf_confidence},
47
+ "onnx": {"label": onnx_label, "confidence": float(onnx_confidence)}
48
+ }
49
+
50
  with col1:
51
  # 创建文本输入框
52
  text_input = st.text_area("请输入中文文本:", height=200)
 
54
  if st.button("分类", key="classify_button"):
55
  if text_input:
56
  with st.spinner("正在分析..."):
57
+ results = classify_text(text_input)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
 
59
+ for model, result in results.items():
60
+ st.subheader(f"{model.upper()} 模型分类结果:")
61
+ label = "垃圾信息" if result["label"] == "spam" else "正常信息"
62
+ if label == "垃圾信息":
63
+ st.error(f"⚠️ {label}")
64
+ else:
65
+ st.success(f"✅ {label}")
66
+ st.write(f"概率: {result['confidence']:.2f}")
67
+ st.progress(result['confidence'])
68
  else:
69
  st.warning("请输入文本后再进行分类。")
70
 
 
93
 
94
  # 添加页脚
95
  st.markdown("---")
96
+ st.markdown("由 Streamlit 和 Hugging Face 提供支持 | 作者:[app-x]")