{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "## Training a new model" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**If the pretrained models do not give satisfactory performance on your data, it is easy to train your own classifier through SpikeInterface!**" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Number of cores set to: 23\n" ] } ], "source": [ "from pathlib import Path\n", "import numpy as np\n", "import pandas as pd\n", "import matplotlib.pyplot as plt\n", "import spikeinterface as si\n", "import spikeinterface.extractors as se\n", "import spikeinterface.postprocessing as spost\n", "import spikeinterface.qualitymetrics as sqm\n", "import os\n", "from os import cpu_count\n", "import json\n", "# Set the number of CPU cores to be used globally - defaults to all cores -1\n", "n_cores = cpu_count() -1\n", "si.set_global_job_kwargs(n_jobs = n_cores)\n", "print(f\"Number of cores set to: {n_cores}\")\n", "\n", "# SET OUTPUT FOLDER\n", "output_folder = Path(r'E:\\spikeinterface_repository_stuff')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Load data " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## First step: Loading the recording and sorting objects" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "For the tutorial, we are using simulated data to create recording and sorting objects." ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "recording = si.generate_recording(num_channels=50, sampling_frequency=30000.,\n", " durations=[30], set_probe=True)\n", "# load your recoring depeding on the acquistion software you used, for example:\n", "# recording = se.read_spikeglx(recording_path, stream_name='imec0.ap')\n", "\n", "labelled_sorting = si.generate_sorting(num_units=100, sampling_frequency=30000., durations=[30],\n", " firing_rates=15, refractory_period_ms=1.5)\n", "# load your sorting depeding on the which spike sorter you used, for example:\n", "# sorting = se.read_kilosort(folder_path)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Step 2: Create a SortingAnalyzer\n", "\n", "Create a SpikeInterface SortingAnalyzer object. In this example, we use simulated data and generate random labels for the units, which serve as our target for training.\n", "\n", "Important: The labels generated here are random, so the model's performance will be at chance level. \n", "\n", "**For real applications, replace this with your own data and curated labels to achieve meaningful results.**\n", "\n", "To know more about sorting analyzer, please refer to : https://spikeinterface.readthedocs.io/en/latest/modules/postprocessing.html\n", "\n", "If you have already have WaveformExtractor from previous run, you can use it to create a SortingAnalyzer. \n", "Please refer to: https://spikeinterface.readthedocs.io/en/latest/tutorials/waveform_extractor_to_sorting_analyzer.html" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "analyzer = si.create_sorting_analyzer(sorting = labelled_sorting, recording = recording, sparse = True)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Step 3: Compute quality metrics\n", "Compute metrics from multiple SortingAnalyzer objects (each corresponding to a different recording).\n", "Pass the metrics as a list to the model training function." ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [], "source": [ "# Defines a function to compute all analyzer properties and quality metrics\n", "# Note: this can be a time-consuming step, especially computing PCA-based metrics for long recordings\n", "\n", "def compute_all_metrics(analyzer):\n", "\n", " # Compute required extensions for quality metrics\n", " analyzer.compute({\n", "\t'noise_levels': {},\n", "\t'random_spikes': {'max_spikes_per_unit': 1_000},\n", "\t'templates': {'ms_before': 1.5, 'ms_after': 3.5},\n", "\t'spike_amplitudes': {},\n", "\t'waveforms': {},\n", "\t'principal_components': {},\n", "\t'spike_locations': {},\n", "\t'unit_locations': {},\n", "\t})\n", "\n", " # Compute all available quality metrics\n", " analyzer.compute(\"quality_metrics\", metric_names = sqm.get_quality_metric_list() + sqm.get_quality_pca_metric_list())\n", " analyzer.compute(\"template_metrics\", metric_names = spost.get_template_metric_names())\n", "\n", "\t# Make metric dataframe\n", " quality_metrics = analyzer.extensions['quality_metrics'].data[\"metrics\"]\n", " template_metrics = analyzer.extensions['template_metrics'].data[\"metrics\"]\n", " calculated_metrics = pd.concat([quality_metrics, template_metrics], axis = 1)\n", "\n", " return calculated_metrics" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Compute all metrics\n", "metrics = compute_all_metrics(analyzer)\n", "metrics.index.name = 'cluster_id'\n", "\n", "# save the analyzer\n", "analyzer.save_as(folder=output_folder / 'sorting_analyzer', format=\"binary_folder\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Add labels on which your model will be trained on\n", "\n", "Provide a single list of labels in the same order as the metrics.\n", "\n", "Flexible Labeling:\n", "The set of unique labels you use is upto your requirements. \n", "\n", "This approach can be used for any cluster categorization task, whether you're sorting clusters into custom categories 'true' or 'false', or using standard labels like 'good', 'mua', and 'noise'." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "labelled_analyzer = si.load_sorting_analyzer(folder= output_folder / 'sorting_analyzer', format=\"binary_folder\")" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [], "source": [ "label_conversion = {'noise': 0, 'good': 1}\n", "\n", "# These are assigned randomly here but you could load these from phy 'cluster_group.tsv', from the 'quality' property of the sorting, or similar\n", "human_labels = np.random.choice(list(label_conversion.values()), labelled_analyzer.get_num_units())\n", "labelled_analyzer.sorting.set_property('quality', human_labels)\n", "\n", "labels = [human_labels.tolist(), human_labels.tolist()]\n", "\n", "# Get labels from phy sorting (if loaded) using:\n", "# human_labels = unlabelled_analyzer.sorting.get_property('quality')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Step 4: Train with Your Data\n", "Load your data and corresponding curation labels, and use them to train the classifier. Experiment to see how well the model performs on your data!\n", "\n", "Note: For better generalizability, you’ll likely need multiple labeled recordings for training. \n", "The best model is saved as 'best_model.skops' file. You can use this file to predict labels on other recorindgs by auto_label_unit() " ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "C:\\Users\\jain\\Documents\\Github_extend\\spikeinterface\\src\\spikeinterface\\curation\\train_manual_curation.py:168: UserWarning: No metric_names provided, using all metrics calculated by the analyzers\n", " warnings.warn(\"No metric_names provided, using all metrics calculated by the analyzers\")\n" ] }, { "data": { "text/html": [ "
Pipeline(steps=[('imputer', SimpleImputer(strategy='median')),\n", " ('scaler', StandardScaler()),\n", " ('classifier',\n", " RandomForestClassifier(class_weight='balanced_subsample',\n", " min_samples_leaf=4, min_samples_split=3,\n", " random_state=1127402010))])In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
Pipeline(steps=[('imputer', SimpleImputer(strategy='median')),\n", " ('scaler', StandardScaler()),\n", " ('classifier',\n", " RandomForestClassifier(class_weight='balanced_subsample',\n", " min_samples_leaf=4, min_samples_split=3,\n", " random_state=1127402010))])
SimpleImputer(strategy='median')
StandardScaler()
RandomForestClassifier(class_weight='balanced_subsample', min_samples_leaf=4,\n", " min_samples_split=3, random_state=1127402010)
\n", " | classifier name | \n", "imputation_strategy | \n", "scaling_strategy | \n", "accuracy | \n", "precision | \n", "recall | \n", "model_id | \n", "best_params | \n", "
---|---|---|---|---|---|---|---|---|
0 | \n", "RandomForestClassifier | \n", "median | \n", "StandardScaler() | \n", "0.899 | \n", "0.899 | \n", "0.899 | \n", "0 | \n", "OrderedDict([('class_weight', 'balanced_subsam... | \n", "
1 | \n", "RandomForestClassifier | \n", "median | \n", "MinMaxScaler() | \n", "0.899 | \n", "0.899 | \n", "0.899 | \n", "1 | \n", "OrderedDict([('class_weight', 'balanced'), ('c... | \n", "
2 | \n", "RandomForestClassifier | \n", "median | \n", "RobustScaler() | \n", "0.899 | \n", "0.899 | \n", "0.899 | \n", "2 | \n", "OrderedDict([('class_weight', 'balanced_subsam... | \n", "
3 | \n", "RandomForestClassifier | \n", "most_frequent | \n", "StandardScaler() | \n", "0.899 | \n", "0.899 | \n", "0.899 | \n", "3 | \n", "OrderedDict([('class_weight', 'balanced_subsam... | \n", "
4 | \n", "RandomForestClassifier | \n", "most_frequent | \n", "MinMaxScaler() | \n", "0.899 | \n", "0.899 | \n", "0.899 | \n", "4 | \n", "OrderedDict([('class_weight', 'balanced'), ('c... | \n", "