Augment

yass.augment.make.load_templates(data_folder, spike_train, CONFIG, chosen_templates_indexes)[source]
Parameters:
data_folder: str

Folder storing the standarized data (if not exist, run preprocess to automatically generate)

spike_train: numpy.ndarray

[number of spikes, 2] Ground truth for training. First column is the spike time, second column is the spike id

chosen_templates_indexes: list

List of chosen templates’ id’s

yass.augment.make.spikes(templates, min_amplitude, max_amplitude, n_per_template, spatial_sig, temporal_sig, make_from_templates=True, make_spatially_misaligned=True, make_temporally_misaligned=True, make_collided=True, make_noise=True, return_metadata=True, templates_kwargs={}, collided_kwargs={'min_shift': 5, 'n_per_spike': 1}, temporally_misaligned_kwargs={'n_per_spike': 1}, spatially_misaligned_kwargs={'force_first_channel_shuffle': True, 'n_per_spike': 1}, add_noise_kwargs={'reject_cancelling_noise': False})[source]

Make spikes, it creates several types of spikes from templates with a range of amplitudes

Parameters:
templates: numpy.ndarray, (n_templates, waveform_length, n_channels)

Templates used to generate the spikes

min_amplitude: float

Minimum amplitude for the spikes

max_amplitude: float

Maximum amplitude for the spikes

n_per_template: int

How many spikes to generate per template. This along with min_amplitude and max_amplitude are used to generate spikes covering the desired amplitude range

make_from_templates: bool

Whether to return spikes generated from the templates (these are the same as the templates but with different amplitudes)

make_spatially_misaligned: bool

Whether to return spatially misaligned spikes (by shuffling channels)

make_temporally_misaligned: bool

Whether to return temporally misaligned spikes (by shifting along the temporal axis)

make_collided: bool

Whether to return collided spikes

make_noise: bool

Whether to return pure noise

return_metadata: bool, optional

Return metadata in the generated spikes

Returns:
x_all: numpy.ndarray, (n_templates * n_per_template, waveform_length,
n_channels)

All generated spikes

x_all_noisy: numpy.ndarray, (n_templates * n_per_template, waveform_length,
n_channels)

Noisy versions of all generated spikes

the_amplitudes: numpy.ndarray, (n_templates * n_per_template,)

Amplitudes for all generated spikes

slices: dictionary

Dictionary where the keys are the kind of spikes (‘from templates’, ‘spatially misaligned’, ‘temporally misaligned’, ‘collided’, noise’) and the values are slice objects with the location for each kind of spike

spatial_SIG
temporal_SIG
yass.augment.make.training_data(CONFIG, templates_uncropped, min_amp, max_amp, n_isolated_spikes, path_to_standarized, noise_ratio=10, collision_ratio=1, misalign_ratio=1, misalign_ratio2=1, multi_channel=True, return_metadata=False)[source]

Makes training sets for detector, triage and autoencoder

Parameters:
CONFIG: yaml file

Configuration file

min_amp: float

Minimum value allowed for the maximum absolute amplitude of the isolated spike on its main channel

max_amp: float

Maximum value allowed for the maximum absolute amplitude of the isolated spike on its main channel

n_isolated_spikes: int

Number of isolated spikes to generate. This is different from the total number of x_detect

path_to_standarized: str

Folder storing the standarized data (if not exist, run preprocess to automatically generate)

noise_ratio: int

Ratio of number of noise to isolated spikes. For example, if n_isolated_spike=1000, noise_ratio=5, then n_noise=5000

collision_ratio: int

Ratio of number of collisions to isolated spikes.

misalign_ratio: int

Ratio of number of spatially and temporally misaligned spikes to isolated spikes

misalign_ratio2: int

Ratio of number of only-spatially misaligned spikes to isolated spikes

multi_channel: bool

If True, generate training data for multi-channel neural network. Otherwise generate single-channel data

Returns:
x_detect: numpy.ndarray

[number of detection training data, temporal length, number of channels] Training data for the detect net.

y_detect: numpy.ndarray

[number of detection training data] Label for x_detect

x_triage: numpy.ndarray

[number of triage training data, temporal length, number of channels] Training data for the triage net.

y_triage: numpy.ndarray

[number of triage training data] Label for x_triage

x_ae: numpy.ndarray

[number of ae training data, temporal length] Training data for the autoencoder: noisy spikes

y_ae: numpy.ndarray

[number of ae training data, temporal length] Denoised x_ae

Notes

  • Detection training data
    • Multi channel
      • Positive examples: Clean spikes + noise, Collided spikes + noise
      • Negative examples: Temporally misaligned spikes + noise, Noise
  • Triage training data
    • Multi channel
      • Positive examples: Clean spikes + noise
      • Negative examples: Collided spikes + noise
yass.augment.make.training_data_detect(templates, minimum_amplitude, maximum_amplitude, n_clean_per_template, n_collided_per_spike, n_temporally_misaligned_per_spike, n_spatially_misaliged_per_spike, n_noise, spatial_SIG, temporal_SIG, from_templates_kwargs={}, collided_kwargs={}, temporally_misaligned_kwargs={}, add_noise_kwargs={'reject_cancelling_noise': False})[source]

Make training data for detector network

Notes

Recordings are passed through the detector network which identifies spikes (clean and collided), it rejects noise and misaligned spikes (temporally and spatially)

yass.augment.make.training_data_triage(templates, minimum_amplitude, maximum_amplitude, n_clean_per_template, n_collided_per_spike, max_shift, min_shift, spatial_SIG, temporal_SIG, from_templates_kwargs, collided_kwargs)[source]

Make training data for triage network