diff --git a/bouter/free/__init__.py b/bouter/free/__init__.py index 8c719bb..6424cfe 100644 --- a/bouter/free/__init__.py +++ b/bouter/free/__init__.py @@ -136,12 +136,13 @@ def compute_velocity( return fish_velocities @decorators.cache_results() - def get_bouts(self, scale=None, threshold=1, **kwargs): + def get_bouts(self, scale=None, threshold=1, conv_detection=False, **kwargs): """Extracts all bouts from a freely-swimming tracking experiment :param exp: the experiment object :param scale: mm per pixel, recalculated by default - :param threshold: velocity threshold in mm/s + :param threshold: velocity threshold in mm/s or score threshold if conv_detection=True + :param conv_detection: whether to use an alternative detection algorithm using convolution. :return: tuple: (list of single bout dataframes, list of boolean arrays marking if the bout i follows bout i-1) """ @@ -156,12 +157,23 @@ def get_bouts(self, scale=None, threshold=1, **kwargs): for i_fish in range(n_fish): vel2 = fish_velocities["vel_f{}".format(i_fish)] - ( - bout_locations, - continuity, - ) = utilities.extract_segments_above_threshold( - vel2.values, threshold=threshold**2, **kwargs - ) + if not conv_detection: + ( + bout_locations, + continuity, + ) = utilities.extract_segments_above_threshold( + vel2.values, threshold=threshold**2, **kwargs + ) + else: + score = utilities.calc_bout_score(vel2.values) + bout_times = utilities.get_bout_times( + score, min_peak_value=threshold, **kwargs + ) + + # For compatability. + bout_locations = np.array(bout_times) + bout_locations = bout_locations[:, [0, 2]] + continuity = [False] * bout_locations.shape[0] all_bouts_fish = [ self._extract_bout(s, e, n_segments, i_fish, scale) for s, e in bout_locations diff --git a/bouter/utilities.py b/bouter/utilities.py index f136d2e..9d8cf92 100644 --- a/bouter/utilities.py +++ b/bouter/utilities.py @@ -93,6 +93,168 @@ def extract_segments_above_threshold( return np.array(segments), np.array(connected) +@jit(nopython=True) +def get_score_trace(trace_pad, kernel, bias=-0.2839): + """ + Used to calculate the correlation score between the squared velocity + trace and the kernel for bout detection. + :param trace_pad: squared velocity (optionally) padded + :param kernel: kernel used for bout detection + :return: + array of correlation scores that matches with the trace by index. + """ + + # For numerical stability. + eta = 0.0000000001 + + corr = np.empty((len(trace_pad) - len(kernel))) + for i in range(len(trace_pad) - len(kernel)): + current_values = trace_pad[i:i+len(kernel)] + relative_values = current_values / (np.max(current_values) + eta) + + # Simplified convolution, works faster and with Numba. + conv = np.sum(relative_values * np.flip(kernel)) + + corr[i] = 1 / (1 + np.exp(conv + bias)) + + # Align the correlation score trace with the actual bouts. + # Found by trial and error - so can be changed as needed. + corr = corr[:-int(kernel.shape[0]/5)] + corr = np.concatenate((np.zeros(int(kernel.shape[0]/5)), corr)) + + return (corr - 1) * -1 + + +def calc_bout_score(trace, kernel=None, bias=-0.2839, pad_len=None): + """ + Wrapper for the get_score_trace function. + :param trace: the trace to detect bouts in. + :param kernel: kernel used for bout detection. + :param bias: the bias to shift the output + :param pad_len: the length of the padding on each side. + :return: + array of correlation scores that matches with the trace by index. + """ + + if kernel is None: + # Kernel trained using a NN. + kernel = np.array([ + -0.5633, -0.5888, -0.5097, -0.3899, -0.4896, -0.4326, -0.5017, + -0.4555, -0.4709, -0.4580, -0.4563, -0.4510, -0.3165, -0.3896, + -0.2522, -0.2982, -0.2503, -0.0670, 0.1450, 0.2824, 0.2420, + 0.3270, 0.2569, 0.3174, 0.3718, 0.2896, 0.3356, 0.3905, + 0.2844, 0.3674, 0.3158, 0.3549, 0.2847, 0.4782, 0.4236, + 0.4416, 0.4049, 0.3622, 0.3755, 0.2569, 0.2525, 0.2626, + 0.2875, 0.1809, 0.1314, 0.1213, 0.1023, -0.0113, 0.1013, + 0.0844, -0.0785, -0.0316, -0.0584, -0.1613, -0.1835, -0.1843, + -0.1421, -0.1407, -0.0838, -0.1655, -0.2004, -0.0968, -0.1559, + -0.1564, -0.1867, -0.1494, -0.1192, -0.2535, -0.1645, -0.1529, + -0.1918, -0.1987, -0.2686, -0.2107, -0.2132, -0.2063, -0.2253, + -0.1670, -0.2638, -0.2669, -0.1228, -0.1679, -0.2795, -0.2066, + -0.1625, -0.1498, -0.1983, -0.2351, -0.2337, -0.2803, -0.3189, + -0.2672, -0.2565, -0.3481, -0.3722, -0.3055, -0.3174, -0.4059, + -0.3363, -0.4085]) + + if pad_len is None: + pad_len = int(kernel.shape[0]/2) + + trace_pad = np.pad(trace, pad_len) + return get_score_trace(trace_pad, kernel) + + +def get_bout_times(trace, + min_peak_value=0.75, + max_baseline_value=0.05, + include_nan=False, + max_zero_length=(55, 55), + min_bout_distance=0, + **kwargs): + """Finds bout peaks and their start and end from a convolved + trace (only tested on freely-swimming experiments). + :param trace: the squared velocity trace convolved with a bout detection kernel. + :param min_peak_value: the minimum peak value for a bump to count as a bout. + In NN terms this is the value of the sigmoid for classification, so 0.5 + Would be the cut-off value. However, one can opt for a higher or lower value + if they want to change the false positive or true negative rate (i.e. a higher + value should classify less noise as bouts but also means that more bouts will + not be detected). + :param max_baseline_value: (ab)uses the fact that a convolution will result in a smooth + signal. The sides of the peak more or less reflect the start and end times of the + bout. This is the cut-off value that determines at which points the peak start/ends + and thus the bout starts/ends. + :param include_nan: include bouts containing NaN values. + :param max_length_to_baseline: the maximum number of samples that it may take to reach the + baseline (defined by max_baseline_value) from the peak. + :param min_bout_distance: minimum distance between the end of a bout and the peak of the + next bout. + :return: tuple: (bout start times, bout peak times, bout end times) + """ + + last_bout_end = 0 + bouts = [] + + i = 0 + while i < trace.shape[0]: + # Check if the value is high enough to indicate a bout. + if trace[i] > min_peak_value: + # We found a bout! + bout_start = i + bout_end = i + + # Check how far back it goes. + j = i + 1 + found_start = False + while j > last_bout_end: + if trace[j] < max_baseline_value: + bout_start = j + found_start = True + break + + j -= 1 + + if not found_start: + bout_start = last_bout_end + + # Check how far the end is. + j = i + 1 + found_end = False + while j < trace.shape[0]: + if trace[j] < max_baseline_value: + bout_end = j + found_end = True + break + elif trace[j-1] < min_peak_value < trace[j]: + # We found the next bout! + # We stop this bout at the valley. + bout_end = i + np.argmin(trace[i:j]) + found_end = True + break + + j += 1 + + if not found_end: + bout_end = trace.shape[0] - 1 + + # Update values. + last_bout_end = bout_end + # Filter out bouts containing NaNs. + if include_nan or not np.isnan(trace[bout_start:bout_end]).any(): + # The convolution during noise is less steep, + # it might not reach the baseline value within the typical + # bout time window. So we can use the requirement of the 'bout' + # stopping within a certain time range, as a way to filter noise. + if max_zero_length is None \ + or (max_zero_length[0] > i - bout_start \ + or max_zero_length[1] > bout_end - i): + bouts.append((bout_start, i, bout_end)) + + # Skip the remaining part of the bout and optionally some extra distance. + i = bout_end + min_bout_distance + + i += 1 + + return bouts + def log_dt(log_df, i_start=10, i_end=110): return np.mean(np.diff(log_df.t[i_start:i_end]))