File size: 132,237 Bytes
54da53e |
|
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Training a new model"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**If the pretrained models do not give satisfactory performance on your data, it is easy to train your own classifier through SpikeInterface!**"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Number of cores set to: 23\n"
]
}
],
"source": [
"from pathlib import Path\n",
"import numpy as np\n",
"import pandas as pd\n",
"import matplotlib.pyplot as plt\n",
"import spikeinterface as si\n",
"import spikeinterface.extractors as se\n",
"import spikeinterface.postprocessing as spost\n",
"import spikeinterface.qualitymetrics as sqm\n",
"import os\n",
"from os import cpu_count\n",
"import json\n",
"# Set the number of CPU cores to be used globally - defaults to all cores -1\n",
"n_cores = cpu_count() -1\n",
"si.set_global_job_kwargs(n_jobs = n_cores)\n",
"print(f\"Number of cores set to: {n_cores}\")\n",
"\n",
"# SET OUTPUT FOLDER\n",
"output_folder = Path(r'E:\\spikeinterface_repository_stuff')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Load data "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## First step: Loading the recording and sorting objects"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"For the tutorial, we are using simulated data to create recording and sorting objects."
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"recording = si.generate_recording(num_channels=50, sampling_frequency=30000.,\n",
" durations=[30], set_probe=True)\n",
"# load your recoring depeding on the acquistion software you used, for example:\n",
"# recording = se.read_spikeglx(recording_path, stream_name='imec0.ap')\n",
"\n",
"labelled_sorting = si.generate_sorting(num_units=100, sampling_frequency=30000., durations=[30],\n",
" firing_rates=15, refractory_period_ms=1.5)\n",
"# load your sorting depeding on the which spike sorter you used, for example:\n",
"# sorting = se.read_kilosort(folder_path)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Step 2: Create a SortingAnalyzer\n",
"\n",
"Create a SpikeInterface SortingAnalyzer object. In this example, we use simulated data and generate random labels for the units, which serve as our target for training.\n",
"\n",
"Important: The labels generated here are random, so the model's performance will be at chance level. \n",
"\n",
"**For real applications, replace this with your own data and curated labels to achieve meaningful results.**\n",
"\n",
"To know more about sorting analyzer, please refer to : https://spikeinterface.readthedocs.io/en/latest/modules/postprocessing.html\n",
"\n",
"If you have already have WaveformExtractor from previous run, you can use it to create a SortingAnalyzer. \n",
"Please refer to: https://spikeinterface.readthedocs.io/en/latest/tutorials/waveform_extractor_to_sorting_analyzer.html"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"analyzer = si.create_sorting_analyzer(sorting = labelled_sorting, recording = recording, sparse = True)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Step 3: Compute quality metrics\n",
"Compute metrics from multiple SortingAnalyzer objects (each corresponding to a different recording).\n",
"Pass the metrics as a list to the model training function."
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
"# Defines a function to compute all analyzer properties and quality metrics\n",
"# Note: this can be a time-consuming step, especially computing PCA-based metrics for long recordings\n",
"\n",
"def compute_all_metrics(analyzer):\n",
"\n",
" # Compute required extensions for quality metrics\n",
" analyzer.compute({\n",
"\t'noise_levels': {},\n",
"\t'random_spikes': {'max_spikes_per_unit': 1_000},\n",
"\t'templates': {'ms_before': 1.5, 'ms_after': 3.5},\n",
"\t'spike_amplitudes': {},\n",
"\t'waveforms': {},\n",
"\t'principal_components': {},\n",
"\t'spike_locations': {},\n",
"\t'unit_locations': {},\n",
"\t})\n",
"\n",
" # Compute all available quality metrics\n",
" analyzer.compute(\"quality_metrics\", metric_names = sqm.get_quality_metric_list() + sqm.get_quality_pca_metric_list())\n",
" analyzer.compute(\"template_metrics\", metric_names = spost.get_template_metric_names())\n",
"\n",
"\t# Make metric dataframe\n",
" quality_metrics = analyzer.extensions['quality_metrics'].data[\"metrics\"]\n",
" template_metrics = analyzer.extensions['template_metrics'].data[\"metrics\"]\n",
" calculated_metrics = pd.concat([quality_metrics, template_metrics], axis = 1)\n",
"\n",
" return calculated_metrics"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Compute all metrics\n",
"metrics = compute_all_metrics(analyzer)\n",
"metrics.index.name = 'cluster_id'\n",
"\n",
"# save the analyzer\n",
"analyzer.save_as(folder=output_folder / 'sorting_analyzer', format=\"binary_folder\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Add labels on which your model will be trained on\n",
"\n",
"Provide a single list of labels in the same order as the metrics.\n",
"\n",
"Flexible Labeling:\n",
"The set of unique labels you use is upto your requirements. \n",
"\n",
"This approach can be used for any cluster categorization task, whether you're sorting clusters into custom categories 'true' or 'false', or using standard labels like 'good', 'mua', and 'noise'."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"labelled_analyzer = si.load_sorting_analyzer(folder= output_folder / 'sorting_analyzer', format=\"binary_folder\")"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
"label_conversion = {'noise': 0, 'good': 1}\n",
"\n",
"# These are assigned randomly here but you could load these from phy 'cluster_group.tsv', from the 'quality' property of the sorting, or similar\n",
"human_labels = np.random.choice(list(label_conversion.values()), labelled_analyzer.get_num_units())\n",
"labelled_analyzer.sorting.set_property('quality', human_labels)\n",
"\n",
"labels = [human_labels.tolist(), human_labels.tolist()]\n",
"\n",
"# Get labels from phy sorting (if loaded) using:\n",
"# human_labels = unlabelled_analyzer.sorting.get_property('quality')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Step 4: Train with Your Data\n",
"Load your data and corresponding curation labels, and use them to train the classifier. Experiment to see how well the model performs on your data!\n",
"\n",
"Note: For better generalizability, you’ll likely need multiple labeled recordings for training. \n",
"The best model is saved as 'best_model.skops' file. You can use this file to predict labels on other recorindgs by auto_label_unit() "
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"C:\\Users\\jain\\Documents\\Github_extend\\spikeinterface\\src\\spikeinterface\\curation\\train_manual_curation.py:168: UserWarning: No metric_names provided, using all metrics calculated by the analyzers\n",
" warnings.warn(\"No metric_names provided, using all metrics calculated by the analyzers\")\n"
]
},
{
"data": {
"text/html": [
"<style>#sk-container-id-1 {\n",
" /* Definition of color scheme common for light and dark mode */\n",
" --sklearn-color-text: black;\n",
" --sklearn-color-line: gray;\n",
" /* Definition of color scheme for unfitted estimators */\n",
" --sklearn-color-unfitted-level-0: #fff5e6;\n",
" --sklearn-color-unfitted-level-1: #f6e4d2;\n",
" --sklearn-color-unfitted-level-2: #ffe0b3;\n",
" --sklearn-color-unfitted-level-3: chocolate;\n",
" /* Definition of color scheme for fitted estimators */\n",
" --sklearn-color-fitted-level-0: #f0f8ff;\n",
" --sklearn-color-fitted-level-1: #d4ebff;\n",
" --sklearn-color-fitted-level-2: #b3dbfd;\n",
" --sklearn-color-fitted-level-3: cornflowerblue;\n",
"\n",
" /* Specific color for light theme */\n",
" --sklearn-color-text-on-default-background: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, black)));\n",
" --sklearn-color-background: var(--sg-background-color, var(--theme-background, var(--jp-layout-color0, white)));\n",
" --sklearn-color-border-box: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, black)));\n",
" --sklearn-color-icon: #696969;\n",
"\n",
" @media (prefers-color-scheme: dark) {\n",
" /* Redefinition of color scheme for dark theme */\n",
" --sklearn-color-text-on-default-background: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, white)));\n",
" --sklearn-color-background: var(--sg-background-color, var(--theme-background, var(--jp-layout-color0, #111)));\n",
" --sklearn-color-border-box: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, white)));\n",
" --sklearn-color-icon: #878787;\n",
" }\n",
"}\n",
"\n",
"#sk-container-id-1 {\n",
" color: var(--sklearn-color-text);\n",
"}\n",
"\n",
"#sk-container-id-1 pre {\n",
" padding: 0;\n",
"}\n",
"\n",
"#sk-container-id-1 input.sk-hidden--visually {\n",
" border: 0;\n",
" clip: rect(1px 1px 1px 1px);\n",
" clip: rect(1px, 1px, 1px, 1px);\n",
" height: 1px;\n",
" margin: -1px;\n",
" overflow: hidden;\n",
" padding: 0;\n",
" position: absolute;\n",
" width: 1px;\n",
"}\n",
"\n",
"#sk-container-id-1 div.sk-dashed-wrapped {\n",
" border: 1px dashed var(--sklearn-color-line);\n",
" margin: 0 0.4em 0.5em 0.4em;\n",
" box-sizing: border-box;\n",
" padding-bottom: 0.4em;\n",
" background-color: var(--sklearn-color-background);\n",
"}\n",
"\n",
"#sk-container-id-1 div.sk-container {\n",
" /* jupyter's `normalize.less` sets `[hidden] { display: none; }`\n",
" but bootstrap.min.css set `[hidden] { display: none !important; }`\n",
" so we also need the `!important` here to be able to override the\n",
" default hidden behavior on the sphinx rendered scikit-learn.org.\n",
" See: https://github.com/scikit-learn/scikit-learn/issues/21755 */\n",
" display: inline-block !important;\n",
" position: relative;\n",
"}\n",
"\n",
"#sk-container-id-1 div.sk-text-repr-fallback {\n",
" display: none;\n",
"}\n",
"\n",
"div.sk-parallel-item,\n",
"div.sk-serial,\n",
"div.sk-item {\n",
" /* draw centered vertical line to link estimators */\n",
" background-image: linear-gradient(var(--sklearn-color-text-on-default-background), var(--sklearn-color-text-on-default-background));\n",
" background-size: 2px 100%;\n",
" background-repeat: no-repeat;\n",
" background-position: center center;\n",
"}\n",
"\n",
"/* Parallel-specific style estimator block */\n",
"\n",
"#sk-container-id-1 div.sk-parallel-item::after {\n",
" content: \"\";\n",
" width: 100%;\n",
" border-bottom: 2px solid var(--sklearn-color-text-on-default-background);\n",
" flex-grow: 1;\n",
"}\n",
"\n",
"#sk-container-id-1 div.sk-parallel {\n",
" display: flex;\n",
" align-items: stretch;\n",
" justify-content: center;\n",
" background-color: var(--sklearn-color-background);\n",
" position: relative;\n",
"}\n",
"\n",
"#sk-container-id-1 div.sk-parallel-item {\n",
" display: flex;\n",
" flex-direction: column;\n",
"}\n",
"\n",
"#sk-container-id-1 div.sk-parallel-item:first-child::after {\n",
" align-self: flex-end;\n",
" width: 50%;\n",
"}\n",
"\n",
"#sk-container-id-1 div.sk-parallel-item:last-child::after {\n",
" align-self: flex-start;\n",
" width: 50%;\n",
"}\n",
"\n",
"#sk-container-id-1 div.sk-parallel-item:only-child::after {\n",
" width: 0;\n",
"}\n",
"\n",
"/* Serial-specific style estimator block */\n",
"\n",
"#sk-container-id-1 div.sk-serial {\n",
" display: flex;\n",
" flex-direction: column;\n",
" align-items: center;\n",
" background-color: var(--sklearn-color-background);\n",
" padding-right: 1em;\n",
" padding-left: 1em;\n",
"}\n",
"\n",
"\n",
"/* Toggleable style: style used for estimator/Pipeline/ColumnTransformer box that is\n",
"clickable and can be expanded/collapsed.\n",
"- Pipeline and ColumnTransformer use this feature and define the default style\n",
"- Estimators will overwrite some part of the style using the `sk-estimator` class\n",
"*/\n",
"\n",
"/* Pipeline and ColumnTransformer style (default) */\n",
"\n",
"#sk-container-id-1 div.sk-toggleable {\n",
" /* Default theme specific background. It is overwritten whether we have a\n",
" specific estimator or a Pipeline/ColumnTransformer */\n",
" background-color: var(--sklearn-color-background);\n",
"}\n",
"\n",
"/* Toggleable label */\n",
"#sk-container-id-1 label.sk-toggleable__label {\n",
" cursor: pointer;\n",
" display: block;\n",
" width: 100%;\n",
" margin-bottom: 0;\n",
" padding: 0.5em;\n",
" box-sizing: border-box;\n",
" text-align: center;\n",
"}\n",
"\n",
"#sk-container-id-1 label.sk-toggleable__label-arrow:before {\n",
" /* Arrow on the left of the label */\n",
" content: \"▸\";\n",
" float: left;\n",
" margin-right: 0.25em;\n",
" color: var(--sklearn-color-icon);\n",
"}\n",
"\n",
"#sk-container-id-1 label.sk-toggleable__label-arrow:hover:before {\n",
" color: var(--sklearn-color-text);\n",
"}\n",
"\n",
"/* Toggleable content - dropdown */\n",
"\n",
"#sk-container-id-1 div.sk-toggleable__content {\n",
" max-height: 0;\n",
" max-width: 0;\n",
" overflow: hidden;\n",
" text-align: left;\n",
" /* unfitted */\n",
" background-color: var(--sklearn-color-unfitted-level-0);\n",
"}\n",
"\n",
"#sk-container-id-1 div.sk-toggleable__content.fitted {\n",
" /* fitted */\n",
" background-color: var(--sklearn-color-fitted-level-0);\n",
"}\n",
"\n",
"#sk-container-id-1 div.sk-toggleable__content pre {\n",
" margin: 0.2em;\n",
" border-radius: 0.25em;\n",
" color: var(--sklearn-color-text);\n",
" /* unfitted */\n",
" background-color: var(--sklearn-color-unfitted-level-0);\n",
"}\n",
"\n",
"#sk-container-id-1 div.sk-toggleable__content.fitted pre {\n",
" /* unfitted */\n",
" background-color: var(--sklearn-color-fitted-level-0);\n",
"}\n",
"\n",
"#sk-container-id-1 input.sk-toggleable__control:checked~div.sk-toggleable__content {\n",
" /* Expand drop-down */\n",
" max-height: 200px;\n",
" max-width: 100%;\n",
" overflow: auto;\n",
"}\n",
"\n",
"#sk-container-id-1 input.sk-toggleable__control:checked~label.sk-toggleable__label-arrow:before {\n",
" content: \"▾\";\n",
"}\n",
"\n",
"/* Pipeline/ColumnTransformer-specific style */\n",
"\n",
"#sk-container-id-1 div.sk-label input.sk-toggleable__control:checked~label.sk-toggleable__label {\n",
" color: var(--sklearn-color-text);\n",
" background-color: var(--sklearn-color-unfitted-level-2);\n",
"}\n",
"\n",
"#sk-container-id-1 div.sk-label.fitted input.sk-toggleable__control:checked~label.sk-toggleable__label {\n",
" background-color: var(--sklearn-color-fitted-level-2);\n",
"}\n",
"\n",
"/* Estimator-specific style */\n",
"\n",
"/* Colorize estimator box */\n",
"#sk-container-id-1 div.sk-estimator input.sk-toggleable__control:checked~label.sk-toggleable__label {\n",
" /* unfitted */\n",
" background-color: var(--sklearn-color-unfitted-level-2);\n",
"}\n",
"\n",
"#sk-container-id-1 div.sk-estimator.fitted input.sk-toggleable__control:checked~label.sk-toggleable__label {\n",
" /* fitted */\n",
" background-color: var(--sklearn-color-fitted-level-2);\n",
"}\n",
"\n",
"#sk-container-id-1 div.sk-label label.sk-toggleable__label,\n",
"#sk-container-id-1 div.sk-label label {\n",
" /* The background is the default theme color */\n",
" color: var(--sklearn-color-text-on-default-background);\n",
"}\n",
"\n",
"/* On hover, darken the color of the background */\n",
"#sk-container-id-1 div.sk-label:hover label.sk-toggleable__label {\n",
" color: var(--sklearn-color-text);\n",
" background-color: var(--sklearn-color-unfitted-level-2);\n",
"}\n",
"\n",
"/* Label box, darken color on hover, fitted */\n",
"#sk-container-id-1 div.sk-label.fitted:hover label.sk-toggleable__label.fitted {\n",
" color: var(--sklearn-color-text);\n",
" background-color: var(--sklearn-color-fitted-level-2);\n",
"}\n",
"\n",
"/* Estimator label */\n",
"\n",
"#sk-container-id-1 div.sk-label label {\n",
" font-family: monospace;\n",
" font-weight: bold;\n",
" display: inline-block;\n",
" line-height: 1.2em;\n",
"}\n",
"\n",
"#sk-container-id-1 div.sk-label-container {\n",
" text-align: center;\n",
"}\n",
"\n",
"/* Estimator-specific */\n",
"#sk-container-id-1 div.sk-estimator {\n",
" font-family: monospace;\n",
" border: 1px dotted var(--sklearn-color-border-box);\n",
" border-radius: 0.25em;\n",
" box-sizing: border-box;\n",
" margin-bottom: 0.5em;\n",
" /* unfitted */\n",
" background-color: var(--sklearn-color-unfitted-level-0);\n",
"}\n",
"\n",
"#sk-container-id-1 div.sk-estimator.fitted {\n",
" /* fitted */\n",
" background-color: var(--sklearn-color-fitted-level-0);\n",
"}\n",
"\n",
"/* on hover */\n",
"#sk-container-id-1 div.sk-estimator:hover {\n",
" /* unfitted */\n",
" background-color: var(--sklearn-color-unfitted-level-2);\n",
"}\n",
"\n",
"#sk-container-id-1 div.sk-estimator.fitted:hover {\n",
" /* fitted */\n",
" background-color: var(--sklearn-color-fitted-level-2);\n",
"}\n",
"\n",
"/* Specification for estimator info (e.g. \"i\" and \"?\") */\n",
"\n",
"/* Common style for \"i\" and \"?\" */\n",
"\n",
".sk-estimator-doc-link,\n",
"a:link.sk-estimator-doc-link,\n",
"a:visited.sk-estimator-doc-link {\n",
" float: right;\n",
" font-size: smaller;\n",
" line-height: 1em;\n",
" font-family: monospace;\n",
" background-color: var(--sklearn-color-background);\n",
" border-radius: 1em;\n",
" height: 1em;\n",
" width: 1em;\n",
" text-decoration: none !important;\n",
" margin-left: 1ex;\n",
" /* unfitted */\n",
" border: var(--sklearn-color-unfitted-level-1) 1pt solid;\n",
" color: var(--sklearn-color-unfitted-level-1);\n",
"}\n",
"\n",
".sk-estimator-doc-link.fitted,\n",
"a:link.sk-estimator-doc-link.fitted,\n",
"a:visited.sk-estimator-doc-link.fitted {\n",
" /* fitted */\n",
" border: var(--sklearn-color-fitted-level-1) 1pt solid;\n",
" color: var(--sklearn-color-fitted-level-1);\n",
"}\n",
"\n",
"/* On hover */\n",
"div.sk-estimator:hover .sk-estimator-doc-link:hover,\n",
".sk-estimator-doc-link:hover,\n",
"div.sk-label-container:hover .sk-estimator-doc-link:hover,\n",
".sk-estimator-doc-link:hover {\n",
" /* unfitted */\n",
" background-color: var(--sklearn-color-unfitted-level-3);\n",
" color: var(--sklearn-color-background);\n",
" text-decoration: none;\n",
"}\n",
"\n",
"div.sk-estimator.fitted:hover .sk-estimator-doc-link.fitted:hover,\n",
".sk-estimator-doc-link.fitted:hover,\n",
"div.sk-label-container:hover .sk-estimator-doc-link.fitted:hover,\n",
".sk-estimator-doc-link.fitted:hover {\n",
" /* fitted */\n",
" background-color: var(--sklearn-color-fitted-level-3);\n",
" color: var(--sklearn-color-background);\n",
" text-decoration: none;\n",
"}\n",
"\n",
"/* Span, style for the box shown on hovering the info icon */\n",
".sk-estimator-doc-link span {\n",
" display: none;\n",
" z-index: 9999;\n",
" position: relative;\n",
" font-weight: normal;\n",
" right: .2ex;\n",
" padding: .5ex;\n",
" margin: .5ex;\n",
" width: min-content;\n",
" min-width: 20ex;\n",
" max-width: 50ex;\n",
" color: var(--sklearn-color-text);\n",
" box-shadow: 2pt 2pt 4pt #999;\n",
" /* unfitted */\n",
" background: var(--sklearn-color-unfitted-level-0);\n",
" border: .5pt solid var(--sklearn-color-unfitted-level-3);\n",
"}\n",
"\n",
".sk-estimator-doc-link.fitted span {\n",
" /* fitted */\n",
" background: var(--sklearn-color-fitted-level-0);\n",
" border: var(--sklearn-color-fitted-level-3);\n",
"}\n",
"\n",
".sk-estimator-doc-link:hover span {\n",
" display: block;\n",
"}\n",
"\n",
"/* \"?\"-specific style due to the `<a>` HTML tag */\n",
"\n",
"#sk-container-id-1 a.estimator_doc_link {\n",
" float: right;\n",
" font-size: 1rem;\n",
" line-height: 1em;\n",
" font-family: monospace;\n",
" background-color: var(--sklearn-color-background);\n",
" border-radius: 1rem;\n",
" height: 1rem;\n",
" width: 1rem;\n",
" text-decoration: none;\n",
" /* unfitted */\n",
" color: var(--sklearn-color-unfitted-level-1);\n",
" border: var(--sklearn-color-unfitted-level-1) 1pt solid;\n",
"}\n",
"\n",
"#sk-container-id-1 a.estimator_doc_link.fitted {\n",
" /* fitted */\n",
" border: var(--sklearn-color-fitted-level-1) 1pt solid;\n",
" color: var(--sklearn-color-fitted-level-1);\n",
"}\n",
"\n",
"/* On hover */\n",
"#sk-container-id-1 a.estimator_doc_link:hover {\n",
" /* unfitted */\n",
" background-color: var(--sklearn-color-unfitted-level-3);\n",
" color: var(--sklearn-color-background);\n",
" text-decoration: none;\n",
"}\n",
"\n",
"#sk-container-id-1 a.estimator_doc_link.fitted:hover {\n",
" /* fitted */\n",
" background-color: var(--sklearn-color-fitted-level-3);\n",
"}\n",
"</style><div id=\"sk-container-id-1\" class=\"sk-top-container\"><div class=\"sk-text-repr-fallback\"><pre>Pipeline(steps=[('imputer', SimpleImputer(strategy='median')),\n",
" ('scaler', StandardScaler()),\n",
" ('classifier',\n",
" RandomForestClassifier(class_weight='balanced_subsample',\n",
" min_samples_leaf=4, min_samples_split=3,\n",
" random_state=1127402010))])</pre><b>In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. <br />On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.</b></div><div class=\"sk-container\" hidden><div class=\"sk-item sk-dashed-wrapped\"><div class=\"sk-label-container\"><div class=\"sk-label fitted sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"sk-estimator-id-1\" type=\"checkbox\" ><label for=\"sk-estimator-id-1\" class=\"sk-toggleable__label fitted sk-toggleable__label-arrow fitted\"> Pipeline<a class=\"sk-estimator-doc-link fitted\" rel=\"noreferrer\" target=\"_blank\" href=\"https://scikit-learn.org/1.4/modules/generated/sklearn.pipeline.Pipeline.html\">?<span>Documentation for Pipeline</span></a><span class=\"sk-estimator-doc-link fitted\">i<span>Fitted</span></span></label><div class=\"sk-toggleable__content fitted\"><pre>Pipeline(steps=[('imputer', SimpleImputer(strategy='median')),\n",
" ('scaler', StandardScaler()),\n",
" ('classifier',\n",
" RandomForestClassifier(class_weight='balanced_subsample',\n",
" min_samples_leaf=4, min_samples_split=3,\n",
" random_state=1127402010))])</pre></div> </div></div><div class=\"sk-serial\"><div class=\"sk-item\"><div class=\"sk-estimator fitted sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"sk-estimator-id-2\" type=\"checkbox\" ><label for=\"sk-estimator-id-2\" class=\"sk-toggleable__label fitted sk-toggleable__label-arrow fitted\"> SimpleImputer<a class=\"sk-estimator-doc-link fitted\" rel=\"noreferrer\" target=\"_blank\" href=\"https://scikit-learn.org/1.4/modules/generated/sklearn.impute.SimpleImputer.html\">?<span>Documentation for SimpleImputer</span></a></label><div class=\"sk-toggleable__content fitted\"><pre>SimpleImputer(strategy='median')</pre></div> </div></div><div class=\"sk-item\"><div class=\"sk-estimator fitted sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"sk-estimator-id-3\" type=\"checkbox\" ><label for=\"sk-estimator-id-3\" class=\"sk-toggleable__label fitted sk-toggleable__label-arrow fitted\"> StandardScaler<a class=\"sk-estimator-doc-link fitted\" rel=\"noreferrer\" target=\"_blank\" href=\"https://scikit-learn.org/1.4/modules/generated/sklearn.preprocessing.StandardScaler.html\">?<span>Documentation for StandardScaler</span></a></label><div class=\"sk-toggleable__content fitted\"><pre>StandardScaler()</pre></div> </div></div><div class=\"sk-item\"><div class=\"sk-estimator fitted sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"sk-estimator-id-4\" type=\"checkbox\" ><label for=\"sk-estimator-id-4\" class=\"sk-toggleable__label fitted sk-toggleable__label-arrow fitted\"> RandomForestClassifier<a class=\"sk-estimator-doc-link fitted\" rel=\"noreferrer\" target=\"_blank\" href=\"https://scikit-learn.org/1.4/modules/generated/sklearn.ensemble.RandomForestClassifier.html\">?<span>Documentation for RandomForestClassifier</span></a></label><div class=\"sk-toggleable__content fitted\"><pre>RandomForestClassifier(class_weight='balanced_subsample', min_samples_leaf=4,\n",
" min_samples_split=3, random_state=1127402010)</pre></div> </div></div></div></div></div></div>"
],
"text/plain": [
"Pipeline(steps=[('imputer', SimpleImputer(strategy='median')),\n",
" ('scaler', StandardScaler()),\n",
" ('classifier',\n",
" RandomForestClassifier(class_weight='balanced_subsample',\n",
" min_samples_leaf=4, min_samples_split=3,\n",
" random_state=1127402010))])"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Load labelled metrics and train model\n",
"from spikeinterface.curation.train_manual_curation import train_model\n",
"\n",
"# We will use a list of two (identical) analyzers here, we would advise using more than one to improve model performance\n",
"trainer = train_model(mode = \"analyzers\",\n",
" labels = labels,\n",
" analyzers = [labelled_analyzer, labelled_analyzer],\n",
" output_folder = str(output_folder), # Optional, can be set to save the model and model_info.json file\n",
" metric_names = None, # Can be set to specify which metrics to use for training\n",
" imputation_strategies = None, # Default to all\n",
" scaling_techniques = None, # Default to all\n",
" classifiers = None, # Default to Random Forest only. Other classifiers you can try [ \"AdaBoostClassifier\",\"GradientBoostingClassifier\",\n",
" # \"LogisticRegression\",\"MLPClassifier\"]\n",
" seed = None)\n",
"\n",
"best_model = trainer.best_pipeline\n",
"best_model\n",
"\n",
" \n",
"# OR load model from file\n",
"# import skops.io\n",
"# pipeline_path = Path(output_folder) / Path(\"best_model_label.skops\")\n",
"# unknown_types = skops.io.get_untrusted_types(file=pipeline_path)\n",
"# best_model = skops.io.load(pipeline_path, trusted=unknown_types)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>classifier name</th>\n",
" <th>imputation_strategy</th>\n",
" <th>scaling_strategy</th>\n",
" <th>accuracy</th>\n",
" <th>precision</th>\n",
" <th>recall</th>\n",
" <th>model_id</th>\n",
" <th>best_params</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>RandomForestClassifier</td>\n",
" <td>median</td>\n",
" <td>StandardScaler()</td>\n",
" <td>0.899</td>\n",
" <td>0.899</td>\n",
" <td>0.899</td>\n",
" <td>0</td>\n",
" <td>OrderedDict([('class_weight', 'balanced_subsam...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>RandomForestClassifier</td>\n",
" <td>median</td>\n",
" <td>MinMaxScaler()</td>\n",
" <td>0.899</td>\n",
" <td>0.899</td>\n",
" <td>0.899</td>\n",
" <td>1</td>\n",
" <td>OrderedDict([('class_weight', 'balanced'), ('c...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>RandomForestClassifier</td>\n",
" <td>median</td>\n",
" <td>RobustScaler()</td>\n",
" <td>0.899</td>\n",
" <td>0.899</td>\n",
" <td>0.899</td>\n",
" <td>2</td>\n",
" <td>OrderedDict([('class_weight', 'balanced_subsam...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>RandomForestClassifier</td>\n",
" <td>most_frequent</td>\n",
" <td>StandardScaler()</td>\n",
" <td>0.899</td>\n",
" <td>0.899</td>\n",
" <td>0.899</td>\n",
" <td>3</td>\n",
" <td>OrderedDict([('class_weight', 'balanced_subsam...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>RandomForestClassifier</td>\n",
" <td>most_frequent</td>\n",
" <td>MinMaxScaler()</td>\n",
" <td>0.899</td>\n",
" <td>0.899</td>\n",
" <td>0.899</td>\n",
" <td>4</td>\n",
" <td>OrderedDict([('class_weight', 'balanced'), ('c...</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" classifier name imputation_strategy scaling_strategy accuracy \\\n",
"0 RandomForestClassifier median StandardScaler() 0.899 \n",
"1 RandomForestClassifier median MinMaxScaler() 0.899 \n",
"2 RandomForestClassifier median RobustScaler() 0.899 \n",
"3 RandomForestClassifier most_frequent StandardScaler() 0.899 \n",
"4 RandomForestClassifier most_frequent MinMaxScaler() 0.899 \n",
"\n",
" precision recall model_id \\\n",
"0 0.899 0.899 0 \n",
"1 0.899 0.899 1 \n",
"2 0.899 0.899 2 \n",
"3 0.899 0.899 3 \n",
"4 0.899 0.899 4 \n",
"\n",
" best_params \n",
"0 OrderedDict([('class_weight', 'balanced_subsam... \n",
"1 OrderedDict([('class_weight', 'balanced'), ('c... \n",
"2 OrderedDict([('class_weight', 'balanced_subsam... \n",
"3 OrderedDict([('class_weight', 'balanced_subsam... \n",
"4 OrderedDict([('class_weight', 'balanced'), ('c... "
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Load and disply top 5 pipelines and accuracies\n",
"accuracies = pd.read_csv(Path(output_folder) / \"model_accuracies.csv\", index_col = 0)\n",
"accuracies.head()"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "",
"text/plain": [
"<Figure size 1200x600 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# Plot feature importances\n",
"importances = best_model.named_steps['classifier'].feature_importances_\n",
"indices = np.argsort(importances)[::-1]\n",
"features = best_model.feature_names_in_\n",
"n_features = best_model.n_features_in_\n",
"\n",
"plt.figure(figsize=(12, 6))\n",
"plt.title(\"Feature Importances\")\n",
"plt.bar(range(n_features), importances[indices], align=\"center\")\n",
"plt.xticks(range(n_features), features, rotation=90)\n",
"plt.xlim([-1, n_features])\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "spike_interface",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.16"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
|