neuralweb commited on
Commit
a011ccd
1 Parent(s): 6c68156

Upload Gemma_inference.ipynb

Browse files
Files changed (1) hide show
  1. Gemma_inference.ipynb +140 -0
Gemma_inference.ipynb ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "nbformat": 4,
3
+ "nbformat_minor": 0,
4
+ "metadata": {
5
+ "colab": {
6
+ "provenance": []
7
+ },
8
+ "kernelspec": {
9
+ "name": "python3",
10
+ "display_name": "Python 3"
11
+ },
12
+ "language_info": {
13
+ "name": "python"
14
+ }
15
+ },
16
+ "cells": [
17
+ {
18
+ "cell_type": "code",
19
+ "execution_count": null,
20
+ "metadata": {
21
+ "id": "UXKT8SDQQ1tI"
22
+ },
23
+ "outputs": [],
24
+ "source": [
25
+ "%%capture\n",
26
+ "import torch\n",
27
+ "import re\n",
28
+ "from pprint import pprint\n",
29
+ "major_version, minor_version = torch.cuda.get_device_capability()\n",
30
+ "if major_version >= 8:\n",
31
+ " # Use this for new GPUs like Ampere, Hopper GPUs (RTX 30xx, RTX 40xx, A100, H100, L40)\n",
32
+ " !pip install \"unsloth[colab-ampere] @ git+https://github.com/unslothai/unsloth.git\"\n",
33
+ "else:\n",
34
+ " # Use this for older GPUs (V100, Tesla T4, RTX 20xx)\n",
35
+ " !pip install \"unsloth[colab] @ git+https://github.com/unslothai/unsloth.git\"\n",
36
+ "pass"
37
+ ]
38
+ },
39
+ {
40
+ "cell_type": "code",
41
+ "source": [
42
+ "from unsloth import FastLanguageModel\n",
43
+ "import torch\n",
44
+ "max_seq_length = 2048\n",
45
+ "# Choose any! We auto support RoPE Scaling internally!\n",
46
+ "dtype = None # None for auto detection. Float16 for Tesla T4, V100, Bfloat16 for Ampere+\n",
47
+ "load_in_4bit = True"
48
+ ],
49
+ "metadata": {
50
+ "id": "Q6gVomWzQ7hU"
51
+ },
52
+ "execution_count": null,
53
+ "outputs": []
54
+ },
55
+ {
56
+ "cell_type": "code",
57
+ "source": [
58
+ "model, tokenizer = FastLanguageModel.from_pretrained(\n",
59
+ " model_name = \"neuralwebtech/mental_health_counseling_gemma_7b_4bit_q\", # YOUR MODEL YOU USED FOR TRAINING\n",
60
+ " max_seq_length = max_seq_length,\n",
61
+ " dtype = dtype,\n",
62
+ " load_in_4bit = load_in_4bit,\n",
63
+ ")\n",
64
+ "FastLanguageModel.for_inference(model) # Enable native 2x faster inference\n",
65
+ "\n",
66
+ "alpaca_prompt = \"\"\"Below is an instruction that describes a task, paired with an input that provides further context.\n",
67
+ " Write a response that appropriately completes the request.\n",
68
+ "\n",
69
+ "### Context:\n",
70
+ "{}\n",
71
+ "\n",
72
+ "### Response:\n",
73
+ "{}\"\"\""
74
+ ],
75
+ "metadata": {
76
+ "id": "_ItV-FhgRC5t"
77
+ },
78
+ "execution_count": null,
79
+ "outputs": []
80
+ },
81
+ {
82
+ "cell_type": "code",
83
+ "source": [
84
+ "inputs = tokenizer(\n",
85
+ "[\n",
86
+ " alpaca_prompt.format(\n",
87
+ " text, # instruction\n",
88
+ " \"\", # output - leave this blank for generation!\n",
89
+ " )\n",
90
+ "], return_tensors = \"pt\").to(\"cuda\")\n",
91
+ "\n",
92
+ "outputs = model.generate(**inputs, max_new_tokens = 128, use_cache = True)\n",
93
+ "final_out=tokenizer.batch_decode(outputs)\n"
94
+ ],
95
+ "metadata": {
96
+ "id": "8eTx88KiRDiL"
97
+ },
98
+ "execution_count": null,
99
+ "outputs": []
100
+ },
101
+ {
102
+ "cell_type": "code",
103
+ "source": [
104
+ "def print_response(lines):\n",
105
+ " text = '\\n'.join(lines)\n",
106
+ " response_match = re.search(r'### Response:\\s*(.*)', text)\n",
107
+ " if response_match:\n",
108
+ " response = response_match.group(1)\n",
109
+ " return response\n",
110
+ " else:\n",
111
+ " return \"No response\""
112
+ ],
113
+ "metadata": {
114
+ "id": "z5s-5_0MRHPt"
115
+ },
116
+ "execution_count": null,
117
+ "outputs": []
118
+ },
119
+ {
120
+ "cell_type": "code",
121
+ "source": [
122
+ "pprint(print_response(final_out))"
123
+ ],
124
+ "metadata": {
125
+ "id": "_DlE2xjBRHUk"
126
+ },
127
+ "execution_count": null,
128
+ "outputs": []
129
+ },
130
+ {
131
+ "cell_type": "code",
132
+ "source": [],
133
+ "metadata": {
134
+ "id": "xHwuwJ-6RHck"
135
+ },
136
+ "execution_count": null,
137
+ "outputs": []
138
+ }
139
+ ]
140
+ }