classgraphic package

Interactive classification diagnostic plots for scikit-learn.

coin sorting machine

We classify things for the purpose of doing something to them. Any classification which does not assist manipulation is worse than useless. - Randolph S. Bourne, “Education and Living”, The Century Co (April 1917)

Module contents

Top-level package for Classgraphic.

Submodules

classgraphic.common

common functions for table and viz.

classgraphic.common.complementary_hex_color(rgb)

Return Complementary hexadecimal color of hex or rgb color.

Parameters:

rgb (str) – color to convert, formatted as either #aabbcc or rgb(111,222,123)

Returns:

hex color (str)

classgraphic.common.get_labels(model, y, as_str=True)

Get labels from the model.

We use a one hot strategy and ultimately get the resulting column names.

Parameters:
  • model – scikit-learn model

  • y (Series or array) – the target column for classification

  • as_str (bool) – default to list of string, else return one hot encoded dataframe

Returns:

(list(str) or dataframe)

classgraphic.common.heatmap_colorscale(colorscale=None, template=None)

Get heatmap colorscale from colorscale or template.

Parameters:
  • colorscale (str or list) – plotly named colorscale (blues, burgyl_r…) or 2d list mapping

  • template (str) – template name to use to get heatmap colorscale from.

Returns:

list (colorscale)

classgraphic.common.light_theme(colorscale_list)

Evaluate if a theme has a colorscale going from dark to light.

Parameters:

colorscale_list (list) – list of rgb or hex colors.

Returns:

bool

classgraphic.common.missing_matrix(X)

Find all missing (None or NaN) values.

For each cell, returns True if value is missing, False if not.

Parameters:

X (DataFrame or 2d array) – data array to evaluate for missing values.

Returns:

array of bool

classgraphic.common.table_template(template=None)

Get template to use with a table.

Parameters:

template (str) – template name. Template has to be valid. If not specified, will use plotly default template.

Returns:

template (str)

classgraphic.common.value_count(y_list)

Get class counts.

Parameters:

y_list (DataFrame or array) – usually the target data (y, y_test etc)

Returns:

counts (dict)

classgraphic.essential

By starting a python script or a notebook with from classgraphic.essential import *, all the basic elements are loaded to start exploring data, build linear classification models and evaluate them.

Specifically, essential exposes the following tables:

  • class_imbalance_table

  • classification_table

  • confusion_matrix_table

  • describe

  • prediction_table

  • table

And the following charts:

  • class_imbalance

  • class_error

  • det

  • feature_importance

  • missing

  • precision_recall

  • roc

  • prediction_histogram

  • threshold

For templates:

  • set_default_template

classgraphic.essential also sets the default template to “dion”, which unifies the look of all visualizations and tables around a gradient of blue, with light background and lines.

Finally, it also imports numpy as np, pandas as pd, plotly.express as px, plotly.graph_objects as go and make_classification, train_test_split and LogisticRegression from sklearn.

classgraphic.table

Classification tables, in plotly and pandas format.

classgraphic.table.class_imbalance_table(y, y2=None, condition=None, sort=False, ascending=False, extended=False, colorscale=None, template=None, **kwargs)

Class imbalance table.

Given a target Series or array (eg. y, y_test, y_train), this visualize the balance/imbalance of classes.

Parameters:
  • y (series or array) – usually, the target column for classification

  • y2 (series or array) – optional. If provided an additional count column will allow for comparison

  • condition (str) – added to title. Something like ‘train’, ‘test’, ‘train and test’, ‘real and test’

  • sort (bool) – if True, data will be sorted by class count

  • ascending (bool) – if True, sort ascending, else descending

  • extended (bool) – return a dataframe, besides the Plotly table figure

  • colorscale (str or list) – plotly named colorscale (blues, burgyl_r…) or 2d list mapping

  • template (str) – template name to use - uses heatmap colorscale, not table colors to provide alternating colors

  • **kwargs – any Plotly kwargs for go.Table

Returns:

Figure (, Dataframe if extended)

classgraphic.table.classification_table(model, y, y_pred, extended=False, colorscale=None, template=None, **kwargs)

Classification report table.

