Permalink
Cannot retrieve contributors at this time
Name already in use
A tag already exists with the provided branch name. Many Git commands accept both tag and branch names, so creating this branch may cause unexpected behavior. Are you sure you want to create this branch?
ECE69500/plotters.py
Go to fileThis commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
124 lines (102 sloc)
3.84 KB
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# third-party libraries | |
import numpy as np | |
from scipy.stats import multivariate_normal | |
import bqplot as bq | |
import bqplot.pyplot as plt | |
from ipywidgets import FloatSlider | |
# local libraries | |
from utils_jgm.widgetizer import Widgetizer | |
class NormalizingFlowWidgetizer(Widgetizer): | |
def __init__(self): | |
self.independent_sliders = { | |
'b': FloatSlider( | |
description='b', value=0, min=-15, max=15, step=0.25, | |
readout_format='.2e', orientation='vertical', | |
), | |
'w_angle': FloatSlider( | |
description='w_angle', value=5*np.pi/4, min=0, max=2*np.pi, | |
step=np.pi/16, readout_format='.2e', orientation='vertical', | |
), | |
'w_mag': FloatSlider( | |
description='w_mag', value=3, min=-5, max=5, step=0.25, | |
readout_format='.2e', orientation='vertical', | |
), | |
'u_angle': FloatSlider( | |
description='u_angle', value=np.pi/2, min=0, max=2*np.pi, | |
step=np.pi/16, readout_format='.2e', orientation='vertical', | |
), | |
'u_mag': FloatSlider( | |
description='u_mag', value=1, min=-5, max=5, step=0.25, | |
readout_format='.2e', orientation='vertical', | |
), | |
} | |
super().__init__() | |
@staticmethod | |
def local_plotter(**kwargs): | |
return planar_flow(**kwargs) | |
def planar_flow( | |
b=0, w_angle=5*np.pi/4, w_mag=3, u_angle=np.pi/2, u_mag=1, plots_dict=None, | |
): | |
# ... | |
N = 500 | |
# | |
if plots_dict is None: | |
plots_dict = dict().fromkeys( | |
['source data', 'flowed data', 'projection axis', 'translation axis'] | |
) | |
# (re-)sample source data if the plot_data is None | |
if plots_dict['source data'] is None: | |
# random data (put elsewhere...) | |
mu = np.array([0, 0]) | |
Sigma = np.array([[1, 0], [0, 1]]) | |
mvn = multivariate_normal(mu, Sigma) | |
Z = mvn.rvs(N) | |
else: | |
Z = np.array([plots_dict['source data'].x, plots_dict['source data'].y]).T | |
# polar->cartesian for projection and shift axes | |
w = w_mag*np.array([np.cos(w_angle), np.sin(w_angle)]) | |
u = u_mag*np.array([np.cos(u_angle), np.sin(u_angle)]) | |
# the transformation | |
Y = Z + u*np.tanh((Z@w[:, None] + b)) | |
# if all the plots are empty... | |
if all(plot_data is None for plot_data in plots_dict.values()): | |
# ...then (re-)populate a figure from scratch | |
Ymin = Y.min(0)*2 | |
Ymax = Y.max(0)*2 | |
Yrange = Ymax - Ymin | |
fig = plt.figure( | |
title='Planar flow', | |
scales={ | |
'x': bq.LinearScale(min=Ymin[0], max=Ymax[0]), | |
'y': bq.LinearScale(min=Ymin[1], max=Ymax[1]), | |
}, | |
x_axis_label='Re', | |
y_axis_label='Im', | |
min_aspect_ratio=Yrange[0]/Yrange[1], | |
max_aspect_ratio=Yrange[0]/Yrange[1], | |
) | |
plots_dict['source data'] = plt.scatter( | |
[], [], colors=['blue'], opacity=[0.1]*Z.shape[0] | |
) | |
plots_dict['flowed data'] = plt.scatter( | |
[], [], colors=['orange'], opacity=[0.1]*Y.shape[0] | |
) | |
plots_dict['projection axis'] = plt.plot( | |
[], [], 'r', stroke_width=6 | |
) | |
plots_dict['translation axis'] = plt.plot( | |
[], [], 'g', stroke_width=6 | |
) | |
figures = [fig] | |
else: | |
figures = [] | |
# just change the values on | |
plots_dict['source data'].x = Z[:, 0] | |
plots_dict['source data'].y = Z[:, 1] | |
plots_dict['flowed data'].x = Y[:, 0] | |
plots_dict['flowed data'].y = Y[:, 1] | |
plots_dict['projection axis'].x = [0, w[0]] | |
plots_dict['projection axis'].y = [0, w[1]] | |
plots_dict['translation axis'].x = [0, u[0]] | |
plots_dict['translation axis'].y = [0, u[1]] | |
return figures, plots_dict, {} | |