rifatramadhani commited on
Commit
fb5d1d1
1 Parent(s): 62a3d9e

wip: basic sentiment analysis

Browse files
Files changed (2) hide show
  1. app.py +45 -4
  2. requirements.txt +129 -0
app.py CHANGED
@@ -1,7 +1,48 @@
1
  import gradio as gr
 
 
 
 
 
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
 
5
 
6
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ import spaces
3
+ import torch
4
+ from transformers import pipeline
5
+ import datetime
6
+ import json
7
+ import logging
8
 
9
+ model_path = "cardiffnlp/twitter-roberta-base-sentiment-latest"
10
+ # Load model for cache
11
+ sentiment_task = pipeline("sentiment-analysis", model=model_path, tokenizer=model_path)
12
 
13
+ @spaces.GPU
14
+ def classify(query):
15
+ torch_device = 0 if torch.cuda.is_available() else -1
16
+ sentiment_task = pipeline("sentiment-analysis", model=model_path, tokenizer=model_path, device=torch_device)
17
+
18
+ request_type = type(query)
19
+ try:
20
+ data = json.loads(query)
21
+ if type(data) != list:
22
+ data = [query]
23
+ else:
24
+ request_type = type(data)
25
+ except Exception as e:
26
+ print(e)
27
+ data = [query]
28
+ pass
29
+
30
+ start_time = datetime.datetime.now()
31
+
32
+ result = sentiment_task(data, batch_size=128)
33
+
34
+ end_time = datetime.datetime.now()
35
+ elapsed_time = end_time - start_time
36
+
37
+ logging.debug("elapsed predict time: %s", str(elapsed_time))
38
+ print("elapsed predict time:", str(elapsed_time))
39
+
40
+ output = {}
41
+ output["time"] = str(elapsed_time)
42
+ output["device"] = torch_device
43
+ output["result"] = result
44
+
45
+ return json.dumps(output)
46
+
47
+ demo = gr.Interface(fn=classify, inputs="text", outputs="text")
48
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ aiofiles==23.2.1
2
+ aiohttp==3.9.5
3
+ aiosignal==1.3.1
4
+ altair==5.3.0
5
+ annotated-types==0.7.0
6
+ anyio==4.4.0
7
+ asttokens==2.4.1
8
+ async-timeout==4.0.3
9
+ attrs==23.2.0
10
+ Authlib==1.3.1
11
+ certifi==2024.6.2
12
+ cffi==1.16.0
13
+ charset-normalizer==3.3.2
14
+ click==8.0.4
15
+ contourpy==1.2.1
16
+ cryptography==42.0.8
17
+ cycler==0.12.1
18
+ datasets==2.19.2
19
+ decorator==5.1.1
20
+ dill==0.3.8
21
+ dnspython==2.6.1
22
+ email_validator==2.2.0
23
+ exceptiongroup==1.2.1
24
+ executing==2.0.1
25
+ fastapi==0.111.0
26
+ fastapi-cli==0.0.4
27
+ ffmpy==0.3.2
28
+ filelock==3.14.0
29
+ fonttools==4.53.0
30
+ frozenlist==1.4.1
31
+ fsspec==2024.3.1
32
+ gradio==4.37.2
33
+ gradio_client==1.0.2
34
+ h11==0.14.0
35
+ hf_transfer==0.1.6
36
+ httpcore==1.0.5
37
+ httptools==0.6.1
38
+ httpx==0.27.0
39
+ huggingface-hub==0.23.3
40
+ idna==3.7
41
+ importlib_resources==6.4.0
42
+ ipython==8.26.0
43
+ itsdangerous==2.2.0
44
+ jedi==0.19.1
45
+ Jinja2==3.1.4
46
+ jsonschema==4.22.0
47
+ jsonschema-specifications==2023.12.1
48
+ kiwisolver==1.4.5
49
+ markdown-it-py==3.0.0
50
+ MarkupSafe==2.1.5
51
+ matplotlib==3.9.0
52
+ matplotlib-inline==0.1.7
53
+ mdurl==0.1.2
54
+ mpmath==1.3.0
55
+ multidict==6.0.5
56
+ multiprocess==0.70.16
57
+ networkx==3.3
58
+ numpy==1.26.4
59
+ nvidia-cublas-cu12==12.1.3.1
60
+ nvidia-cuda-cupti-cu12==12.1.105
61
+ nvidia-cuda-nvrtc-cu12==12.1.105
62
+ nvidia-cuda-runtime-cu12==12.1.105
63
+ nvidia-cudnn-cu12==8.9.2.26
64
+ nvidia-cufft-cu12==11.0.2.54
65
+ nvidia-curand-cu12==10.3.2.106
66
+ nvidia-cusolver-cu12==11.4.5.107
67
+ nvidia-cusparse-cu12==12.1.0.106
68
+ nvidia-nccl-cu12==2.19.3
69
+ nvidia-nvjitlink-cu12==12.5.40
70
+ nvidia-nvtx-cu12==12.1.105
71
+ orjson==3.10.5
72
+ packaging==24.0
73
+ pandas==2.2.2
74
+ parso==0.8.4
75
+ pexpect==4.9.0
76
+ pillow==10.3.0
77
+ prompt_toolkit==3.0.47
78
+ protobuf==3.20.3
79
+ psutil==5.9.8
80
+ ptyprocess==0.7.0
81
+ pure-eval==0.2.2
82
+ pyarrow==16.1.0
83
+ pyarrow-hotfix==0.6
84
+ pycparser==2.22
85
+ pydantic==2.7.4
86
+ pydantic_core==2.18.4
87
+ pydub==0.25.1
88
+ Pygments==2.18.0
89
+ pyparsing==3.1.2
90
+ python-dateutil==2.9.0.post0
91
+ python-dotenv==1.0.1
92
+ python-multipart==0.0.9
93
+ pytz==2024.1
94
+ PyYAML==6.0.1
95
+ referencing==0.35.1
96
+ regex==2024.5.15
97
+ requests==2.32.3
98
+ rich==13.7.1
99
+ rpds-py==0.18.1
100
+ ruff==0.5.0
101
+ safetensors==0.4.3
102
+ semantic-version==2.10.0
103
+ shellingham==1.5.4
104
+ six==1.16.0
105
+ sniffio==1.3.1
106
+ spaces==0.28.3
107
+ stack-data==0.6.3
108
+ starlette==0.37.2
109
+ sympy==1.12.1
110
+ tokenizers==0.19.1
111
+ tomlkit==0.12.0
112
+ toolz==0.12.1
113
+ torch==2.2.0
114
+ tqdm==4.66.4
115
+ traitlets==5.14.3
116
+ transformers==4.42.3
117
+ triton==2.2.0
118
+ typer==0.12.3
119
+ typing_extensions==4.12.1
120
+ tzdata==2024.1
121
+ ujson==5.10.0
122
+ urllib3==2.2.1
123
+ uvicorn==0.30.1
124
+ uvloop==0.19.0
125
+ watchfiles==0.22.0
126
+ wcwidth==0.2.13
127
+ websockets==11.0.3
128
+ xxhash==3.4.1
129
+ yarl==1.9.4