Skip to content
This repository was archived by the owner on Aug 28, 2024. It is now read-only.

Commit

Permalink
Browse files Browse the repository at this point in the history
Added in collision movie generator (but it won't work until I get the particle info file methods finished)
  • Loading branch information
daminton committed Aug 10, 2021
1 parent d657281 commit 44df3c5
Showing 1 changed file with 20 additions and 20 deletions.
40 changes: 20 additions & 20 deletions examples/symba_energy_momentum/collision_movie.py
100644 → 100755
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#!/usr/bin/env python3
import swiftest
import numpy as np
import matplotlib.pyplot as plt
Expand All @@ -17,8 +18,7 @@ def scale_sim(ds, param):

dsscale = ds

dsscale['Mass'] = ds['Mass'] / param['GU']
Mtot = dsscale['Mass'].sum(skipna=True, dim="id").isel(time=0)
GMtot = dsscale['GMass'].sum(skipna=True, dim="id").isel(time=0)
rscale = sum(ds['Radius'].sel(id=[2, 3], time=0)).item()
ds['Radius'] /= rscale

Expand All @@ -28,19 +28,19 @@ def scale_sim(ds, param):
dsscale['py'] /= rscale
dsscale['pz'] /= rscale

mpx = dsscale['Mass'] * dsscale['px']
mpy = dsscale['Mass'] * dsscale['py']
mpz = dsscale['Mass'] * dsscale['pz']
xbsys = mpx.sum(skipna=True, dim="id") / Mtot
ybsys = mpy.sum(skipna=True, dim="id") / Mtot
zbsys = mpz.sum(skipna=True, dim="id") / Mtot
mpx = dsscale['GMass'] * dsscale['px']
mpy = dsscale['GMass'] * dsscale['py']
mpz = dsscale['GMass'] * dsscale['pz']
xbsys = mpx.sum(skipna=True, dim="id") / GMtot
ybsys = mpy.sum(skipna=True, dim="id") / GMtot
zbsys = mpz.sum(skipna=True, dim="id") / GMtot

mvx = dsscale['Mass'] * dsscale['vx']
mvy = dsscale['Mass'] * dsscale['vy']
mvz = dsscale['Mass'] * dsscale['vz']
vxbsys = mvx.sum(skipna=True, dim="id") / Mtot
vybsys = mvy.sum(skipna=True, dim="id") / Mtot
vzbsys = mvz.sum(skipna=True, dim="id") / Mtot
mvx = dsscale['GMass'] * dsscale['vx']
mvy = dsscale['GMass'] * dsscale['vy']
mvz = dsscale['GMass'] * dsscale['vz']
vxbsys = mvx.sum(skipna=True, dim="id") / GMtot
vybsys = mvy.sum(skipna=True, dim="id") / GMtot
vzbsys = mvz.sum(skipna=True, dim="id") / GMtot

dsscale['pxb'] = dsscale['px'] - xbsys
dsscale['pyb'] = dsscale['py'] - ybsys
Expand Down Expand Up @@ -184,7 +184,7 @@ def spin_arrows(self, pl, id, len):
def setup_plot(self):
# First frame
"""Initial drawing of the scatter plot."""
t, name, Mass, Radius, npl, pl, radmarker, origin = next(self.data_stream(0))
t, name, GMass, Radius, npl, pl, radmarker, origin = next(self.data_stream(0))

cval = self.origin_to_color(origin)
# set up the figure
Expand Down Expand Up @@ -217,7 +217,7 @@ def setup_plot(self):

def update(self,frame):
"""Update the scatter plot."""
t, name, Mass, Radius, npl, pl, radmarker, origin = next(self.data_stream(frame))
t, name, GMass, Radius, npl, pl, radmarker, origin = next(self.data_stream(frame))
cval = self.origin_to_color(origin)
#varrowend, varrowtip = self.velocity_vectors(pl, radmarker)
sarrowend, sarrowtip = self.spin_arrows(pl, name, radmarker)
Expand All @@ -237,13 +237,13 @@ def data_stream(self, frame=0):
while True:
d = self.ds.isel(time=frame)
Radius = d['radmarker'].values
Mass = d['Mass'].values
GMass = d['GMass'].values
x = d['pxb'].values
y = d['pyb'].values
vx = d['vxb'].values
vy = d['vyb'].values
name = d['id'].values
npl = d['npl'].values
npl = d.id.count().values
id = d['id'].values
rotx = d['rot_x'].values
roty = d['rot_y'].values
Expand All @@ -260,7 +260,7 @@ def data_stream(self, frame=0):
vx = np.nan_to_num(vx, copy=False)
vy = np.nan_to_num(vy, copy=False)
radmarker = np.nan_to_num(radmarker, copy=False)
Mass = np.nan_to_num(Mass, copy=False)
GMass = np.nan_to_num(Mass, copy=False)
Radius = np.nan_to_num(Radius, copy=False)
rotx = np.nan_to_num(rotx, copy=False)
roty = np.nan_to_num(roty, copy=False)
Expand All @@ -278,7 +278,7 @@ def data_stream(self, frame=0):
for i in id[idxactive]:
self.rot_angle[i] = self.rot_angle[i] + dt * np.array(self.rotvec[i])
frame += 1
yield t, name, Mass, Radius, npl, np.c_[x, y, vx, vy], radmarker, origin
yield t, name, GMass, Radius, npl, np.c_[x, y, vx, vy], radmarker, origin

for case in cases:
if case == 'supercat_off':
Expand Down

0 comments on commit 44df3c5

Please sign in to comment.