Parameters:
  • y (series or array) – usually, the target column for classification

  • y_pred (series or array) – predicted values

  • condition (str) – added to title. Something like ‘train’, ‘test’, ‘train and test’, ‘real and test’

  • sort (bool) – if True, data will be sorted by class count

  • ascending (bool) – if True, sort ascending, else descending

  • extended (bool) – return a dataframe, besides the Plotly table figure

  • colorscale (str or list) – plotly named colorscale (blues, burgyl_r…) or 2d list mapping

  • template (str) – template name to use - uses heatmap colorscale, not table colors to provide alternating colors

  • **kwargs – any Plotly kwargs for go.Table

Returns:

Figure (, Dataframe if extended)

classgraphic.table.confusion_matrix_table(model, y, y_pred, normalize=None, extended=False, colorscale=None, template=None, **kwargs)

Confusion matrix table.

Parameters:
  • model – scikit-learn model

  • y (series or array) – the target column for classification

  • y_pred (series or array) – predicted values

  • normalize (str)) – one of None, ‘true’ (or synonym ‘real’), ‘pred’, ‘all’

  • extended (bool) – return a dataframe, besides the Plotly table figure

  • colorscale (str or list) – plotly named colorscale (blues, burgyl_r…) or 2d list mapping

  • template (str) – template name to use - uses heatmap colorscale, not table colors to provide alternating colors

  • **kwargs – any Plotly kwargs for go.Heatmap

Returns:

Figure (table)(, Dataframe if extended)

classgraphic.table.describe(X, colorscale=None, template=None, transpose=False, **kwargs)

Dataframe description table.

This is equivalent to pandas df.describe(), with extra quantiles and non numerical feature stats, such as top and frequency.

Parameters:
  • X (dataframe or array) – tabular data

  • colorscale (str or list) – plotly named colorscale (blues, burgyl_r…) or 2d list mapping

  • template (str) – template name to use - uses heatmap colorscale, not table colors to provide alternating colors

  • transpose (bool) – columns become rows, rows become columns

  • **kwargs – any Plotly kwargs for go.Figure

Returns:

Figure (table)

classgraphic.table.prediction_table(model, y, y_pred, colorscale=None, template=None, **kwargs)

Real vs Prediction table.

Parameters:
  • model – scikit-learn model

  • y (series or array) – the target column for classification

  • y_pred (series or array) – predicted values

  • normalize (str)) – one of None, ‘true’ (or synonym ‘real’), ‘pred’, ‘all’

  • extended (bool) – return a dataframe, besides the Plotly table figure

  • colorscale (str or list) – plotly named colorscale (blues, burgyl_r…) or 2d list mapping

  • template (str) – template name to use - uses heatmap colorscale, not table colors to provide alternating colors

  • **kwargs – any Plotly kwargs for go.Heatmap

Returns:

Figure (table)

classgraphic.table.table(df, title=None, format=None, colorscale=None, template=None, named_index=True, **kwargs)

Build generic plotly table.

Parameters:
  • df (dataframe) – tabular data

  • title (str) – title for the table. Can use <b></b>, <sub></sub>, <br> etc. Font size is defined by the template.

  • colorscale (str or list) – plotly named colorscale (blues, burgyl_r…) or 2d list mapping

  • template (str) – template name to use - uses heatmap colorscale, not table colors to provide alternating colors

  • **kwargs – any Plotly kwargs for go.Figure

Returns:

Figure (table)

classgraphic.template

template.py.

Various templates in plotly format.

Logo template

