{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"provenance": [],
"collapsed_sections": [
"nMPIMsBXnowt",
"zWz1-JCKnudh",
"gFUsQMWP87EE",
"6TbYU2UKn0DJ",
"qyMS8mQnn2Dx"
]
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
}
},
"cells": [
{
"cell_type": "markdown",
"source": [
"### Data Preparation"
],
"metadata": {
"id": "nMPIMsBXnowt"
}
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "nrcgcY0HWd3u",
"outputId": "998fc695-b11e-4648-ac5f-d2b73f88e306"
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Collecting opendatasets\n",
" Downloading opendatasets-0.1.22-py3-none-any.whl.metadata (9.2 kB)\n",
"Requirement already satisfied: tqdm in /usr/local/lib/python3.10/dist-packages (from opendatasets) (4.66.5)\n",
"Requirement already satisfied: kaggle in /usr/local/lib/python3.10/dist-packages (from opendatasets) (1.6.17)\n",
"Requirement already satisfied: click in /usr/local/lib/python3.10/dist-packages (from opendatasets) (8.1.7)\n",
"Requirement already satisfied: six>=1.10 in /usr/local/lib/python3.10/dist-packages (from kaggle->opendatasets) (1.16.0)\n",
"Requirement already satisfied: certifi>=2023.7.22 in /usr/local/lib/python3.10/dist-packages (from kaggle->opendatasets) (2024.7.4)\n",
"Requirement already satisfied: python-dateutil in /usr/local/lib/python3.10/dist-packages (from kaggle->opendatasets) (2.8.2)\n",
"Requirement already satisfied: requests in /usr/local/lib/python3.10/dist-packages (from kaggle->opendatasets) (2.32.3)\n",
"Requirement already satisfied: python-slugify in /usr/local/lib/python3.10/dist-packages (from kaggle->opendatasets) (8.0.4)\n",
"Requirement already satisfied: urllib3 in /usr/local/lib/python3.10/dist-packages (from kaggle->opendatasets) (2.0.7)\n",
"Requirement already satisfied: bleach in /usr/local/lib/python3.10/dist-packages (from kaggle->opendatasets) (6.1.0)\n",
"Requirement already satisfied: webencodings in /usr/local/lib/python3.10/dist-packages (from bleach->kaggle->opendatasets) (0.5.1)\n",
"Requirement already satisfied: text-unidecode>=1.3 in /usr/local/lib/python3.10/dist-packages (from python-slugify->kaggle->opendatasets) (1.3)\n",
"Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests->kaggle->opendatasets) (3.3.2)\n",
"Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests->kaggle->opendatasets) (3.7)\n",
"Downloading opendatasets-0.1.22-py3-none-any.whl (15 kB)\n",
"Installing collected packages: opendatasets\n",
"Successfully installed opendatasets-0.1.22\n"
]
}
],
"source": [
"!pip install opendatasets"
]
},