Source code for yass.templates.processor

import os

import numpy as np

from yass.templates.util import get_templates, main_channels
# from yass.templates.util import align as _align
from yass.geometry import order_channels_by_distance


# TODO: remove config
[docs]class TemplatesProcessor: """Provides functions for manipulating templates """ def __init__(self, templates): self._update_templates(templates) @classmethod def from_spike_train(cls, CONFIG, half_waveform_length, spike_train, path_to_data): # make sure standarized data already exists if not os.path.exists(path_to_data): raise ValueError('Standarized data does not exist in: {}, this is ' 'needed to generate training data, run the ' 'preprocesor first to generate it' .format(path_to_data)) n_spikes, _ = spike_train.shape # add weight of one to every spike weighted_spike_train = np.hstack((spike_train, np.ones((n_spikes, 1), 'int32'))) # get templates templates, _ = get_templates(weighted_spike_train, path_to_data, CONFIG.resources.max_memory, half_waveform_length) templates = np.transpose(templates, (2, 1, 0)) return cls(templates) def _update_templates(self, templates): self.templates = templates self.amplitudes = np.max(np.abs(templates), axis=(1, 2)) self.main_channels = main_channels(templates) def _check_half_waveform_length(self, half_waveform_length): _, current_waveform_length, _ = self.templates.shape if half_waveform_length > current_waveform_length: raise ValueError('New half_waveform_length ({}) must be smaller' 'than current half_waveform_length ({})' .format(half_waveform_length, current_waveform_length))
[docs] def choose_with_indexes(self, indexes, inplace=False): """ Keep only selected templates and from those, only the ones above certain value Returns ------- """ try: chosen_templates = self.templates[indexes] except IndexError: raise IndexError('Error getting chosen_templates, make sure ' 'the ids exist') if inplace: self._update_templates(chosen_templates) else: return TemplatesProcessor(chosen_templates)
def choose_with_minimum_amplitude(self, minimum_amplitude, inplace=False): chosen_templates = self.templates[self.amplitudes > minimum_amplitude] if inplace: self._update_templates(chosen_templates) else: return TemplatesProcessor(chosen_templates) def crop_temporally(self, half_waveform_length, inplace=False): self._check_half_waveform_length(half_waveform_length) _, current_waveform_length, _ = self.templates.shape mid_point = int(current_waveform_length/2) MID_POINT_IDX = slice(mid_point - half_waveform_length, mid_point + half_waveform_length + 1) new_templates = self.templates[:, MID_POINT_IDX, :] if inplace: self._update_templates(new_templates) else: return TemplatesProcessor(new_templates) def align(self, half_waveform_length, inplace=False): # deactivated, need to fix align function pass # self._check_half_waveform_length(half_waveform_length) # new_templates = _align(self.templates, half_waveform_length) # if inplace: # self._update_templates(new_templates) # else: # return TemplatesProcessor(new_templates) # FIXME: this needs a better name # FIXME: should we order by ptp instead of amplitude?
[docs] def crop_spatially(self, neighbors, geometry, inplace=False): """ Swap channels so the first channel is the one with the largest amplitude, the second one is the nearest neighbor, and so on. Keep only n neighbors, determined by `neighbors` """ n_templates, waveform_length, _ = self.templates.shape # spatially crop (only keep neighbors) n_neigh_to_keep = np.max(np.sum(neighbors, 0)) new_templates = np.zeros((n_templates, waveform_length, n_neigh_to_keep)) for k in range(n_templates): # get neighbors for the main channel in the kth template ch_idx = np.where(neighbors[self.main_channels[k]])[0] # order channels ch_idx, _ = order_channels_by_distance(self.main_channels[k], ch_idx, geometry) # new kth template is the old kth template by keeping only # ordered neighboring channels new_templates[k, :, :ch_idx.shape[0]] = self.templates[k][:, ch_idx] if inplace: self._update_templates(new_templates) else: return TemplatesProcessor(new_templates)
@property def values(self): return self.templates