classgraphic.template.DION_TEMPLATE = {'data': {'heatmap': [{'colorbar': {'outlinewidth': 0, 'ticks': ''}, 'colorscale': [[0.0, 'rgb(247,251,255)'], [0.125, 'rgb(222,235,247)'], [0.25, 'rgb(198,219,239)'], [0.375, 'rgb(158,202,225)'], [0.5, 'rgb(107,174,214)'], [0.625, 'rgb(66,146,198)'], [0.75, 'rgb(33,113,181)'], [0.875, 'rgb(8,81,156)'], [1.0, 'rgb(8,48,107)']], 'type': 'heatmap'}], 'heatmapgl': [{'colorbar': {'outlinewidth': 0, 'ticks': ''}, 'colorscale': [[0.0, 'rgb(247,251,255)'], [0.125, 'rgb(222,235,247)'], [0.25, 'rgb(198,219,239)'], [0.375, 'rgb(158,202,225)'], [0.5, 'rgb(107,174,214)'], [0.625, 'rgb(66,146,198)'], [0.75, 'rgb(33,113,181)'], [0.875, 'rgb(8,81,156)'], [1.0, 'rgb(8,48,107)']], 'type': 'heatmapgl'}], 'histogram': [{'marker': {'colorbar': {'outlinewidth': 0, 'ticks': ''}}, 'type': 'histogram'}], 'pie': [{'automargin': True, 'type': 'pie'}], 'table': [{'cells': {'fill': {'color': 'rgb(247,251,255)'}, 'line': {'color': 'black'}}, 'header': {'fill': {'color': 'rgb(158,202,225)'}, 'line': {'color': 'black'}}, 'type': 'table'}]}, 'layout': {'colorscale': {'sequential': [[0.0, 'rgb(247,251,255)'], [0.125, 'rgb(222,235,247)'], [0.25, 'rgb(198,219,239)'], [0.375, 'rgb(158,202,225)'], [0.5, 'rgb(107,174,214)'], [0.625, 'rgb(66,146,198)'], [0.75, 'rgb(33,113,181)'], [0.875, 'rgb(8,81,156)'], [1.0, 'rgb(8,48,107)']]}, 'font': {'color': 'rgb(8,48,107)'}, 'margin': {'b': 40, 't': 140}, 'title': {'x': 0.05}, 'title_font_size': 24}}

Dion Research blue gradient template

classgraphic.template.HIDE_LEGEND = layout.Template({     'layout': {'showlegend': False} })

Hide Legend

classgraphic.template.seq_colorscale = [[0.0, 'rgb(247,251,255)'], [0.125, 'rgb(222,235,247)'], [0.25, 'rgb(198,219,239)'], [0.375, 'rgb(158,202,225)'], [0.5, 'rgb(107,174,214)'], [0.625, 'rgb(66,146,198)'], [0.75, 'rgb(33,113,181)'], [0.875, 'rgb(8,81,156)'], [1.0, 'rgb(8,48,107)']]

sequential colorscale

classgraphic.template.set_default_template(template='dion')

set_default_template.

Parameters:

template (str) – template to use for all charts. Defaults to DEFAULT_TEMPLATE.

classgraphic.viz

Classification visualizations, in plotly format and some also in pandas format.

classgraphic.viz.class_error(model, y, y_pred, normalize=False, reference=True, **kwargs)

Classification error chart.

Given a model, a target Series or array and a prediction Series or array, this visualize the prediction errors (for each real class, a breakdown of how many predicted by class). If normalize is False, it also show balance/imbalance of classes. If normalize is True, the bars all represent 100% of the class, and each segment, the percent of that predicted class for the real class.

Parameters:
  • model (sklearn model) – A scikit learn model (displayed in title)

  • y (series or array) – the target column for classification

  • y_pred (series or array) – predicted values

  • normalize (bool) – if True, each segment is in percent and each true class is equal to 100%

  • reference (bool) – add a reference line to better compare classes

  • **kwargs – any Plotly kwargs for px.bar

Returns:

Figure

classgraphic.viz.class_imbalance(y, y2=None, condition=None, sort=False, ascending=False, reference=True, always_bar=False, barmode='stack', **kwargs)

Class imbalance visualization.

Given a target Series or array (eg. y, y_test, y_train), this visualize the balance/imbalance of classes. For binary classification, the default is a pie chart. For multiclass, it is a bar chart.

Parameters:
  • y (series or array) – usually, the target column for classification

  • condition (str) – added to title. Something like ‘train’, ‘test’

  • sort (bool) – if True, data will be sorted by class count

  • ascending (bool) – if True, sort ascending, else descending

  • reference (bool) – add a reference line to better compare classes

  • always_bar (bool) – it True, for binary classification use bar instead of pie chart

  • barmode

  • **kwargs – any Plotly kwargs for px.pie or px.bar

Returns:

Figure

classgraphic.viz.det(model, y, y_pred, **kwargs)

Plot detection error tradeoff (DET) curve.

A visualization of error rates for binary classification, plotting the probit of false negative

rate vs. false positive rate. The interval -3, 3 from the probit is mapped to a percentage.

