# FIXME: remove this
try:
from pathlib2 import Path
except ImportError:
from pathlib import Path
import warnings
import logging
import tensorflow as tf
from sklearn.model_selection import train_test_split
import numpy as np
from tqdm import trange
from yass.neuralnetwork.utils import (weight_variable, bias_variable, conv2d,
conv2d_VALID)
from yass.util import load_yaml, change_extension
from yass.neuralnetwork.model import Model
[docs]class NeuralNetTriage(Model):
"""Convolutional Neural Network for spike detection
Parameters
----------
path_to_model: str
Where to save the trained model
threshold: float
Threshold between 0 and 1, values higher than the threshold are
considered spikes
input_tensor
Attributes
-----------
C: int
spatial filter size of the spatial convolutional layer.
R1: int
temporal filter sizes for the temporal convolutional layers.
K1,K2: int
number of filters for each convolutional layer.
W1, W11, W2: tf.Variable
[temporal_filter_size, spatial_filter_size, input_filter_number,
ouput_filter_number] weight matrices
for the covolutional layers.
b1, b11, b2: tf.Variable
bias variable for the convolutional layers.
saver: tf.train.Saver
saver object for the neural network detector.
detector: NeuralNetDetector
Instance of detector
threshold: int
threshold for neural net triage
CLEAN: int
Label assigned to the clean spike class (1)
COLLIDED: int
Label assigned to the collided spike class (0)
"""
CLEAN = 1
COLLIDED = 0
def __init__(self, path_to_model, filters_size,
waveform_length, n_neighbors, threshold,
n_iter=50000, n_batch=512, l2_reg_scale=0.00000005,
train_step_size=0.001, input_tensor=None,
load_test_set=False):
self.logger = logging.getLogger(__name__)
if input_tensor is not None:
if n_neighbors != input_tensor.shape[2]:
warnings.warn('Network n_neighbors ({}) does not match '
'n_neighbors on input_tensor ({}), using '
'only the first n_neighbors from the '
'input_tensor'.format(n_neighbors,
input_tensor.shape[2]))
self.path_to_model = path_to_model
self.model_name = Path(path_to_model).name.replace('.ckpt', '')
self.filters_size = filters_size
self.n_neighbors = n_neighbors
self.waveform_length = waveform_length
self.threshold = threshold
self.n_batch = n_batch
self.l2_reg_scale = l2_reg_scale
self.train_step_size = train_step_size
self.n_iter = n_iter
self.idx_clean = self._make_graph(threshold, input_tensor,
filters_size,
waveform_length,
n_neighbors)
if load_test_set:
self._load_test_set()
[docs] @classmethod
def load(cls, path_to_model, threshold, input_tensor=None,
load_test_set=False):
"""Load a model from a file
"""
if not path_to_model.endswith('.ckpt'):
path_to_model = path_to_model+'.ckpt'
# load necessary parameters
path_to_params = change_extension(path_to_model, 'yaml')
params = load_yaml(path_to_params)
return cls(path_to_model=path_to_model,
filters_size=params['filters_size'],
waveform_length=params['waveform_length'],
n_neighbors=params['n_neighbors'],
threshold=threshold,
input_tensor=input_tensor, load_test_set=load_test_set)
@classmethod
def _make_network(cls, input_tensor, filters_size, waveform_length,
n_neighbors):
"""Mates tensorflow network, from first layer to output layer
"""
K1, K2 = filters_size
# initialize and save nn weights
W1 = weight_variable([waveform_length, 1, 1, K1])
b1 = bias_variable([K1])
W11 = weight_variable([1, 1, K1, K2])
b11 = bias_variable([K2])
W2 = weight_variable([1, n_neighbors, K2, 1])
b2 = bias_variable([1])
# first layer: temporal feature
layer1 = tf.nn.relu(conv2d_VALID(tf.expand_dims(input_tensor, -1),
W1) + b1)
# second layer: feataure mapping
layer11 = tf.nn.relu(conv2d(layer1, W11) + b11)
# third layer: spatial convolution
o_layer = conv2d_VALID(layer11, W2) + b2
vars_dict = {"W1": W1, "W11": W11, "W2": W2, "b1": b1, "b11": b11,
"b2": b2}
return o_layer, vars_dict
def _make_graph(self, threshold, input_tensor, filters_size,
waveform_length, n_neighbors):
"""Builds graph for triage
Parameters:
-----------
input_tensor: tf tensor (n_spikes, n_temporal_length, n_neighbors)
tf tensor that produces spikes waveforms
threshold: int
threshold used on a probability obtained after nn to determine
whether it is a clear spike
Returns:
-----------
tf tensor (n_spikes,)
a boolean tensorflow tensor that produces indices of
clear spikes
"""
# input tensor (waveforms)
if input_tensor is None:
self.x_tf = tf.placeholder("float", [None, None, n_neighbors])
else:
self.x_tf = input_tensor
(self.o_layer,
vars_dict) = NeuralNetTriage._make_network(self.x_tf,
filters_size,
waveform_length,
n_neighbors)
self.saver = tf.train.Saver(vars_dict)
# thrshold it
return self.o_layer[:, 0, 0, 0] > np.log(threshold / (1 - threshold))
[docs] def restore(self, sess):
"""Restore tensor values
"""
self.logger.debug('Restoring tensorflow session from: %s',
self.path_to_model)
self.saver.restore(sess, self.path_to_model)
[docs] def predict(self, waveforms):
"""Triage waveforms
"""
_, waveform_length, n_neighbors = waveforms.shape
# self._validate_dimensions(waveform_length, n_neighbors)
with tf.Session() as sess:
self.restore(sess)
idx_clean = sess.run(self.idx_clean,
feed_dict={self.x_tf: waveforms})
return idx_clean
[docs] def fit(self, x_train, y_train, test_size=0.3, save_test_set=False):
"""Trains the triage network
Parameters
----------
x_train: np.array
[number of data, temporal length, number of channels] training data
for the triage network.
y_train: np.array
[number of data] training label for the triage network.
test_size: float, optional
Proportion of the training set to be used, data is shuffled before
splitting, defaults to 0.3
Returns
-------
dict
Dictionary with network parameters and metrics
Notes
-----
Size is determined but the second dimension in x_train
"""
#####################
# Splitting dataset #
#####################
(self.x_train, self.x_test,
self.y_train, self.y_test) = train_test_split(x_train, y_train,
test_size=test_size)
# get parameters
n_data, waveform_length_train, n_neighbors_train = self.x_train.shape
self._validate_dimensions(waveform_length_train, n_neighbors_train)
# x and y input tensors
x_tf = tf.placeholder("float", [None, self.waveform_length,
self.n_neighbors])
y_tf = tf.placeholder("float", [None])
o_layer, vars_dict = (NeuralNetTriage
._make_network(x_tf,
self.filters_size,
self.waveform_length,
self.n_neighbors))
logits = tf.squeeze(o_layer)
# cross entropy
cross_entropy = tf.reduce_mean(
tf.nn.sigmoid_cross_entropy_with_logits(logits=logits,
labels=y_tf))
# regularization term
weights = tf.trainable_variables()
l2_regularizer = (tf.contrib.layers
.l2_regularizer(scale=self.l2_reg_scale))
regularization_penalty = tf.contrib.layers.apply_regularization(
l2_regularizer, weights)
regularized_loss = cross_entropy + regularization_penalty
# train step
train_step = tf.train.AdamOptimizer(self.train_step_size).minimize(
regularized_loss)
# saver
saver = tf.train.Saver(vars_dict)
############
# training #
############
self.logger.debug('Training triage network...')
with tf.Session() as sess:
init_op = tf.global_variables_initializer()
sess.run(init_op)
pbar = trange(self.n_iter)
for i in pbar:
idx_batch = np.random.choice(n_data, self.n_batch,
replace=False)
res = sess.run([train_step, regularized_loss],
feed_dict={x_tf: self.x_train[idx_batch],
y_tf: self.y_train[idx_batch]})
if i % 100 == 0:
# compute validation loss and metrics
output = sess.run({'val loss': regularized_loss},
feed_dict={x_tf: self.x_test,
y_tf: self.y_test})
pbar.set_description('Tr loss: %s, '
'Val loss: %s' % (res[1],
output['val loss']))
self.logger.debug('Saving network: %s', self.path_to_model)
saver.save(sess, self.path_to_model)
path_to_params = change_extension(self.path_to_model, 'yaml')
self.logger.debug('Saving network parameters: %s', path_to_params)
params = dict(filters_size=self.filters_size,
waveform_length=self.waveform_length,
n_neighbors=self.n_neighbors,
name=self.model_name)
# compute metrics (print them and return them)
metrics = self._evaluate()
params.update(metrics)
# save parameters to disk
self._save_params(path=path_to_params, params=params)
if save_test_set:
self._save_test_set()
return params