# Author: Yubo "Paul" Yang
# Email: yubo.paul.yang@gmail.com
# Kyrt is a versatile fabric exclusive to the planet Florina of Sark.
# The fluorescent and mutable kyrt is ideal for artsy decorations.
# OK, this is a library of reasonable defaults for matplotlib figures.
# May this library restore elegance to your plots.
import matplotlib as mpl
import matplotlib.pyplot as plt
from matplotlib.ticker import FormatStrFormatter
# ======================== library of defaults =========================
# expose some default colors for convenience
from matplotlib.cm import get_cmap
cmap = get_cmap('viridis')
colors = cmap.colors # 256 default colors
dark8 = [ # Colors from www.ColorBrewer.org by Cynthia A. Brewer, Geography, Pennsylvania State University.
'#1b9e77',
'#d95f02',
'#7570b3',
'#e7298a',
'#66a61e',
'#e6ab02',
'#a6761d',
'#666666'
]
errorbar_style = {
'cyq': {
'linestyle': 'none', # do 1 thing
'markersize': 3.5, # readable
'markeredgecolor': 'black', # accentuate
'markeredgewidth': 0.3,
'capsize': 4,
'elinewidth': 0.5
}
}
# ======================== level 0: basic color =========================
[docs]def get_cmap(name='viridis'):
""" return color map by name
Args:
name (str, optional): name of color map, default 'viridis'
Return:
matplotlib.colors.ListedColormap: requested colormap
"""
from matplotlib import cm
cmap = cm.get_cmap(name)
return cmap
[docs]def get_norm(vmin, vmax):
""" return norm function for scalar in range (vmin, vmax)
Args:
vmin (float): value minimum
vmax (float): value maximum
Return:
matplotlib.colors.Normalize: color normalization function
"""
norm = plt.Normalize(vmin, vmax)
return norm
[docs]def scalar_colormap(vmin, vmax, name='viridis'):
""" return a function that maps a number to a color
Args:
vmin (float): minimum scalar value
vmax (float): maximum scalar value
name (str, optional): color map name, default is 'viridis'
Return:
function: float -> (float,)*4 RGBA color space
"""
cmap = get_cmap(name)
norm = get_norm(vmin, vmax)
def v2c(v): # function mapping value to color
return cmap(norm(v))
return v2c
[docs]def scalar_colorbar(vmin, vmax, name='viridis', **kwargs):
""" return a colorbar for scalar_color_map()
Args:
vmin (float): minimum scalar value
vmax (float): maximum scalar value
name (str, optional): color map name, default is 'viridis'
Return:
matplotlib.colorbar.Colorbar: colorbar
"""
cmap = get_cmap(name)
norm = get_norm(vmin, vmax)
# issue 3644
sm = plt.cm.ScalarMappable(norm=norm, cmap=cmap)
sm.set_array([])
cbar = plt.colorbar(sm, **kwargs)
return cbar
# ======================== level 0: basic ax edits =========================
[docs]def figaxad(labelsize=12):
""" construct a absolute/difference (ad) figure
top 3/4 of the plot will be comparison at an absolute scale
bottom 1/4 of the plot will be comparison at a relative scale
Args:
labelsize (int, optional): tick label size
Return:
(fig, axa, axd): figure and axes for absolute and difference plots
"""
from matplotlib.gridspec import GridSpec
gs = GridSpec(4, 4)
fig = plt.figure()
axa = fig.add_subplot(gs[0:3, :])
axd = fig.add_subplot(gs[3, :], sharex=axa)
plt.setp(axa.get_xticklabels(), visible=False)
axa.tick_params(axis='y', labelsize=labelsize)
axd.tick_params(labelsize=labelsize)
fig.subplots_adjust(hspace=0)
return fig, axa, axd
[docs]def set_tick_font(ax, xsize=14, ysize=14,
xweight='bold', yweight='bold', **kwargs):
""" change x,y tick fonts
Args:
ax (plt.Axes): matplotlib axes
xsize (int,optional): xtick fontsize, default is 14
ysize (int,optional): ytick fontsize, default is 14
xweight (str,optional): xtick fontweight, default is 'bold'
yweight (str,optional): ytick fontweight, default is 'bold'
kwargs (dict): other tick-related properties
"""
plt.setp(ax.get_xticklabels(), fontsize=xsize,
fontweight=xweight, **kwargs)
plt.setp(ax.get_yticklabels(), fontsize=ysize,
fontweight=yweight, **kwargs)
[docs]def set_label_font(ax, xsize=14, ysize=14,
xweight='bold', yweight='bold', **kwargs):
""" change x,y label fonts
Args:
ax (plt.Axes): matplotlib axes
xsize (int,optional): xlabel fontsize, default is 14
ysize (int,optional): ylabel fontsize, default is 14
xweight (str,optional): xlabel fontweight, default is 'bold'
yweight (str,optional): ylabel fontweight, default is 'bold'
kwargs (dict): other label-related properties
"""
plt.setp(ax.xaxis.label, fontsize=xsize,
fontweight=xweight, **kwargs)
plt.setp(ax.yaxis.label, fontsize=ysize,
fontweight=yweight, **kwargs)
[docs]def xtop(ax):
""" move xaxis label and ticks to the top
Args:
ax (plt.Axes): matplotlib axes
"""
xaxis = ax.get_xaxis()
xaxis.tick_top()
xaxis.set_label_position('top')
[docs]def yright(ax):
""" move yaxis label and ticks to the right
Args:
ax (plt.Axes): matplotlib axes
"""
yaxis = ax.get_yaxis()
yaxis.tick_right()
yaxis.set_label_position('right')
# ======================= level 1: advanced ax edits ========================
[docs]def cox(ax, x, xtlabels):
"""Add co-xticklabels at top of the plot, e.g., with a different unit
Args:
ax (plt.Axes): matplotlib axes
x (list): xtick locations
xtlabels (list): xtick labels
"""
ax1 = ax.twiny()
ax1.set_xlim(ax.get_xlim())
ax.set_xticks(x)
ax1.set_xticks(x)
ax1.set_xticklabels(xtlabels)
xtop(ax1)
return ax1
[docs]def coy(ax, y, ytlabels):
"""Add co-yticklabels on the right of the plot, e.g., with a different unit
Args:
ax (plt.Axes): matplotlib axes
y (list): ytick locations
ytlabels (list): ytick labels
"""
ax1 = ax.twinx()
ax1.set_ylim(ax.get_ylim())
ax.set_yticks(y)
ax1.set_yticks(y)
ax1.set_yticklabels(ytlabels)
yright(ax1)
return ax1
# ====================== level 0: basic legend edits =======================
[docs]def set_legend_marker_size(leg, ms=10):
handl = leg.legendHandles
msl = [ms]*len(handl) # override marker sizes here
for hand, ms in zip(handl, msl):
hand._legmarker.set_markersize(ms)
[docs]def create_legend(ax, styles, labels, **kwargs):
""" create custom legend
learned from "Composing Custom Legends"
Args:
ax (plt.Axes): matplotlib axes
Return:
plt.legend.Legend: legend artist
"""
from matplotlib.lines import Line2D
custom_lines = [Line2D([], [], **style) for style in styles]
leg = ax.legend(custom_lines, labels, **kwargs)
return leg
# ====================== level 0: global edits =======================
[docs]def set_style(style='ticks', context='talk', **kwargs):
import seaborn as sns
if (context=='talk') and ('font_scale' not in kwargs):
kwargs['font_scale'] = 0.7
sns.set_style(style)
sns.set_context(context, **kwargs)
# ====================== level 0: basic Line2D edits =======================
[docs]def get_style(line):
""" get plot styles from Line2D object
mostly copied from "Line2D.update_from"
Args:
line (Line2D): source of style
Return:
dict: line styles readily usable for another plot
"""
styles = {
'linestyle': line.get_linestyle(),
'linewidth': line.get_linewidth(),
'color': line.get_color(),
'markersize': line.get_markersize(),
'linestyle': line.get_linestyle(),
'marker': line.get_marker()
}
return styles
# ====================== level 0: basic Line2D =======================
[docs]def errorshade(ax, x, ym, ye, **kwargs):
line = ax.plot(x, ym, **kwargs)
alpha = 0.4
myc = line[0].get_color()
eline = ax.fill_between(x, ym-ye, ym+ye, color=myc, alpha=alpha)
return line, eline
# ===================== level 1: fit line ======================
[docs]def show_fit(ax, line, model, sel=None, nx=64, popt=None,
xmin=None, xmax=None, circle=True, circle_style=None,
cross=False, cross_style=None, **kwargs):
""" fit a segment of (x, y) data and show fit
get x, y data from line; use sel to make selection
Args:
ax (Axes): matplotlib axes
line (Line2D): line with data
model (callable): model function
sel (np.array, optional): boolean selector array
nx (int, optional): grid size, default 64
xmin (float, optional): grid min
xmax (float, optional): grid max
circle (bool, optional): circle selected points, default True
cross (bool, optional): cross out deselected points, default False
Return:
(np.array, np.array, list): (popt, perr, lines)
"""
import numpy as np
from scipy.optimize import curve_fit
# get and select data to fit
myx = line.get_xdata()
myy = line.get_ydata()
# show selected data
if sel is None:
sel = np.ones(len(myx), dtype=bool)
myx1 = myx[sel]
myy1 = myy[sel]
myx11 = myx[~sel]
myy11 = myy[~sel]
if xmin is None:
xmin = myx1.min()
if xmax is None:
xmax = myx1.max()
lines = []
if circle:
styles = get_style(line)
styles['linestyle'] = ''
styles['marker'] = 'o'
styles['fillstyle'] = 'none'
if circle_style is not None:
styles.update(circle_style)
line1 = ax.plot(myx[sel], myy[sel], **styles)
lines.append(line1[0])
if cross:
styles = get_style(line)
styles['linestyle'] = ''
styles['marker'] = 'x'
if cross_style is not None:
styles.update(cross_style)
line11 = ax.plot(myx11, myy11, **styles)
lines.append(line11[0])
if popt is None: # perform fit
popt, pcov = curve_fit(model, myx1, myy1)
perr = np.sqrt(np.diag(pcov))
else:
perr = None
# show fit
finex = np.linspace(xmin, xmax, nx)
line2 = ax.plot(finex, model(finex, *popt),
c=line.get_color(), **kwargs)
lines.append(line2[0])
return popt, perr, lines
[docs]def smooth_bspline(myx, myy, nxmult=10, **spl_kws):
import numpy as np
from scipy.interpolate import splrep, splev
nx = len(myx)*nxmult
idx = np.argsort(myx)
tck = splrep(myx[idx], myy[idx], **spl_kws)
finex = np.linspace(myx.min(), myx.max(), nx)
finey = splev(finex, tck)
return finex, finey
[docs]def show_spline(ax, line, spl_kws=dict(), sel=None, **kwargs):
""" show a smooth spline through given line x y
Args:
ax (plt.Axes): matplotlib axes
line (Line1D): matplotlib line object
spl_kws (dict, optional): keyword arguments to splrep, default is empty
nx (int, optional): number of points to allocate to 1D grid
Return:
Line1D: interpolating line
"""
import numpy as np
myx = line.get_xdata()
myy = line.get_ydata()
if sel is None:
sel = np.ones(len(myx), dtype=bool)
myx = myx[sel]
myy = myy[sel]
finex, finey = smooth_bspline(myx, myy, **spl_kws)
color = line.get_color()
line1 = ax.plot(finex, finey, c=color, **kwargs)
return line1
[docs]def krig(finex, x0, y0, length_scale, noise_level):
from sklearn.gaussian_process.gpr import GaussianProcessRegressor
from sklearn.gaussian_process.kernels import DotProduct, RBF
from sklearn.gaussian_process.kernels import WhiteKernel
kernel = DotProduct() + RBF(length_scale=length_scale)
kernel += WhiteKernel(noise_level=noise_level)
gpr = GaussianProcessRegressor(kernel=kernel)
gpr.fit(x0[:, None], y0)
ym, ye = gpr.predict(finex[:, None], return_std=True)
return ym, ye
[docs]def gpr_errorshade(ax, x, ym, ye,
length_scale, noise_level, fb_kwargs=None,
**kwargs):
"""WARNING: length_scale and noise_level are VERY DIFFICULT to tune """
# make errorbar plot and extract color
if ('ls' not in kwargs) and ('linestyle' not in kwargs):
kwargs['ls'] = ''
line = ax.errorbar(x, ym, ye, **kwargs)
myc = line[0].get_color()
# smoothly fit data
import numpy as np
dx = abs(x[1]-x[0])
xmin = x.min(); xmax = x.max()
finex = np.arange(xmin, xmax, dx/10.)
ylm, yle = krig(finex, x, ym-ye,
length_scale=length_scale, noise_level=noise_level)
yhm, yhe = krig(finex, x, ym+ye,
length_scale=length_scale, noise_level=noise_level)
# plot fit
if fb_kwargs is None:
fb_kwargs = {'color': myc, 'alpha': 0.4}
eline = ax.fill_between(finex, ylm-yle, yhm+yhe, **fb_kwargs)
return line[0], eline
# ===================== level 2: insets ======================
[docs]def inset_zoom(fig, ax_box, xlim, ylim, draw_func, xy_label=False):
""" show an inset that zooms into a given part of the figure
Args:
fig (plt.Figure): figure
ax_box (tuple): inset location and size (x0, y0, dx, dy) in figure ratio
xlim (tuple): (xmin, xmax)
ylim (tuple): (ymin, ymax)
draw_func (callable): draw_func(ax) should recreate the figure
xy_label (bool, optional): label inset axes, default is False
Return:
plt.Axes: inset axes
Example:
>>> ax1 = inset_zoom(fig, [0.15, 0.15, 0.3, 0.3], [0.1, 0.5], [-0.02, 0.01],
>>> lambda ax: ax.plot(x, y))
>>> ax.indicate_inset_zoom(axins)
"""
ax1 = fig.add_axes(ax_box)
ax1.set_xlim(*xlim)
ax1.set_ylim(*ylim)
draw_func(ax1)
if not xy_label:
ax1.set_xticks([])
ax1.set_yticks([])
return ax1
# ======================== composition =========================
[docs]def pretty_up(ax):
set_tick_font(ax)
set_label_font(ax)