defsamples_nd(samples,points=[],**kwargs):opts={# what to plot on triagonal and diagonal subplots'upper':'hist',# hist/scatter/None'diag':'hist',# hist/None#'lower': None, # hist/scatter/None # TODO: implement# title and legend'title':None,'legend':False,# labels'labels':[],# for dimensions'labels_points':[],# for points'labels_samples':[],# for samples# colors'samples_colors':plt.rcParams['axes.prop_cycle'].by_key()['color'],'points_colors':plt.rcParams['axes.prop_cycle'].by_key()['color'],# subset'subset':None,# axes limits'limits':[],# ticks'ticks':[],'tickformatter':mpl.ticker.FormatStrFormatter('%g'),'tick_labels':None,# options for hist'hist_diag':{'alpha':1.,'bins':25,'density':False,'histtype':'step'},'hist_offdiag':{#'edgecolor': 'none',#'linewidth': 0.0,'bins':25,},# options for kde'kde_diag':{'bw_method':'scott','bins':100,'color':'black'},'kde_offdiag':{'bw_method':'scott','bins':25},# options for contour'contour_offdiag':{'levels':[0.68]},# options for scatter'scatter_offdiag':{'alpha':0.5,'edgecolor':'none','rasterized':False,},# options for plot'plot_offdiag':{},# formatting points (scale, markers)'points_diag':{},'points_offdiag':{'marker':'.','markersize':20,},# matplotlib style'style':os.path.join(os.path.dirname(__file__),'matplotlibrc'),# other options'fig_size':(10,10),'fig_bg_colors':{'upper':None,'diag':None,'lower':None},'fig_subplots_adjust':{'top':0.9,},'subplots':{},'despine':{'offset':5,},'title_format':{'fontsize':16},}# TODO: add color map support# TODO: automatically determine good bin sizes for histograms# TODO: get rid of seaborn dependency for despine# TODO: add legend (if legend is True)samples_nd.defaults=opts.copy()opts=_update(opts,kwargs)# Prepare samplesiftype(samples)!=list:samples=[samples]# Prepare pointsiftype(points)!=list:points=[points]points=[np.atleast_2d(p)forpinpoints]# Dimensionsdim=samples[0].shape[1]num_samples=samples[0].shape[0]# TODO: add asserts checking compatiblity of dimensions# Prepare labelsifopts['labels']==[]oropts['labels']isNone:labels_dim=['dim {}'.format(i+1)foriinrange(dim)]else:labels_dim=opts['labels']# Prepare limitsifopts['limits']==[]oropts['limits']isNone:limits=[]fordinrange(dim):min=+np.infmax=-np.infforsampleinsamples:min_=sample[:,d].min()min=min_ifmin_<minelseminmax_=sample[:,d].max()max=max_ifmax_>maxelsemaxlimits.append([min,max])else:iflen(opts['limits'])==1:limits=[opts['limits'][0]for_inrange(dim)]else:limits=opts['limits']# Prepare ticksifopts['ticks']==[]oropts['ticks']isNone:ticks=Noneelse:iflen(opts['ticks'])==1:ticks=[opts['ticks'][0]for_inrange(dim)]else:ticks=opts['ticks']# Prepare diag/upper/loweriftype(opts['diag'])isnotlist:opts['diag']=[opts['diag']for_inrange(len(samples))]iftype(opts['upper'])isnotlist:opts['upper']=[opts['upper']for_inrange(len(samples))]#if type(opts['lower']) is not list:# opts['lower'] = [opts['lower'] for _ in range(len(samples))]opts['lower']=None# Styleifopts['style']in['dark','light']:style=os.path.join(os.path.dirname(__file__),'matplotlib_{}.style'.format(opts['style']))else:style=opts['style'];# Apply custom style as contextwithmpl.rc_context(fname=style):# Figure out if we subset the plotsubset=opts['subset']ifsubsetisNone:rows=cols=dimsubset=[iforiinrange(dim)]else:iftype(subset)==int:subset=[subset]eliftype(subset)==list:passelse:raiseNotImplementedErrorrows=cols=len(subset)fig,axes=plt.subplots(rows,cols,figsize=opts['fig_size'],**opts['subplots'])axes=axes.reshape(rows,cols)# Style figurefig.subplots_adjust(**opts['fig_subplots_adjust'])fig.suptitle(opts['title'],**opts['title_format'])# Style axesrow_idx=-1forrowinrange(dim):ifrownotinsubset:continueelse:row_idx+=1col_idx=-1forcolinrange(dim):ifcolnotinsubset:continueelse:col_idx+=1ifrow==col:current='diag'elifrow<col:current='upper'else:current='lower'ax=axes[row_idx,col_idx]plt.sca(ax)# Background colorifcurrentinopts['fig_bg_colors']and \
opts['fig_bg_colors'][current]isnotNone:ax.set_facecolor(opts['fig_bg_colors'][current])# Axesifopts[current]isNone:ax.axis('off')continue# LimitsiflimitsisnotNone:ax.set_xlim((limits[col][0],limits[col][1]))ifcurrentisnot'diag':ax.set_ylim((limits[row][0],limits[row][1]))xmin,xmax=ax.get_xlim()ymin,ymax=ax.get_ylim()# TicksifticksisnotNone:ax.set_xticks((ticks[col][0],ticks[col][1]))ifcurrentisnot'diag':ax.set_yticks((ticks[row][0],ticks[row][1]))# Despinedespine(ax=ax,**opts['despine'])# Formatting axesifcurrent=='diag':# off-diagnoalsifopts['lower']isNoneorcol==dim-1:_format_axis(ax,xhide=False,xlabel=labels_dim[col],yhide=True,tickformatter=opts['tickformatter'])else:_format_axis(ax,xhide=True,yhide=True)else:# off-diagnoalsifrow==dim-1:_format_axis(ax,xhide=False,xlabel=labels_dim[col],yhide=True,tickformatter=opts['tickformatter'])else:_format_axis(ax,xhide=True,yhide=True)ifopts['tick_labels']isnotNone:ax.set_xticklabels((str(opts['tick_labels'][col][0]),str(opts['tick_labels'][col][1])))# Diagonalsifcurrent=='diag':iflen(samples)>0:forn,vinenumerate(samples):ifopts['diag'][n]=='hist':h=plt.hist(v[:,row],color=opts['samples_colors'][n],**opts['hist_diag'])elifopts['diag'][n]=='kde':density=gaussian_kde(v[:,row],bw_method=opts['kde_diag']['bw_method'])xs=np.linspace(xmin,xmax,opts['kde_diag']['bins'])ys=density(xs)h=plt.plot(xs,ys,color=opts['samples_colors'][n],)else:passiflen(points)>0:extent=ax.get_ylim()forn,vinenumerate(points):h=plt.plot([v[:,row],v[:,row]],extent,color=opts['points_colors'][n],**opts['points_diag'])# Off-diagonalselse:iflen(samples)>0:forn,vinenumerate(samples):ifopts['upper'][n]=='hist'oropts['upper'][n]=='hist2d':hist,xedges,yedges=np.histogram2d(v[:,col],v[:,row],range=[[limits[col][0],limits[col][1]],[limits[row][0],limits[row][1]]],**opts['hist_offdiag'])h=plt.imshow(hist.T,origin='lower',extent=[xedges[0],xedges[-1],yedges[0],yedges[-1]],aspect='auto')elifopts['upper'][n]in['kde','kde2d','contour','contourf']:density=gaussian_kde(v[:,[col,row]].T,bw_method=opts['kde_offdiag']['bw_method'])X,Y=np.meshgrid(np.linspace(limits[col][0],limits[col][1],opts['kde_offdiag']['bins']),np.linspace(limits[row][0],limits[row][1],opts['kde_offdiag']['bins']))positions=np.vstack([X.ravel(),Y.ravel()])Z=np.reshape(density(positions).T,X.shape)ifopts['upper'][n]=='kde'oropts['upper'][n]=='kde2d':h=plt.imshow(Z,extent=[limits[col][0],limits[col][1],limits[row][0],limits[row][1]],origin='lower',aspect='auto',)elifopts['upper'][n]=='contour':Z=(Z-Z.min())/(Z.max()-Z.min())h=plt.contour(X,Y,Z,origin='lower',extent=[limits[col][0],limits[col][1],limits[row][0],limits[row][1]],colors=opts['samples_colors'][n],**opts['contour_offdiag'])else:passelifopts['upper'][n]=='scatter':h=plt.scatter(v[:,col],v[:,row],color=opts['samples_colors'][n],**opts['scatter_offdiag'])elifopts['upper'][n]=='plot':h=plt.plot(v[:,col],v[:,row],color=opts['samples_colors'][n],**opts['plot_offdiag'])else:passiflen(points)>0:forn,vinenumerate(points):h=plt.plot(v[:,col],v[:,row],color=opts['points_colors'][n],**opts['points_offdiag'])iflen(subset)<dim:forrowinrange(len(subset)):ax=axes[row,len(subset)-1]x0,x1=ax.get_xlim()y0,y1=ax.get_ylim()text_kwargs={'fontsize':plt.rcParams['font.size']*2.}ax.text(x1+(x1-x0)/8.,(y0+y1)/2.,'...',**text_kwargs)ifrow==len(subset)-1:ax.text(x1+(x1-x0)/12.,y0-(y1-y0)/1.5,'...',rotation=-45,**text_kwargs)returnfig,axes