bisoye commited on
Commit
d64d27f
·
verified ·
1 Parent(s): 44e579b

Upload 7 files

Browse files
app.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from gradio import ChatMessage
3
+ from utils import audio_to_text, text_to_speech
4
+ from main import get_chatbot_response, load_peft_model_and_tokenizer
5
+
6
+
7
+ PEFT_MODEL = 'microsoft/phi-2'
8
+ BASE_MODEL = 'bisoye/phi-2-for-mental-health-2'
9
+
10
+
11
+ tokenizer, model = load_peft_model_and_tokenizer(PEFT_MODEL, BASE_MODEL)
12
+
13
+
14
+ def respond_to_audio(audio):
15
+
16
+ text_from_audio = audio_to_text(audio)
17
+ response = get_chatbot_response(model, tokenizer,
18
+ text_from_audio)
19
+
20
+ return response
21
+
22
+
23
+ def chat_history(message: str, history: list):
24
+
25
+ response = get_chatbot_response(model, tokenizer, message)
26
+ history.append(ChatMessage(role = 'user',
27
+ content = message))
28
+ history.append(ChatMessage(role = 'assistant',
29
+ content = response))
30
+
31
+ audio_filename = text_to_speech(response) # convert response to audio
32
+ return history, audio_filename
33
+
34
+
35
+ with gr.Blocks(css='custom_css') as demo:
36
+
37
+ with gr.Row():
38
+
39
+ with gr.Column():
40
+ chatbot = gr.Chatbot(label='Mental Health Chatbot',
41
+ type = 'messages')
42
+
43
+ input_text = gr.Textbox(label = 'Enter your question here: ')
44
+ input_audio = gr.Audio(label='Send question as audio: ',
45
+ sources='microphone',
46
+ type='filepath')
47
+
48
+ send_audio_button = gr.Button(value = 'Send Audio')
49
+
50
+ with gr.Column():
51
+ output_audio = gr.Audio(label='AI audio response: ',
52
+ sources='upload',
53
+ type='filepath',
54
+ interactive=False,
55
+ autoplay=True)
56
+
57
+ clear_button = gr.ClearButton(components=[
58
+ input_text,
59
+ input_audio,
60
+ output_audio]
61
+ )
62
+
63
+ input_text.submit(chat_history, inputs=[input_text, chatbot], outputs=[chatbot])
64
+ send_audio_button.click(respond_to_audio, inputs=[input_audio], outputs=[chatbot, output_audio])
65
+
66
+ demo.launch()
app_design.css ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ body {
2
+ background-color: #1e1e1e;
3
+ color: white;
4
+ }
5
+
6
+ input, textarea, button, select {
7
+ font-size: 20px !important;
8
+ }
9
+
10
+ .gradio-container {
11
+ font-size: 18px;
12
+ }
13
+
14
+ .gr-input, .gr-textbox, .gr-button, .gr-audio {
15
+ font-size: 18px !important;
16
+ height: 50px;
17
+ }
18
+
19
+ .gr-chatbot {
20
+ font-size: 18px !important;
21
+ height: 400px;
22
+ }
counsel_data.csv ADDED
The diff for this file is too large to render. See raw diff
 
