Source code for yass.batch.batch

from functools import partial
import time
import logging
import os.path
from copy import copy
import os

try:
    from pathlib2 import Path
except ImportError:
    from pathlib import Path

from multiprocess import Pool, Manager
import yaml
from tqdm import tqdm

from yass.util import function_path, human_readable_time
from yass.batch import util
from yass.batch.generator import IndexGenerator
from yass.batch.reader import RecordingsReader


[docs]class BatchProcessor(object): """ Batch processing for large numpy matrices Parameters ---------- path_to_recordings: str Path to recordings file dtype: str Numpy dtype n_channels: int Number of channels data_order: str Recordings order, one of ('channels', 'samples'). In a dataset with k observations per channel and j channels: 'channels' means first k contiguous observations come from channel 0, then channel 1, and so on. 'sample' means first j contiguous data are the first observations from all channels, then the second observations from all channels and so on max_memory: int or str Max memory to use in each batch, interpreted as bytes if int, if string, it can be any of {N}KB, {N}MB or {N}GB buffer_size: int, optional Buffer size, defaults to 0. Only relevant when performing multi-channel operations loader: str ('memmap', 'array' or 'python'), optional How to load the data. memmap loads the data using a wrapper around np.memmap (see :class:`~yass.batch.MemoryMap` for details), 'array' using numpy.fromfile and 'python' loads it using a wrapper around Python file API. Defaults to 'python'. Beware that the Python loader has limited indexing capabilities, see :class:`~yass.batch.BinaryReader` for details show_progress_bar: bool, optional Show progress bar when running operations, defaults to True Raises ------ ValueError If dimensions do not match according to the file size, dtype and number of channels """ def __init__(self, path_to_recordings, dtype=None, n_channels=None, data_order=None, max_memory='1GB', buffer_size=0, loader='memmap', show_progress_bar=True): self.data_order = data_order self.buffer_size = buffer_size self.path_to_recordings = path_to_recordings self.dtype = dtype self.n_channels = n_channels self.data_order = data_order self.loader = loader self.show_progress_bar = show_progress_bar self.reader = RecordingsReader(self.path_to_recordings, self.dtype, self.n_channels, self.data_order, loader=self.loader, buffer_size=buffer_size, return_data_index=True) self.indexer = IndexGenerator(self.reader.observations, self.reader.channels, self.reader.dtype, max_memory) self.logger = logging.getLogger(__name__)
[docs] def single_channel(self, force_complete_channel_batch=True, from_time=None, to_time=None, channels='all'): """ Generate batches where each index has observations from a single channel Returns ------- A generator that yields batches, if force_complete_channel_batch is False, each generated value is a tuple with the batch and the channel for the index for the corresponding channel Examples -------- .. literalinclude:: ../../examples/batch/single_channel.py """ indexes = self.indexer.single_channel(force_complete_channel_batch, from_time, to_time, channels) if force_complete_channel_batch: for idx in indexes: subset, _ = self.reader[idx] yield subset else: for idx in indexes: channel_idx = idx[1] subset, _ = self.reader[idx] yield subset, channel_idx
[docs] def multi_channel(self, from_time=None, to_time=None, channels='all', return_data=True): """ Generate indexes where each index has observations from more than one channel Returns ------- generator: A tuple of size three: the first element is the subset of the data for the ith batch, second element is the slice object with the limits of the data in [observations, channels] format (excluding the buffer), the last element is the absolute index of the data again in [observations, channels] format Examples -------- .. literalinclude:: ../../examples/batch/multi_channel.py """ indexes = self.indexer.multi_channel(from_time, to_time, channels) for idx in indexes: if return_data: subset, _ = self.reader[idx] yield subset else: yield idx
[docs] def single_channel_apply(self, function, mode, output_path=None, force_complete_channel_batch=True, from_time=None, to_time=None, channels='all', if_file_exists='overwrite', cast_dtype=None, **kwargs): """ Apply a transformation where each batch has observations from a single channel Parameters ---------- function: callable Function to be applied, must accept a 1D numpy array as its first parameter mode: str 'disk' or 'memory', if 'disk', a binary file is created at the beginning of the operation and each partial result is saved (ussing numpy.ndarray.tofile function), at the end of the operation two files are generated: the binary file and a yaml file with some file parameters (useful if you want to later use RecordingsReader to read the file). If 'memory', partial results are kept in memory and returned as a list output_path: str, optional Where to save the output, required if 'disk' mode force_complete_channel_batch: bool, optional If True, every index generated will correspond to all the observations in a single channel, hence n_batches = n_selected_channels, defaults to True. If True from_time and to_time must be None from_time: int, optional Starting time, defaults to None to_time: int, optional Ending time, defaults to None channels: int, tuple or str, optional A tuple with the channel indexes or 'all' to traverse all channels, defaults to 'all' if_file_exists: str, optional One of 'overwrite', 'abort', 'skip'. If 'overwrite' it replaces the file if it exists, if 'abort' if raise a ValueError exception if the file exists, if 'skip' if skips the operation if the file exists. Only valid when mode = 'disk' cast_dtype: str, optional Output dtype, defaults to None which means no cast is done **kwargs kwargs to pass to function Examples -------- .. literalinclude:: ../../examples/batch/single_channel_apply.py Notes ----- When applying functions in 'disk' mode will incur in memory overhead, which depends on the function implementation, this is an important thing to consider if the transformation changes the data's dtype (e.g. converts int16 to float64), which means that a chunk of 1MB in int16 will have a size of 4MB in float64. Take that into account when setting max_memory. For performance reasons in 'disk' mode, output data is in 'channels' order """ if mode not in ['disk', 'memory']: raise ValueError('Mode should be disk or memory, received: {}' .format(mode)) if mode == 'disk' and output_path is None: raise ValueError('output_path is required in "disk" mode') if (mode == 'disk' and if_file_exists == 'abort' and os.path.exists(output_path)): raise ValueError('{} already exists'.format(output_path)) if (mode == 'disk' and if_file_exists == 'skip' and os.path.exists(output_path)): # load params... path_to_yaml = output_path.replace('.bin', '.yaml') if not os.path.exists(path_to_yaml): raise ValueError("if_file_exists = 'skip', but {}" " is missing, aborting..." .format(path_to_yaml)) with open(path_to_yaml) as f: params = yaml.load(f) self.logger.info('{} exists, skiping...'.format(output_path)) return output_path, params self.logger.info('Applying function {}...' .format(function_path(function))) if mode == 'disk': fn = self._single_channel_apply_disk start = time.time() res = fn(function, output_path, force_complete_channel_batch, from_time, to_time, channels, cast_dtype, **kwargs) elapsed = time.time() - start self.logger.info('{} took {}' .format(function_path(function), human_readable_time(elapsed))) return res else: fn = self._single_channel_apply_memory start = time.time() res = fn(function, force_complete_channel_batch, from_time, to_time, channels, cast_dtype, **kwargs) elapsed = time.time() - start self.logger.info('{} took {}' .format(function_path(function), human_readable_time(elapsed))) return res
def _single_channel_apply_disk(self, function, output_path, force_complete_channel_batch, from_time, to_time, channels, cast_dtype, **kwargs): f = open(output_path, 'wb') indexes = self.indexer.single_channel(force_complete_channel_batch, from_time, to_time, channels) indexes = list(indexes) iterator = enumerate(indexes) if self.show_progress_bar: iterator = tqdm(iterator, total=len(indexes)) for i, idx in iterator: self.logger.debug('Processing channel {}...'.format(i)) self.logger.debug('Reading batch...') subset, _ = self.reader[idx] self.logger.debug('Executing function...') if cast_dtype is None: res = function(subset, **kwargs) else: res = function(subset, **kwargs).astype(cast_dtype) self.logger.debug('Writing to disk...') res.tofile(f) f.close() params = util.make_metadata(channels, self.n_channels, str(res.dtype), output_path) return output_path, params def _single_channel_apply_memory(self, function, force_complete_channel_batch, from_time, to_time, channels, cast_dtype, **kwargs): indexes = self.indexer.single_channel(force_complete_channel_batch, from_time, to_time, channels) indexes = list(indexes) iterator = enumerate(indexes) if self.show_progress_bar: iterator = tqdm(iterator, total=len(indexes)) results = [] for i, idx in iterator: self.logger.debug('Processing channel {}...'.format(i)) self.logger.debug('Reading batch...') subset, _ = self.reader[idx] if cast_dtype is None: res = function(subset, **kwargs) else: res = function(subset, **kwargs).astype(cast_dtype) self.logger.debug('Appending partial result...') results.append(res) return results
[docs] def multi_channel_apply(self, function, mode, cleanup_function=None, output_path=None, from_time=None, to_time=None, channels='all', if_file_exists='overwrite', cast_dtype=None, pass_batch_info=False, pass_batch_results=False, processes=1, **kwargs): """ Apply a function where each batch has observations from more than one channel Parameters ---------- function: callable Function to be applied, first parameter passed will be a 2D numpy array in 'long' shape (number of observations, number of channels). If pass_batch_info is True, another two keyword parameters will be passed to function: 'idx_local' is the slice object with the limits of the data in [observations, channels] format (excluding the buffer), 'idx' is the absolute index of the data again in [observations, channels] format mode: str 'disk' or 'memory', if 'disk', a binary file is created at the beginning of the operation and each partial result is saved (ussing numpy.ndarray.tofile function), at the end of the operation two files are generated: the binary file and a yaml file with some file parameters (useful if you want to later use RecordingsReader to read the file). If 'memory', partial results are kept in memory and returned as a list cleanup_function: callable, optional A function to be executed after `function` and before adding the partial result to the list of results (if `memory` mode) or to the biinary file (if in `disk mode`). `cleanup_function` will be called with the following parameters (in that order): result from applying `function` to the batch, slice object with the idx where the data is located (exludes buffer), slice object with the absolute location of the data and buffer size output_path: str, optional Where to save the output, required if 'disk' mode force_complete_channel_batch: bool, optional If True, every index generated will correspond to all the observations in a single channel, hence n_batches = n_selected_channels, defaults to True. If True from_time and to_time must be None from_time: int, optional Starting time, defaults to None to_time: int, optional Ending time, defaults to None channels: int, tuple or str, optional A tuple with the channel indexes or 'all' to traverse all channels, defaults to 'all' if_file_exists: str, optional One of 'overwrite', 'abort', 'skip'. If 'overwrite' it replaces the file if it exists, if 'abort' if raise a ValueError exception if the file exists, if 'skip' if skips the operation if the file exists. Only valid when mode = 'disk' cast_dtype: str, optional Output dtype, defaults to None which means no cast is done pass_batch_info: bool, optional Whether to call the function with batch info or just call it with the batch data (see description in the function) parameter pass_batch_results: bool, optional Whether to pass results from the previous batch to the next one, defaults to False. Only relevant when mode='memory'. If True, function will be called with the keyword parameter 'previous_batch' which contains the computation for the last batch, it is set to None in the first batch **kwargs kwargs to pass to function Returns ------- output_path, params (when mode is 'disk') Path to output binary file, Binary file params list (when mode is 'memory' and pass_batch_results is False) List where every element is the result of applying the function to one batch. When pass_batch_results is True, it returns the output of the function for the last batch Examples -------- .. literalinclude:: ../../examples/batch/multi_channel_apply_disk.py .. literalinclude:: ../../examples/batch/multi_channel_apply_memory.py Notes ----- Applying functions will incur in memory overhead, which depends on the function implementation, this is an important thing to consider if the transformation changes the data's dtype (e.g. converts int16 to float64), which means that a chunk of 1MB in int16 will have a size of 4MB in float64. Take that into account when setting max_memory For performance reasons, outputs data in 'samples' order. """ if mode not in ['disk', 'memory']: raise ValueError('Mode should be disk or memory, received: {}' .format(mode)) if mode == 'disk' and output_path is None: raise ValueError('output_path is required in "disk" mode') if (mode == 'disk' and if_file_exists == 'abort' and os.path.exists(output_path)): raise ValueError('{} already exists'.format(output_path)) self.logger.info('Applying function {}...' .format(function_path(function))) if (mode == 'disk' and if_file_exists == 'skip' and os.path.exists(output_path)): # load params... path_to_yaml = output_path.replace('.bin', '.yaml') if not os.path.exists(path_to_yaml): raise ValueError("if_file_exists = 'skip', but {}" " is missing, aborting..." .format(path_to_yaml)) with open(path_to_yaml) as f: params = yaml.load(f) self.logger.info('{} exists, skiping...'.format(output_path)) return output_path, params if mode == 'disk': if processes == 1: fn = self._multi_channel_apply_disk else: fn = partial(self._multi_channel_apply_disk_parallel, processes=processes) start = time.time() res = fn(function, cleanup_function, output_path, from_time, to_time, channels, cast_dtype, pass_batch_info, pass_batch_results, **kwargs) elapsed = time.time() - start self.logger.info('{} took {}' .format(function_path(function), human_readable_time(elapsed))) return res else: fn = self._multi_channel_apply_memory start = time.time() res = fn(function, cleanup_function, from_time, to_time, channels, cast_dtype, pass_batch_info, pass_batch_results, **kwargs) elapsed = time.time() - start self.logger.info('{} took {}' .format(function_path(function), human_readable_time(elapsed))) return res
def _multi_channel_apply_disk(self, function, cleanup_function, output_path, from_time, to_time, channels, cast_dtype, pass_batch_info, pass_batch_results, **kwargs): if pass_batch_results: raise NotImplementedError("pass_batch_results is not " "implemented on 'disk' mode") f = open(output_path, 'wb') output_path = Path(output_path) data = self.multi_channel(from_time, to_time, channels, return_data=False) n_batches = self.indexer.n_batches(from_time, to_time, channels) iterator = enumerate(data) if self.show_progress_bar: iterator = tqdm(iterator, total=n_batches) for i, idx in iterator: res = util.batch_runner((i, idx), function, self.reader, pass_batch_info, cast_dtype, kwargs, cleanup_function, self.buffer_size, save_chunks=False) res.tofile(f) f.close() params = util.make_metadata(channels, self.n_channels, str(res.dtype), output_path) return output_path, params def _multi_channel_apply_disk_parallel(self, function, cleanup_function, output_path, from_time, to_time, channels, cast_dtype, pass_batch_info, pass_batch_results, processes, **kwargs): self.logger.debug('Starting parallel operation...') if pass_batch_results: raise NotImplementedError("pass_batch_results is not " "implemented on 'disk' mode") # need to convert to a list, oherwise cannot be pickled data = list(self.multi_channel(from_time, to_time, channels, return_data=False)) n_batches = self.indexer.n_batches(from_time, to_time, channels) self.logger.info('Data will be splitted in %s batches', n_batches) output_path = Path(output_path) # create local variables to avoid pickling problems _path_to_recordings = copy(self.path_to_recordings) _dtype = copy(self.dtype) _n_channels = copy(self.n_channels) _data_order = copy(self.data_order) _loader = copy(self.loader) _buffer_size = copy(self.buffer_size) reader = partial(RecordingsReader, path_to_recordings=_path_to_recordings, dtype=_dtype, n_channels=_n_channels, data_order=_data_order, loader=_loader, return_data_index=True) m = Manager() mapping = m.dict() next_to_write = m.Value('i', 0) def parallel_runner(element): i, _ = element res = util.batch_runner(element, function, reader, pass_batch_info, cast_dtype, kwargs, cleanup_function, _buffer_size, save_chunks=False, output_path=output_path) if i == 0: mapping['dtype'] = str(res.dtype) while True: if next_to_write.value == i: with open(str(output_path), 'wb' if i == 0 else 'ab') as f: res.tofile(f) next_to_write.value += 1 break # run jobs self.logger.debug('Creating processes pool...') p = Pool(processes) res = p.map_async(parallel_runner, enumerate(data)) finished = 0 if self.show_progress_bar: pbar = tqdm(total=n_batches) if self.show_progress_bar: while True: if next_to_write.value > finished: update = next_to_write.value - finished pbar.update(update) finished = next_to_write.value if next_to_write.value == n_batches: break pbar.close() else: res.get() # save metadata params = util.make_metadata(channels, self.n_channels, mapping['dtype'], output_path) return output_path, params def _multi_channel_apply_memory(self, function, cleanup_function, from_time, to_time, channels, cast_dtype, pass_batch_info, pass_batch_results, **kwargs): data = self.multi_channel(from_time, to_time, channels, return_data=False) n_batches = self.indexer.n_batches(from_time, to_time, channels) results = [] if pass_batch_results: kwargs['previous_batch'] = None iterator = enumerate(data) if self.show_progress_bar: iterator = tqdm(iterator, total=n_batches) for i, idx in iterator: res = util.batch_runner((i, idx), function, self.reader, pass_batch_info, cast_dtype, kwargs, cleanup_function, self.buffer_size, save_chunks=False) if pass_batch_results: kwargs['previous_batch'] = res else: results.append(res) return res if pass_batch_results else results