__author__ = "Simon Andermatt, Simon Pezold"
__copyright__ = "Copyright (C) 2017 Simon Andermatt"
import copy
import logging
import os
import subprocess
import sys
import numpy as np
from mdgru.helper import argget, compile_arguments, generate_defaults_info
[docs]class DataCollection(object):
'''Abstract class for all data handling classes.
'''
_defaults = {'seed': {'help': 'Seed to be used for deterministic random sampling, given no threading is used', 'value': 1234},
'nclasses': None,
}
def __init__(self, kw):
self.origargs = copy.copy(kw)
data_kw, kw = compile_arguments(DataCollection, kw, transitive=False)
for k, v in data_kw.items():
setattr(self, k, v)
self.randomstate = np.random.RandomState(self.seed)
[docs] def set_states(self, state):
''' reset random state generators given the states in "states"
Parameters
----------
states: object
Random generator state
'''
if state is None:
logging.getLogger('eval').warning('could not reproduce state, setting unreproducable random seed')
self.randomstate.set_seed(np.random.randint(0, 1000000))
self.randomstate.set_state(state)
[docs] def get_states(self):
''' Get states of this data collection'''
return self.randomstate.get_state()
[docs] def reset_seed(self, seed=12345678):
''' reset main random number generator with given seed '''
self.randomstate = np.random.RandomState(seed)
[docs] def random_sample(self, **kw):
'''Randomly samples from our dataset. If the implementation knows
different datasets, the dataset string can be used to choose one, if
not, it will be ignored.
Parameters
----------
\*\*kw: keyword args
batch_size can be set, amongst other parameters. See implementing methods for
more detail.
Returns
-------
array
A random sample of length batch_size.
'''
raise Exception("random_sample not implemented in {}"
.format(self.__class__))
[docs] def get_shape(self):
raise Exception("needs to be implemented. should return batch shape" +
"with batch size set to None")
[docs] def get_data_dims(self):
'''Returns the dimensionality of the whole collection (even if samples
are returned/computed on the fly, the theoretical size is returned).
Has between two and three entries (Depending on the type of data. A
dataset with sequence of vectors has 3, a dataset with sequences of
indices has two, etc)
Returns
-------
list
A shape array of the dimensionality of the data.
'''
raise Exception("get_data_dims not implemented in {}"
.format(self.__class__))
[docs] def _one_hot_vectorize(self, indexlabels, nclasses=None, zero_out_label=None):
'''
simplified onehotlabels method. we discourage using interpolated labels
anyways, hence this only allows integer values in indexlabels
Parameters
----------
indexlabels : ndarray
array containing labels or indices for each class, starting at 0 until nclasses-1
nclasses : int
number of classes
zero_out_label : int
label to assign probability of zero for the whole probability distribution
Returns
-------
ndarray
Probabilitydistributions per pixel where at position indexlabels the value is set to 1, otherwise to 0
'''
if nclasses is None:
nclasses = self.nclasses
# we reshape it into dims*classes, onehotvectorize it, and shape it back:
lshape = indexlabels.shape
lsprod = np.prod(lshape)
l = np.zeros([lsprod, nclasses], dtype=np.int32)
indexlabels = indexlabels.flatten()
# print(l.shape)
# print(np.max(indexlabels))
# print(lshape)
l[np.arange(0, lsprod, dtype=np.int32), indexlabels] = 1
if zero_out_label is not None:
l[:, zero_out_label] = 0
# go back to shape from before.
l = l.reshape(list(lshape) + [nclasses])
return l
[docs] @staticmethod
def get_all_tps(folder, featurefiles, maskfiles):
'''
computes list of all folders that are subfolders of folder and contain all provided featurefiles and maskfiles.
Parameters
----------
folder: str
location at which timepoints are searched
featurefiles: list of str
necessary featurefiles to be contained in a timepoint
maskfiles: list of str
necessary maskfiles to be contained in a timepoint
Returns
-------
sorted list
valid timepoints in string format
'''
comm = "find '" + os.path.join(folder, '') + "' -type d -exec test -e {}/" + featurefiles[0]
for i in featurefiles[1:]:
comm += " -a -e {}/" + i
for i in maskfiles:
comm += " -a -e {}/" + i
comm += " \\; -print\n"
res, err = subprocess.Popen(comm, stdout=subprocess.PIPE, shell=True).communicate()
# print(comm)
if (sys.version_info > (3, 0)):
# Python 3 code in this block
return sorted([str(r, 'utf-8') for r in res.split() if r])
else:
# Python 2 code in this block
return sorted([str(r) for r in res.split() if r])
generate_defaults_info(DataCollection)