Source code for yass.batch.reader

from __future__ import division
import os
import yaml
import numpy as np
from functools import partial, reduce
from collections import Iterable
from yass.batch.buffer import BufferGenerator


[docs]class RecordingsReader(object): """ Neural recordings reader. If a file with the same name but yaml extension exists in the directory it looks for dtype, channels and data_order, otherwise you need to pass the parameters in the constructor Parameters ---------- path_to_recordings: str Path to recordings file dtype: str Numpy dtype n_channels: int Number of channels data_order: str Recordings order, one of ('channels', 'samples'). In a dataset with k observations per channel and j channels: 'channels' means first k contiguous observations come from channel 0, then channel 1, and so on. 'sample' means first j contiguous data are the first observations from all channels, then the second observations from all channels and so on loader: str ('memmap', 'array' or 'python'), optional How to load the data. memmap loads the data using a wrapper around np.memmap (see :class:`~yass.batch.MemoryMap` for details), 'array' using numpy.fromfile and 'python' loads it using a wrapper around Python file API. Defaults to 'python'. Beware that the Python loader has limited indexing capabilities, see :class:`~yass.batch.BinaryReader` for details buffer_size: int, optional Adds buffer return_data_index: bool, optional If True, a tuple will be returned when indexing: the first element will be the data and the second the index corresponding to the actual data (excluding bufffer), when buffer is equal to zero, this just returns they original index since there is no buffer Raises ------ ValueError If dimensions do not match according to the file size, dtype and number of channels Notes ----- This is just an utility class to index binary files in a consistent way, it does not matter the order of the file ('channels' or 'samples'), indexing is performed in [observations, channels] format. This class is mainly used by other internal YASS classes to maintain a consistent indexing order. Examples -------- .. literalinclude:: ../../examples/batch/reader.py """ def __init__(self, path_to_recordings, dtype=None, n_channels=None, data_order=None, loader='memmap', buffer_size=0, return_data_index=False): path_to_recordings = str(path_to_recordings) path_to_yaml = str(path_to_recordings).replace('.bin', '.yaml') if (not os.path.isfile(path_to_yaml) and (dtype is None or n_channels is None or data_order is None)): raise ValueError('At least one of: dtype, channels or data_order ' 'are None, this is only allowed when a yaml ' 'file is present in the same location as ' 'the bin file, but no {} file exists' .format(path_to_yaml)) elif (os.path.isfile(path_to_yaml) and dtype is None and n_channels is None and data_order is None): with open(path_to_yaml) as f: params = yaml.load(f) dtype = params['dtype'] n_channels = params['n_channels'] data_order = params['data_order'] self._data_order = data_order self._n_channels = n_channels self._dtype = dtype if not isinstance(dtype, str) else np.dtype(dtype) self.buffer_size = buffer_size self.return_data_index = return_data_index filesize = os.path.getsize(path_to_recordings) if not (filesize / self._dtype.itemsize).is_integer(): raise ValueError('Wrong filesize and/or dtype, filesize {:, }' 'bytes is not divisible by the item size {}' ' bytes'.format(filesize, self._dtype.itemsize)) if int(filesize / self._dtype.itemsize) % n_channels: raise ValueError('Wrong n_channels, length of the data does not ' 'match number of n_channels (observations % ' 'n_channels != 0, verify that the number of ' 'n_channels and/or the dtype are correct') self._n_observations = int(filesize / self._dtype.itemsize / n_channels) if self.buffer_size: # data format is long since reader will return data in that format self.buffer_generator = BufferGenerator(self._n_observations, data_shape='long', buffer_size=buffer_size) if loader not in ['memmap', 'array', 'python']: raise ValueError("loader must be one of 'memmap', 'array' or " "'python'") # if data is in channels order, we will read as "columns first", # if data is ith sample order, we will read as as "rows first", # this ensures we have a consistent index array[observations, channels] order = dict(channels='F', samples='C') shape = self._n_observations, n_channels def fromfile(path, dtype, data_order, shape): if data_order == 'samples': return np.fromfile(path, dtype=dtype).reshape(shape) else: return np.fromfile(path, dtype=dtype).reshape(shape[::-1]).T if loader in ['memmap', 'array']: fn = (partial(MemoryMap, mode='r', shape=shape, order=order[data_order]) if loader == 'memmap' else partial(fromfile, data_order=data_order, shape=shape)) self._data = fn(path_to_recordings, dtype=self._dtype) if loader == 'array': self._data = self._data.reshape(shape) else: self._data = BinaryReader(path_to_recordings, dtype, shape, order=order[data_order]) def __getitem__(self, key): # this happens when doung something like # x[[1,2,3]] or x[np.array([1,2,3])] if not isinstance(key, tuple): key = (key, slice(None)) obs_idx, _ = key # index where the data is located (excluding buffer) start = obs_idx.start or 0 stop = obs_idx.stop or self.observations # build indexes for observations idx = slice(self.buffer_size, stop - start + self.buffer_size, obs_idx.step) # buffer is added to all channels ch_idx = slice(None, None, None) data_idx = (idx, ch_idx) if self.buffer_size: # modify indexes to include buffered data (idx_new, (buff_start, buff_end)) = (self.buffer_generator .update_key_with_buffer(key)) subset = self._data[idx_new] # add zeros if needed (start or end of the data) subset_buff = self.buffer_generator.add_buffer(subset, buff_start, buff_end) return ((subset_buff, data_idx) if self.return_data_index else subset_buff) else: subset = self._data[key] return (subset, data_idx) if self.return_data_index else subset def __repr__(self): return ('Reader for recordings with {:,} observations and {:,} ' 'channels in "{}" format' .format(self.observations, self.channels, self._data_order)) @property def shape(self): """Data shape in (observations, channels) format """ return self._data.shape @property def observations(self): """Number of observations """ return self._n_observations @property def channels(self): """Number of channels """ return self._n_channels @property def data_order(self): """Data order """ return self._data_order @property def dtype(self): """Numpy's dtype """ return self._dtype @property def data(self): """Underlying numpy data """ return self._data
[docs]class BinaryReader(object): """ Reading batches from large array binary files on disk, similar to numpy.memmap. It is essentially just a wrapper around Python files API to read through large array binary file using the array[:,:] syntax. Parameters ---------- order: str Array order 'C' for 'Row-major order' or 'F' for 'Column-major order' Notes ----- https://en.wikipedia.org/wiki/Row-_and_column-major_order """ def __init__(self, path_to_file, dtype, shape, order='F'): if order not in ('C', 'F'): raise ValueError('order must be either "C" or "F"') self.order = order self.dtype = dtype if not isinstance(dtype, str) else np.dtype(dtype) self.itemsize = self.dtype.itemsize self.n_row, self.n_col = shape self.f = open(path_to_file, 'rb') self.row_size_byte = self.itemsize * self.n_col self.col_size_byte = self.itemsize * self.n_row def _read_n_bytes_from(self, f, n, start): f.seek(int(start)) return f.read(n) def _read_from_starts(self, f, starts): b = [self._read_n_bytes_from(f, n=1, start=s) for s in starts] return reduce(lambda x, y: x+y, b) def _read_row_major_order(self, rows, col_start, col_end): """Data where contiguous bytes are from the same row (C, row-major) """ # compute offset to read from "col_start" start_byte = col_start * self.itemsize # number of consecutive observations to read n_cols_to_read = col_end - col_start # number of consecutive bytes to read to_read_bytes = n_cols_to_read * self.itemsize # compute bytes where reading starts in every row: # where row starts + offset due to row_start start_bytes = [row * self.row_size_byte + start_byte for row in rows] batch = [np.frombuffer(self._read_n_bytes_from(self.f, to_read_bytes, start), dtype=self.dtype) for start in start_bytes] return np.array(batch) def _read_column_major_order(self, row_start, row_end, cols): """Data where contiguous bytes are from the same column (F, column-major) """ # compute start byte position for every row start_byte = row_start * self.itemsize # how many consecutive bytes in each read rows_to_read = row_end - row_start to_read_bytes = self.itemsize * rows_to_read # compute seek poisitions (first "row_start "observation on # desired columns) start_bytes = [col * self.col_size_byte + start_byte for col in cols] batch = [np.frombuffer(self._read_n_bytes_from(self.f, to_read_bytes, start), dtype=self.dtype) for start in start_bytes] batch = np.array(batch) return batch.T def __getitem__(self, key): if not isinstance(key, tuple) or len(key) > 2: raise ValueError('Must pass two slice objects i.e. obj[:,:]') int_key = any((isinstance(k, int) for k in key)) def _int2slice(k): # when passing ints instead of slices array[0, 0] return k if not isinstance(k, int) else slice(k, k+1, None) key = [_int2slice(k) for k in key] for k in key: if isinstance(k, slice) and k.step: raise ValueError('Step size not supported') rows, cols = key # fill slices in case they are [:X] or [X:] if isinstance(rows, slice): rows = slice(rows.start or 0, rows.stop or self.n_row, None) if isinstance(cols, slice): cols = slice(cols.start or 0, cols.stop or self.n_col, None) if self.order == 'C': if isinstance(cols, Iterable): raise NotImplementedError('Column indexing with iterables ' 'is not implemented in C order') if isinstance(rows, slice): rows = range(rows.start, rows.stop) res = self._read_row_major_order(rows, cols.start, cols.stop) else: if isinstance(rows, Iterable): raise NotImplementedError('Row indexing with iterables ' 'is not implemented in F order') if isinstance(cols, slice): cols = range(cols.start, cols.stop) res = self._read_column_major_order(rows.start, rows.stop, cols) # convert to 1D array if either of keys was int return res if not int_key else res.reshape(-1) def __del__(self): self.f.close() @property def shape(self): return self.n_row, self.n_col def __len__(self): return self.n_row
# FIXME: this is a temporary solution, we need to investigate why memmap # is blowing up memory
[docs]class MemoryMap: """ Wrapper for numpy.memmap that creates a new memmap on each __getitem__ call to save memory """ def __init__(self, *args, **kwargs): self.args = args self.kwargs = kwargs self._init_mmap() def _init_mmap(self): self._mmap = np.memmap(*self.args, **self.kwargs) def __getitem__(self, index): res = self._mmap[index] self._init_mmap() return res def __getattr__(self, key): return getattr(self._mmap, key)