Parameters:
  • model (sklearn model) – A scikit learn model (displayed in title)

  • y (series or array) – the target column for classification

  • y_pred (series or array) – predicted values

  • **kwargs – any Plotly kwargs for px.line

Returns:

Figure

classgraphic.viz.feature_importance(model, y, transpose=False, extended=False, **kwargs)

Feature Importance.

Parameters:
  • model (sklearn model) – A scikit learn model (displayed in title)

  • y (series or array) – the target column for classification

  • transpose (bool) – columns become rows, rows become columns

  • extended (bool) – return a dataframe, besides the Plotly figure

  • **kwargs – any Plotly kwargs for px.histogram

Returns:

Figure

classgraphic.viz.missing(X, warning=True, **kwargs)

Missing data heatmap.

Missing is defined here as NaN on numerical data types.

Parameters:
  • X (sklearn model) – A scikit learn model (displayed in title)

  • warning (bool) – Issue a warning if there are some missing values.

  • **kwargs – any Plotly kwargs for px.imshow

Returns:

Figure

classgraphic.viz.precision_recall(model, y, y_probs, **kwargs)

Precision / Recall curve.

Parameters:
  • model (sklearn model) – A scikit learn model (displayed in title)

  • y (series or array) – the target column for classification

  • y_probs (series or array) – predicted value probabilities

  • **kwargs – any Plotly kwargs for px.line

Returns:

Figure

classgraphic.viz.prediction_histogram(model, y, y_probs, nbins=50, barmode='overlay', reference=True, extended=False, **kwargs)

Prediction Histogram chart.

Parameters:
  • model (sklearn model) – A scikit learn model (displayed in title)

  • y (series or array) – the target column for classification

  • y_probs (series or array) – predicted value probabilities

  • nbins (int) – number of bins, default to 50

  • barmode (str) – one of [‘stack’, ‘group’, ‘overlay’, ‘relative’, None], defaults to overlay

  • reference (bool) – add a reference line to better compare classes

  • extended (bool) – return a dataframe, besides the Plotly figure

  • **kwargs – any Plotly kwargs for px.histogram

Returns:

Figure

classgraphic.viz.roc(model, y, y_probs, ovo=False, ovr=True, **kwargs)

Plot Receiver Operating Characteristic (ROC) curve.

Parameters:
  • model (sklearn model) – A scikit learn model (displayed in title)

  • y (series or array) – the target column for classification

  • y_probs (series or array) – predicted value probabilities

  • ovo (bool) – if True, display One vs One macro average and weighted AUC (multiclass)

  • ovr (bool) – if True, display One vs Rest macro average and weighted AUC (multiclass)

  • **kwargs – any Plotly kwargs for go.Figure

Returns:

Figure

classgraphic.viz.threshold(model, y, y_probs, **kwargs)

Plot True Positive Rate vs False Positive Rate for all Threshold.

Parameters:
  • model (sklearn model) – A scikit learn model (displayed in title)

  • y (series or array) – the target column for classification

  • y_probs (series or array) – predicted value probabilities

  • **kwargs – any Plotly kwargs for go.Figure

Returns:

Figure

classgraphic.viz.view(X, y_real, y_pred=None, nobs=10, extended=False, **kwargs)

Receiver Operating Characteristic.

Parameters:
  • X (DataFrame or serues or array) – feature data to view

  • y_real (series or array) – the target column for classification matching X

  • y_pred (series or array) – optional, predicted value probabilities matching X

  • nobs (int) – number of observations to display

  • extended (bool) – if True, return dataframe

  • **kwargs – any Plotly kwargs for px.scatter

Returns:

Figure

classgraphic.viz.voronoi_diagram(centroid_source, data=None, labels=None, plot_voronoi=True, plot_delauney=False, **kwargs)

Plot a 2D voronoi diagram.

Parameters:
  • centroid_source – either a fitted clustering alrorithm with cluster_centers, or cluster center coordinates

  • data – optional, original data (X) used for calculating the clusters

  • labels – labels (y) for the clusters

  • plot_voronoi – if False, will not plot boundaries, aka Voronoi tessalation

  • plot_delauney – if True, plots Delauney triangulation

  • **kwargs – for now, passed to initial scatter plot

Returns:

Plotly figure