Neural network¶
Neural Network Detector¶
-
class
yass.neuralnetwork.
NeuralNetDetector
(path_to_model, filters_size, waveform_length, n_neighbors, threshold, channel_index, n_iter=50000, n_batch=512, l2_reg_scale=5e-08, train_step_size=0.001, load_test_set=False)[source]¶ Class for training and running convolutional neural network detector for spike detection
Parameters: - C: int
spatial filter size of the spatial convolutional layer.
- R1: int
temporal filter sizes for the temporal convolutional layers.
- K1,K2: int
number of filters for each convolutional layer.
- W1, W11, W2: tf.Variable
[temporal_filter_size, spatial_filter_size, input_filter_number, ouput_filter_number] weight matrices for the covolutional layers.
- b1, b11, b2: tf.Variable
bias variable for the convolutional layers.
- saver: tf.train.Saver
saver object for the neural network detector.
- threshold: int
threshold for neural net detection
- channel_index: np.array (n_channels, n_neigh)
Each row indexes its neighboring channels. For example, channel_index[c] is the index of neighboring channels (including itself) If any value is equal to n_channels, it is nothing but a placeholder in a case that a channel has less than n_neigh neighboring channels
Attributes: - SPIKE: int
Label assigned to the spike class (1)
- NOT_SPIKE: int
Label assigned to the not spike class (0)
-
fit
(x_train, y_train, test_size=0.3, save_test_set=False)[source]¶ Trains the neural network detector for spike detection
Parameters: - x_train: np.array
[number of training data, temporal length, number of channels] augmented training data consisting of isolated spikes, noise and misaligned spikes.
- y_train: np.array
[number of training data] label for x_train. ‘1’ denotes presence of an isolated spike and ‘0’ denotes the presence of a noise data or misaligned spike.
- test_size: float, optional
Proportion of the training set to be used, data is shuffled before splitting, defaults to 0.3
Returns: - dict
Dictionary with network parameters and metrics
-
predict_recording
(recording, output_names=('spike_index', ), sess=None)[source]¶ Make predictions on recordings
Parameters: - output: tuple
Which output layers to return, valid options are: spike_index, waveform and probability
Returns: - tuple
A tuple of numpy.ndarrays, one for every element in output_names
Neural Network Triage¶
-
class
yass.neuralnetwork.
NeuralNetTriage
(path_to_model, filters_size, waveform_length, n_neighbors, threshold, n_iter=50000, n_batch=512, l2_reg_scale=5e-08, train_step_size=0.001, input_tensor=None, load_test_set=False)[source]¶ Convolutional Neural Network for spike detection
Parameters: - path_to_model: str
Where to save the trained model
- threshold: float
Threshold between 0 and 1, values higher than the threshold are considered spikes
- input_tensor
Attributes: - C: int
spatial filter size of the spatial convolutional layer.
- R1: int
temporal filter sizes for the temporal convolutional layers.
- K1,K2: int
number of filters for each convolutional layer.
- W1, W11, W2: tf.Variable
[temporal_filter_size, spatial_filter_size, input_filter_number, ouput_filter_number] weight matrices for the covolutional layers.
- b1, b11, b2: tf.Variable
bias variable for the convolutional layers.
- saver: tf.train.Saver
saver object for the neural network detector.
- detector: NeuralNetDetector
Instance of detector
- threshold: int
threshold for neural net triage
- CLEAN: int
Label assigned to the clean spike class (1)
- COLLIDED: int
Label assigned to the collided spike class (0)
-
fit
(x_train, y_train, test_size=0.3, save_test_set=False)[source]¶ Trains the triage network
Parameters: - x_train: np.array
[number of data, temporal length, number of channels] training data for the triage network.
- y_train: np.array
[number of data] training label for the triage network.
- test_size: float, optional
Proportion of the training set to be used, data is shuffled before splitting, defaults to 0.3
Returns: - dict
Dictionary with network parameters and metrics
Notes
Size is determined but the second dimension in x_train