Neural network

Neural Network Detector

class yass.neuralnetwork.NeuralNetDetector(path_to_model, filters_size, waveform_length, n_neighbors, threshold, channel_index, n_iter=50000, n_batch=512, l2_reg_scale=5e-08, train_step_size=0.001, load_test_set=False)[source]

Class for training and running convolutional neural network detector for spike detection

C: int

spatial filter size of the spatial convolutional layer.

R1: int

temporal filter sizes for the temporal convolutional layers.

K1,K2: int

number of filters for each convolutional layer.

W1, W11, W2: tf.Variable

[temporal_filter_size, spatial_filter_size, input_filter_number, ouput_filter_number] weight matrices for the covolutional layers.

b1, b11, b2: tf.Variable

bias variable for the convolutional layers.

saver: tf.train.Saver

saver object for the neural network detector.

threshold: int

threshold for neural net detection

channel_index: np.array (n_channels, n_neigh)

Each row indexes its neighboring channels. For example, channel_index[c] is the index of neighboring channels (including itself) If any value is equal to n_channels, it is nothing but a placeholder in a case that a channel has less than n_neigh neighboring channels

SPIKE: int

Label assigned to the spike class (1)


Label assigned to the not spike class (0)

fit(x_train, y_train, test_size=0.3, save_test_set=False)[source]

Trains the neural network detector for spike detection

x_train: np.array

[number of training data, temporal length, number of channels] augmented training data consisting of isolated spikes, noise and misaligned spikes.

y_train: np.array

[number of training data] label for x_train. ‘1’ denotes presence of an isolated spike and ‘0’ denotes the presence of a noise data or misaligned spike.

test_size: float, optional

Proportion of the training set to be used, data is shuffled before splitting, defaults to 0.3


Dictionary with network parameters and metrics


Predict classes (higher or equal than threshold)


Predict probabilities

predict_recording(recording, output_names=('spike_index', ), sess=None)[source]

Make predictions on recordings

output: tuple

Which output layers to return, valid options are: spike_index, waveform and probability


A tuple of numpy.ndarrays, one for every element in output_names


Restore tensor values

Neural Network Triage

class yass.neuralnetwork.NeuralNetTriage(path_to_model, filters_size, waveform_length, n_neighbors, threshold, n_iter=50000, n_batch=512, l2_reg_scale=5e-08, train_step_size=0.001, input_tensor=None, load_test_set=False)[source]

Convolutional Neural Network for spike detection

path_to_model: str

Where to save the trained model

threshold: float

Threshold between 0 and 1, values higher than the threshold are considered spikes

C: int

spatial filter size of the spatial convolutional layer.

R1: int

temporal filter sizes for the temporal convolutional layers.

K1,K2: int

number of filters for each convolutional layer.

W1, W11, W2: tf.Variable

[temporal_filter_size, spatial_filter_size, input_filter_number, ouput_filter_number] weight matrices for the covolutional layers.

b1, b11, b2: tf.Variable

bias variable for the convolutional layers.

saver: tf.train.Saver

saver object for the neural network detector.

detector: NeuralNetDetector

Instance of detector

threshold: int

threshold for neural net triage

CLEAN: int

Label assigned to the clean spike class (1)


Label assigned to the collided spike class (0)

fit(x_train, y_train, test_size=0.3, save_test_set=False)[source]

Trains the triage network

x_train: np.array

[number of data, temporal length, number of channels] training data for the triage network.

y_train: np.array

[number of data] training label for the triage network.

test_size: float, optional

Proportion of the training set to be used, data is shuffled before splitting, defaults to 0.3


Dictionary with network parameters and metrics


Size is determined but the second dimension in x_train

classmethod load(path_to_model, threshold, input_tensor=None, load_test_set=False)[source]

Load a model from a file


Triage waveforms


Restore tensor values