File size: 4,015 Bytes
a011ccd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "provenance": []
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "UXKT8SDQQ1tI"
      },
      "outputs": [],
      "source": [
        "%%capture\n",
        "import torch\n",
        "import re\n",
        "from pprint import pprint\n",
        "major_version, minor_version = torch.cuda.get_device_capability()\n",
        "if major_version >= 8:\n",
        "    # Use this for new GPUs like Ampere, Hopper GPUs (RTX 30xx, RTX 40xx, A100, H100, L40)\n",
        "    !pip install \"unsloth[colab-ampere] @ git+https://github.com/unslothai/unsloth.git\"\n",
        "else:\n",
        "    # Use this for older GPUs (V100, Tesla T4, RTX 20xx)\n",
        "    !pip install \"unsloth[colab] @ git+https://github.com/unslothai/unsloth.git\"\n",
        "pass"
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "from unsloth import FastLanguageModel\n",
        "import torch\n",
        "max_seq_length = 2048\n",
        "# Choose any! We auto support RoPE Scaling internally!\n",
        "dtype = None # None for auto detection. Float16 for Tesla T4, V100, Bfloat16 for Ampere+\n",
        "load_in_4bit = True"
      ],
      "metadata": {
        "id": "Q6gVomWzQ7hU"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "model, tokenizer = FastLanguageModel.from_pretrained(\n",
        "    model_name = \"neuralwebtech/mental_health_counseling_gemma_7b_4bit_q\", # YOUR MODEL YOU USED FOR TRAINING\n",
        "    max_seq_length = max_seq_length,\n",
        "    dtype = dtype,\n",
        "    load_in_4bit = load_in_4bit,\n",
        ")\n",
        "FastLanguageModel.for_inference(model) # Enable native 2x faster inference\n",
        "\n",
        "alpaca_prompt = \"\"\"Below is an instruction that describes a task, paired with an input that provides further context.\n",
        " Write a response that appropriately completes the request.\n",
        "\n",
        "### Context:\n",
        "{}\n",
        "\n",
        "### Response:\n",
        "{}\"\"\""
      ],
      "metadata": {
        "id": "_ItV-FhgRC5t"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "inputs = tokenizer(\n",
        "[\n",
        "    alpaca_prompt.format(\n",
        "        text, # instruction\n",
        "        \"\", # output - leave this blank for generation!\n",
        "    )\n",
        "], return_tensors = \"pt\").to(\"cuda\")\n",
        "\n",
        "outputs = model.generate(**inputs, max_new_tokens = 128, use_cache = True)\n",
        "final_out=tokenizer.batch_decode(outputs)\n"
      ],
      "metadata": {
        "id": "8eTx88KiRDiL"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "def print_response(lines):\n",
        "    text = '\\n'.join(lines)\n",
        "    response_match = re.search(r'### Response:\\s*(.*)', text)\n",
        "    if response_match:\n",
        "        response = response_match.group(1)\n",
        "        return response\n",
        "    else:\n",
        "        return \"No response\""
      ],
      "metadata": {
        "id": "z5s-5_0MRHPt"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "pprint(print_response(final_out))"
      ],
      "metadata": {
        "id": "_DlE2xjBRHUk"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [],
      "metadata": {
        "id": "xHwuwJ-6RHck"
      },
      "execution_count": null,
      "outputs": []
    }
  ]
}