Skip to content
This repository was archived by the owner on Aug 28, 2024. It is now read-only.

Commit

Permalink
Finished the new add_body and add_solar_system_body methods
Browse files Browse the repository at this point in the history
  • Loading branch information
daminton committed Nov 11, 2022
1 parent 3c4c2b6 commit d166719
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 12 deletions.
13 changes: 9 additions & 4 deletions python/swiftest/swiftest/init_cond.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,7 +369,7 @@ def vec2xr(param: Dict,
vec_cb = np.expand_dims(vec_cb.T,axis=0) # Make way for the time dimension!
ds_cb = xr.DataArray(vec_cb, dims=dims, coords={'time': [t], 'id': idvals[icb], 'vec': lab_cb}).to_dataset(dim='vec')
else:
ds_cb = xr.Dataset()
ds_cb = None
if ispl:
lab_pl = plab.copy()
vec_pl = np.vstack([vec[:,ipl], GMpl[ipl]])
Expand All @@ -383,16 +383,21 @@ def vec2xr(param: Dict,
vec_pl = np.expand_dims(vec_pl.T,axis=0) # Make way for the time dimension!
ds_pl = xr.DataArray(vec_pl, dims=dims, coords={'time': [t], 'id': idvals[ipl], 'vec': lab_pl}).to_dataset(dim='vec')
else:
ds_pl = xr.Dataset()
ds_pl = None
if istp:
lab_tp = tlab.copy()
vec_tp = np.expand_dims(vec[:,itp].T,axis=0) # Make way for the time dimension!
ds_tp = xr.DataArray(vec_tp, dims=dims, coords={'time': [t], 'id': idvals[itp], 'vec': lab_tp}).to_dataset(dim='vec')
particle_type[itp] = np.repeat("Test Particle",idvals[itp].size)
else:
ds_tp = xr.Dataset()
ds_tp = None

ds_info = xr.DataArray(np.vstack([namevals,particle_type]).T, dims=infodims, coords={'id': idvals, 'vec' : ["name", "particle_type"]}).to_dataset(dim='vec')
ds = xr.combine_by_coords([ds_cb, ds_pl, ds_tp, ds_info])
ds = [d for d in [ds_cb, ds_pl, ds_tp] if d is not None]
if len(ds) > 1:
ds = xr.combine_by_coords(ds)
else:
ds = ds[0]
ds = xr.merge([ds_info,ds])

return ds
52 changes: 44 additions & 8 deletions python/swiftest/swiftest/simulation_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -1622,10 +1622,15 @@ def add_solar_system_body(self,
if all(np.isnan(J4)):
J4 = None

t = self.param['TSTART']

dsnew = init_cond.vec2xr(self.param,name,v1,v2,v3,v4,v5,v6,ephemeris_id,GMpl,Rpl,rhill,Ip1,Ip2,Ip3,rotx,roty,rotz,J2,J4)
dsnew = init_cond.vec2xr(self.param,name,v1,v2,v3,v4,v5,v6,ephemeris_id,
GMpl=GMpl, Rpl=Rpl, rhill=rhill,
Ip1=Ip1, Ip2=Ip2, Ip3=Ip3,
rotx=rotx, roty=roty, rotz=rotz,
J2=J2, J4=J4, t=t)

self.ds = xr.combine_by_coords([self.ds,dsnew])
dsnew = self._combine_and_fix_dsnew(dsnew)

return dsnew

Expand Down Expand Up @@ -1831,19 +1836,50 @@ def input_to_array(val,t,n=None):

t = self.param['TSTART']

dsnew = init_cond.vec2xr(self.param, idvals, name, v1, v2, v3, v4, v5, v6,
dsnew = init_cond.vec2xr(self.param, name, v1, v2, v3, v4, v5, v6, idvals,
GMpl=GMpl, Rpl=Rpl, rhill=rhill,
Ip1=Ip1, Ip2=Ip2, Ip3=Ip3,
rotx=rotx, roty=roty, rotz=rotz,
J2=J2, J4=J4, t=t)
if dsnew is not None:
self.ds = xr.combine_by_coords([self.ds, dsnew])
self.ds['ntp'] = self.ds['id'].where(np.isnan(self.ds['Gmass'])).count(dim="id")
self.ds['npl'] = self.ds['id'].where(np.invert(np.isnan(self.ds['Gmass']))).count(dim="id") - 1
J2=J2, J4=J4,t=t)

dsnew = self._combine_and_fix_dsnew(dsnew)

return dsnew

def _combine_and_fix_dsnew(self,dsnew):
"""
Combines the new Dataset with the old one. Also computes the values of ntp and npl and sets the proper types.
Parameters
----------
dsnew : xarray Dataset
Dataset with new bodies
Returns
-------
dsnew : xarray Dataset
Updated Dataset with ntp, npl values and types fixed.
"""

self.ds = xr.combine_by_coords([self.ds, dsnew])

def get_nvals(ds):
if "Gmass" in dsnew:
ds['ntp'] = ds['id'].where(np.isnan(ds['Gmass'])).count(dim="id")
ds['npl'] = ds['id'].where(np.invert(np.isnan(ds['Gmass']))).count(dim="id") - 1
else:
ds['ntp'] = ds['id'].count(dim="id")
ds['npl'] = xr.full_like(ds['ntp'],0)
return ds

dsnew = get_nvals(dsnew)
self.ds = get_nvals(self.ds)

if self.param['OUT_TYPE'] == "NETCDF_DOUBLE":
dsnew = io.fix_types(dsnew, ftype=np.float64)
self.ds = io.fix_types(self.ds, ftype=np.float64)
elif self.param['OUT_TYPE'] == "NETCDF_FLOAT":
dsnew = io.fix_types(dsnew, ftype=np.float32)
self.ds = io.fix_types(self.ds, ftype=np.float32)

return dsnew
Expand Down

0 comments on commit d166719

Please sign in to comment.