Source code for dye_score.plotting

import os
from bokeh.models import (
    CDSView,
    ColumnDataSource,
    IndexFilter,
    LinearAxis,
    NumeralTickFormatter,
    Range1d,
)
from bokeh.palettes import inferno
from bokeh.plotting import figure
from numpy import histogram
from pandas import (
    read_csv as pd_read_csv
)


[docs]def plot_hist(title, hist, edges, y_axis_type='linear', bottom=0): p = figure(title=title, tools='', background_fill_color="#fafafa", y_axis_type=y_axis_type, width=800, height=400) p.quad(top=hist, bottom=bottom, left=edges[:-1], right=edges[1:], fill_color="navy", line_color="white", alpha=0.5) p.y_range.start = bottom p.grid.grid_line_color = "white" return p
[docs]def plot_key_leaky(percent_to_dye, key, y_axis_type='linear', bottom=0, bins=40): hist, edges = histogram(percent_to_dye, bins=bins) p = plot_hist(f'Distribution of leaky - {key}', hist, edges, y_axis_type=y_axis_type, bottom=bottom) p.width = 400 p.height = 300 p.xaxis.axis_label = '%age dye_site wants to dye' p.yaxis.axis_label = 'Frequency' return p
[docs]def get_pr_plot( pr_df, title, n_scripts_range, y_range=(0, 1), recall_color='black', n_scripts_color='firebrick', **extra_plot_opts, ): """Example code for plotting dye score threshold plots""" source = ColumnDataSource(pr_df) p = figure(title=title, y_range=y_range, **extra_plot_opts) p.x_range.flipped = True p.line( x='dye_score_threshold', y='recall', color=recall_color, source=source, line_width=3) p.extra_y_ranges = {"n_scripts": Range1d(start=n_scripts_range[0], end=n_scripts_range[1])} p.line( x='dye_score_threshold', y='n_over_threshold', y_range_name='n_scripts', color=n_scripts_color, source=source) p.add_layout( LinearAxis( y_range_name="n_scripts", formatter=NumeralTickFormatter(format="0a"), axis_label_text_color=n_scripts_color, major_label_text_color=n_scripts_color ), 'right' ) return p
[docs]def get_plots_for_thresholds( ds, thresholds, leaky_threshold, n_scripts_range, filename_suffix='dye_snippets', y_range=(0, 1), recall_color='black', n_scripts_color='firebrick', **extra_plot_opts ): resultsdir = ds.config('DYESCORE_RESULTS_DIR') # Infile validation for threshold in thresholds: inpath = os.path.join( resultsdir, f'dye_score_plot_data_from_{filename_suffix}_{threshold}_leak_{leaky_threshold}.csv') ds.file_in_validation(inpath) plots = {} for threshold in thresholds: inpath = os.path.join( resultsdir, f'dye_score_plot_data_from_{filename_suffix}_{threshold}_leak_{leaky_threshold}.csv') if ds.s3: with ds.s3.open(inpath, 'r') as f: pr_df = pd_read_csv(f) else: pr_df = pd_read_csv(inpath) plots[threshold] = get_pr_plot( pr_df, f'{threshold}', n_scripts_range, y_range, recall_color, n_scripts_color, **extra_plot_opts ) return plots
[docs]def get_threshold_summary_plot(ds): resultsdir = ds.config('DYESCORE_RESULTS_DIR') inpath = os.path.join(resultsdir, f'recall_summary_plot_data.csv') ds.file_in_validation(inpath) if ds.s3: with ds.s3.open(inpath, 'r') as f: results_df = pd_read_csv(f) else: results_df = pd_read_csv(inpath) recall_thresholds = sorted(results_df.recall_threshold.unique()) grouped_results_df = results_df.groupby('recall_threshold').agg(lambda x: list(x)) palette = inferno(len(recall_thresholds) + 1) # The yellow is often a little light source = ColumnDataSource(grouped_results_df) p = figure( title=f'Scripts captured by distance threshold for {len(recall_thresholds)} recall thresholds (colored)', width=800, toolbar_location=None, tools='', y_range=Range1d(results_df.n_over_threshold.min(), results_df.n_over_threshold.max()), ) p.xaxis.axis_label = 'distance threshold' p.yaxis.axis_label = 'minimum n_scripts' p.yaxis.formatter = NumeralTickFormatter(format="0a") p.extra_y_ranges = {'percent': Range1d(results_df.percent.min(), results_df.percent.max())} p.add_layout(LinearAxis( y_range_name='percent', axis_label='minimum n_scripts (percent of total)', formatter=NumeralTickFormatter(format='0%') ), 'right') for i, recall_threshold in enumerate(recall_thresholds): view = CDSView(source=source, filters=[IndexFilter([i])]) opts = dict( source=source, view=view, legend=str(recall_threshold), color=palette[i], line_width=5, line_alpha=0.6 ) p.multi_line(xs='distance_threshold', ys='n_over_threshold', **opts) p.multi_line(xs='distance_threshold', ys='percent', y_range_name='percent', **opts) p.legend.click_policy = 'hide' return p