Neural network

Neural Network Detector

class yass.neuralnetwork.NeuralNetDetector(path_to_model, filters_size, waveform_length, n_neighbors, threshold, channel_index, n_iter=50000, n_batch=512, l2_reg_scale=5e-08, train_step_size=0.001, load_test_set=False)[source]

Class for training and running convolutional neural network detector for spike detection

Parameters:
C: int

spatial filter size of the spatial convolutional layer.

R1: int

temporal filter sizes for the temporal convolutional layers.

K1,K2: int

number of filters for each convolutional layer.

W1, W11, W2: tf.Variable

[temporal_filter_size, spatial_filter_size, input_filter_number, ouput_filter_number] weight matrices for the covolutional layers.

b1, b11, b2: tf.Variable

bias variable for the convolutional layers.

saver: tf.train.Saver

saver object for the neural network detector.

threshold: int

threshold for neural net detection

channel_index: np.array (n_channels, n_neigh)

Each row indexes its neighboring channels. For example, channel_index[c] is the index of neighboring channels (including itself) If any value is equal to n_channels, it is nothing but a placeholder in a case that a channel has less than n_neigh neighboring channels

Attributes:
SPIKE: int

Label assigned to the spike class (1)

NOT_SPIKE: int

Label assigned to the not spike class (0)

fit(x_train, y_train, test_size=0.3, save_test_set=False)[source]

Trains the neural network detector for spike detection

Parameters:
x_train: np.array

[number of training data, temporal length, number of channels] augmented training data consisting of isolated spikes, noise and misaligned spikes.

y_train: np.array

[number of training data] label for x_train. ‘1’ denotes presence of an isolated spike and ‘0’ denotes the presence of a noise data or misaligned spike.

test_size: float, optional

Proportion of the training set to be used, data is shuffled before splitting, defaults to 0.3

Returns:
dict

Dictionary with network parameters and metrics

predict(waveforms)[source]

Predict classes (higher or equal than threshold)

predict_proba(waveforms)[source]

Predict probabilities

predict_recording(recording, output_names=('spike_index', ), sess=None)[source]

Make predictions on recordings

Parameters:
output: tuple

Which output layers to return, valid options are: spike_index, waveform and probability

Returns:
tuple

A tuple of numpy.ndarrays, one for every element in output_names

restore(sess)[source]

Restore tensor values

Neural Network Triage

class yass.neuralnetwork.NeuralNetTriage(path_to_model, filters_size, waveform_length, n_neighbors, threshold, n_iter=50000, n_batch=512, l2_reg_scale=5e-08, train_step_size=0.001, input_tensor=None, load_test_set=False)[source]

Convolutional Neural Network for spike detection

Parameters:
path_to_model: str

Where to save the trained model

threshold: float

Threshold between 0 and 1, values higher than the threshold are considered spikes

input_tensor
Attributes:
C: int

spatial filter size of the spatial convolutional layer.

R1: int

temporal filter sizes for the temporal convolutional layers.

K1,K2: int

number of filters for each convolutional layer.

W1, W11, W2: tf.Variable

[temporal_filter_size, spatial_filter_size, input_filter_number, ouput_filter_number] weight matrices for the covolutional layers.

b1, b11, b2: tf.Variable

bias variable for the convolutional layers.

saver: tf.train.Saver

saver object for the neural network detector.

detector: NeuralNetDetector

Instance of detector

threshold: int

threshold for neural net triage

CLEAN: int

Label assigned to the clean spike class (1)

COLLIDED: int

Label assigned to the collided spike class (0)

fit(x_train, y_train, test_size=0.3, save_test_set=False)[source]

Trains the triage network

Parameters:
x_train: np.array

[number of data, temporal length, number of channels] training data for the triage network.

y_train: np.array

[number of data] training label for the triage network.

test_size: float, optional

Proportion of the training set to be used, data is shuffled before splitting, defaults to 0.3

Returns:
dict

Dictionary with network parameters and metrics

Notes

Size is determined but the second dimension in x_train

classmethod load(path_to_model, threshold, input_tensor=None, load_test_set=False)[source]

Load a model from a file

predict(waveforms)[source]

Triage waveforms

restore(sess)[source]

Restore tensor values