diff --git a/examples/symba_mars_disk/aescattermovie.py b/examples/symba_mars_disk/aescattermovie.py index 0171b4a4b..909b674fb 100755 --- a/examples/symba_mars_disk/aescattermovie.py +++ b/examples/symba_mars_disk/aescattermovie.py @@ -13,6 +13,7 @@ ymin = 1e-6 ymax = 1.0 framejump = 1 +ncutoff = 1e20 class AnimatedScatter(object): """An animated scatter plot using matplotlib.animations.FuncAnimation.""" @@ -22,8 +23,10 @@ def __init__(self, ds, param): nframes = int(ds['time'].size / framejump) self.ds = ds self.param = param + self.Rcb = self.ds['radius'].sel(id=0, time=0).values self.ds['radmarker'] = self.ds['radius'].fillna(0) - self.ds['radmarker'] = self.ds['radmarker'] / self.ds['radmarker'].max() * radscale + np.where(self.ds['radmarker'] > ncutoff, 0, self.ds['radmarker']) + self.ds['radmarker'] = (self.ds['radmarker'] / self.Rcb) * radscale self.clist = {'Initial conditions' : 'xkcd:faded blue', 'Disruption' : 'xkcd:marigold', @@ -38,14 +41,14 @@ def __init__(self, ds, param): self.ax.set_xlim(xmin, xmax) self.ax.set_ylim(ymin, ymax) fig.add_axes(self.ax) - self.ani = animation.FuncAnimation(fig, self.update, interval=1, frames=nframes, init_func=self.setup_plot, blit=False) - self.ani.save('aescatter.mp4', fps=30, dpi=300, extra_args=['-vcodec', 'mpeg4']) + self.ani = animation.FuncAnimation(fig, self.update, interval=1, frames=nframes, init_func=self.setup_plot, blit=True) + self.ani.save('aescatter.mp4', fps=30, dpi=300, extra_args=['-vcodec', 'libx264']) print('Finished writing aescattter.mp4') def scatters(self, pl, radmarker, origin): scat = [] for key, value in self.clist.items(): - idx = origin == value + idx = origin == key s = self.ax.scatter(pl[idx, 0], pl[idx, 1], marker='o', s=radmarker[idx], c=value, alpha=0.75, label=key) scat.append(s) return scat @@ -62,7 +65,7 @@ def setup_plot(self): self.ax.set_yscale('log') self.title = self.ax.text(0.50, 1.05, "", bbox={'facecolor': 'w', 'alpha': 0.5, 'pad': 5}, transform=self.ax.transAxes, - ha="center") + ha="center", animated=True) self.title.set_text(f"{titletext} - Time = ${t / 24 / 3600:4.1f}$ days with ${npl:f}$ particles") slist = self.scatters(pl, radmarker, origin) @@ -76,15 +79,18 @@ def setup_plot(self): def data_stream(self, frame=0): while True: d = self.ds.isel(time=frame) - d = d.where(np.invert(np.isnan(d['a'])), drop=True) + d = d.where(d['a'] < ncutoff, drop=True) + radius = d['radmarker'].values + d = d.where(d['a'] > self.Rcb, drop=True) + Gmass = d['Gmass'].values - a = d['a'].values / RMars + a = d['a'].values / self.Rcb e = d['e'].values name = d['id'].values npl = d.id.count().values radmarker = d['radmarker'] - origin = d['origin_type'] + origin = d['origin_type'].values t = self.ds.coords['time'].values[frame]