main.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
3
+ from peft import PeftModel
4
+ from huggingface_hub import login
5
+
6
+
7
+ login()
8
+
9
+
10
+ SYSTEM_PROMPT = """You are a mental health therapist/pyschologist/licensed professional counsellor.
11
+ People will talk to you about their personal life/mental health issues and you will reply them the same way a professional
12
+ therapist would, while being empathetic."""
13
+
14
+
15
+ def load_peft_model_and_tokenizer(peft_model, base_model):
16
+
17
+ tokenizer = AutoTokenizer.from_pretrained(peft_model)
18
+
19
+
20
+ bnb_config = BitsAndBytesConfig(
21
+ load_in_4bit=True,
22
+ bnb_4bit_use_double_quant=True,
23
+ bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16
24
+ )
25
+
26
+ # Load base model
27
+ base_model = AutoModelForCausalLM.from_pretrained(
28
+ base_model,
29
+ device_map="auto",
30
+ torch_dtype=torch.bfloat16,
31
+ quantization_config=bnb_config
32
+ )
33
+
34
+
35
+ base_model.resize_token_embeddings(len(tokenizer))
36
+ model = PeftModel.from_pretrained(model=base_model, model_id=peft_model)
37
+
38
+ return tokenizer, model
39
+
40
+
41
+ def get_chatbot_response(model, tokenizer, message):
42
+
43
+ input_ids = tokenizer(
44
+ f"### System: {SYSTEM_PROMPT}. ### Human: {message} ### Assistant: ",
45
+ return_tensors="pt",
46
+ truncation=True,
47
+ ).input_ids.cuda()
48
+
49
+ outputs = model.generate(
50
+ input_ids=input_ids,
51
+ max_new_tokens=100,
52
+ do_sample=True,
53
+ temperature=0.5,
54
+ top_p=0.9,
55
+ top_k=50,
56
+ repetition_penalty=1.1,
57
+ )
58
+
59
+ # Decode the generated response
60
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True)
61
+
62
+ # Extract the chatbot's response from the decoded text
63
+ response = response.split("### Assistant: ")[1].strip()
64
+
65
+ return response
66
+
67
+
mental-health-chatbot-inference.ipynb ADDED
@@ -0,0 +1 @@
 
 
1
+ {"metadata":{"kernelspec":{"language":"python","display_name":"Python 3","name":"python3"},"language_info":{"name":"python","version":"3.10.13","mimetype":"text/x-python","codemirror_mode":{"name":"ipython","version":3},"pygments_lexer":"ipython3","nbconvert_exporter":"python","file_extension":".py"},"kaggle":{"accelerator":"gpu","dataSources":[],"dockerImageVersionId":30747,"isInternetEnabled":true,"language":"python","sourceType":"notebook","isGpuEnabled":true}},"nbformat_minor":4,"nbformat":4,"cells":[{"cell_type":"code","source":"!pip install -Uqq transformers datasets evaluate bitsandbytes peft accelerate scipy einops trl","metadata":{"execution":{"iopub.status.busy":"2024-08-12T14:49:38.260628Z","iopub.execute_input":"2024-08-12T14:49:38.261217Z","iopub.status.idle":"2024-08-12T14:50:21.011903Z","shell.execute_reply.started":"2024-08-12T14:49:38.261184Z","shell.execute_reply":"2024-08-12T14:50:21.010797Z"},"trusted":true},"execution_count":1,"outputs":[{"name":"stdout","text":"\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\ncuml 24.6.1 requires cupy-cuda11x>=12.0.0, which is not installed.\nkfp 2.5.0 requires google-cloud-storage<3,>=2.2.1, but you have google-cloud-storage 1.44.0 which is incompatible.\nlibpysal 4.9.2 requires packaging>=22, but you have packaging 21.3 which is incompatible.\nlibpysal 4.9.2 requires shapely>=2.0.1, but you have shapely 1.8.5.post1 which is incompatible.\nmomepy 0.7.2 requires shapely>=2, but you have shapely 1.8.5.post1 which is incompatible.\npointpats 2.5.0 requires shapely>=2, but you have shapely 1.8.5.post1 which is incompatible.\nspaghetti 1.7.6 requires shapely>=2.0.1, but you have shapely 1.8.5.post1 which is incompatible.\nspopt 0.6.1 requires shapely>=2.0.1, but you have shapely 1.8.5.post1 which is incompatible.\nydata-profiling 4.6.4 requires numpy<1.26,>=1.16.0, but you have numpy 1.26.4 which is incompatible.\nydata-profiling 4.6.4 requires scipy<1.12,>=1.4.1, but you have scipy 1.14.0 which is incompatible.\u001b[0m\u001b[31m\n\u001b[0m","output_type":"stream"}]},{"cell_type":"code","source":"import torch\nimport wandb\nimport evaluate\nimport datasets\nimport peft\nfrom peft import AutoPeftModelForCausalLM\nfrom transformers import AutoTokenizer, pipeline\nfrom huggingface_hub import login\nfrom kaggle_secrets import UserSecretsClient\nfrom transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig\nfrom trl import setup_chat_format","metadata":{"execution":{"iopub.status.busy":"2024-08-12T14:50:25.811088Z","iopub.execute_input":"2024-08-12T14:50:25.812019Z","iopub.status.idle":"2024-08-12T14:50:45.078565Z","shell.execute_reply.started":"2024-08-12T14:50:25.811981Z","shell.execute_reply":"2024-08-12T14:50:45.077727Z"},"trusted":true},"execution_count":2,"outputs":[{"name":"stderr","text":"2024-08-12 14:50:34.357616: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n2024-08-12 14:50:34.357778: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n2024-08-12 14:50:34.484438: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n","output_type":"stream"}]},{"cell_type":"code","source":"user_secrets = UserSecretsClient()\nsecret_value_0 = user_secrets.get_secret(\"mlops_huggingface_token\")\nsecret_value_1 = user_secrets.get_secret(\"wandb_key\")\n\nwandb.login(key = secret_value_1)","metadata":{"execution":{"iopub.status.busy":"2024-08-12T14:50:45.080592Z","iopub.execute_input":"2024-08-12T14:50:45.081183Z","iopub.status.idle":"2024-08-12T14:50:47.425904Z","shell.execute_reply.started":"2024-08-12T14:50:45.081141Z","shell.execute_reply":"2024-08-12T14:50:47.424945Z"},"trusted":true},"execution_count":3,"outputs":[{"name":"stderr","text":"\u001b[34m\u001b[1mwandb\u001b[0m: W&B API key is configured. Use \u001b[1m`wandb login --relogin`\u001b[0m to force relogin\n\u001b[34m\u001b[1mwandb\u001b[0m: \u001b[33mWARNING\u001b[0m If you're specifying your api key in code, ensure this code is not shared publicly.\n\u001b[34m\u001b[1mwandb\u001b[0m: \u001b[33mWARNING\u001b[0m Consider setting the WANDB_API_KEY environment variable, or running `wandb login` from the command line.\n\u001b[34m\u001b[1mwandb\u001b[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc\n","output_type":"stream"},{"execution_count":3,"output_type":"execute_result","data":{"text/plain":"True"},"metadata":{}}]},{"cell_type":"code","source":"login(token = secret_value_0)","metadata":{"execution":{"iopub.status.busy":"2024-08-12T14:50:47.427066Z","iopub.execute_input":"2024-08-12T14:50:47.427802Z","iopub.status.idle":"2024-08-12T14:50:47.556901Z","shell.execute_reply.started":"2024-08-12T14:50:47.427771Z","shell.execute_reply":"2024-08-12T14:50:47.555544Z"},"trusted":true},"execution_count":4,"outputs":[{"name":"stdout","text":"The token has not been saved to the git credentials helper. Pass `add_to_git_credential=True` in this function directly or `--add-to-git-credential` if using via `huggingface-cli` if you want to set the git credential as well.\nToken is valid (permission: fineGrained).\nYour token has been saved to /root/.cache/huggingface/token\nLogin successful\n","output_type":"stream"}]},{"cell_type":"markdown","source":"#### Download model","metadata":{}},{"cell_type":"code","source":"base_model = 'microsoft/phi-2'\npeft_model = 'bisoye/phi-2-for-mental-health-2'","metadata":{"execution":{"iopub.status.busy":"2024-08-12T14:50:47.559103Z","iopub.execute_input":"2024-08-12T14:50:47.559939Z","iopub.status.idle":"2024-08-12T14:50:47.645370Z","shell.execute_reply.started":"2024-08-12T14:50:47.559900Z","shell.execute_reply":"2024-08-12T14:50:47.644355Z"},"trusted":true},"execution_count":5,"outputs":[]},{"cell_type":"code","source":"# Load base model\nmodel = AutoModelForCausalLM.from_pretrained(\n base_model,\n device_map=\"auto\",\n torch_dtype=torch.bfloat16,\n quantization_config=bnb_config\n)","metadata":{"execution":{"iopub.status.busy":"2024-08-12T14:13:24.186945Z","iopub.execute_input":"2024-08-12T14:13:24.187935Z","iopub.status.idle":"2024-08-12T14:16:42.905124Z","shell.execute_reply.started":"2024-08-12T14:13:24.187886Z","shell.execute_reply":"2024-08-12T14:16:42.904179Z"},"trusted":true},"execution_count":9,"outputs":[{"output_type":"display_data","data":{"text/plain":"config.json: 0%| | 0.00/735 [00:00<?, ?B/s]","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":"e157f73c9e3f49de93510cc1564bde5f"}},"metadata":{}},{"output_type":"display_data","data":{"text/plain":"model.safetensors.index.json: 0%| | 0.00/35.7k [00:00<?, ?B/s]","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":"cc1f9a82f2c548908baca19438c08cd1"}},"metadata":{}},{"output_type":"display_data","data":{"text/plain":"Downloading shards: 0%| | 0/2 [00:00<?, ?it/s]","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":"6518b226abf148cf8efe66f564141a1f"}},"metadata":{}},{"output_type":"display_data","data":{"text/plain":"model-00001-of-00002.safetensors: 0%| | 0.00/5.00G [00:00<?, ?B/s]","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":"37d03facde064524b20a25a42f056218"}},"metadata":{}},{"output_type":"display_data","data":{"text/plain":"model-00002-of-00002.safetensors: 0%| | 0.00/564M [00:00<?, ?B/s]","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":"9771d3f5e3794213981727d3045a251d"}},"metadata":{}},{"output_type":"display_data","data":{"text/plain":"Loading checkpoint shards: 0%| | 0/2 [00:00<?, ?it/s]","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":"87040aa136714bd4a55bebccf3842670"}},"metadata":{}},{"output_type":"display_data","data":{"text/plain":"generation_config.json: 0%| | 0.00/124 [00:00<?, ?B/s]","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":"e7bda228f9264158873e30c1bfb6ee64"}},"metadata":{}}]},{"cell_type":"code","source":"len(tokenizer)","metadata":{"execution":{"iopub.status.busy":"2024-08-12T14:16:46.610477Z","iopub.execute_input":"2024-08-12T14:16:46.610756Z","iopub.status.idle":"2024-08-12T14:16:46.621654Z","shell.execute_reply.started":"2024-08-12T14:16:46.610730Z","shell.execute_reply":"2024-08-12T14:16:46.620778Z"},"trusted":true},"execution_count":11,"outputs":[{"execution_count":11,"output_type":"execute_result","data":{"text/plain":"50297"},"metadata":{}}]},{"cell_type":"code","source":"#resize embeddings\nmodel.resize_token_embeddings(len(tokenizer))","metadata":{"execution":{"iopub.status.busy":"2024-08-12T14:16:51.248213Z","iopub.execute_input":"2024-08-12T14:16:51.249157Z","iopub.status.idle":"2024-08-12T14:16:51.291341Z","shell.execute_reply.started":"2024-08-12T14:16:51.249120Z","shell.execute_reply":"2024-08-12T14:16:51.290324Z"},"trusted":true},"execution_count":12,"outputs":[{"execution_count":12,"output_type":"execute_result","data":{"text/plain":"Embedding(50297, 2560)"},"metadata":{}}]},{"cell_type":"code","source":"peft_model = PeftModel.from_pretrained(model=model, model_id=model_id)","metadata":{"execution":{"iopub.status.busy":"2024-08-12T14:17:22.151940Z","iopub.execute_input":"2024-08-12T14:17:22.152501Z","iopub.status.idle":"2024-08-12T14:18:10.854284Z","shell.execute_reply.started":"2024-08-12T14:17:22.152463Z","shell.execute_reply":"2024-08-12T14:18:10.853279Z"},"trusted":true},"execution_count":13,"outputs":[{"output_type":"display_data","data":{"text/plain":"adapter_config.json: 0%| | 0.00/693 [00:00<?, ?B/s]","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":"9d8700e2f7e041da9f0707582f235c5a"}},"metadata":{}},{"output_type":"display_data","data":{"text/plain":"adapter_model.safetensors: 0%| | 0.00/1.20G [00:00<?, ?B/s]","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":"b35ca6348a7a4143a07d945a6befccf1"}},"metadata":{}}]},{"cell_type":"markdown","source":"#### Download data","metadata":{}},{"cell_type":"code","source":"# Load Model with PEFT adapter\nmodel = AutoPeftModelForCausalLM.from_pretrained(\n peft_model,\n device_map=\"auto\",\n torch_dtype=torch.float16\n)\ntokenizer = AutoTokenizer.from_pretrained(peft_model)\n# load into pipeline\npipe = pipeline(\"text-generation\", model=model, tokenizer=tokenizer)","metadata":{"execution":{"iopub.status.busy":"2024-08-12T14:51:24.842716Z","iopub.execute_input":"2024-08-12T14:51:24.843521Z","iopub.status.idle":"2024-08-12T14:52:08.942582Z","shell.execute_reply.started":"2024-08-12T14:51:24.843474Z","shell.execute_reply":"2024-08-12T14:52:08.941641Z"},"trusted":true},"execution_count":6,"outputs":[{"output_type":"display_data","data":{"text/plain":"adapter_config.json: 0%| | 0.00/693 [00:00<?, ?B/s]","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":"c64cde16592d46e7a4643090b553ce1b"}},"metadata":{}},{"output_type":"display_data","data":{"text/plain":"config.json: 0%| | 0.00/735 [00:00<?, ?B/s]","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":"b7f96e183f79490a8bbf35c1e44cc5aa"}},"metadata":{}},{"output_type":"display_data","data":{"text/plain":"model.safetensors.index.json: 0%| | 0.00/35.7k [00:00<?, ?B/s]","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":"0ad5fe04ab5447ce9c2f7eae0e43c51f"}},"metadata":{}},{"output_type":"display_data","data":{"text/plain":"Downloading shards: 0%| | 0/2 [00:00<?, ?it/s]","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":"2b7e2c970940428eb8ac0670f1aee03c"}},"metadata":{}},{"output_type":"display_data","data":{"text/plain":"model-00001-of-00002.safetensors: 0%| | 0.00/5.00G [00:00<?, ?B/s]","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":"f3ca11b96fa54d4a9a6a86ae4c84086a"}},"metadata":{}},{"output_type":"display_data","data":{"text/plain":"model-00002-of-00002.safetensors: 0%| | 0.00/564M [00:00<?, ?B/s]","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":"97e0cea2f28e4686b757c35f99e5b555"}},"metadata":{}},{"output_type":"display_data","data":{"text/plain":"Loading checkpoint shards: 0%| | 0/2 [00:00<?, ?it/s]","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":"782756a28e6d46ca8e704b0ba9c3993e"}},"metadata":{}},{"output_type":"display_data","data":{"text/plain":"generation_config.json: 0%| | 0.00/124 [00:00<?, ?B/s]","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":"bac1b0b00507400886af9623f3675047"}},"metadata":{}},{"output_type":"display_data","data":{"text/plain":"tokenizer_config.json: 0%| | 0.00/8.05k [00:00<?, ?B/s]","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":"f9b90b7b23d34947984293389000fb01"}},"metadata":{}},{"output_type":"display_data","data":{"text/plain":"vocab.json: 0%| | 0.00/798k [00:00<?, ?B/s]","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":"484f0227517b46c3ad2a7debb3440fd1"}},"metadata":{}},{"output_type":"display_data","data":{"text/plain":"merges.txt: 0%| | 0.00/456k [00:00<?, ?B/s]","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":"80062be423ef433f9c993d4e3bc945c4"}},"metadata":{}},{"output_type":"display_data","data":{"text/plain":"tokenizer.json: 0%| | 0.00/2.12M [00:00<?, ?B/s]","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":"42d6facfc932444f8edecd6cb546fa1f"}},"metadata":{}},{"output_type":"display_data","data":{"text/plain":"added_tokens.json: 0%| | 0.00/1.13k [00:00<?, ?B/s]","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":"7501abeb7eba445b944e0ca8c8fd018a"}},"metadata":{}},{"output_type":"display_data","data":{"text/plain":"special_tokens_map.json: 0%| | 0.00/565 [00:00<?, ?B/s]","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":"ea1d154c2987456f9443e585e4b2b6d6"}},"metadata":{}},{"output_type":"display_data","data":{"text/plain":"adapter_model.safetensors: 0%| | 0.00/1.20G [00:00<?, ?B/s]","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":"5e828bbcd11146fdb29107afdc922232"}},"metadata":{}},{"name":"stderr","text":"The model 'PeftModelForCausalLM' is not supported for text-generation. Supported models are ['BartForCausalLM', 'BertLMHeadModel', 'BertGenerationDecoder', 'BigBirdForCausalLM', 'BigBirdPegasusForCausalLM', 'BioGptForCausalLM', 'BlenderbotForCausalLM', 'BlenderbotSmallForCausalLM', 'BloomForCausalLM', 'CamembertForCausalLM', 'LlamaForCausalLM', 'CodeGenForCausalLM', 'CohereForCausalLM', 'CpmAntForCausalLM', 'CTRLLMHeadModel', 'Data2VecTextForCausalLM', 'DbrxForCausalLM', 'ElectraForCausalLM', 'ErnieForCausalLM', 'FalconForCausalLM', 'FuyuForCausalLM', 'GemmaForCausalLM', 'Gemma2ForCausalLM', 'GitForCausalLM', 'GPT2LMHeadModel', 'GPT2LMHeadModel', 'GPTBigCodeForCausalLM', 'GPTNeoForCausalLM', 'GPTNeoXForCausalLM', 'GPTNeoXJapaneseForCausalLM', 'GPTJForCausalLM', 'JambaForCausalLM', 'JetMoeForCausalLM', 'LlamaForCausalLM', 'MambaForCausalLM', 'Mamba2ForCausalLM', 'MarianForCausalLM', 'MBartForCausalLM', 'MegaForCausalLM', 'MegatronBertForCausalLM', 'MistralForCausalLM', 'MixtralForCausalLM', 'MptForCausalLM', 'MusicgenForCausalLM', 'MusicgenMelodyForCausalLM', 'MvpForCausalLM', 'NemotronForCausalLM', 'OlmoForCausalLM', 'OpenLlamaForCausalLM', 'OpenAIGPTLMHeadModel', 'OPTForCausalLM', 'PegasusForCausalLM', 'PersimmonForCausalLM', 'PhiForCausalLM', 'Phi3ForCausalLM', 'PLBartForCausalLM', 'ProphetNetForCausalLM', 'QDQBertLMHeadModel', 'Qwen2ForCausalLM', 'Qwen2MoeForCausalLM', 'RecurrentGemmaForCausalLM', 'ReformerModelWithLMHead', 'RemBertForCausalLM', 'RobertaForCausalLM', 'RobertaPreLayerNormForCausalLM', 'RoCBertForCausalLM', 'RoFormerForCausalLM', 'RwkvForCausalLM', 'Speech2Text2ForCausalLM', 'StableLmForCausalLM', 'Starcoder2ForCausalLM', 'TransfoXLLMHeadModel', 'TrOCRForCausalLM', 'WhisperForCausalLM', 'XGLMForCausalLM', 'XLMWithLMHeadModel', 'XLMProphetNetForCausalLM', 'XLMRobertaForCausalLM', 'XLMRobertaXLForCausalLM', 'XLNetLMHeadModel', 'XmodForCausalLM'].\n","output_type":"stream"}]},{"cell_type":"code","source":"from datasets import load_dataset\nfrom random import randint\nfrom pprint import pprint\n \neval_dataset = load_dataset(\"bisoye/mental_health_chatbot\", split=\"train\")\nrand_idx = randint(0, len(eval_dataset))\n \n# Test on sample\nprompt = pipe.tokenizer.apply_chat_template(eval_dataset[rand_idx][\"messages\"][:2], tokenize=False, add_generation_prompt=True)\noutputs = pipe(prompt, max_new_tokens=256, do_sample=False, temperature=0.1, top_k=50, top_p=0.1, eos_token_id=pipe.tokenizer.eos_token_id, pad_token_id=pipe.tokenizer.pad_token_id)\n \nprint(f\"Query:\\n\\t{eval_dataset[rand_idx]['messages'][1]['content']}\")\nprint()\nprint(f\"Original Answer:\\n\\t{eval_dataset[rand_idx]['messages'][2]['content']}\")\nprint()\nprint(f\"Generated Answer:\\n\\t{outputs[0]['generated_text'][len(prompt):].strip()}\")","metadata":{"execution":{"iopub.status.busy":"2024-08-12T15:04:37.118896Z","iopub.execute_input":"2024-08-12T15:04:37.119448Z","iopub.status.idle":"2024-08-12T15:04:47.943034Z","shell.execute_reply.started":"2024-08-12T15:04:37.119404Z","shell.execute_reply":"2024-08-12T15:04:47.942006Z"},"trusted":true},"execution_count":15,"outputs":[{"name":"stdout","text":"Query:\n\tIs it normal for people to cry during therapy, or is it just me?\n\nOriginal Answer:\n\tit is quite normal as conversations we have may touch on emotions, thoughts and feelings that have been covered up for a long time. Just as laughter (which may also be present in therapy), joy, sadness, reflections, these are all emotions and insights that can occur. Allowing yourself to feel and express yourself in a space of safety is freeing and enlightening. Not all sessions can have that but those moments are wonderful and continue on ones pattern of growth. Grab a Kleenex and let it out!\n\nGenerated Answer:\n\tAbsolutely!  Therapy is a safe place to explore your feelings and emotions.  It is a place where you can be yourself and be vulnerable.  It is a place where you can be honest and open.  It is a place where you can be free.  It is a place where you can be yourself.  It is a place where you can be free.  It is a place where you can be yourself.  It is a place where you can be free.  It is a place where you can be yourself.  It is a place where you can be free.  It is a place where you can be yourself.  It is a place where you can be free.  It is a place where you can be yourself.  It is a place where you can be free.  It is a place where you can be yourself.  It is a place where you can be free.  It is a place where you can be yourself.  It is a place where you can be free.  It is a place where you can be yourself.  It is a place where you can be free.  It is a place where you can be yourself.  It is a place where you can be free.  It is a place\n","output_type":"stream"}]},{"cell_type":"code","source":"","metadata":{},"execution_count":null,"outputs":[]},{"cell_type":"code","source":"","metadata":{},"execution_count":null,"outputs":[]}]}
mental-health-chatbot-training.ipynb ADDED
@@ -0,0 +1 @@
 
 
1
+ {"metadata":{"kernelspec":{"language":"python","display_name":"Python 3","name":"python3"},"language_info":{"name":"python","version":"3.10.13","mimetype":"text/x-python","codemirror_mode":{"name":"ipython","version":3},"pygments_lexer":"ipython3","nbconvert_exporter":"python","file_extension":".py"},"kaggle":{"accelerator":"gpu","dataSources":[],"dockerImageVersionId":30747,"isInternetEnabled":true,"language":"python","sourceType":"notebook","isGpuEnabled":true}},"nbformat_minor":4,"nbformat":4,"cells":[{"cell_type":"code","source":"!pip install -Uqq transformers datasets evaluate bitsandbytes peft accelerate scipy einops trl","metadata":{"execution":{"iopub.status.busy":"2024-08-12T14:04:31.424906Z","iopub.execute_input":"2024-08-12T14:04:31.425286Z","iopub.status.idle":"2024-08-12T14:05:14.673029Z","shell.execute_reply.started":"2024-08-12T14:04:31.425257Z","shell.execute_reply":"2024-08-12T14:05:14.671875Z"},"trusted":true},"execution_count":3,"outputs":[{"name":"stdout","text":"\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\ncuml 24.6.1 requires cupy-cuda11x>=12.0.0, which is not installed.\nkfp 2.5.0 requires google-cloud-storage<3,>=2.2.1, but you have google-cloud-storage 1.44.0 which is incompatible.\nlibpysal 4.9.2 requires packaging>=22, but you have packaging 21.3 which is incompatible.\nlibpysal 4.9.2 requires shapely>=2.0.1, but you have shapely 1.8.5.post1 which is incompatible.\nmomepy 0.7.2 requires shapely>=2, but you have shapely 1.8.5.post1 which is incompatible.\npointpats 2.5.0 requires shapely>=2, but you have shapely 1.8.5.post1 which is incompatible.\nspaghetti 1.7.6 requires shapely>=2.0.1, but you have shapely 1.8.5.post1 which is incompatible.\nspopt 0.6.1 requires shapely>=2.0.1, but you have shapely 1.8.5.post1 which is incompatible.\nydata-profiling 4.6.4 requires numpy<1.26,>=1.16.0, but you have numpy 1.26.4 which is incompatible.\nydata-profiling 4.6.4 requires scipy<1.12,>=1.4.1, but you have scipy 1.14.0 which is incompatible.\u001b[0m\u001b[31m\n\u001b[0m","output_type":"stream"}]},{"cell_type":"code","source":"import torch\nimport wandb\nimport evaluate\nimport datasets\nimport pandas as pd\nfrom datasets import load_dataset, Dataset, DatasetDict\nfrom pprint import pprint\nfrom huggingface_hub import login\nfrom kaggle_secrets import UserSecretsClient\nfrom transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig\nfrom trl import setup_chat_format\nfrom peft import LoraConfig, get_peft_model, PeftConfig, PeftModel, prepare_model_for_kbit_training","metadata":{"execution":{"iopub.status.busy":"2024-08-12T14:07:05.671162Z","iopub.execute_input":"2024-08-12T14:07:05.671582Z","iopub.status.idle":"2024-08-12T14:07:24.804667Z","shell.execute_reply.started":"2024-08-12T14:07:05.671544Z","shell.execute_reply":"2024-08-12T14:07:24.803840Z"},"trusted":true},"execution_count":4,"outputs":[{"name":"stderr","text":"2024-08-12 14:07:14.310516: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n2024-08-12 14:07:14.310628: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n2024-08-12 14:07:14.452921: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n","output_type":"stream"}]},{"cell_type":"code","source":"\nuser_secrets = UserSecretsClient()\nsecret_value_0 = user_secrets.get_secret(\"mlops_huggingface_token\")\nsecret_value_1 = user_secrets.get_secret(\"wandb_key\")\n\nwandb.login(key = secret_value_1)","metadata":{"execution":{"iopub.status.busy":"2024-08-12T14:07:24.806645Z","iopub.execute_input":"2024-08-12T14:07:24.807318Z","iopub.status.idle":"2024-08-12T14:07:27.187735Z","shell.execute_reply.started":"2024-08-12T14:07:24.807288Z","shell.execute_reply":"2024-08-12T14:07:27.186823Z"},"trusted":true},"execution_count":5,"outputs":[{"name":"stderr","text":"\u001b[34m\u001b[1mwandb\u001b[0m: W&B API key is configured. Use \u001b[1m`wandb login --relogin`\u001b[0m to force relogin\n\u001b[34m\u001b[1mwandb\u001b[0m: \u001b[33mWARNING\u001b[0m If you're specifying your api key in code, ensure this code is not shared publicly.\n\u001b[34m\u001b[1mwandb\u001b[0m: \u001b[33mWARNING\u001b[0m Consider setting the WANDB_API_KEY environment variable, or running `wandb login` from the command line.\n\u001b[34m\u001b[1mwandb\u001b[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc\n","output_type":"stream"},{"execution_count":5,"output_type":"execute_result","data":{"text/plain":"True"},"metadata":{}}]},{"cell_type":"code","source":"login(token = secret_value_0)","metadata":{"execution":{"iopub.status.busy":"2024-08-12T14:07:27.189188Z","iopub.execute_input":"2024-08-12T14:07:27.189821Z","iopub.status.idle":"2024-08-12T14:07:27.334472Z","shell.execute_reply.started":"2024-08-12T14:07:27.189780Z","shell.execute_reply":"2024-08-12T14:07:27.333536Z"},"trusted":true},"execution_count":6,"outputs":[{"name":"stdout","text":"The token has not been saved to the git credentials helper. Pass `add_to_git_credential=True` in this function directly or `--add-to-git-credential` if using via `huggingface-cli` if you want to set the git credential as well.\nToken is valid (permission: fineGrained).\nYour token has been saved to /root/.cache/huggingface/token\nLogin successful\n","output_type":"stream"}]},{"cell_type":"code","source":"ds = load_dataset(\"nbertagnolli/counsel-chat\", split = 'train')","metadata":{"execution":{"iopub.status.busy":"2024-08-12T14:07:27.336951Z","iopub.execute_input":"2024-08-12T14:07:27.337247Z","iopub.status.idle":"2024-08-12T14:07:29.588187Z","shell.execute_reply.started":"2024-08-12T14:07:27.337222Z","shell.execute_reply":"2024-08-12T14:07:29.587319Z"},"trusted":true},"execution_count":7,"outputs":[{"output_type":"display_data","data":{"text/plain":"Downloading readme: 0%| | 0.00/4.92k [00:00<?, ?B/s]","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":"ca8f8474aa77408cb6bcc7f122f71b3c"}},"metadata":{}},{"name":"stderr","text":"Repo card metadata block was not found. Setting CardData to empty.\n","output_type":"stream"},{"output_type":"display_data","data":{"text/plain":"Downloading data: 0%| | 0.00/4.13M [00:00<?, ?B/s]","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":"f7b486cf5be8401495ba4ef57d626c1e"}},"metadata":{}},{"output_type":"display_data","data":{"text/plain":"Generating train split: 0%| | 0/2775 [00:00<?, ? examples/s]","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":"45e7063b3b1c49e9b9f4d248c1d1d61f"}},"metadata":{}}]},{"cell_type":"code","source":"ds","metadata":{"execution":{"iopub.status.busy":"2024-08-12T14:07:29.589333Z","iopub.execute_input":"2024-08-12T14:07:29.589624Z","iopub.status.idle":"2024-08-12T14:07:29.595659Z","shell.execute_reply.started":"2024-08-12T14:07:29.589599Z","shell.execute_reply":"2024-08-12T14:07:29.594668Z"},"trusted":true},"execution_count":8,"outputs":[{"execution_count":8,"output_type":"execute_result","data":{"text/plain":"Dataset({\n features: ['questionID', 'questionTitle', 'questionText', 'questionLink', 'topic', 'therapistInfo', 'therapistURL', 'answerText', 'upvotes', 'views'],\n num_rows: 2775\n})"},"metadata":{}}]},{"cell_type":"code","source":"columns_to_check = ['questionText', 'answerText']\n\n# Define a function to check for NaN values in the specified columns\ndef drop_nans(example):\n return not any(pd.isna(example[col]) for col in columns_to_check)\n\n# Apply the filter\nds = ds.filter(drop_nans)","metadata":{"execution":{"iopub.status.busy":"2024-08-12T14:07:29.596741Z","iopub.execute_input":"2024-08-12T14:07:29.597033Z","iopub.status.idle":"2024-08-12T14:07:29.712066Z","shell.execute_reply.started":"2024-08-12T14:07:29.597011Z","shell.execute_reply":"2024-08-12T14:07:29.711092Z"},"trusted":true},"execution_count":9,"outputs":[{"output_type":"display_data","data":{"text/plain":"Filter: 0%| | 0/2775 [00:00<?, ? examples/s]","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":"fcd9e26734c2473b835291e9307b6d77"}},"metadata":{}}]},{"cell_type":"code","source":"#split dataset into train, valid and test split\ntrain_temp_dataset = ds.train_test_split(test_size = .2)\n# val_test_dataset = train_temp_dataset['test'].train_test_split(.5)","metadata":{"execution":{"iopub.status.busy":"2024-08-12T14:07:29.713351Z","iopub.execute_input":"2024-08-12T14:07:29.713642Z","iopub.status.idle":"2024-08-12T14:07:29.731813Z","shell.execute_reply.started":"2024-08-12T14:07:29.713617Z","shell.execute_reply":"2024-08-12T14:07:29.730835Z"},"trusted":true},"execution_count":10,"outputs":[]},{"cell_type":"code","source":"train_temp_dataset","metadata":{"execution":{"iopub.status.busy":"2024-08-12T14:07:29.733266Z","iopub.execute_input":"2024-08-12T14:07:29.733603Z","iopub.status.idle":"2024-08-12T14:07:29.739738Z","shell.execute_reply.started":"2024-08-12T14:07:29.733575Z","shell.execute_reply":"2024-08-12T14:07:29.738899Z"},"trusted":true},"execution_count":11,"outputs":[{"execution_count":11,"output_type":"execute_result","data":{"text/plain":"DatasetDict({\n train: Dataset({\n features: ['questionID', 'questionTitle', 'questionText', 'questionLink', 'topic', 'therapistInfo', 'therapistURL', 'answerText', 'upvotes', 'views'],\n num_rows: 2089\n })\n test: Dataset({\n features: ['questionID', 'questionTitle', 'questionText', 'questionLink', 'topic', 'therapistInfo', 'therapistURL', 'answerText', 'upvotes', 'views'],\n num_rows: 523\n })\n})"},"metadata":{}}]},{"cell_type":"code","source":"features = list(train_temp_dataset['train'].features.keys())\n\ncols_to_keep = ['questionText', 'answerText']\ncols_to_remove = [i for i in features if i not in cols_to_keep]\nprint(cols_to_remove)","metadata":{"execution":{"iopub.status.busy":"2024-08-12T14:07:29.740782Z","iopub.execute_input":"2024-08-12T14:07:29.741116Z","iopub.status.idle":"2024-08-12T14:07:29.749041Z","shell.execute_reply.started":"2024-08-12T14:07:29.741085Z","shell.execute_reply":"2024-08-12T14:07:29.748138Z"},"trusted":true},"execution_count":12,"outputs":[{"name":"stdout","text":"['questionID', 'questionTitle', 'questionLink', 'topic', 'therapistInfo', 'therapistURL', 'upvotes', 'views']\n","output_type":"stream"}]},{"cell_type":"code","source":"train_temp_dataset = train_temp_dataset.remove_columns(cols_to_remove)","metadata":{"execution":{"iopub.status.busy":"2024-08-12T14:07:29.752404Z","iopub.execute_input":"2024-08-12T14:07:29.752667Z","iopub.status.idle":"2024-08-12T14:07:29.766544Z","shell.execute_reply.started":"2024-08-12T14:07:29.752646Z","shell.execute_reply":"2024-08-12T14:07:29.765620Z"},"trusted":true},"execution_count":13,"outputs":[]},{"cell_type":"code","source":"df = train_temp_dataset['train'].to_pandas()","metadata":{"execution":{"iopub.status.busy":"2024-08-12T14:07:29.767911Z","iopub.execute_input":"2024-08-12T14:07:29.769001Z","iopub.status.idle":"2024-08-12T14:07:29.789542Z","shell.execute_reply.started":"2024-08-12T14:07:29.768971Z","shell.execute_reply":"2024-08-12T14:07:29.788775Z"},"trusted":true},"execution_count":14,"outputs":[]},{"cell_type":"code","source":"df.isna().sum()","metadata":{"execution":{"iopub.status.busy":"2024-08-12T14:07:29.790543Z","iopub.execute_input":"2024-08-12T14:07:29.790838Z","iopub.status.idle":"2024-08-12T14:07:29.799394Z","shell.execute_reply.started":"2024-08-12T14:07:29.790807Z","shell.execute_reply":"2024-08-12T14:07:29.798418Z"},"trusted":true},"execution_count":15,"outputs":[{"execution_count":15,"output_type":"execute_result","data":{"text/plain":"questionText 0\nanswerText 0\ndtype: int64"},"metadata":{}}]},{"cell_type":"code","source":"train_temp_dataset","metadata":{"execution":{"iopub.status.busy":"2024-08-12T14:07:29.800731Z","iopub.execute_input":"2024-08-12T14:07:29.801147Z","iopub.status.idle":"2024-08-12T14:07:29.809412Z","shell.execute_reply.started":"2024-08-12T14:07:29.801117Z","shell.execute_reply":"2024-08-12T14:07:29.808491Z"},"trusted":true},"execution_count":16,"outputs":[{"execution_count":16,"output_type":"execute_result","data":{"text/plain":"DatasetDict({\n train: Dataset({\n features: ['questionText', 'answerText'],\n num_rows: 2089\n })\n test: Dataset({\n features: ['questionText', 'answerText'],\n num_rows: 523\n })\n})"},"metadata":{}}]},{"cell_type":"markdown","source":"#### Preparing the dataset for conversational finetuning\n\n{\"messages\": [{\"role\": \"system\", \"content\": \"You are...\"}, {\"role\": \"user\", \"content\": \"...\"}, {\"role\": \"assistant\", \"content\": \"...\"}]}\n{\"messages\": [{\"role\": \"system\", \"content\": \"You are...\"}, {\"role\": \"user\", \"content\": \"...\"}, {\"role\": \"assistant\", \"content\": \"...\"}]}\n{\"messages\": [{\"role\": \"system\", \"content\": \"You are...\"}, {\"role\": \"user\", \"content\": \"...\"}, {\"role\": \"assistant\", \"content\": \"...\"}]}\n\n","metadata":{}},{"cell_type":"code","source":"system_messasge = \"\"\"You are a mental health chatbot that follows mental health therapist professional guidelines.\nUsers will talk to you about their mental health issues and you will reply them the same way a professional \ntherapist would.\"\"\"\n\ndef create_conversation_data(sample):\n \n return{\n 'messages': [\n {'role': 'system',\n 'content': system_messasge},\n {'role': 'user',\n 'content': sample['questionText']},\n {'role': 'assistant',\n 'content': sample['answerText']}\n ]\n }\n\n","metadata":{"execution":{"iopub.status.busy":"2024-08-12T14:07:29.810476Z","iopub.execute_input":"2024-08-12T14:07:29.810812Z","iopub.status.idle":"2024-08-12T14:07:29.817932Z","shell.execute_reply.started":"2024-08-12T14:07:29.810762Z","shell.execute_reply":"2024-08-12T14:07:29.816977Z"},"trusted":true},"execution_count":17,"outputs":[]},{"cell_type":"code","source":"chat_dataset = train_temp_dataset.map(create_conversation_data, batched = False)","metadata":{"execution":{"iopub.status.busy":"2024-08-12T14:07:29.819037Z","iopub.execute_input":"2024-08-12T14:07:29.819365Z","iopub.status.idle":"2024-08-12T14:07:30.215456Z","shell.execute_reply.started":"2024-08-12T14:07:29.819335Z","shell.execute_reply":"2024-08-12T14:07:30.214521Z"},"trusted":true},"execution_count":18,"outputs":[{"output_type":"display_data","data":{"text/plain":"Map: 0%| | 0/2089 [00:00<?, ? examples/s]","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":"66412314b66f48eeb66e34df757dec8c"}},"metadata":{}},{"output_type":"display_data","data":{"text/plain":"Map: 0%| | 0/523 [00:00<?, ? examples/s]","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":"1dca7a36a0bd4d3281160fecb06a34f1"}},"metadata":{}}]},{"cell_type":"code","source":"chat_dataset","metadata":{"execution":{"iopub.status.busy":"2024-08-12T14:07:30.216850Z","iopub.execute_input":"2024-08-12T14:07:30.217212Z","iopub.status.idle":"2024-08-12T14:07:30.224108Z","shell.execute_reply.started":"2024-08-12T14:07:30.217178Z","shell.execute_reply":"2024-08-12T14:07:30.223035Z"},"trusted":true},"execution_count":19,"outputs":[{"execution_count":19,"output_type":"execute_result","data":{"text/plain":"DatasetDict({\n train: Dataset({\n features: ['questionText', 'answerText', 'messages'],\n num_rows: 2089\n })\n test: Dataset({\n features: ['questionText', 'answerText', 'messages'],\n num_rows: 523\n })\n})"},"metadata":{}}]},{"cell_type":"code","source":"pprint(chat_dataset[\"train\"][145][\"messages\"])","metadata":{"execution":{"iopub.status.busy":"2024-08-12T14:07:30.225301Z","iopub.execute_input":"2024-08-12T14:07:30.225641Z","iopub.status.idle":"2024-08-12T14:07:30.246726Z","shell.execute_reply.started":"2024-08-12T14:07:30.225608Z","shell.execute_reply":"2024-08-12T14:07:30.245749Z"},"trusted":true},"execution_count":20,"outputs":[{"name":"stdout","text":"[{'content': 'You are a mental health chatbot that follows mental health '\n 'therapist professional guidelines.\\n'\n 'Users will talk to you about their mental health issues and you '\n 'will reply them the same way a professional \\n'\n 'therapist would.',\n 'role': 'system'},\n {'content': \"I am in my 30s with 2 kids and no job and I'm behind on the \"\n \"bills. I've been trying everything with no results. I feel like \"\n 'everyone else puts their needs in front on mine. I give intil it '\n \"hurts and I'm afraid that I'm going to fail my babies I've \"\n 'been struggling for 5 yrs in a relationship with no '\n 'communication and no alone time. How can I get some support?',\n 'role': 'user'},\n {'content': 'I would want to know where do you put your needs? Often times we '\n 'show people how to treat us based on how we treat ourselves. '\n 'First identify how to practice good self care which can only '\n 'happen when you are authentic and real about your needs and '\n 'desires. You cannot pour from an empty cup and must take find '\n 'the time to explore your purpose in life. It may be helpful to '\n 'seek out additional support to work on your intrapersonal '\n 'relationship.',\n 'role': 'assistant'}]\n","output_type":"stream"}]},{"cell_type":"code","source":"# save datasets to disk\nchat_dataset[\"train\"].to_json(\"train_dataset.json\", orient=\"records\")\nchat_dataset[\"test\"].to_json(\"test_dataset.json\", orient=\"records\")","metadata":{"execution":{"iopub.status.busy":"2024-08-12T14:07:30.248046Z","iopub.execute_input":"2024-08-12T14:07:30.248344Z","iopub.status.idle":"2024-08-12T14:07:30.386487Z","shell.execute_reply.started":"2024-08-12T14:07:30.248321Z","shell.execute_reply":"2024-08-12T14:07:30.385562Z"},"trusted":true},"execution_count":21,"outputs":[{"output_type":"display_data","data":{"text/plain":"Creating json from Arrow format: 0%| | 0/3 [00:00<?, ?ba/s]","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":"8fe54227d3d2456192feec63e910ddc8"}},"metadata":{}},{"output_type":"display_data","data":{"text/plain":"Creating json from Arrow format: 0%| | 0/1 [00:00<?, ?ba/s]","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":"be9af39824c54d80960a995b7615b968"}},"metadata":{}},{"execution_count":21,"output_type":"execute_result","data":{"text/plain":"1455984"},"metadata":{}}]},{"cell_type":"code","source":"train_dataset = load_dataset('json', data_files='/kaggle/working/train_dataset.json', split='train')","metadata":{"execution":{"iopub.status.busy":"2024-08-12T14:07:30.387599Z","iopub.execute_input":"2024-08-12T14:07:30.387909Z","iopub.status.idle":"2024-08-12T14:07:30.545776Z","shell.execute_reply.started":"2024-08-12T14:07:30.387885Z","shell.execute_reply":"2024-08-12T14:07:30.544873Z"},"trusted":true},"execution_count":22,"outputs":[{"output_type":"display_data","data":{"text/plain":"Generating train split: 0 examples [00:00, ? examples/s]","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":"89dc0a5891614526ac94b5fb29cf0f67"}},"metadata":{}}]},{"cell_type":"code","source":"pprint(train_dataset[0], sort_dicts=False)","metadata":{"execution":{"iopub.status.busy":"2024-08-12T14:07:30.547032Z","iopub.execute_input":"2024-08-12T14:07:30.547389Z","iopub.status.idle":"2024-08-12T14:07:30.555549Z","shell.execute_reply.started":"2024-08-12T14:07:30.547355Z","shell.execute_reply":"2024-08-12T14:07:30.554688Z"},"trusted":true},"execution_count":23,"outputs":[{"name":"stdout","text":"{'questionText': \"I've had posttraumatic stress disorder for years without my \"\n 'parents ever finding out. I want to overcome it, but it’s so '\n \"vivid, it’s like it’s happening again. I'm scared and \"\n 'paranoid. I have depression, which I have been struggling '\n 'with since a young age.',\n 'answerText': 'Post traumatic stress disorder (PTSD) is a very complex '\n 'issue.\\xa0 It is the root cause of many other issues including '\n 'depression, anxiety and anger to name a few.\\xa0 '\n 'Re-experiencing the event is a telltale symptom of PTSD.\\xa0 '\n 'One of the best things you can do for yourself is to find '\n 'someone who specializes in this area\\xa0to help you work '\n 'through the symptoms associated with this disorder as well as '\n 'directly addressing the event/events that initially caused the '\n 'symptoms.\\xa0 There are various treatment methods available to '\n 'helping professionals that are successful in treating this '\n 'disorder.\\xa0 In the meantime there is a book called \"The Body '\n 'Keeps the Score\", it is\\xa0informative and\\xa0may be helpful '\n 'for you on your journey.\\xa0 It is my hope that you are not '\n 'dealing with this alone.\\xa0\\xa0 Best wishes!',\n 'messages': [{'content': 'You are a mental health chatbot that follows mental '\n 'health therapist professional guidelines.\\n'\n 'Users will talk to you about their mental health '\n 'issues and you will reply them the same way a '\n 'professional \\n'\n 'therapist would.',\n 'role': 'system'},\n {'content': \"I've had posttraumatic stress disorder for years \"\n 'without my parents ever finding out. I want to '\n 'overcome it, but it’s so vivid, it’s like it’s '\n \"happening again. I'm scared and paranoid. I have \"\n 'depression, which I have been struggling with since '\n 'a young age.',\n 'role': 'user'},\n {'content': 'Post traumatic stress disorder (PTSD) is a very '\n 'complex issue.\\xa0 It is the root cause of many '\n 'other issues including depression, anxiety and '\n 'anger to name a few.\\xa0 Re-experiencing the event '\n 'is a telltale symptom of PTSD.\\xa0 One of the best '\n 'things you can do for yourself is to find someone '\n 'who specializes in this area\\xa0to help you work '\n 'through the symptoms associated with this disorder '\n 'as well as directly addressing the event/events '\n 'that initially caused the symptoms.\\xa0 There are '\n 'various treatment methods available to helping '\n 'professionals that are successful in treating this '\n 'disorder.\\xa0 In the meantime there is a book '\n 'called \"The Body Keeps the Score\", it is\\xa0'\n 'informative and\\xa0may be helpful for you on your '\n 'journey.\\xa0 It is my hope that you are not dealing '\n 'with this alone.\\xa0\\xa0 Best wishes!',\n 'role': 'assistant'}]}\n","output_type":"stream"}]},{"cell_type":"code","source":"train_dataset","metadata":{"execution":{"iopub.status.busy":"2024-08-12T14:07:30.556679Z","iopub.execute_input":"2024-08-12T14:07:30.556992Z","iopub.status.idle":"2024-08-12T14:07:30.567293Z","shell.execute_reply.started":"2024-08-12T14:07:30.556967Z","shell.execute_reply":"2024-08-12T14:07:30.566352Z"},"trusted":true},"execution_count":24,"outputs":[{"execution_count":24,"output_type":"execute_result","data":{"text/plain":"Dataset({\n features: ['questionText', 'answerText', 'messages'],\n num_rows: 2089\n})"},"metadata":{}}]},{"cell_type":"markdown","source":"#### Modeling","metadata":{}},{"cell_type":"code","source":"model_id = \"microsoft/phi-2\"\n\n# BitsAndBytesConfig int-4 config\nbnb_config = BitsAndBytesConfig(\n load_in_4bit=True, \n bnb_4bit_use_double_quant=True, \n bnb_4bit_quant_type=\"nf4\", bnb_4bit_compute_dtype=torch.bfloat16\n)","metadata":{"execution":{"iopub.status.busy":"2024-08-12T14:07:30.890686Z","iopub.execute_input":"2024-08-12T14:07:30.891069Z","iopub.status.idle":"2024-08-12T14:07:30.897661Z","shell.execute_reply.started":"2024-08-12T14:07:30.891039Z","shell.execute_reply":"2024-08-12T14:07:30.896833Z"},"trusted":true},"execution_count":25,"outputs":[]},{"cell_type":"code","source":"# Load model and tokenizer\nmodel = AutoModelForCausalLM.from_pretrained(\n model_id,\n device_map=\"auto\",\n #attn_implementation=\"flash_attention_2\",\n torch_dtype=torch.bfloat16,\n quantization_config=bnb_config\n)\ntokenizer = AutoTokenizer.from_pretrained(model_id)\n","metadata":{"execution":{"iopub.status.busy":"2024-08-12T14:07:31.901016Z","iopub.execute_input":"2024-08-12T14:07:31.901878Z","iopub.status.idle":"2024-08-12T14:08:08.271009Z","shell.execute_reply.started":"2024-08-12T14:07:31.901845Z","shell.execute_reply":"2024-08-12T14:08:08.270057Z"},"trusted":true},"execution_count":26,"outputs":[{"output_type":"display_data","data":{"text/plain":"config.json: 0%| | 0.00/735 [00:00<?, ?B/s]","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":"19e2dc209abe4147b7e33a615690b904"}},"metadata":{}},{"output_type":"display_data","data":{"text/plain":"model.safetensors.index.json: 0%| | 0.00/35.7k [00:00<?, ?B/s]","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":"7a3d65defc2e4630b25acfe29d2a3dfa"}},"metadata":{}},{"output_type":"display_data","data":{"text/plain":"Downloading shards: 0%| | 0/2 [00:00<?, ?it/s]","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":"722b37288c1c4313a51f8a6305cd985a"}},"metadata":{}},{"output_type":"display_data","data":{"text/plain":"model-00001-of-00002.safetensors: 0%| | 0.00/5.00G [00:00<?, ?B/s]","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":"b7cb09f6c1c546aaa3400a7a7fa7006a"}},"metadata":{}},{"output_type":"display_data","data":{"text/plain":"model-00002-of-00002.safetensors: 0%| | 0.00/564M [00:00<?, ?B/s]","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":"5f058a2dbf6f43a4a3d8d41c20f346ce"}},"metadata":{}},{"output_type":"display_data","data":{"text/plain":"Loading checkpoint shards: 0%| | 0/2 [00:00<?, ?it/s]","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":"dbf41ddefd6b481c8856bd6f2fae70ce"}},"metadata":{}},{"output_type":"display_data","data":{"text/plain":"generation_config.json: 0%| | 0.00/124 [00:00<?, ?B/s]","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":"4add9a39afe645c288fb551a724516a6"}},"metadata":{}},{"output_type":"display_data","data":{"text/plain":"tokenizer_config.json: 0%| | 0.00/7.34k [00:00<?, ?B/s]","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":"d0e9ef621fd343d7b9f8c5cb4d79a5be"}},"metadata":{}},{"output_type":"display_data","data":{"text/plain":"vocab.json: 0%| | 0.00/798k [00:00<?, ?B/s]","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":"bd9b8cc21c78452c860aad79373f64b9"}},"metadata":{}},{"output_type":"display_data","data":{"text/plain":"merges.txt: 0%| | 0.00/456k [00:00<?, ?B/s]","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":"360dc2690f3042cda19b4c8275ed2e2c"}},"metadata":{}},{"output_type":"display_data","data":{"text/plain":"tokenizer.json: 0%| | 0.00/2.11M [00:00<?, ?B/s]","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":"768c3edf4ab248b895c9ae66e39988af"}},"metadata":{}},{"output_type":"display_data","data":{"text/plain":"added_tokens.json: 0%| | 0.00/1.08k [00:00<?, ?B/s]","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":"676b9a4df0874628a663d9fa27bd7bd8"}},"metadata":{}},{"output_type":"display_data","data":{"text/plain":"special_tokens_map.json: 0%| | 0.00/99.0 [00:00<?, ?B/s]","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":"b89a55e337f74f9ab00b3e1cfce5ed22"}},"metadata":{}}]},{"cell_type":"code","source":"","metadata":{},"execution_count":null,"outputs":[]},{"cell_type":"code","source":"","metadata":{},"execution_count":null,"outputs":[]},{"cell_type":"code","source":"# # # set chat template to OAI chatML, remove if you start from a fine-tuned model\nmodel, tokenizer = setup_chat_format(model, tokenizer)","metadata":{"execution":{"iopub.status.busy":"2024-08-08T14:53:08.935622Z","iopub.execute_input":"2024-08-08T14:53:08.936297Z","iopub.status.idle":"2024-08-08T14:53:08.976635Z","shell.execute_reply.started":"2024-08-08T14:53:08.936260Z","shell.execute_reply":"2024-08-08T14:53:08.975779Z"},"trusted":true},"execution_count":25,"outputs":[]},{"cell_type":"code","source":"model = prepare_model_for_kbit_training(model)","metadata":{"execution":{"iopub.status.busy":"2024-08-08T14:53:08.977688Z","iopub.execute_input":"2024-08-08T14:53:08.978008Z","iopub.status.idle":"2024-08-08T14:53:13.041666Z","shell.execute_reply.started":"2024-08-08T14:53:08.977983Z","shell.execute_reply":"2024-08-08T14:53:13.040881Z"},"trusted":true},"execution_count":26,"outputs":[]},{"cell_type":"code","source":"tokenizer.pad_token = tokenizer.eos_token ","metadata":{"execution":{"iopub.status.busy":"2024-08-08T14:53:13.042814Z","iopub.execute_input":"2024-08-08T14:53:13.043103Z","iopub.status.idle":"2024-08-08T14:53:13.112117Z","shell.execute_reply.started":"2024-08-08T14:53:13.043079Z","shell.execute_reply":"2024-08-08T14:53:13.110993Z"},"trusted":true},"execution_count":27,"outputs":[]},{"cell_type":"code","source":"# LoRA config based on QLoRA paper & Sebastian Raschka experiment\npeft_config = LoraConfig(\n lora_alpha=128,\n lora_dropout=0.05,\n r=256,\n bias=\"none\",\n target_modules=[ # Setting names of modules in falcon-7b model that we want to apply LoRA to\n \"query_key_value\",\n \"dense\",\n \"dense_h_to_4h\",\n \"dense_4h_to_h\",\n ],\n task_type=\"CAUSAL_LM\",\n)","metadata":{"execution":{"iopub.status.busy":"2024-08-08T14:53:13.113642Z","iopub.execute_input":"2024-08-08T14:53:13.114354Z","iopub.status.idle":"2024-08-08T14:53:13.183924Z","shell.execute_reply.started":"2024-08-08T14:53:13.114325Z","shell.execute_reply":"2024-08-08T14:53:13.182813Z"},"trusted":true},"execution_count":28,"outputs":[]},{"cell_type":"code","source":"peft_model = get_peft_model(model, peft_config)","metadata":{"execution":{"iopub.status.busy":"2024-08-08T14:53:13.185198Z","iopub.execute_input":"2024-08-08T14:53:13.185476Z","iopub.status.idle":"2024-08-08T14:53:13.891050Z","shell.execute_reply.started":"2024-08-08T14:53:13.185454Z","shell.execute_reply":"2024-08-08T14:53:13.890063Z"},"trusted":true},"execution_count":29,"outputs":[]},{"cell_type":"code","source":"from transformers import TrainingArguments\n \nargs = TrainingArguments(\n output_dir=f'finetuned-{model_id}-for-mental-health-3', # directory to save and repository id\n num_train_epochs=1, # number of training epochs\n per_device_train_batch_size=3, # batch size per device during training\n gradient_accumulation_steps=2, # number of steps before performing a backward/update pass\n gradient_checkpointing=True, # use gradient checkpointing to save memory\n optim=\"adamw_torch_fused\", # use fused adamw optimizer\n logging_steps=10, # log every 10 steps\n save_strategy=\"epoch\", # save checkpoint every epoch\n learning_rate=2e-4, # learning rate, based on QLoRA paper\n fp16=True, # use tf32 precision\n max_grad_norm=0.3, # max gradient norm based on QLoRA paper\n warmup_ratio=0.03, # warmup ratio based on QLoRA paper\n lr_scheduler_type=\"constant\", # use constant learning rate scheduler\n push_to_hub=True, # push model to hub\n report_to=\"wandb\", # report metrics to tensorboard\n)","metadata":{"execution":{"iopub.status.busy":"2024-08-08T14:53:13.892315Z","iopub.execute_input":"2024-08-08T14:53:13.892669Z","iopub.status.idle":"2024-08-08T14:53:13.925328Z","shell.execute_reply.started":"2024-08-08T14:53:13.892638Z","shell.execute_reply":"2024-08-08T14:53:13.924280Z"},"trusted":true},"execution_count":30,"outputs":[]},{"cell_type":"code","source":"from trl import SFTTrainer\n# max sequence length for model and packing of the dataset\n \ntrainer = SFTTrainer(\n model=model,\n args=args,\n train_dataset=train_dataset,\n peft_config=peft_config,\n tokenizer=tokenizer,\n)","metadata":{"execution":{"iopub.status.busy":"2024-08-08T14:53:13.926621Z","iopub.execute_input":"2024-08-08T14:53:13.927007Z","iopub.status.idle":"2024-08-08T14:53:16.875456Z","shell.execute_reply.started":"2024-08-08T14:53:13.926974Z","shell.execute_reply":"2024-08-08T14:53:16.874675Z"},"trusted":true},"execution_count":31,"outputs":[{"name":"stderr","text":"/opt/conda/lib/python3.10/site-packages/trl/trainer/sft_trainer.py:289: UserWarning: You didn't pass a `max_seq_length` argument to the SFTTrainer, this will default to 1024\n warnings.warn(\n/opt/conda/lib/python3.10/site-packages/trl/trainer/sft_trainer.py:366: UserWarning: You passed a `dataset_kwargs` argument to the SFTTrainer, the value you passed will override the one in the `SFTConfig`.\n warnings.warn(\n","output_type":"stream"},{"output_type":"display_data","data":{"text/plain":"Map: 0%| | 0/2089 [00:00<?, ? examples/s]","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":"354ced84f44f46bf9e52c89ba53c4d95"}},"metadata":{}}]},{"cell_type":"code","source":"peft_model.config.use_cache = False\ntrainer.train()","metadata":{"execution":{"iopub.status.busy":"2024-08-08T14:53:16.876585Z","iopub.execute_input":"2024-08-08T14:53:16.876949Z"},"trusted":true},"execution_count":null,"outputs":[{"name":"stderr","text":"\u001b[34m\u001b[1mwandb\u001b[0m: \u001b[33mWARNING\u001b[0m The `run_name` is currently set to the same value as `TrainingArguments.output_dir`. If this was not intended, please specify a different run name by setting the `TrainingArguments.run_name` parameter.\n\u001b[34m\u001b[1mwandb\u001b[0m: Currently logged in as: \u001b[33mabisoye\u001b[0m. Use \u001b[1m`wandb login --relogin`\u001b[0m to force relogin\n","output_type":"stream"},{"output_type":"display_data","data":{"text/plain":"<IPython.core.display.HTML object>","text/html":"wandb version 0.17.6 is available! To upgrade, please run:\n $ pip install wandb --upgrade"},"metadata":{}},{"output_type":"display_data","data":{"text/plain":"<IPython.core.display.HTML object>","text/html":"Tracking run with wandb version 0.17.4"},"metadata":{}},{"output_type":"display_data","data":{"text/plain":"<IPython.core.display.HTML object>","text/html":"Run data is saved locally in <code>/kaggle/working/wandb/run-20240808_145317-mk243f2l</code>"},"metadata":{}},{"output_type":"display_data","data":{"text/plain":"<IPython.core.display.HTML object>","text/html":"Syncing run <strong><a href='https://wandb.ai/abisoye/huggingface/runs/mk243f2l' target=\"_blank\">finetuned-microsoft/phi-2-for-mental-health-3</a></strong> to <a href='https://wandb.ai/abisoye/huggingface' target=\"_blank\">Weights & Biases</a> (<a href='https://wandb.me/run' target=\"_blank\">docs</a>)<br/>"},"metadata":{}},{"output_type":"display_data","data":{"text/plain":"<IPython.core.display.HTML object>","text/html":" View project at <a href='https://wandb.ai/abisoye/huggingface' target=\"_blank\">https://wandb.ai/abisoye/huggingface</a>"},"metadata":{}},{"output_type":"display_data","data":{"text/plain":"<IPython.core.display.HTML object>","text/html":" View run at <a href='https://wandb.ai/abisoye/huggingface/runs/mk243f2l' target=\"_blank\">https://wandb.ai/abisoye/huggingface/runs/mk243f2l</a>"},"metadata":{}},{"name":"stderr","text":"/opt/conda/lib/python3.10/site-packages/torch/utils/checkpoint.py:429: UserWarning: torch.utils.checkpoint: please pass in use_reentrant=True or use_reentrant=False explicitly. The default value of use_reentrant will be updated to be False in the future. To maintain current behavior, pass use_reentrant=True. It is recommended that you use use_reentrant=False. Refer to docs for more details on the differences between the two variants.\n warnings.warn(\n","output_type":"stream"},{"output_type":"display_data","data":{"text/plain":"<IPython.core.display.HTML object>","text/html":"\n <div>\n \n <progress value='10' max='348' style='width:300px; height:20px; vertical-align: middle;'></progress>\n [ 10/348 01:09 < 48:38, 0.12 it/s, Epoch 0.03/1]\n </div>\n <table border=\"1\" class=\"dataframe\">\n <thead>\n <tr style=\"text-align: left;\">\n <th>Step</th>\n <th>Training Loss</th>\n </tr>\n </thead>\n <tbody>\n </tbody>\n</table><p>"},"metadata":{}}]},{"cell_type":"code","source":"trainer.push_to_hub()","metadata":{"trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"code","source":"# free the memory again\ndel model\ndel trainer\ndel peft_model\ntorch.cuda.empty_cache()","metadata":{"trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"code","source":" \n#### COMMENT IN TO MERGE PEFT AND BASE MODEL ####\nfrom peft import AutoPeftModelForCausalLM\n \n# Load PEFT model on CPU\nmodel = AutoPeftModelForCausalLM.from_pretrained(\n args.output_dir,\n torch_dtype=torch.float16,\n low_cpu_mem_usage=True,\n)\n# Merge LoRA and base model and save\nmerged_model = model.merge_and_unload()\nmerged_model.save_pretrained(args.output_dir,safe_serialization=True, max_shard_size=\"2GB\")","metadata":{"trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"code","source":"merged_model.push_to_hub(f'peft-finetuned-{model_id}-for-mental-health-3')","metadata":{},"execution_count":null,"outputs":[]},{"cell_type":"code","source":"","metadata":{},"execution_count":null,"outputs":[]}]}
utils.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import whisper
2
+ from gtts import gTTS
3
+ from pydub import AudioSegment
4
+
5
+
6
+ #text to sppech function
7
+ def text_to_speech(text):
8
+ # Convert text to speech with a US accent using gTTS
9
+ tts = gTTS(text=text, lang='en', tld='us', slow=False)
10
+ tts.save('temp.mp3')
11
+
12
+ # Load the audio file
13
+ audio = AudioSegment.from_file('temp.mp3')
14
+
15
+ # Adjust the speed to approximately 170 wpm
16
+ playback_speed = 1.20
17
+ audio = audio.speedup(playback_speed=playback_speed)
18
+
19
+ # Save and return the adjusted audio file
20
+ final_filename = 'text_to_speech.mp3'
21
+ audio.export(final_filename, format='mp3')
22
+
23
+ return final_filename
24
+
25
+ #speech to text function
26
+ def audio_to_text(audio):
27
+
28
+ model = whisper.load_model("base.en")
29
+ result = model.transcribe(audio)
30
+ return result["text"]