Source code for molyso.generic.etc

# -*- coding: utf-8 -*-
"""
etc.py contains various helper functions and classes, which are not directly related to data processing.
"""
from __future__ import division, unicode_literals, print_function
import os
import sys
import hashlib
import time
import logging
import sqlite3

import numpy as np

from io import BytesIO
from ..debugging import DebugPlot

try:
    import tqdm
except ImportError:
    tqdm = None


logger = logging.getLogger(__name__)


[docs]def silent_progress_bar(iterable): """ Dummy function, just returns an iterator. :param iterable: the iterable to turn into an iterable :type iterable: iterable :return: iterable :rtype: iterable >>> next(silent_progress_bar([1, 2, 3])) 1 """ return iter(iterable)
def fancy_progress_bar(iterable): """ Returns in iterator which will show progress as well. Will either use the tqdm module when available, or a simpler implementation. :param iterable: the iterable to progress-ify :type iterable: iterable :rtype: iterable :return: progress-ified iterable """ if tqdm: # weird bug: if the threading magic in tqdm is active, multiprocessing in molyso gets stuck! # should be investigated further, but for now, let us just disable it ... tqdm.tqdm.monitor_interval = 0 for item in tqdm.tqdm(iterable): yield item else: times = np.zeros(len(iterable), dtype=float) for n, i in enumerate(iterable): start_time = time.time() yield i stop_time = time.time() times[n] = stop_time - start_time eta = " ETA %.2fs" % float(np.mean(times[:n + 1]) * (len(iterable) - (n + 1))) logger.info("processed %d/%d [took %.3fs%s]" % (n + 1, len(iterable), times[n], eta))
[docs]def iter_time(iterable): """ Will print the total time elapsed during iteration of ``iterable`` afterwards. :param iterable: iterable :type iterable: iterable :rtype: iterable :return: iterable """ start_time = time.time() for n in iterable: yield n stop_time = time.time() logger.info("whole step took %.3fs" % (stop_time - start_time,))
_fancy_progress_bar = fancy_progress_bar
[docs]def fancy_progress_bar(iterable): """ :param iterable: :return: """ return iter_time(_fancy_progress_bar(iterable))
[docs]def dummy_progress_indicator(): """ :return: """ return iter(int, 1)
[docs]def ignorant_next(iterable): """ Will try to iterate to the next value, or return None if none is available. :param iterable: :return: """ try: return next(iterable) except StopIteration: return None
[docs]class QuickTableDumper(object): """ :param recipient: """ delimiter = '\t' line_end = '\n' precision = 8 def __init__(self, recipient=None): if recipient is None: recipient = sys.stdout self.recipient = recipient self.headers = []
[docs] def write_list(self, the_list): """ :param the_list: """ self.recipient.write(self.delimiter.join(map(self.stringify, the_list)) + self.line_end)
[docs] def add(self, row): """ :param row: """ if len(self.headers) == 0: self.headers = list(sorted(row.keys())) self.write_list(self.headers) self.write_list(row[k] for k in self.headers)
[docs] def stringify(self, obj): """ :param obj: :return: """ if type(obj) in (float, np.float32, np.float64) and self.precision: return str(round(obj, self.precision)) else: return str(obj)
try: # noinspection PyUnresolvedReferences import cPickle pickle = cPickle except ImportError: import pickle try: import _thread except ImportError: import thread as _thread if os.name != 'nt': def correct_windows_signal_handlers(): """ Dummy for non-windows os. """ pass else:
[docs] def correct_windows_signal_handlers(): """ Corrects Windows signal handling, otherwise multiprocessing solutions will not correctly exit if Ctrl-C is used to interrupt them. :return: """ os.environ['PATH'] += os.path.pathsep + os.path.dirname(os.path.abspath(sys.executable)) try: # noinspection PyUnresolvedReferences import win32api def _handler(_, hook=_thread.interrupt_main): hook() return 1 win32api.SetConsoleCtrlHandler(_handler, 1) except ImportError: logger.warning("Running on Windows, but module 'win32api' could not be imported to fix signal handler.\n" + "Ctrl-C might break the program ..." + "Fix: Install the module!")
[docs]def debug_init(): """ Initialized debug mode, as of now this means that DebugPlot is set to active (it will produce a debug.pdf) """ DebugPlot.force_active = True np.set_printoptions(threshold=sys.maxsize)
[docs]def parse_range(s, maximum=0): """ :param s: :param maximum: :return: """ maximum -= 1 splits = s.replace(' ', '').replace(';', ',').split(',') ranges = [] remove = [] not_values = False for frag in splits: if frag[0] == '~': not_values = not not_values frag = frag[1:] if '-' in frag: f, t = frag.split('-') interval = 1 if '%' in t: t, _interval = t.split('%') interval = int(_interval) if t == '': t = maximum f, t = int(f), int(t) t = min(t, maximum) parsed_fragment = range(f, t + 1, interval) else: parsed_fragment = [int(frag)] if not_values: remove += parsed_fragment else: ranges += parsed_fragment return list(sorted(set(ranges) - set(remove)))
[docs]def prettify_numpy_array(arr, space_or_prefix): """ Returns a properly indented string representation of a numpy array. :param arr: :param space_or_prefix: :return: """ six_spaces = ' ' * 6 prepared = repr(np.array(arr)).replace(')', '').replace('array(', six_spaces) if isinstance(space_or_prefix, int): return prepared.replace(six_spaces, ' ' * space_or_prefix) else: return space_or_prefix + prepared.replace(six_spaces, ' ' * len(space_or_prefix)).lstrip()
[docs]def bits_to_numpy_type(bits): """ Returns a numpy.dtype for one of the common image bit-depths: 8 for unsigned int, 16 for unsigned short, 32 for float :param bits: :return: """ # this is supposed to throw an error return { 8: np.uint8, 16: np.uint16, 32: np.float32 }[int(bits)]
[docs]class BaseCache(object): """ A caching class """
[docs] @staticmethod def prepare_key(key): """ :param key: :return: """ if isinstance(key, type('')): return key else: return repr(key)
[docs] @staticmethod def serialize(data): """ :param data: :return: """ try: bio = BytesIO() pickle.dump(data, bio, protocol=pickle.HIGHEST_PROTOCOL) try: # noinspection PyUnresolvedReferences pickled_data = bio.getbuffer() except AttributeError: pickled_data = bio.getvalue() except ImportError: pickled_data = pickle.dumps(data, protocol=pickle.HIGHEST_PROTOCOL) return pickled_data
[docs] @staticmethod def deserialize(data): """ :param data: :return: """ assert data is not None bio = BytesIO(data) return pickle.load(bio)
def __init__(self, filename_to_be_hashed, ignore_cache='nothing', cache_token=None): self.logger = logging.getLogger(__name__ + '.' + self.__class__.__name__) self.filename_hash_source = filename_to_be_hashed if cache_token is None: self.cache_token = "%s.%s" % ( os.path.basename(filename_to_be_hashed).replace('.', '_').replace('?', '_').replace(',', '_'), hashlib.sha1(str(os.path.abspath(filename_to_be_hashed).lower()).encode()).hexdigest()[:8]) else: self.cache_token = cache_token if ignore_cache == 'everything': self.ignore_cache = True elif ignore_cache == 'nothing': self.ignore_cache = [] else: self.ignore_cache = ignore_cache.split(',')
[docs] def contains(self, key): """ :param key: :return: """ return False
[docs] def get(self, key): """ :param key: :return: """ return ''
[docs] def set(self, key, value): """ :param key: :param value: :return: """ return
def __contains__(self, key): if self.ignore_cache is True or key in self.ignore_cache: return False else: try: self.logger.debug("Checking whether '%s' exists", key) return self.contains(self.prepare_key(key)) except Exception as e: self.logger.exception( "While %s an Exception occurred (but continuing): %", repr(self.__contains__), repr(e) ) return False def __getitem__(self, key): try: self.logger.debug("Getting data for '%s'", key) return self.deserialize(self.get(self.prepare_key(key))) except Exception as e: self.logger.exception( "While %s an Exception occurred (but continuing): %s. Note that this will yield undefined behavior.", repr(self.__getitem__), repr(e) ) # this is technically wrong ... return None def __setitem__(self, key, value): if self.ignore_cache is True or key in self.ignore_cache: return else: try: self.logger.debug("Setting data for '%s'", key) self.set(self.prepare_key(key), self.serialize(value)) except Exception as e: self.logger.exception( "While %s an Exception occurred (but continuing): %s", repr(self.__setitem__), repr(e) )
[docs]class FileCache(BaseCache): """ A caching class which stores the data in flat files. """
[docs] def build_cache_filename(self, suffix): """ :param suffix: :return: """ return "%s.%s.cache" % (self.cache_token, suffix,)
[docs] def contains(self, key): """ :param key: :return: """ return os.path.isfile(self.build_cache_filename(key))
[docs] def get(self, key): """ :param key: :return: """ with open(self.build_cache_filename(key), 'rb') as fp: return fp.read(os.path.getsize(self.build_cache_filename(key)))
[docs] def set(self, key, value): """ :param key: :param value: """ with open(self.build_cache_filename(key), 'wb+') as fp: fp.write(value)
Cache = FileCache
[docs]class Sqlite3Cache(BaseCache): """ A caching class which stores the data in a sqlite3 database. """
[docs] def contains(self, key): """ :param key: :return: """ result = self.conn.execute('SELECT COUNT(*) FROM entries WHERE name = ?', (key,)) for row in result: return row[0] == 1 return False
[docs] def get(self, key): """ :param key: :return: """ result = self.conn.execute('SELECT value FROM entries WHERE name = ?', (key,)) for row in result: return row[0]
[docs] def keys(self): """ :return: """ result = self.conn.execute('SELECT name FROM entries') return [row[0] for row in result]
[docs] def set(self, key, value): """ :param key: :param value: """ self.conn.execute('DELETE FROM entries WHERE name = ?', (key,)) self.conn.execute( 'INSERT INTO entries (name, value) VALUES (?, ?)', (key, sqlite3.Binary(value),) ) self.conn.commit()
def __init__(self, *args, **kwargs): super(Sqlite3Cache, self).__init__(*args, **kwargs) self.conn = None if self.ignore_cache is not True: self.conn = sqlite3.connect('%s.sq3.cache' % (self.cache_token, )) self.conn.isolation_level = None self.conn.execute('PRAGMA journal_mode = WAL') self.conn.execute('PRAGMA synchronous = NORMAL') self.conn.isolation_level = 'DEFERRED' self.conn.execute('CREATE TABLE IF NOT EXISTS entries (name TEXT, value BLOB)') self.conn.execute('CREATE UNIQUE INDEX IF NOT EXISTS entries_name ON entries (name)') def __del__(self): if self.conn: self.conn.close()
[docs]class NotReallyATree(list): """ The class is a some-what duck-type compatible (it has a ``query`` method) dumb replacement for (c)KDTrees. It can be used to find the nearest matching point to a query point. (And does that by exhaustive search...) """ def __init__(self, iterable): """ :param iterable: input data :type iterable: iterable :return: the queryable object :rtype: NotReallyAtree """ super(NotReallyATree, self).__init__(self) for i in iterable: self.append(i) self.na = np.array(iterable)
[docs] def query(self, q): # w_numpy """ Finds the point which is nearest to ``q``. Uses the Euclidean distance. :param q: query point :return: distance, index :rtype: float, int >>> t = NotReallyATree([[1.0, 1.0], [2.0, 2.0], [3.0, 3.0]]) >>> t.query([1.25, 1.25]) (0.3535533905932738, 0) >>> t = NotReallyATree([[1.0, 1.0], [2.0, 2.0], [3.0, 3.0]]) >>> t.query([2.3535533905932737622, 2.3535533905932737622]) (0.5000000000000002, 1) """ distances = np.sqrt(np.sum(np.power(self.na - q, 2.0), 1)) pos = np.argmin(distances, 0) return distances[pos], pos