
{
"cell_type": "code",
"source": [
"import opendatasets as od\n",
"od.download(\"https://www.kaggle.com/datasets/prasad22/healthcare-dataset\")"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "I6P_5cbGWnYv",
"outputId": "fbdbc44f-1ab0-49be-a17c-7bfe0aa77c12"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Dataset URL: https://www.kaggle.com/datasets/prasad22/healthcare-dataset\n",
"Downloading healthcare-dataset.zip to ./healthcare-dataset\n"
]
},
{
"output_type": "stream",
"name": "stderr",
"text": [
"100%|██████████| 2.91M/2.91M [00:00<00:00, 46.8MB/s]"
]
},
{
"output_type": "stream",
"name": "stdout",
"text": [
"\n"
]
},
{
"output_type": "stream",
"name": "stderr",
"text": [
"\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"import pandas as pd\n",
"df = pd.read_csv(\"/content/healthcare_dataset.csv\")\n",
"df = df[['Age','Gender','Blood Type','Medical Condition','Test Results','Medication']]"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 423
},
"collapsed": true,
"id": "q-CvLDIMWrs5",
"outputId": "303ca4ec-f55a-4ffc-9466-c7b241521a4c"
},
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
" Age Gender Blood Type Medical Condition Test Results Medication\n",
"0 30 Male B- Cancer Normal Paracetamol\n",
"1 62 Male A+ Obesity Inconclusive Ibuprofen\n",
"2 76 Female A- Obesity Normal Aspirin\n",
"3 28 Female O+ Diabetes Abnormal Ibuprofen\n",
"4 43 Female AB+ Cancer Abnormal Penicillin\n",
"... ... ... ... ... ... ...\n",
"55495 42 Female O+ Asthma Abnormal Penicillin\n",
"55496 61 Female AB- Obesity Normal Aspirin\n",
"55497 38 Female B+ Hypertension Abnormal Ibuprofen\n",
"55498 43 Male O- Arthritis Abnormal Ibuprofen\n",
"55499 53 Female O+ Arthritis Abnormal Ibuprofen\n",
"\n",
"[55500 rows x 6 columns]"
],
"text/html": [
"\n",
"
\n",
"
\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" Age | \n",
" Gender | \n",
" Blood Type | \n",
" Medical Condition | \n",
" Test Results | \n",
" Medication | \n",
"
\n",
" \n",
" \n",
" \n",
" 0 | \n",
" 30 | \n",
" Male | \n",
" B- | \n",
" Cancer | \n",
" Normal | \n",
" Paracetamol | \n",
"
\n",
" \n",
" 1 | \n",
" 62 | \n",
" Male | \n",
" A+ | \n",
" Obesity | \n",
" Inconclusive | \n",
" Ibuprofen | \n",
"
\n",
" \n",
" 2 | \n",
" 76 | \n",
" Female | \n",
" A- | \n",
" Obesity | \n",
" Normal | \n",
" Aspirin | \n",
"
\n",
" \n",
" 3 | \n",
" 28 | \n",
" Female | \n",
" O+ | \n",
" Diabetes | \n",
" Abnormal | \n",
" Ibuprofen | \n",
"
\n",
" \n",
" 4 | \n",
" 43 | \n",
" Female | \n",
" AB+ | \n",
" Cancer | \n",
" Abnormal | \n",
" Penicillin | \n",
"
\n",
" \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
"
\n",
" \n",
" 55495 | \n",
" 42 | \n",
" Female | \n",
" O+ | \n",
" Asthma | \n",
" Abnormal | \n",
" Penicillin | \n",
"
\n",
" \n",
" 55496 | \n",
" 61 | \n",
" Female | \n",
" AB- | \n",
" Obesity | \n",
" Normal | \n",
" Aspirin | \n",
"
\n",
" \n",
" 55497 | \n",
" 38 | \n",
" Female | \n",
" B+ | \n",
" Hypertension | \n",
" Abnormal | \n",
" Ibuprofen | \n",
"
\n",
" \n",
" 55498 | \n",
" 43 | \n",
" Male | \n",
" O- | \n",
" Arthritis | \n",
" Abnormal | \n",
" Ibuprofen | \n",
"
\n",
" \n",
" 55499 | \n",
" 53 | \n",
" Female | \n",
" O+ | \n",
" Arthritis | \n",
" Abnormal | \n",
" Ibuprofen | \n",
"
\n",
" \n",
"
\n",
"
55500 rows × 6 columns
\n",
"
\n",
"
\n",
"
\n"
],
"application/vnd.google.colaboratory.intrinsic+json": {
"type": "dataframe",
"variable_name": "df",
"summary": "{\n \"name\": \"df\",\n \"rows\": 55500,\n \"fields\": [\n {\n \"column\": \"Age\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 19,\n \"min\": 13,\n \"max\": 89,\n \"num_unique_values\": 77,\n \"samples\": [\n 43,\n 22,\n 72\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"Gender\",\n \"properties\": {\n \"dtype\": \"category\",\n \"num_unique_values\": 2,\n \"samples\": [\n \"Female\",\n \"Male\"\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"Blood Type\",\n \"properties\": {\n \"dtype\": \"category\",\n \"num_unique_values\": 8,\n \"samples\": [\n \"A+\",\n \"AB-\"\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"Medical Condition\",\n \"properties\": {\n \"dtype\": \"category\",\n \"num_unique_values\": 6,\n \"samples\": [\n \"Cancer\",\n \"Obesity\"\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"Test Results\",\n \"properties\": {\n \"dtype\": \"category\",\n \"num_unique_values\": 3,\n \"samples\": [\n \"Normal\",\n \"Inconclusive\"\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"Medication\",\n \"properties\": {\n \"dtype\": \"category\",\n \"num_unique_values\": 5,\n \"samples\": [\n \"Ibuprofen\",\n \"Lipitor\"\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n }\n ]\n}"
}
},
"metadata": {},
"execution_count": 1
}
]
},