Source code for yass.neuralnetwork.model_detector

try:
    from pathlib2 import Path
except ImportError:
    from pathlib import Path

import numpy as np
import tensorflow as tf
from sklearn.model_selection import train_test_split
from tqdm import trange
import logging

from yass.neuralnetwork.utils import (weight_variable, bias_variable, conv2d,
                                      conv2d_VALID, max_pool)
from yass.util import load_yaml, change_extension
from yass.neuralnetwork.model import Model


[docs]class NeuralNetDetector(Model): """ Class for training and running convolutional neural network detector for spike detection Parameters ---------- C: int spatial filter size of the spatial convolutional layer. R1: int temporal filter sizes for the temporal convolutional layers. K1,K2: int number of filters for each convolutional layer. W1, W11, W2: tf.Variable [temporal_filter_size, spatial_filter_size, input_filter_number, ouput_filter_number] weight matrices for the covolutional layers. b1, b11, b2: tf.Variable bias variable for the convolutional layers. saver: tf.train.Saver saver object for the neural network detector. threshold: int threshold for neural net detection channel_index: np.array (n_channels, n_neigh) Each row indexes its neighboring channels. For example, channel_index[c] is the index of neighboring channels (including itself) If any value is equal to n_channels, it is nothing but a placeholder in a case that a channel has less than n_neigh neighboring channels Attributes ---------- SPIKE: int Label assigned to the spike class (1) NOT_SPIKE: int Label assigned to the not spike class (0) """ SPIKE = 1 NOT_SPIKE = 0 def __init__(self, path_to_model, filters_size, waveform_length, n_neighbors, threshold, channel_index, n_iter=50000, n_batch=512, l2_reg_scale=0.00000005, train_step_size=0.001, load_test_set=False): """ Initializes the attributes for the class NeuralNetDetector. Parameters: ----------- path_to_model: str location of trained neural net detectior """ self.logger = logging.getLogger(__name__) self.path_to_model = path_to_model self.model_name = Path(path_to_model).name.replace('.ckpt', '') self.filters_size = filters_size self.n_neighbors = n_neighbors self.waveform_length = waveform_length self.threshold = threshold self.n_batch = n_batch self.l2_reg_scale = l2_reg_scale self.train_step_size = train_step_size self.n_iter = n_iter # variables K1, K2 = filters_size W1 = weight_variable([waveform_length, 1, 1, K1]) b1 = bias_variable([K1]) W11 = weight_variable([1, 1, K1, K2]) b11 = bias_variable([K2]) W2 = weight_variable([1, self.n_neighbors, K2, 1]) b2 = bias_variable([1]) self.vars_dict = {"W1": W1, "W11": W11, "W2": W2, "b1": b1, "b11": b11, "b2": b2} # graphs (self.x_tf, self.spike_index_tf, self.probability_tf, self.waveform_tf) = (NeuralNetDetector ._make_recordings_graph(threshold, channel_index, waveform_length, filters_size, n_neighbors, self.vars_dict)) (self.x_tf_tr, self.y_tf_tr, self.o_layer_tr, self.sigmoid_tr) = (NeuralNetDetector ._make_training_graph(self.waveform_length, self.n_neighbors, self.vars_dict)) # create saver variables self.saver = tf.train.Saver(self.vars_dict) if load_test_set: self._load_test_set() @classmethod def load(cls, path_to_model, threshold, channel_index, load_test_set=False): if not path_to_model.endswith('.ckpt'): path_to_model = path_to_model+'.ckpt' # load nn parameter files path_to_params = change_extension(path_to_model, 'yaml') params = load_yaml(path_to_params) return cls(path_to_model, params['filters_size'], params['waveform_length'], params['n_neighbors'], threshold, channel_index, load_test_set=load_test_set) @classmethod def _make_network(cls, input_layer, vars_dict, padding): # first temporal layer # FIXME: old training code was using conv2d_VALID, old graph building # for prediction was using conv2d, that's why I need to add the # padding parameter, otherwise it breaks. we need to fix it layer1 = tf.nn.relu(conv2d(input_layer, vars_dict['W1'], padding) + vars_dict['b1']) # second temporal layer layer11 = tf.nn.relu(conv2d(layer1, vars_dict['W11']) + vars_dict['b11']) return vars_dict, layer11 @classmethod def _make_recordings_graph(cls, threshold, channel_index, waveform_length, filters_size, n_neigh, vars_dict): """Build tensorflow graph with input and two output layers used for predicting on recordings Parameters ----------- x_tf: tf.tensors (n_observations, n_channels) placeholder of recording for running tensorflow channel_index: np.array (n_channels, n_neigh) Each row indexes its neighboring channels. For example, channel_index[c] is the index of neighboring channels (including itself) If any value is equal to n_channels, it is nothing but a placeholder in a case that a channel has less than n_neigh neighboring channels threshold: int threshold on a probability to determine location of spikes Returns ------- spike_index_tf: tf tensor (n_spikes, 2) tensorflow tensor that produces spike_index """ ###################### # Loading parameters # ###################### # FIXME: Eduardo: CONFIG.channel_index (which is the one passed here) # has steps=2, so they are passing :n_neigh which is the same # as steps=1, it is unclear why we are creating it with steps=2 # in the first place, we need to check if at any point in the # pipeline we need it small_channel_index = channel_index[:, :n_neigh] # placeholder for input recording x_tf = tf.placeholder("float", [None, None]) # Temporal shape of input T = tf.shape(x_tf)[0] #################### # Building network # #################### # input tensor into CNN - add one dimension at the beginning and # at the end x_cnn_tf = tf.expand_dims(tf.expand_dims(x_tf, -1), 0) vars_dict, layer11 = cls._make_network(x_cnn_tf, vars_dict, padding='SAME') W2 = vars_dict['W2'] b2 = vars_dict['b2'] K1, K2 = filters_size # first spatial layer zero_added_layer11 = tf.concat((tf.transpose(layer11, [2, 0, 1, 3]), tf.zeros((1, 1, T, K2))), axis=0) temp = tf.transpose(tf.gather(zero_added_layer11, small_channel_index), [0, 2, 3, 1, 4]) temp2 = conv2d_VALID(tf.reshape(temp, [-1, T, n_neigh, K2]), W2) + b2 o_layer = tf.transpose(temp2, [2, 1, 0, 3]) ################################ # Output layer transformations # ################################ o_layer_val = tf.squeeze(o_layer) # probability output - just sigmoid of output layer probability_tf = tf.sigmoid(o_layer_val) # spike index output (local maximum crossing a threshold) temporal_max = tf.squeeze(max_pool(o_layer, [1, 3, 1, 1]) - 1e-8) higher_than_max_pool = o_layer_val >= temporal_max higher_than_threshold = (o_layer_val > np.log(threshold / (1 - threshold))) both_higher = tf.logical_and(higher_than_max_pool, higher_than_threshold) index_all = tf.cast(tf.where(both_higher), 'int32') spike_index_tf = cls._remove_edge_spikes(x_tf, index_all, waveform_length) # waveform output from spike index output waveform_tf = cls._make_waveform_tf(x_tf, spike_index_tf, channel_index, waveform_length) return x_tf, spike_index_tf, probability_tf, waveform_tf @classmethod def _make_training_graph(cls, waveform_length, n_neighbors, vars_dict): """Make graph for training Returns ------- x_tf: tf.tensor Input tensor y_tf: tf.tensor Labels tensor o_layer: tf.tensor Output tensor """ # x and y input tensors x_tf = tf.placeholder("float", [None, waveform_length, n_neighbors]) y_tf = tf.placeholder("float", [None]) input_tf = tf.expand_dims(x_tf, -1) vars_dict, layer11 = (NeuralNetDetector ._make_network(input_tf, vars_dict, padding='VALID')) W2 = vars_dict['W2'] b2 = vars_dict['b2'] # third layer: spatial convolution o_layer = tf.squeeze(conv2d_VALID(layer11, W2) + b2) # sigmoid sigmoid = tf.sigmoid(o_layer) return x_tf, y_tf, o_layer, sigmoid @classmethod def _remove_edge_spikes(cls, x_tf, spike_index_tf, waveform_length): """ It moves spikes at edge times. Parameters ---------- x_tf: tf.tensors (n_observations, n_channels) placeholder of recording for running tensorflow spike_index_tf: tf tensor (n_spikes, 2) a tf tensor holding spike index. The first column is time and the second column is the main channel waveform_length: int temporal length of waveform Returns ------- tf tensor (n_spikes, 2) """ R = int((waveform_length-1)/2) min_spike_time = R max_spike_time = tf.shape(x_tf)[0] - R idx_middle = tf.logical_and(spike_index_tf[:, 0] > min_spike_time, spike_index_tf[:, 0] < max_spike_time) return tf.boolean_mask(spike_index_tf, idx_middle) @classmethod def _make_waveform_tf(cls, x_tf, spike_index_tf, channel_index, wf_length): """ It produces a tf tensor holding waveforms given recording and spike index. It does not hold waveforms on all channels but channels around their main channels specified in channel_index Parameters ---------- x_tf: tf.tensors (n_observations, n_channels) placeholder of recording for running tensorflow spike_index_tf: tf tensor (n_spikes, 2) a tf tensor holding spike index. The first column is time and the second column is the main channel channel_index: np.array (n_channels, n_neigh) refer above wf_length: int temporal length of waveform Returns ------- tf tensor (n_spikes, wf_length, n_neigh) """ R = int((wf_length-1)/2) # half waveform length T = tf.shape(x_tf)[0] # length of recording # get waveform temporally # make indexes with the appropriate waveform length, centered at zero # shape: [1, wf_length] waveform_indexes = tf.expand_dims(tf.range(-R, R+1), 0) # get all spike times, shape: [n_spikes, 1] spike_times = tf.expand_dims(spike_index_tf[:, 0], -1) # shift indexes and add two dimensions, shape: [n_spikes, wf_length] _ = tf.add(spike_times, waveform_indexes) # add two trailing extra dimensions, shape: [n_spikes, wf_length, 1, 1] wf_temporal = tf.expand_dims(tf.expand_dims(_, -1), -1) # get waveform spatially # get neighbors for main channels in the spike index # shape: [n_spikes, n_neigh] _ = tf.gather(channel_index, spike_index_tf[:, 1]) # add one dimension to the left and one to the right # shape: [n_spikes, 1, n_neigh, 1] wf_spatial = tf.expand_dims(tf.expand_dims(_, 1), -1) # build spatio-temporal index # tile temporal indexes on the number of channels and spatial indexes # on the waveform length, then concatenate # FIXME: there is a mismatch here, we aren't we using self.n_neigh n_neigh = channel_index.shape[1] _ = (tf.tile(wf_temporal, (1, 1, n_neigh, 1)), tf.tile(wf_spatial, (1, wf_length, 1, 1))) idx = tf.concat(_, 3) # add one extra value in the channels dimension x_tf_zero_added = tf.concat([x_tf, tf.zeros((T, 1))], axis=1) return tf.gather_nd(x_tf_zero_added, idx)
[docs] def restore(self, sess): """Restore tensor values """ self.logger.debug('Restoring tensorflow session from: %s', self.path_to_model) self.saver.restore(sess, self.path_to_model)
[docs] def predict_recording(self, recording, output_names=('spike_index',), sess=None): """Make predictions on recordings Parameters ---------- output: tuple Which output layers to return, valid options are: spike_index, waveform and probability Returns ------- tuple A tuple of numpy.ndarrays, one for every element in output_names """ output_tensors = [getattr(self, name+'_tf') for name in output_names] if sess is None: with tf.Session() as sess: self.restore(sess) output = sess.run(output_tensors, feed_dict={self.x_tf: recording}) else: output = sess.run(output_tensors, feed_dict={self.x_tf: recording}) return output
[docs] def predict_proba(self, waveforms): """Predict probabilities """ _, waveform_length, n_neighbors = waveforms.shape self._validate_dimensions(waveform_length, n_neighbors) with tf.Session() as sess: self.restore(sess) output = sess.run(self.sigmoid_tr, feed_dict={self.x_tf_tr: waveforms}) return output
[docs] def predict(self, waveforms): """Predict classes (higher or equal than threshold) """ probas = self.predict_proba(waveforms) return (probas > self.threshold).astype('int')
[docs] def fit(self, x_train, y_train, test_size=0.3, save_test_set=False): """ Trains the neural network detector for spike detection Parameters ---------- x_train: np.array [number of training data, temporal length, number of channels] augmented training data consisting of isolated spikes, noise and misaligned spikes. y_train: np.array [number of training data] label for x_train. '1' denotes presence of an isolated spike and '0' denotes the presence of a noise data or misaligned spike. test_size: float, optional Proportion of the training set to be used, data is shuffled before splitting, defaults to 0.3 Returns ------- dict Dictionary with network parameters and metrics """ logger = logging.getLogger(__name__) ##################### # Splitting dataset # ##################### (self.x_train, self.x_test, self.y_train, self.y_test) = train_test_split(x_train, y_train, test_size=test_size) ###################### # Loading parameters # ###################### # get parameters n_data, waveform_length_train, n_neighbors_train = self.x_train.shape self._validate_dimensions(waveform_length_train, n_neighbors_train) ########################## # Optimization objective # ########################## # cross entropy _ = tf.nn.sigmoid_cross_entropy_with_logits(logits=self.o_layer_tr, labels=self.y_tf_tr) cross_entropy = tf.reduce_mean(_) weights = tf.trainable_variables() # regularization term l2_regularizer = (tf.contrib.layers .l2_regularizer(scale=self.l2_reg_scale)) regularization = tf.contrib.layers.apply_regularization(l2_regularizer, weights) regularized_loss = cross_entropy + regularization # train step train_step = (tf.train.AdamOptimizer(self.train_step_size) .minimize(regularized_loss)) ############ # Training # ############ # saver saver = tf.train.Saver(self.vars_dict) logger.debug('Training detector network...') with tf.Session() as sess: init_op = tf.global_variables_initializer() sess.run(init_op) pbar = trange(self.n_iter) for i in pbar: # sample n_batch observations from 0, ..., n_data idx_batch = np.random.choice(n_data, self.n_batch, replace=False) x_train_batch = self.x_train[idx_batch] y_train_batch = self.y_train[idx_batch] # run a training step and compute training loss res = sess.run([train_step, regularized_loss], feed_dict={self.x_tf_tr: x_train_batch, self.y_tf_tr: y_train_batch}) if i % 100 == 0: # compute validation loss and metrics output = sess.run({'val loss': regularized_loss}, feed_dict={self.x_tf_tr: self.x_test, self.y_tf_tr: self.y_test}) pbar.set_description('Tr loss: %s, ' 'Val loss: %s' % (res[1], output['val loss'])) logger.debug('Saving network: %s', self.path_to_model) saver.save(sess, self.path_to_model) path_to_params = change_extension(self.path_to_model, 'yaml') logger.debug('Saving network parameters: %s', path_to_params) params = dict(filters_size=self.filters_size, waveform_length=self.waveform_length, n_neighbors=self.n_neighbors, name=self.model_name) # compute metrics (print them and return them) metrics = self._evaluate() params.update(metrics) # save parameters to disk self._save_params(path=path_to_params, params=params) if save_test_set: self._save_test_set() return params