{ "nbformat": 4, "nbformat_minor": 0, "metadata": { "colab": { "provenance": [], "collapsed_sections": [ "nMPIMsBXnowt", "zWz1-JCKnudh", "gFUsQMWP87EE", "6TbYU2UKn0DJ", "qyMS8mQnn2Dx" ] }, "kernelspec": { "name": "python3", "display_name": "Python 3" }, "language_info": { "name": "python" } }, "cells": [ { "cell_type": "markdown", "source": [ "### Data Preparation" ], "metadata": { "id": "nMPIMsBXnowt" } }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "nrcgcY0HWd3u", "outputId": "998fc695-b11e-4648-ac5f-d2b73f88e306" }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Collecting opendatasets\n", " Downloading opendatasets-0.1.22-py3-none-any.whl.metadata (9.2 kB)\n", "Requirement already satisfied: tqdm in /usr/local/lib/python3.10/dist-packages (from opendatasets) (4.66.5)\n", "Requirement already satisfied: kaggle in /usr/local/lib/python3.10/dist-packages (from opendatasets) (1.6.17)\n", "Requirement already satisfied: click in /usr/local/lib/python3.10/dist-packages (from opendatasets) (8.1.7)\n", "Requirement already satisfied: six>=1.10 in /usr/local/lib/python3.10/dist-packages (from kaggle->opendatasets) (1.16.0)\n", "Requirement already satisfied: certifi>=2023.7.22 in /usr/local/lib/python3.10/dist-packages (from kaggle->opendatasets) (2024.7.4)\n", "Requirement already satisfied: python-dateutil in /usr/local/lib/python3.10/dist-packages (from kaggle->opendatasets) (2.8.2)\n", "Requirement already satisfied: requests in /usr/local/lib/python3.10/dist-packages (from kaggle->opendatasets) (2.32.3)\n", "Requirement already satisfied: python-slugify in /usr/local/lib/python3.10/dist-packages (from kaggle->opendatasets) (8.0.4)\n", "Requirement already satisfied: urllib3 in /usr/local/lib/python3.10/dist-packages (from kaggle->opendatasets) (2.0.7)\n", "Requirement already satisfied: bleach in /usr/local/lib/python3.10/dist-packages (from kaggle->opendatasets) (6.1.0)\n", "Requirement already satisfied: webencodings in /usr/local/lib/python3.10/dist-packages (from bleach->kaggle->opendatasets) (0.5.1)\n", "Requirement already satisfied: text-unidecode>=1.3 in /usr/local/lib/python3.10/dist-packages (from python-slugify->kaggle->opendatasets) (1.3)\n", "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests->kaggle->opendatasets) (3.3.2)\n", "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests->kaggle->opendatasets) (3.7)\n", "Downloading opendatasets-0.1.22-py3-none-any.whl (15 kB)\n", "Installing collected packages: opendatasets\n", "Successfully installed opendatasets-0.1.22\n" ] } ], "source": [ "!pip install opendatasets" ] }, { "cell_type": "code", "source": [ "import opendatasets as od\n", "od.download(\"https://www.kaggle.com/datasets/prasad22/healthcare-dataset\")" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "I6P_5cbGWnYv", "outputId": "fbdbc44f-1ab0-49be-a17c-7bfe0aa77c12" }, "execution_count": null, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Dataset URL: https://www.kaggle.com/datasets/prasad22/healthcare-dataset\n", "Downloading healthcare-dataset.zip to ./healthcare-dataset\n" ] }, { "output_type": "stream", "name": "stderr", "text": [ "100%|██████████| 2.91M/2.91M [00:00<00:00, 46.8MB/s]" ] }, { "output_type": "stream", "name": "stdout", "text": [ "\n" ] }, { "output_type": "stream", "name": "stderr", "text": [ "\n" ] } ] }, { "cell_type": "code", "source": [ "import pandas as pd\n", "df = pd.read_csv(\"/content/healthcare_dataset.csv\")\n", "df = df[['Age','Gender','Blood Type','Medical Condition','Test Results','Medication']]" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 423 }, "collapsed": true, "id": "q-CvLDIMWrs5", "outputId": "303ca4ec-f55a-4ffc-9466-c7b241521a4c" }, "execution_count": null, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ " Age Gender Blood Type Medical Condition Test Results Medication\n", "0 30 Male B- Cancer Normal Paracetamol\n", "1 62 Male A+ Obesity Inconclusive Ibuprofen\n", "2 76 Female A- Obesity Normal Aspirin\n", "3 28 Female O+ Diabetes Abnormal Ibuprofen\n", "4 43 Female AB+ Cancer Abnormal Penicillin\n", "... ... ... ... ... ... ...\n", "55495 42 Female O+ Asthma Abnormal Penicillin\n", "55496 61 Female AB- Obesity Normal Aspirin\n", "55497 38 Female B+ Hypertension Abnormal Ibuprofen\n", "55498 43 Male O- Arthritis Abnormal Ibuprofen\n", "55499 53 Female O+ Arthritis Abnormal Ibuprofen\n", "\n", "[55500 rows x 6 columns]" ], "text/html": [ "\n", "
\n", "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
AgeGenderBlood TypeMedical ConditionTest ResultsMedication
030MaleB-CancerNormalParacetamol
162MaleA+ObesityInconclusiveIbuprofen
276FemaleA-ObesityNormalAspirin
328FemaleO+DiabetesAbnormalIbuprofen
443FemaleAB+CancerAbnormalPenicillin
.....................
5549542FemaleO+AsthmaAbnormalPenicillin
5549661FemaleAB-ObesityNormalAspirin
5549738FemaleB+HypertensionAbnormalIbuprofen
5549843MaleO-ArthritisAbnormalIbuprofen
5549953FemaleO+ArthritisAbnormalIbuprofen
\n", "

55500 rows × 6 columns

\n", "
\n", "
\n", "\n", "
\n", " \n", "\n", " \n", "\n", " \n", "
\n", "\n", "\n", "
\n", " \n", "\n", "\n", "\n", " \n", "
\n", "\n", "
\n", " \n", " \n", " \n", "
\n", "\n", "
\n", "
\n" ], "application/vnd.google.colaboratory.intrinsic+json": { "type": "dataframe", "variable_name": "df", "summary": "{\n \"name\": \"df\",\n \"rows\": 55500,\n \"fields\": [\n {\n \"column\": \"Age\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 19,\n \"min\": 13,\n \"max\": 89,\n \"num_unique_values\": 77,\n \"samples\": [\n 43,\n 22,\n 72\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"Gender\",\n \"properties\": {\n \"dtype\": \"category\",\n \"num_unique_values\": 2,\n \"samples\": [\n \"Female\",\n \"Male\"\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"Blood Type\",\n \"properties\": {\n \"dtype\": \"category\",\n \"num_unique_values\": 8,\n \"samples\": [\n \"A+\",\n \"AB-\"\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"Medical Condition\",\n \"properties\": {\n \"dtype\": \"category\",\n \"num_unique_values\": 6,\n \"samples\": [\n \"Cancer\",\n \"Obesity\"\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"Test Results\",\n \"properties\": {\n \"dtype\": \"category\",\n \"num_unique_values\": 3,\n \"samples\": [\n \"Normal\",\n \"Inconclusive\"\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"Medication\",\n \"properties\": {\n \"dtype\": \"category\",\n \"num_unique_values\": 5,\n \"samples\": [\n \"Ibuprofen\",\n \"Lipitor\"\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n }\n ]\n}" } }, "metadata": {}, "execution_count": 1 } ] }, { "cell_type": "code", "source": [ "df['Test Results'].value_counts()" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 209 }, "id": "PUq7tSYfWzTl", "outputId": "897d2af8-e1d5-4864-f88f-7c99fd14fed8" }, "execution_count": null, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "Test Results\n", "Abnormal 18627\n", "Normal 18517\n", "Inconclusive 18356\n", "Name: count, dtype: int64" ], "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
count
Test Results
Abnormal18627
Normal18517
Inconclusive18356
\n", "

" ] }, "metadata": {}, "execution_count": 6 } ] }, { "cell_type": "code", "source": [ "df['Medical Condition'].value_counts()" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 303 }, "id": "LprQu4JKXg5v", "outputId": "3d467787-7747-471b-9bc0-7151943dbef5" }, "execution_count": null, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "Medical Condition\n", "Arthritis 9308\n", "Diabetes 9304\n", "Hypertension 9245\n", "Obesity 9231\n", "Cancer 9227\n", "Asthma 9185\n", "Name: count, dtype: int64" ], "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
count
Medical Condition
Arthritis9308
Diabetes9304
Hypertension9245
Obesity9231
Cancer9227
Asthma9185
\n", "

" ] }, "metadata": {}, "execution_count": 7 } ] }, { "cell_type": "code", "source": [ "df['Blood Type'].value_counts()" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 366 }, "id": "PklqIg_QX1LK", "outputId": "928ad232-a538-44dd-8568-36c3de23886b" }, "execution_count": null, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "Blood Type\n", "A- 6969\n", "A+ 6956\n", "AB+ 6947\n", "AB- 6945\n", "B+ 6945\n", "B- 6944\n", "O+ 6917\n", "O- 6877\n", "Name: count, dtype: int64" ], "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
count
Blood Type
A-6969
A+6956
AB+6947
AB-6945
B+6945
B-6944
O+6917
O-6877
\n", "

" ] }, "metadata": {}, "execution_count": 8 } ] }, { "cell_type": "code", "source": [ "df['Medication'].value_counts()" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 272 }, "id": "QGBA7xBnX8zA", "outputId": "4bfd0a84-e68a-4553-d918-2d684dde6dc9" }, "execution_count": null, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "Medication\n", "Lipitor 11140\n", "Ibuprofen 11127\n", "Aspirin 11094\n", "Paracetamol 11071\n", "Penicillin 11068\n", "Name: count, dtype: int64" ], "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
count
Medication
Lipitor11140
Ibuprofen11127
Aspirin11094
Paracetamol11071
Penicillin11068
\n", "

" ] }, "metadata": {}, "execution_count": 9 } ] }, { "cell_type": "code", "source": [ "df['Gender'].value_counts()" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 178 }, "id": "kZV7YperYB4s", "outputId": "c0a79c87-9f71-4fd0-dbb7-542b946f4490" }, "execution_count": null, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "Gender\n", "Male 27774\n", "Female 27726\n", "Name: count, dtype: int64" ], "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
count
Gender
Male27774
Female27726
\n", "

" ] }, "metadata": {}, "execution_count": 10 } ] }, { "cell_type": "code", "source": [ "df.isnull().sum()" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 272 }, "id": "inBn2HEPYKBk", "outputId": "6fd328f2-e84d-47db-df61-3468983ce528" }, "execution_count": null, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "Age 0\n", "Gender 0\n", "Blood Type 0\n", "Medical Condition 0\n", "Test Results 0\n", "Medication 0\n", "dtype: int64" ], "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
0
Age0
Gender0
Blood Type0
Medical Condition0
Test Results0
Medication0
\n", "

" ] }, "metadata": {}, "execution_count": 11 } ] }, { "cell_type": "code", "source": [ "from sklearn.preprocessing import LabelEncoder\n", "\n", "# Encode categorical features\n", "label_encoders = {}\n", "for column in ['Gender', 'Blood Type', 'Medical Condition', 'Test Results']:\n", " le = LabelEncoder()\n", " df[column] = le.fit_transform(df[column])\n", " label_encoders[column] = le\n", "\n", "# Encode the target variable\n", "target_encoder = LabelEncoder()\n", "df['Medication'] = target_encoder.fit_transform(df['Medication'])" ], "metadata": { "id": "5EDh_scLZF_N" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "from sklearn.model_selection import train_test_split\n", "\n", "# Define features and target\n", "X = df[['Age', 'Gender', 'Blood Type', 'Medical Condition', 'Test Results']]\n", "y = df['Medication']\n", "\n", "# Split the data\n", "X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)" ], "metadata": { "id": "NRwGc4aQZMP0" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "len(X_train), len(X_test), len(y_train), len(y_test)" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "rpcJAbA_ZeN-", "outputId": "01fcf1b0-5b45-4dbb-ee95-57e9361e2f91" }, "execution_count": null, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "(44400, 11100, 44400, 11100)" ] }, "metadata": {}, "execution_count": 4 } ] }, { "cell_type": "markdown", "source": [ "### Model Training" ], "metadata": { "id": "zWz1-JCKnudh" } }, { "cell_type": "code", "source": [ "from sklearn.ensemble import RandomForestClassifier\n", "from sklearn.metrics import classification_report, accuracy_score\n", "\n", "# Initialize and train the model\n", "model = RandomForestClassifier(n_estimators=100, random_state=42)\n", "model.fit(X_train, y_train)\n", "\n", "# Make predictions\n", "y_pred = model.predict(X_test)\n", "\n", "# Evaluate the model\n", "print(f\"Accuracy: {accuracy_score(y_test, y_pred)}\")\n", "print(classification_report(y_test, y_pred, target_names=target_encoder.classes_))\n" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "aOoTscBpZYVF", "outputId": "b4b4b35d-1e42-4457-ddbb-2ea03f0183c8" }, "execution_count": null, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Accuracy: 0.2036036036036036\n", " precision recall f1-score support\n", "\n", " Aspirin 0.20 0.20 0.20 2211\n", " Ibuprofen 0.21 0.20 0.21 2271\n", " Lipitor 0.21 0.21 0.21 2224\n", " Paracetamol 0.21 0.21 0.21 2207\n", " Penicillin 0.19 0.19 0.19 2187\n", "\n", " accuracy 0.20 11100\n", " macro avg 0.20 0.20 0.20 11100\n", "weighted avg 0.20 0.20 0.20 11100\n", "\n" ] } ] }, { "cell_type": "code", "source": [ "from sklearn.preprocessing import StandardScaler\n", "from tensorflow.keras.utils import to_categorical\n", "\n", "# Normalize numerical features\n", "scaler = StandardScaler()\n", "X_scaled = scaler.fit_transform(X[['Age']])\n", "X_scaled = pd.DataFrame(X_scaled, columns=['Age'])\n", "\n", "# Concatenate scaled numerical features with encoded categorical features\n", "X_encoded = X.drop(columns=['Age'])\n", "X_final = pd.concat([X_scaled, X_encoded], axis=1)\n", "\n", "# One-hot encode the target variable\n", "y_final = to_categorical(y)\n" ], "metadata": { "id": "T_kRZhaQat3s" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "from sklearn.neighbors import KNeighborsClassifier\n", "from sklearn.metrics import classification_report, accuracy_score\n", "\n", "# Initialize the KNN model\n", "knn = KNeighborsClassifier(n_neighbors=5) # You can adjust n_neighbors for better performance\n", "\n", "# Train the model\n", "knn.fit(X_train, y_train)\n", "\n", "# Predict on the test set\n", "y_pred = knn.predict(X_test)\n", "\n", "# Evaluate the model\n", "accuracy = accuracy_score(y_test, y_pred)\n", "print(f\"Test Accuracy: {accuracy}\")\n", "print(classification_report(y_test, y_pred, target_names=target_encoder.classes_))\n" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "6jp6Gcqth5cB", "outputId": "b34b1aba-d90b-4b5a-f0bd-a51a8a0c6015" }, "execution_count": null, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Test Accuracy: 0.2018018018018018\n", " precision recall f1-score support\n", "\n", " Aspirin 0.19 0.29 0.23 2211\n", " Ibuprofen 0.21 0.23 0.22 2271\n", " Lipitor 0.21 0.20 0.20 2224\n", " Paracetamol 0.20 0.16 0.18 2207\n", " Penicillin 0.20 0.13 0.16 2187\n", "\n", " accuracy 0.20 11100\n", " macro avg 0.20 0.20 0.20 11100\n", "weighted avg 0.20 0.20 0.20 11100\n", "\n" ] } ] }, { "cell_type": "markdown", "source": [ "### FINAL" ], "metadata": { "id": "gFUsQMWP87EE" } }, { "cell_type": "code", "source": [ "\n", "import pandas as pd\n", "from sklearn.preprocessing import LabelEncoder, StandardScaler\n", "from sklearn.neighbors import KNeighborsClassifier\n", "from sklearn.model_selection import train_test_split\n", "from sklearn.metrics import accuracy_score, classification_report\n", "import joblib\n", "\n", "# Load the dataset\n", "data = pd.read_csv('/content/healthcare_dataset.csv')\n", "\n", "# If 'Medication' column is numeric, manually map them to their names\n", "medication_mapping = {\n", " 0: 'Aspirin',\n", " 1: 'Ibuprofen',\n", " 2: 'Lipitor',\n", " 3: 'Paracetamol',\n", " 4: 'Penicillin'\n", "}\n", "\n", "# Encode categorical features\n", "label_encoders = {}\n", "for column in ['Gender', 'Blood Type', 'Medical Condition', 'Test Results']:\n", " le = LabelEncoder()\n", " data[column] = le.fit_transform(data[column])\n", " label_encoders[column] = le\n", "\n", "# Encode the target variable 'Medication'\n", "medication_encoder = LabelEncoder()\n", "data['Medication'] = medication_encoder.fit_transform(data['Medication'])\n", "\n", "# Define features and target\n", "X = data[['Age', 'Gender', 'Blood Type', 'Medical Condition', 'Test Results']]\n", "y = data['Medication']\n", "\n", "# Split the dataset into training and testing sets\n", "X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)\n", "\n", "# Normalize ONLY the 'Age' column\n", "age_scaler = StandardScaler()\n", "X_train['Age'] = age_scaler.fit_transform(X_train[['Age']])\n", "X_test['Age'] = age_scaler.transform(X_test[['Age']])" ], "metadata": { "id": "dMaqw6Ao7iJC" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "# Initialize and train the KNN model\n", "knn = KNeighborsClassifier(n_neighbors=5)\n", "knn.fit(X_train, y_train)\n", "\n", "# Evaluate the model on the test set\n", "y_pred = knn.predict(X_test)\n", "accuracy = accuracy_score(y_test, y_pred)\n", "print(f\"Test Accuracy: {accuracy}\")\n", "\n", "# Print the classification report\n", "print(\"Classification Report:\")\n", "print(classification_report(y_test, y_pred, target_names=medication_encoder.classes_))\n" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "ux4L1tsX9CS2", "outputId": "52e4b74c-22ec-4934-f80d-f5e37b893326" }, "execution_count": null, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Test Accuracy: 0.20306306306306307\n", "Classification Report:\n", " precision recall f1-score support\n", "\n", " Aspirin 0.20 0.29 0.24 2211\n", " Ibuprofen 0.21 0.23 0.22 2271\n", " Lipitor 0.22 0.21 0.21 2224\n", " Paracetamol 0.20 0.16 0.17 2207\n", " Penicillin 0.18 0.13 0.15 2187\n", "\n", " accuracy 0.20 11100\n", " macro avg 0.20 0.20 0.20 11100\n", "weighted avg 0.20 0.20 0.20 11100\n", "\n" ] } ] }, { "cell_type": "markdown", "source": [ "### Testing" ], "metadata": { "id": "6TbYU2UKn0DJ" } }, { "cell_type": "code", "source": [ "# Example new data for prediction\n", "new_data = pd.DataFrame({\n", " 'Age': [62],\n", " 'Gender': ['Male'],\n", " 'Blood Type': ['A+'],\n", " 'Medical Condition': ['Obesity'],\n", " 'Test Results': ['Normal']\n", "})\n", "\n", "# Encode the new data using the same label encoders\n", "for column in ['Gender', 'Blood Type', 'Medical Condition', 'Test Results']:\n", " new_data[column] = label_encoders[column].transform(new_data[column])\n", "\n", "# Normalize the 'Age' column in the new data\n", "new_data['Age'] = age_scaler.transform(new_data[['Age']])\n", "\n", "# Make predictions\n", "predictions = knn.predict(new_data)\n", "\n", "# Decode the predictions back to the original medication names\n", "predicted_medications = medication_encoder.inverse_transform(predictions)\n", "\n", "print(f\"Predicted Medication: {predicted_medications[0]}\")\n" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "ubmJkLPj9ELT", "outputId": "aff25ba4-1459-47a1-e813-257a0faad04a" }, "execution_count": null, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Predicted Medication: Ibuprofen\n" ] } ] }, { "cell_type": "markdown", "source": [ "### Saving" ], "metadata": { "id": "qyMS8mQnn2Dx" } }, { "cell_type": "code", "source": [ "# Save the trained model, label encoders, and age scaler\n", "joblib.dump(knn, 'knn_model.pkl')\n", "joblib.dump(label_encoders, 'label_encoders.pkl')\n", "joblib.dump(age_scaler, 'age_scaler.pkl')\n", "joblib.dump(medication_encoder, 'medication_encoder.pkl')\n", "\n", "print(\"Model and encoders saved successfully.\")\n" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "WOitiTRa9Gxa", "outputId": "61bbdb60-b67f-4719-e0be-bf78b88df92b" }, "execution_count": null, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Model and encoders saved successfully.\n" ] } ] } ] }