
{
"cell_type": "code",
"source": [
"df['Test Results'].value_counts()"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 209
},
"id": "PUq7tSYfWzTl",
"outputId": "897d2af8-e1d5-4864-f88f-7c99fd14fed8"
},
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"Test Results\n",
"Abnormal 18627\n",
"Normal 18517\n",
"Inconclusive 18356\n",
"Name: count, dtype: int64"
],
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" count | \n",
"
\n",
" \n",
" Test Results | \n",
" | \n",
"
\n",
" \n",
" \n",
" \n",
" Abnormal | \n",
" 18627 | \n",
"
\n",
" \n",
" Normal | \n",
" 18517 | \n",
"
\n",
" \n",
" Inconclusive | \n",
" 18356 | \n",
"
\n",
" \n",
"
\n",
"
"
]
},
"metadata": {},
"execution_count": 6
}
]
},
{
"cell_type": "code",
"source": [
"df['Medical Condition'].value_counts()"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 303
},
"id": "LprQu4JKXg5v",
"outputId": "3d467787-7747-471b-9bc0-7151943dbef5"
},
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"Medical Condition\n",
"Arthritis 9308\n",
"Diabetes 9304\n",
"Hypertension 9245\n",
"Obesity 9231\n",
"Cancer 9227\n",
"Asthma 9185\n",
"Name: count, dtype: int64"
],
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" count | \n",
"
\n",
" \n",
" Medical Condition | \n",
" | \n",
"
\n",
" \n",
" \n",
" \n",
" Arthritis | \n",
" 9308 | \n",
"
\n",
" \n",
" Diabetes | \n",
" 9304 | \n",
"
\n",
" \n",
" Hypertension | \n",
" 9245 | \n",
"
\n",
" \n",
" Obesity | \n",
" 9231 | \n",
"
\n",
" \n",
" Cancer | \n",
" 9227 | \n",
"
\n",
" \n",
" Asthma | \n",
" 9185 | \n",
"
\n",
" \n",
"
\n",
"
"
]
},
"metadata": {},
"execution_count": 7
}
]
},
{
"cell_type": "code",
"source": [
"df['Blood Type'].value_counts()"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 366
},
"id": "PklqIg_QX1LK",
"outputId": "928ad232-a538-44dd-8568-36c3de23886b"
},
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"Blood Type\n",
"A- 6969\n",
"A+ 6956\n",
"AB+ 6947\n",
"AB- 6945\n",
"B+ 6945\n",
"B- 6944\n",
"O+ 6917\n",
"O- 6877\n",
"Name: count, dtype: int64"
],
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" count | \n",
"
\n",
" \n",
" Blood Type | \n",
" | \n",
"
\n",
" \n",
" \n",
" \n",
" A- | \n",
" 6969 | \n",
"
\n",
" \n",
" A+ | \n",
" 6956 | \n",
"
\n",
" \n",
" AB+ | \n",
" 6947 | \n",
"
\n",
" \n",
" AB- | \n",
" 6945 | \n",
"
\n",
" \n",
" B+ | \n",
" 6945 | \n",
"
\n",
" \n",
" B- | \n",
" 6944 | \n",
"
\n",
" \n",
" O+ | \n",
" 6917 | \n",
"
\n",
" \n",
" O- | \n",
" 6877 | \n",
"
\n",
" \n",
"
\n",
"
"
]
},
"metadata": {},
"execution_count": 8
}
]
},
{
"cell_type": "code",
"source": [
"df['Medication'].value_counts()"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 272
},
"id": "QGBA7xBnX8zA",
"outputId": "4bfd0a84-e68a-4553-d918-2d684dde6dc9"
},
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"Medication\n",
"Lipitor 11140\n",
"Ibuprofen 11127\n",
"Aspirin 11094\n",
"Paracetamol 11071\n",
"Penicillin 11068\n",
"Name: count, dtype: int64"
],
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" count | \n",
"
\n",
" \n",
" Medication | \n",
" | \n",
"
\n",
" \n",
" \n",
" \n",
" Lipitor | \n",
" 11140 | \n",
"
\n",
" \n",
" Ibuprofen | \n",
" 11127 | \n",
"
\n",
" \n",
" Aspirin | \n",
" 11094 | \n",
"
\n",
" \n",
" Paracetamol | \n",
" 11071 | \n",
"
\n",
" \n",
" Penicillin | \n",
" 11068 | \n",
"
\n",
" \n",
"
\n",
"
"
]
},
"metadata": {},
"execution_count": 9
}
]
},