Tran Xuan Huy commited on
Commit
6961a96
1 Parent(s): 7b84b5b

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +56 -0
app.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import numpy as np
3
+ from tqdm import tqdm
4
+ from copy import deepcopy
5
+ import torch
6
+ import json
7
+ from numpy.linalg import norm
8
+ import gradio as gr
9
+ from sentence_transformers import SentenceTransformer
10
+
11
+ # necessary function
12
+ def cosinesimilarity(vector1, vector2):
13
+ cosine = np.dot(vector1, vector2)/(norm(vector1)*norm(vector2))
14
+ return cosine
15
+
16
+ def encode_input_and_return_top_n(input_in, db_dff, top_k, new2oldmatching):
17
+ embed1 = model.encode(input_in)
18
+ scores = []
19
+ db_df_in = deepcopy(db_dff)
20
+ db_in = list(set(db_df_in['Câu lệnh có sẵn'].tolist()))
21
+ for i, func in enumerate(db_in):
22
+ embed2 = db_df_in['Embedding'].loc[i]
23
+ scores.append(round(cosinesimilarity(embed1, embed2), 3))
24
+ db_df_in["Điểm"] = scores
25
+ db_df_in.sort_values(by=['Điểm'], inplace=True, ascending=False)
26
+ ids = db_df_in[:top_k].index.tolist()
27
+ output = {new2oldmatching[db_df_in['Câu lệnh có sẵn'][i].strip()]: round(db_df_in['Điểm'][i].item(), 2) for i in ids}
28
+ return output
29
+
30
+ def image_classifier(Input):
31
+ inputt = Input.lower()
32
+ result = encode_input_and_return_top_n(inputt, db_df, 3, new2oldmatch)
33
+ return result
34
+
35
+ def encode_database(db_in):
36
+ df = pd.DataFrame(list(zip(db_in, [[]]*len(db_in))), columns=["Câu lệnh có sẵn", "Embedding"])
37
+ for i, func in tqdm(enumerate(db_in)):
38
+ embedding2 = model.encode(func)
39
+ df['Embedding'].loc[i] = embedding2
40
+ else:
41
+ print()
42
+ print("Encode database successfully")
43
+ return df
44
+
45
+ model = SentenceTransformer("something/model")
46
+ model.eval()
47
+
48
+ with open('something/new2oldmatch.json', 'r') as openfile:
49
+ new2oldmatch = json.load(openfile)
50
+ new2oldmatch = {u.strip().lower(): v.strip() for u, v in new2oldmatch.items()}
51
+
52
+ database = [cmd.lower() for cmd in new2oldmatch.keys()]
53
+ db_df = encode_database(database)
54
+
55
+ demo = gr.Interface(fn=image_classifier, inputs="text", outputs="label")
56
+ demo.launch(share=True)