Skip to content
This repository was archived by the owner on Aug 28, 2024. It is now read-only.

Commit

Permalink
Browse files Browse the repository at this point in the history
Changed animation script to plot rotations instead of velocity vectors
  • Loading branch information
daminton committed Jun 10, 2021
1 parent cc623e8 commit aa129ab
Showing 1 changed file with 72 additions and 13 deletions.
85 changes: 72 additions & 13 deletions examples/symba_energy_momentum/collision_movie.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,15 @@
import matplotlib.pyplot as plt
from matplotlib import animation
import matplotlib.collections as clt
from scipy.spatial.transform import Rotation as R

xmin = -20.0
xmax = 20.0
ymin = -20.0
ymax = 20.0

cases = ['supercat_head', 'supercat_off', 'disruption_head', 'disruption_off']
#cases = ['disruption_head']
#cases = ['supercat_head', 'supercat_off', 'disruption_head', 'disruption_off']
cases = ['disruption_off']

def scale_sim(ds, config):

Expand Down Expand Up @@ -69,6 +70,7 @@ def __init__(self, ds, config):
nframes = ds['time'].size
self.ds = scale_sim(ds, config)
self.config = config
self.rot_angle = {}

self.clist = {'Initial conditions' : 'xkcd:windows blue',
'Disruption' : 'xkcd:baby poop',
Expand Down Expand Up @@ -115,14 +117,25 @@ def vec_props(self, c):
return aarg

def plot_pl_vectors(self, pl, cval, r):
varrowend, varrowtip = self.arrowheads(pl, r)
varrowend, varrowtip = self.velocity_vectors(pl, r)
arrows = []
for i in range(pl.shape[0]):
aarg = self.vec_props(cval[i])
a = self.ax.annotate("",xy=varrowend[i],xytext=varrowtip[i], **aarg)
arrows.append(a)
return arrows

def plot_pl_spins(self, pl, id, cval, len):
sarrowend, sarrowtip = self.spin_arrows(pl, id, len)
arrows = []
for i in range(pl.shape[0]):
aarg = self.vec_props(cval[i])
aarg['arrowprops']['mutation_scale'] = 5
aarg['arrowprops']['arrowstyle'] = "simple"
a = self.ax.annotate("",xy=sarrowend[i],xytext=sarrowtip[i], **aarg)
arrows.append(a)
return arrows

def origin_to_color(self, origin):
cval = []
for o in origin:
Expand All @@ -131,17 +144,17 @@ def origin_to_color(self, origin):

return cval

def arrowheads(self, pl, r):
def velocity_vectors(self, pl, r):
px = pl[:, 0]
py = pl[:, 1]
vx = pl[:, 2]
vy = pl[:, 3]
vmag = np.sqrt(vx ** 2 + vy ** 2)
ux = np.zeros_like(vx)
uy = np.zeros_like(vx)
mask = vmag > 0.0
ux[mask] = vx[mask] / vmag[mask]
uy[mask] = vy[mask] / vmag[mask]
goodv = vmag > 0.0
ux[goodv] = vx[goodv] / vmag[goodv]
uy[goodv] = vy[goodv] / vmag[goodv]
varrowend = []
varrowtip = []
for i in range(pl.shape[0]):
Expand All @@ -151,6 +164,25 @@ def arrowheads(self, pl, r):
varrowtip.append(vtip)
return varrowend, varrowtip

def spin_arrows(self, pl, id, len):
px = pl[:, 0]
py = pl[:, 1]
sarrowend = []
sarrowtip = []
idxactive = np.arange(id.size)[self.mask]
for i in range(pl.shape[0]):
endrel = np.array([0.0, len[i], 0.0])
tiprel = np.array([0.0, -len[i], 0.0])
r = R.from_rotvec(self.rot_angle[id[i]])
if i in idxactive:
endrel = r.apply(endrel)
tiprel = r.apply(tiprel)
send = (px[i] + endrel[0], py[i] + endrel[1])
stip = (px[i] + tiprel[0], py[i] + tiprel[1])
sarrowend.append(send)
sarrowtip.append(stip)
return sarrowend, sarrowtip

def setup_plot(self):
# First frame
"""Initial drawing of the scatter plot."""
Expand Down Expand Up @@ -180,24 +212,28 @@ def setup_plot(self):

self.collection = UpdatablePatchCollection(self.patches, color=cval, alpha=0.5, zorder=50)
self.ax.add_collection(self.collection)
self.arrows = self.plot_pl_vectors(pl, cval, radmarker)
#self.varrows = self.plot_pl_vectors(pl, cval, radmarker)
self.sarrows = self.plot_pl_spins(pl, name, cval, radmarker)

return self.collection, self.arrows
return self.collection, self.sarrows

def update(self,frame):
"""Update the scatter plot."""
t, name, mass, radius, npl, pl, radmarker, origin = next(self.data_stream(frame))
cval = self.origin_to_color(origin)
varrowend, varrowtip = self.arrowheads(pl, radmarker)
#varrowend, varrowtip = self.velocity_vectors(pl, radmarker)
sarrowend, sarrowtip = self.spin_arrows(pl, name, radmarker)
for i, p in enumerate(self.patches):
p.set_center((pl[i, 0], pl[i,1]))
p.set_radius(radmarker[i])
p.set_color(cval[i])
self.arrows[i].set_position(varrowtip[i])
self.arrows[i].xy = varrowend[i]
#self.varrows[i].set_position(varrowtip[i])
#self.varrows[i].xy = varrowend[i]
self.sarrows[i].set_position(sarrowtip[i])
self.sarrows[i].xy = sarrowend[i]

self.collection.set_paths(self.patches)
return self.collection, self.arrows
return self.collection, self.sarrows

def data_stream(self, frame=0):
while True:
Expand All @@ -210,10 +246,16 @@ def data_stream(self, frame=0):
vy = d['vyb'].values
name = d['id'].values
npl = d['npl'].values
id = d['id'].values
rotx = d['rot_x'].values
roty = d['rot_y'].values
rotz = d['rot_z'].values

radmarker = d['radmarker'].values
origin = d['origin_type'].values

t = self.ds.coords['time'].values[frame]
self.mask = np.logical_not(np.isnan(x))

x = np.nan_to_num(x, copy=False)
y = np.nan_to_num(y, copy=False)
Expand All @@ -222,6 +264,23 @@ def data_stream(self, frame=0):
radmarker = np.nan_to_num(radmarker, copy=False)
mass = np.nan_to_num(mass, copy=False)
radius = np.nan_to_num(radius, copy=False)
rotx = np.nan_to_num(rotx, copy=False)
roty = np.nan_to_num(roty, copy=False)
rotz = np.nan_to_num(rotz, copy=False)
rotvec = np.array([rotx, roty, rotz])

if frame == 0:
tmp = np.zeros_like(rotvec)
self.rot_angle = dict(zip(id, zip(*tmp)))
else:
t0 = self.ds.coords['time'].values[frame-1]
dt = t - t0
for i in np.arange(npl):
if id[i] in self.rot_angle:
self.rot_angle[id[i]] = self.rot_angle[id[i]] + dt * rotvec[:,i]
self.rot_angle[id[i]] = self.rot_angle[id[i]] % (2 * np.pi)
else:
self.rot_angle[id[i]] = np.zeros(3)

frame += 1
yield t, name, mass, radius, npl, np.c_[x, y, vx, vy], radmarker, origin
Expand Down

0 comments on commit aa129ab

Please sign in to comment.