
{
"cell_type": "code",
"source": [
"df['Gender'].value_counts()"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 178
},
"id": "kZV7YperYB4s",
"outputId": "c0a79c87-9f71-4fd0-dbb7-542b946f4490"
},
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"Gender\n",
"Male 27774\n",
"Female 27726\n",
"Name: count, dtype: int64"
],
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" count | \n",
"
\n",
" \n",
" Gender | \n",
" | \n",
"
\n",
" \n",
" \n",
" \n",
" Male | \n",
" 27774 | \n",
"
\n",
" \n",
" Female | \n",
" 27726 | \n",
"
\n",
" \n",
"
\n",
"
"
]
},
"metadata": {},
"execution_count": 10
}
]
},
{
"cell_type": "code",
"source": [
"df.isnull().sum()"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 272
},
"id": "inBn2HEPYKBk",
"outputId": "6fd328f2-e84d-47db-df61-3468983ce528"
},
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"Age 0\n",
"Gender 0\n",
"Blood Type 0\n",
"Medical Condition 0\n",
"Test Results 0\n",
"Medication 0\n",
"dtype: int64"
],
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" 0 | \n",
"
\n",
" \n",
" \n",
" \n",
" Age | \n",
" 0 | \n",
"
\n",
" \n",
" Gender | \n",
" 0 | \n",
"
\n",
" \n",
" Blood Type | \n",
" 0 | \n",
"
\n",
" \n",
" Medical Condition | \n",
" 0 | \n",
"
\n",
" \n",
" Test Results | \n",
" 0 | \n",
"
\n",
" \n",
" Medication | \n",
" 0 | \n",
"
\n",
" \n",
"
\n",
"
"
]
},
"metadata": {},
"execution_count": 11
}
]
},
{
"cell_type": "code",
"source": [
"from sklearn.preprocessing import LabelEncoder\n",
"\n",
"# Encode categorical features\n",
"label_encoders = {}\n",
"for column in ['Gender', 'Blood Type', 'Medical Condition', 'Test Results']:\n",
" le = LabelEncoder()\n",
" df[column] = le.fit_transform(df[column])\n",
" label_encoders[column] = le\n",
"\n",
"# Encode the target variable\n",
"target_encoder = LabelEncoder()\n",
"df['Medication'] = target_encoder.fit_transform(df['Medication'])"
],
"metadata": {
"id": "5EDh_scLZF_N"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"from sklearn.model_selection import train_test_split\n",
"\n",
"# Define features and target\n",
"X = df[['Age', 'Gender', 'Blood Type', 'Medical Condition', 'Test Results']]\n",
"y = df['Medication']\n",
"\n",
"# Split the data\n",
"X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)"
],
"metadata": {
"id": "NRwGc4aQZMP0"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"len(X_train), len(X_test), len(y_train), len(y_test)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "rpcJAbA_ZeN-",
"outputId": "01fcf1b0-5b45-4dbb-ee95-57e9361e2f91"
},
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"(44400, 11100, 44400, 11100)"
]
},
"metadata": {},
"execution_count": 4
}
]
},
{
"cell_type": "markdown",
"source": [
"### Model Training"
],
"metadata": {
"id": "zWz1-JCKnudh"
}
},