Source code for molyso.mm.highlevel_interactive_ground_truth

# -*- coding: utf-8 -*-
"""
documentation
"""

from __future__ import division, unicode_literals, print_function

import numpy as np

from .tracking_output import s_to_h
from ..generic.etc import QuickTableDumper

try:
    # noinspection PyUnresolvedReferences
    import cPickle

    pickle = cPickle
except ImportError:
    import pickle


[docs]def interactive_ground_truth_main(args, tracked_results): """ Ground truth mode entry function. :param args: :param tracked_results: :return: :raise SystemExit: """ acceptable_pos_chans = \ {p: list(range(len(tracked_results[list(tracked_results.keys())[p]].channel_accumulator.keys()))) for p in range(len(tracked_results.keys())) if len(tracked_results[list(tracked_results.keys())[p]].channel_accumulator.keys()) > 0} def plots_info(): """ Outputs some information about the data set. """ print("Positions " + str(list(tracked_results.keys()))) print("Acceptable channels per position " + repr(acceptable_pos_chans)) plots_info() ground_truth_data = args.ground_truth # noinspection PyUnresolvedReferences try: with open(ground_truth_data, 'rb') as fp: all_envs = pickle.load(fp) except FileNotFoundError: print("File did not exist, starting anew") all_envs = {} def save_data(): """ Saves the ground truth data to the file specified. (pickled data) """ with open(ground_truth_data, 'wb+') as inner_fp: pickle.dump(all_envs, inner_fp, protocol=pickle.HIGHEST_PROTOCOL) print("Saved data to %s" % (ground_truth_data,)) lowest_position = min(acceptable_pos_chans.keys()) highest_position = max(acceptable_pos_chans.keys()) next_dataset = [lowest_position, next(iter(acceptable_pos_chans[lowest_position]))] def perform_it(): """ Runs the ground truth mode. :return: :raise SystemExit: """ next_pos, next_chan = next_dataset def empty_env(): """ Generates an empty environment. :return: """ return { 'points': np.ma.array(np.zeros((1024, 3))), 'points_empty': np.ma.array(np.zeros((1024, 3))), 'used': 0, 'last_point_x': None, 'last_point_y': None, } if (next_pos, next_chan) not in all_envs: all_envs[(next_pos, next_chan)] = empty_env() env = all_envs[(next_pos, next_chan)] pos = list(tracked_results.keys())[next_pos] tracking = tracked_results[pos] chan_num = list(tracking.channel_accumulator.keys())[next_chan] channels = tracking.channel_accumulator[chan_num] print("Opening position %d, channel %d" % (pos, chan_num,)) data = np.zeros((len(channels), 6)) n_timepoint, n_width, n_height, n_top, n_bottom, n_width_cumsum = 0, 1, 2, 3, 4, 5 some_channel_image = None for n, cc in enumerate(channels): data[n, n_timepoint] = cc.image.timepoint data[n, n_width] = cc.channel_image.shape[1] data[n, n_height] = cc.channel_image.shape[0] data[n, n_top] = cc.top data[n, n_bottom] = cc.bottom some_channel_image = cc.channel_image data[:, n_width_cumsum] = np.cumsum(data[:, n_width]) max_top, min_top = data[:, n_top].max(), data[:, n_top].min() max_bottom, min_bottom = data[:, n_bottom].max(), data[:, n_bottom].min() low, high = int(np.floor(min_top)), int(np.ceil(max_bottom)) large_image = np.zeros((high - low, int(data[-1, n_width_cumsum])), dtype=some_channel_image.dtype) for n, cc in enumerate(channels): lower_border = int(np.floor(data[n, n_top] - low)) large_image[ lower_border:int(lower_border + data[n, n_height]), int(data[n, n_width_cumsum] - data[n, n_width]):int(data[n, n_width_cumsum]) ] = cc.channel_image import matplotlib.pyplot as plt fig, ax = plt.subplots() plt.subplots_adjust(left=0.25, bottom=0.25) fig.canvas.set_window_title("Image Viewer") channels_per_inch = 5.0 plt.rcParams['figure.figsize'] = (len(channels) / channels_per_inch, 4.0) plt.rcParams['figure.dpi'] = 150 plt.rcParams['figure.subplot.top'] = 0.8 plt.rcParams['figure.subplot.bottom'] = 0.2 plt.rcParams['figure.subplot.left'] = 0.2 plt.rcParams['figure.subplot.right'] = 0.8 plt.rcParams['image.cmap'] = 'gray' plt.imshow(large_image) plt.title("Ground Truth — Position %d, channel %d" % (pos, chan_num,)) plt.xlabel("x [Pixel]") plt.ylabel("y [Pixel]") fig.tight_layout() o_scatter = ax.scatter(0, 0) o_lines, = plt.plot(0, 0) def refresh(): """ Refreshes the overlay. """ o_lines.set_data(env['points'][:env['used'], 0], env['points'][:env['used'], 1]) o_scatter.set_offsets(env['points'][:env['used'], :2]) fig.canvas.draw() def show_help(): """ Shows a help text for the ground truth mode. """ print(""" Ground Truth Mode: = Mouse ===================================== Mark division events by right click: First a division, then a child's division. = Keys ====================================== h show this help p print growth rates (last is based on mean division time) d delete last division event n/N next/previous multipoint m/M next/previous channel o/O output tabular data to console/file w write data (to previously specified filename) i start interactive python console q quit ground truth mode """) refresh() show_help() def click(e): """ :param e: :return: """ x, y = e.xdata, e.ydata if x is None or y is None: return if e.button == 3: last_point_x, last_point_y = env['last_point_x'], env['last_point_y'] if last_point_x is not None: if env['used'] + 3 >= env['points'].shape[0]: oldmask = env['points'].mask[:env['used']] env['points'] = np.ma.array(np.r_[env['points'], env['points_empty']]) env['points'].mask = np.zeros_like(env['points']).astype(np.dtype(bool)) # [:env['used']] env['points'].mask[:env['used']] = oldmask n_x = np.searchsorted(data[:, n_width_cumsum], x, side='right') n_last_x = np.searchsorted(data[:, n_width_cumsum], last_point_x, side='right') if x < last_point_x: x, y, last_point_x, last_point_y = last_point_x, last_point_y, x, y n_x, n_last_x = n_last_x, n_x env['points'][env['used'], 0] = last_point_x env['points'][env['used'], 1] = last_point_y env['points'][env['used'], 2] = data[n_last_x, n_timepoint] env['used'] += 1 env['points'][env['used'], 0] = x env['points'][env['used'], 1] = y env['points'][env['used'], 2] = data[n_x, n_timepoint] env['used'] += 1 env['points'][env['used'], :] = np.ma.masked env['used'] += 1 refresh() # print(n, data[n, n_timepoint]) env['last_point_x'], env['last_point_y'] = None, None else: env['last_point_x'], env['last_point_y'] = x, y def key_press(event): """ :param event: :return: :raise SystemExit: """ def show_stats(): """ Shows statistics. """ inner_times = env['points'][:env['used'], 2].compressed() inner_times = inner_times.reshape(inner_times.size // 2, 2) inner_deltas = inner_times[:, 1] - inner_times[:, 0] inner_deltas /= 60.0 * 60.0 inner_mu = np.log(2) / inner_deltas print(inner_mu, np.mean(inner_mu)) def try_new_poschan(p, c): """ :param p: :param c: :return: """ next_pos, next_chan = next_dataset if p == 1: while (next_pos + p) not in acceptable_pos_chans and (next_pos + p) < highest_position: p += 1 elif p == -1: while (next_pos + p) not in acceptable_pos_chans and (next_pos + p) > lowest_position: p -= 1 if (next_pos + p) not in acceptable_pos_chans: print("Position does not exist") return if p != 0: c = 0 next_chan = acceptable_pos_chans[next_pos + p][0] if c == 1: while (next_chan + c) not in acceptable_pos_chans[next_pos + p] and \ (next_chan + c) < max(acceptable_pos_chans[next_pos + p]): c += 1 elif c == -1: while (next_chan + c) not in acceptable_pos_chans[next_pos + p] and \ (next_chan + c) > min(acceptable_pos_chans[next_pos + p]): c -= 1 if (next_chan + c) not in acceptable_pos_chans[next_pos + p]: print("Channel does not exist") return next_dataset[0] = next_pos + p next_dataset[1] = next_chan + c plt.close() # if event.key == 'left': # timepoint.set_val(max(1, int(timepoint.val) - 1)) # elif event.key == 'right': # timepoint.set_val(min(tp_max, int(timepoint.val) + 1)) # elif event.key == 'ctrl+left': # timepoint.set_val(max(1, int(timepoint.val) - 10)) # elif event.key == 'ctrl+right': # timepoint.set_val(min(tp_max, int(timepoint.val) + 10)) # elif event.key == 'down': # multipoint.set_val(max(1, int(multipoint.val) - 1)) # elif event.key == 'up': # multipoint.set_val(min(mp_max, int(multipoint.val) + 1)) if event.key == 'h': show_help() elif event.key == 'p': show_stats() elif event.key == 'd': env['used'] -= 3 show_stats() refresh() # n next position, m next channel elif event.key == 'n': try_new_poschan(1, 0) elif event.key == 'N': try_new_poschan(-1, 0) elif event.key == 'm': try_new_poschan(0, 1) elif event.key == 'M': try_new_poschan(0, -1) elif event.key == 'o' or event.key == 'O': recipient = None if event.key == 'O': print("Please enter file name for tabular output [will be overwritten if exists]:") file_name = input() recipient = open(file_name, 'w+') out = QuickTableDumper(recipient=recipient) for (t_pos, t_chan), t_env in all_envs.items(): x_pos = list(tracked_results.keys())[t_pos] x_chan = list(tracked_results[x_pos].channel_accumulator.keys())[t_chan] times = t_env['points'][:t_env['used'], 2].compressed() times = times.reshape(times.size // 2, 2) deltas = times[:, 1] - times[:, 0] mu = np.log(2) / (deltas / (60.0 * 60.0)) for num in range(len(mu)): out.add({ 'position': x_pos, 'channel': x_chan, 'growth_rate': mu[num], 'growth_rate_channel_mean': np.mean(mu), 'division_age': s_to_h(deltas[num]), 'growth_start': s_to_h(times[num, 0]), 'growth_end': s_to_h(times[num, 1]), }) if event.key == 'O': recipient.close() print("File written.") elif event.key == 'w': save_data() elif event.key == 'i': import code code.InteractiveConsole(locals=globals()).interact() elif event.key == 'q': raise SystemExit fig.canvas.mpl_connect('key_press_event', key_press) fig.canvas.mpl_connect('button_press_event', click) plt.show() while True: perform_it()