Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 18 additions & 20 deletions src/spikeinterface/metrics/quality/misc_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,7 @@
import numpy as np

from spikeinterface.core.analyzer_extension_core import BaseMetric
from spikeinterface.core.job_tools import fix_job_kwargs, split_job_kwargs
from spikeinterface.core import SortingAnalyzer, get_noise_levels, NumpySorting
from spikeinterface.core import SortingAnalyzer, NumpySorting
from spikeinterface.core.template_tools import (
get_template_extremum_channel,
get_template_extremum_amplitude,
Expand Down Expand Up @@ -1239,7 +1238,8 @@ def compute_sd_ratio(
censored_period_ms: float = 4.0,
correct_for_drift: bool = True,
correct_for_template_itself: bool = True,
**kwargs,
peak_sign: str = "neg",
**job_kwargs,
):
"""
Computes the SD (Standard Deviation) of each unit's spike amplitudes, and compare it to the SD of noise.
Expand All @@ -1264,20 +1264,21 @@ def compute_sd_ratio(
correct_for_template_itself : bool, default: True
If true, will take into account that the template itself impacts the standard deviation of the noise,
and will make a rough estimation of what that impact is (and remove it).
**kwargs : dict, default: {}
Keyword arguments for computing spike amplitudes and extremum channel.
peak_sign : "neg" | "pos" | "both", default: "neg"
The peak sign used to select the template extremum channel.
**job_kwargs : dict, default: {}
Keyword arguments sent to get_noise_levels.

Returns
-------
num_spikes : dict
The number of spikes, across all segments, for each unit ID.
sd_ratio : dict
The ratio of the standard deviation of spike amplitudes to the standard deviation of noise, for each unit ID.
"""

from spikeinterface.curation.curation_tools import find_duplicated_spikes
from spikeinterface.core import get_noise_levels

check_has_required_extensions("sd_ratio", sorting_analyzer)
kwargs, job_kwargs = split_job_kwargs(kwargs)
job_kwargs = fix_job_kwargs(job_kwargs)

sorting = sorting_analyzer.sorting
sorting = sorting.select_periods(periods=periods)
Expand Down Expand Up @@ -1309,11 +1310,11 @@ def compute_sd_ratio(
noise_levels = get_noise_levels(
sorting_analyzer.recording, return_in_uV=sorting_analyzer.return_in_uV, method="std", **job_kwargs
)
best_channels = get_template_extremum_channel(sorting_analyzer, outputs="index", **kwargs)
n_spikes = sorting_analyzer.sorting.count_num_spikes_per_unit(unit_ids=unit_ids)
best_channels = get_template_extremum_channel(sorting_analyzer, outputs="index", peak_sign=peak_sign)

if correct_for_template_itself:
tamplates_array = get_dense_templates_array(sorting_analyzer, return_in_uV=sorting_analyzer.return_in_uV)
n_spikes = sorting_analyzer.sorting.count_num_spikes_per_unit(unit_ids=unit_ids)
templates_array = get_dense_templates_array(sorting_analyzer, return_in_uV=sorting_analyzer.return_in_uV)

sd_ratio = {}

Expand Down Expand Up @@ -1348,21 +1349,17 @@ def compute_sd_ratio(
best_channel = best_channels[unit_id]
std_noise = noise_levels[best_channel]

n_samples = sorting_analyzer.get_total_samples()

if correct_for_template_itself:
# template = sorting_analyzer.get_template(unit_id, force_dense=True)[:, best_channel]
unit_index = sorting.id_to_index(unit_id)

template = tamplates_array[unit_index, :, :][:, best_channel]
nsamples = template.shape[0]
template = templates_array[unit_index, :, best_channel]

# Computing the variance of a trace that is all 0 and n_spikes non-overlapping template.
# TODO: Take into account that templates for different segments might differ.
p = nsamples * n_spikes[unit_id] / n_samples
total_variance = p * np.mean(template**2) - p**2 * np.mean(template) ** 2
p = len(template) * n_spikes[unit_id] / sorting_analyzer.get_total_samples()
template_variance = p * np.mean(template**2) - p**2 * np.mean(template) ** 2

std_noise = np.sqrt(std_noise**2 - total_variance)
std_noise = np.sqrt(std_noise**2 - template_variance)

sd_ratio[unit_id] = unit_std / std_noise

Expand All @@ -1376,6 +1373,7 @@ class SDRatio(BaseMetric):
"censored_period_ms": 4.0,
"correct_for_drift": True,
"correct_for_template_itself": True,
"peak_sign": "neg",
}
metric_columns = {"sd_ratio": float}
metric_descriptions = {
Expand Down
Loading