
{
"cell_type": "code",
"source": [
"from sklearn.ensemble import RandomForestClassifier\n",
"from sklearn.metrics import classification_report, accuracy_score\n",
"\n",
"# Initialize and train the model\n",
"model = RandomForestClassifier(n_estimators=100, random_state=42)\n",
"model.fit(X_train, y_train)\n",
"\n",
"# Make predictions\n",
"y_pred = model.predict(X_test)\n",
"\n",
"# Evaluate the model\n",
"print(f\"Accuracy: {accuracy_score(y_test, y_pred)}\")\n",
"print(classification_report(y_test, y_pred, target_names=target_encoder.classes_))\n"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "aOoTscBpZYVF",
"outputId": "b4b4b35d-1e42-4457-ddbb-2ea03f0183c8"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Accuracy: 0.2036036036036036\n",
" precision recall f1-score support\n",
"\n",
" Aspirin 0.20 0.20 0.20 2211\n",
" Ibuprofen 0.21 0.20 0.21 2271\n",
" Lipitor 0.21 0.21 0.21 2224\n",
" Paracetamol 0.21 0.21 0.21 2207\n",
" Penicillin 0.19 0.19 0.19 2187\n",
"\n",
" accuracy 0.20 11100\n",
" macro avg 0.20 0.20 0.20 11100\n",
"weighted avg 0.20 0.20 0.20 11100\n",
"\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"from sklearn.preprocessing import StandardScaler\n",
"from tensorflow.keras.utils import to_categorical\n",
"\n",
"# Normalize numerical features\n",
"scaler = StandardScaler()\n",
"X_scaled = scaler.fit_transform(X[['Age']])\n",
"X_scaled = pd.DataFrame(X_scaled, columns=['Age'])\n",
"\n",
"# Concatenate scaled numerical features with encoded categorical features\n",
"X_encoded = X.drop(columns=['Age'])\n",
"X_final = pd.concat([X_scaled, X_encoded], axis=1)\n",
"\n",
"# One-hot encode the target variable\n",
"y_final = to_categorical(y)\n"
],
"metadata": {
"id": "T_kRZhaQat3s"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"from sklearn.neighbors import KNeighborsClassifier\n",
"from sklearn.metrics import classification_report, accuracy_score\n",
"\n",
"# Initialize the KNN model\n",
"knn = KNeighborsClassifier(n_neighbors=5) # You can adjust n_neighbors for better performance\n",
"\n",
"# Train the model\n",
"knn.fit(X_train, y_train)\n",
"\n",
"# Predict on the test set\n",
"y_pred = knn.predict(X_test)\n",
"\n",
"# Evaluate the model\n",
"accuracy = accuracy_score(y_test, y_pred)\n",
"print(f\"Test Accuracy: {accuracy}\")\n",
"print(classification_report(y_test, y_pred, target_names=target_encoder.classes_))\n"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "6jp6Gcqth5cB",
"outputId": "b34b1aba-d90b-4b5a-f0bd-a51a8a0c6015"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Test Accuracy: 0.2018018018018018\n",
" precision recall f1-score support\n",
"\n",
" Aspirin 0.19 0.29 0.23 2211\n",
" Ibuprofen 0.21 0.23 0.22 2271\n",
" Lipitor 0.21 0.20 0.20 2224\n",
" Paracetamol 0.20 0.16 0.18 2207\n",
" Penicillin 0.20 0.13 0.16 2187\n",
"\n",
" accuracy 0.20 11100\n",
" macro avg 0.20 0.20 0.20 11100\n",
"weighted avg 0.20 0.20 0.20 11100\n",
"\n"
]
}
]
},
{
"cell_type": "markdown",
"source": [
"### FINAL"
],
"metadata": {
"id": "gFUsQMWP87EE"
}
},
{
"cell_type": "code",
"source": [
"\n",
"import pandas as pd\n",
"from sklearn.preprocessing import LabelEncoder, StandardScaler\n",
"from sklearn.neighbors import KNeighborsClassifier\n",
"from sklearn.model_selection import train_test_split\n",
"from sklearn.metrics import accuracy_score, classification_report\n",
"import joblib\n",
"\n",
"# Load the dataset\n",
"data = pd.read_csv('/content/healthcare_dataset.csv')\n",
"\n",
"# If 'Medication' column is numeric, manually map them to their names\n",
"medication_mapping = {\n",
" 0: 'Aspirin',\n",
" 1: 'Ibuprofen',\n",
" 2: 'Lipitor',\n",
" 3: 'Paracetamol',\n",
" 4: 'Penicillin'\n",
"}\n",
"\n",
"# Encode categorical features\n",
"label_encoders = {}\n",
"for column in ['Gender', 'Blood Type', 'Medical Condition', 'Test Results']:\n",
" le = LabelEncoder()\n",
" data[column] = le.fit_transform(data[column])\n",
" label_encoders[column] = le\n",
"\n",
"# Encode the target variable 'Medication'\n",
"medication_encoder = LabelEncoder()\n",
"data['Medication'] = medication_encoder.fit_transform(data['Medication'])\n",
"\n",
"# Define features and target\n",
"X = data[['Age', 'Gender', 'Blood Type', 'Medical Condition', 'Test Results']]\n",
"y = data['Medication']\n",
"\n",
"# Split the dataset into training and testing sets\n",
"X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)\n",
"\n",
"# Normalize ONLY the 'Age' column\n",
"age_scaler = StandardScaler()\n",
"X_train['Age'] = age_scaler.fit_transform(X_train[['Age']])\n",
"X_test['Age'] = age_scaler.transform(X_test[['Age']])"
],
"metadata": {
"id": "dMaqw6Ao7iJC"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# Initialize and train the KNN model\n",
"knn = KNeighborsClassifier(n_neighbors=5)\n",
"knn.fit(X_train, y_train)\n",
"\n",
"# Evaluate the model on the test set\n",
"y_pred = knn.predict(X_test)\n",
"accuracy = accuracy_score(y_test, y_pred)\n",
"print(f\"Test Accuracy: {accuracy}\")\n",
"\n",
"# Print the classification report\n",
"print(\"Classification Report:\")\n",
"print(classification_report(y_test, y_pred, target_names=medication_encoder.classes_))\n"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "ux4L1tsX9CS2",
"outputId": "52e4b74c-22ec-4934-f80d-f5e37b893326"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Test Accuracy: 0.20306306306306307\n",
"Classification Report:\n",
" precision recall f1-score support\n",
"\n",
" Aspirin 0.20 0.29 0.24 2211\n",
" Ibuprofen 0.21 0.23 0.22 2271\n",
" Lipitor 0.22 0.21 0.21 2224\n",
" Paracetamol 0.20 0.16 0.17 2207\n",
" Penicillin 0.18 0.13 0.15 2187\n",
"\n",
" accuracy 0.20 11100\n",
" macro avg 0.20 0.20 0.20 11100\n",
"weighted avg 0.20 0.20 0.20 11100\n",
"\n"
]
}
]
},
{
"cell_type": "markdown",
"source": [
"### Testing"
],
"metadata": {
"id": "6TbYU2UKn0DJ"
}
},
{
"cell_type": "code",
"source": [
"# Example new data for prediction\n",
"new_data = pd.DataFrame({\n",
" 'Age': [62],\n",
" 'Gender': ['Male'],\n",
" 'Blood Type': ['A+'],\n",
" 'Medical Condition': ['Obesity'],\n",
" 'Test Results': ['Normal']\n",
"})\n",
"\n",
"# Encode the new data using the same label encoders\n",
"for column in ['Gender', 'Blood Type', 'Medical Condition', 'Test Results']:\n",
" new_data[column] = label_encoders[column].transform(new_data[column])\n",
"\n",
"# Normalize the 'Age' column in the new data\n",
"new_data['Age'] = age_scaler.transform(new_data[['Age']])\n",
"\n",
"# Make predictions\n",
"predictions = knn.predict(new_data)\n",
"\n",
"# Decode the predictions back to the original medication names\n",
"predicted_medications = medication_encoder.inverse_transform(predictions)\n",
"\n",
"print(f\"Predicted Medication: {predicted_medications[0]}\")\n"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "ubmJkLPj9ELT",
"outputId": "aff25ba4-1459-47a1-e813-257a0faad04a"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Predicted Medication: Ibuprofen\n"
]
}
]
},
{
"cell_type": "markdown",
"source": [
"### Saving"
],
"metadata": {
"id": "qyMS8mQnn2Dx"
}
},
{
"cell_type": "code",
"source": [
"# Save the trained model, label encoders, and age scaler\n",
"joblib.dump(knn, 'knn_model.pkl')\n",
"joblib.dump(label_encoders, 'label_encoders.pkl')\n",
"joblib.dump(age_scaler, 'age_scaler.pkl')\n",
"joblib.dump(medication_encoder, 'medication_encoder.pkl')\n",
"\n",
"print(\"Model and encoders saved successfully.\")\n"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "WOitiTRa9Gxa",
"outputId": "61bbdb60-b67f-4719-e0be-bf78b88df92b"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Model and encoders saved successfully.\n"
]
}
]
}
]
}