Skip to content
This repository was archived by the owner on Aug 28, 2024. It is now read-only.

Commit

Permalink
Browse files Browse the repository at this point in the history
Made improvements to the solar system and body add methods. Still not quite finished, but getting close.
  • Loading branch information
daminton committed Nov 10, 2022
1 parent a8a2423 commit d07c603
Show file tree
Hide file tree
Showing 2 changed files with 306 additions and 210 deletions.
222 changes: 94 additions & 128 deletions python/swiftest/swiftest/init_cond.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,26 @@
You should have received a copy of the GNU General Public License along with Swiftest.
If not, see: https://www.gnu.org/licenses.
"""
from __future__ import annotations

import swiftest
import numpy as np
import numpy.typing as npt
from astroquery.jplhorizons import Horizons
import astropy.units as u
from astropy.coordinates import SkyCoord
import datetime
from datetime import date
import xarray as xr

def solar_system_horizons(plname, idval, param, ephemerides_start_date):
from typing import (
Literal,
Dict,
List,
Any
)
def solar_system_horizons(plname: str,
param: Dict,
ephemerides_start_date: str,
idval: int | None = None):
"""
Initializes a Swiftest dataset containing the major planets of the Solar System at a particular data from JPL/Horizons
Expand Down Expand Up @@ -118,10 +127,10 @@ def solar_system_horizons(plname, idval, param, ephemerides_start_date):
THIRDLONG = np.longdouble(1.0) / np.longdouble(3.0)

# Central body value vectors
GMcb = np.array([swiftest.GMSun * param['TU2S'] ** 2 / param['DU2M'] ** 3])
Rcb = np.array([swiftest.RSun / param['DU2M']])
J2RP2 = np.array([swiftest.J2Sun * (swiftest.RSun / param['DU2M']) ** 2])
J4RP4 = np.array([swiftest.J4Sun * (swiftest.RSun / param['DU2M']) ** 4])
GMcb = swiftest.GMSun * param['TU2S'] ** 2 / param['DU2M'] ** 3
Rcb = swiftest.RSun / param['DU2M']
J2RP2 = swiftest.J2Sun * (swiftest.RSun / param['DU2M']) ** 2
J4RP4 = swiftest.J4Sun * (swiftest.RSun / param['DU2M']) ** 4

solarpole = SkyCoord(ra=286.13 * u.degree, dec=63.87 * u.degree)
solarrot = planetrot['Sun'] * param['TU2S']
Expand All @@ -145,12 +154,12 @@ def solar_system_horizons(plname, idval, param, ephemerides_start_date):
J2 = J2RP2
J4 = J4RP4
if param['ROTATION']:
Ip1 = [Ipsun[0]]
Ip2 = [Ipsun[1]]
Ip3 = [Ipsun[2]]
rotx = [rotcb.x]
roty = [rotcb.y]
rotz = [rotcb.z]
Ip1 = Ipsun[0]
Ip2 = Ipsun[1]
Ip3 = Ipsun[2]
rotx = rotcb.x.value
roty = rotcb.y.value
rotz = rotcb.z.value
else:
Ip1 = None
Ip2 = None
Expand All @@ -168,12 +177,6 @@ def solar_system_horizons(plname, idval, param, ephemerides_start_date):
ephemerides_end_date = tend.isoformat()
ephemerides_step = '1d'

v1 = []
v2 = []
v3 = []
v4 = []
v5 = []
v6 = []
J2 = None
J4 = None

Expand All @@ -183,55 +186,46 @@ def solar_system_horizons(plname, idval, param, ephemerides_start_date):
'step': ephemerides_step})

if param['IN_FORM'] == 'XV':
v1.append(pldata[plname].vectors()['x'][0] * DCONV)
v2.append(pldata[plname].vectors()['y'][0] * DCONV)
v3.append(pldata[plname].vectors()['z'][0] * DCONV)
v4.append(pldata[plname].vectors()['vx'][0] * VCONV)
v5.append(pldata[plname].vectors()['vy'][0] * VCONV)
v6.append(pldata[plname].vectors()['vz'][0] * VCONV)
v1 = pldata[plname].vectors()['x'][0] * DCONV
v2 = pldata[plname].vectors()['y'][0] * DCONV
v3 = pldata[plname].vectors()['z'][0] * DCONV
v4 = pldata[plname].vectors()['vx'][0] * VCONV
v5 = pldata[plname].vectors()['vy'][0] * VCONV
v6 = pldata[plname].vectors()['vz'][0] * VCONV
elif param['IN_FORM'] == 'EL':
v1.append(pldata[plname].elements()['a'][0] * DCONV)
v2.append(pldata[plname].elements()['e'][0])
v3.append(pldata[plname].elements()['incl'][0])
v4.append(pldata[plname].elements()['Omega'][0])
v5.append(pldata[plname].elements()['w'][0])
v6.append(pldata[plname].elements()['M'][0])
v1 = pldata[plname].elements()['a'][0] * DCONV
v2 = pldata[plname].elements()['e'][0]
v3 = pldata[plname].elements()['incl'][0]
v4 = pldata[plname].elements()['Omega'][0]
v5 = pldata[plname].elements()['w'][0]
v6 = pldata[plname].elements()['M'][0]

if ispl:
GMpl = []
GMpl.append(GMcb[0] / MSun_over_Mpl[plname])
GMpl = GMcb / MSun_over_Mpl[plname]
if param['CHK_CLOSE']:
Rpl = []
Rpl.append(planetradius[plname] * DCONV)
Rpl = planetradius[plname] * DCONV
else:
Rpl = None

# Generate planet value vectors
if (param['RHILL_PRESENT']):
rhill = []
rhill.append(pldata[plname].elements()['a'][0] * DCONV * (3 * MSun_over_Mpl[plname]) ** (-THIRDLONG))
rhill = pldata[plname].elements()['a'][0] * DCONV * (3 * MSun_over_Mpl[plname]) ** (-THIRDLONG)
else:
rhill = None
if (param['ROTATION']):
Ip1 = []
Ip2 = []
Ip3 = []
rotx = []
roty = []
rotz = []
RA = pldata[plname].ephemerides()['NPole_RA'][0]
DEC = pldata[plname].ephemerides()['NPole_DEC'][0]

rotpole = SkyCoord(ra=RA * u.degree, dec=DEC * u.degree)
rotrate = planetrot[plname] * param['TU2S']
rot = rotpole.cartesian * rotrate
Ip = np.array([0.0, 0.0, planetIpz[plname]])
Ip1.append(Ip[0])
Ip2.append(Ip[1])
Ip3.append(Ip[2])
rotx.append(rot.x)
roty.append(rot.y)
rotz.append(rot.z)
Ip1 = Ip[0]
Ip2 = Ip[1]
Ip3 = Ip[2]
rotx = rot.x.value
roty = rot.y.value
rotz = rot.z.value
else:
Ip1 = None
Ip2 = None
Expand All @@ -243,13 +237,33 @@ def solar_system_horizons(plname, idval, param, ephemerides_start_date):
GMpl = None

if idval is None:
plid = np.array([planetid[plname]], dtype=int)
plid = planetid[plname]
else:
plid = np.array([idval], dtype=int)
plid = idval

return plid,[plname],v1,v2,v3,v4,v5,v6,GMpl,Rpl,rhill,Ip1,Ip2,Ip3,rotx,roty,rotz,J2,J4
return plname,v1,v2,v3,v4,v5,v6,idval,GMpl,Rpl,rhill,Ip1,Ip2,Ip3,rotx,roty,rotz,J2,J4

def vec2xr(param, idvals, namevals, v1, v2, v3, v4, v5, v6, GMpl=None, Rpl=None, rhill=None, Ip1=None, Ip2=None, Ip3=None, rotx=None, roty=None, rotz=None, J2=None, J4=None,t=0.0):
def vec2xr(param: Dict,
namevals: npt.NDArray[np.str_],
v1: npt.NDArray[np.float_],
v2: npt.NDArray[np.float_],
v3: npt.NDArray[np.float_],
v4: npt.NDArray[np.float_],
v5: npt.NDArray[np.float_],
v6: npt.NDArray[np.float_],
idvals: npt.NDArray[np.int_],
GMpl: npt.NDArray[np.float_] | None=None,
Rpl: npt.NDArray[np.float_] | None=None,
rhill: npt.NDArray[np.float_] | None=None,
Ip1: npt.NDArray[np.float_] | None=None,
Ip2: npt.NDArray[np.float_] | None=None,
Ip3: npt.NDArray[np.float_] | None=None,
rotx: npt.NDArray[np.float_] | None=None,
roty: npt.NDArray[np.float_] | None=None,
rotz: npt.NDArray[np.float_] | None=None,
J2: npt.NDArray[np.float_] | None=None,
J4: npt.NDArray[np.float_] | None=None,
t: float=0.0):
"""
Converts and stores the variables of all bodies in an xarray dataset.
Expand Down Expand Up @@ -297,11 +311,6 @@ def vec2xr(param, idvals, namevals, v1, v2, v3, v4, v5, v6, GMpl=None, Rpl=None,
-------
ds : xarray dataset
"""
if v1 is None: # This is the central body
iscb = True
else:
iscb = False

if param['ROTATION']:
if Ip1 is None:
Ip1 = np.full_like(v1, 0.4)
Expand All @@ -318,11 +327,18 @@ def vec2xr(param, idvals, namevals, v1, v2, v3, v4, v5, v6, GMpl=None, Rpl=None,

dims = ['time', 'id', 'vec']
infodims = ['id', 'vec']
if not iscb and GMpl is not None:

# The central body is always given id 0
icb = idvals == 0
iscb = any(icb)

if GMpl is not None:
ispl = True
ipl = ~np.isnan(GMpl)
itp = np.isnan(GMpl)
else:
ispl = False

if ispl and param['CHK_CLOSE'] and Rpl is None:
print("Massive bodies need a radius value.")
return None
Expand All @@ -337,83 +353,33 @@ def vec2xr(param, idvals, namevals, v1, v2, v3, v4, v5, v6, GMpl=None, Rpl=None,
param['OUT_FORM'] = old_out_form
vec_str = np.vstack([namevals])
label_str = ["name"]
particle_type = np.empty_like(namevals)
if iscb:
label_float = clab.copy()
vec_float = np.vstack([GMpl,Rpl,J2,J4])
vec_float = np.vstack([GMpl[icb],Rpl[icb],J2[icb],J4[icb]])
if param['ROTATION']:
vec_float = np.vstack([vec_float, Ip1[icb], Ip2[icb], Ip3[icb], rotx[icb], roty[icb], rotz[icb]])
particle_type[icb] = "Central Body"
# vec_float = np.vstack([v1, v2, v3, v4, v5, v6])
if ispl:
label_float = plab.copy()
vec_float = np.vstack([vec_float, GMpl])
if param['CHK_CLOSE']:
vec_float = np.vstack([vec_float, Rpl])
if param['RHILL_PRESENT']:
vec_float = np.vstack([vec_float, rhill])
if param['ROTATION']:
vec_float = np.vstack([vec_float, Ip1, Ip2, Ip3, rotx, roty, rotz])
particle_type = "Central Body"
particle_type[ipl] = np.repeat("Massive Body",idvals.size)
else:
vec_float = np.vstack([v1, v2, v3, v4, v5, v6])
if ispl:
label_float = plab.copy()
vec_float = np.vstack([vec_float, GMpl])
if param['CHK_CLOSE']:
vec_float = np.vstack([vec_float, Rpl])
if param['RHILL_PRESENT']:
vec_float = np.vstack([vec_float, rhill])
if param['ROTATION']:
vec_float = np.vstack([vec_float, Ip1, Ip2, Ip3, rotx, roty, rotz])
particle_type = np.repeat("Massive Body",idvals.size)
else:
label_float = tlab.copy()
particle_type = np.repeat("Test Particle",idvals.size)
origin_type = np.repeat("User Added Body",idvals.size)
origin_time = np.full_like(v1,t)
collision_id = np.full_like(idvals,0)
origin_xhx = v1
origin_xhy = v2
origin_xhz = v3
origin_vhx = v4
origin_vhy = v5
origin_vhz = v6
discard_time = np.full_like(v1,-1.0)
status = np.repeat("ACTIVE",idvals.size)
discard_xhx = np.zeros_like(v1)
discard_xhy = np.zeros_like(v1)
discard_xhz = np.zeros_like(v1)
discard_vhx = np.zeros_like(v1)
discard_vhy = np.zeros_like(v1)
discard_vhz = np.zeros_like(v1)
discard_body_id = np.full_like(idvals,-1)
info_vec_float = np.vstack([
origin_time,
origin_xhx,
origin_xhy,
origin_xhz,
origin_vhx,
origin_vhy,
origin_vhz,
discard_time,
discard_xhx,
discard_xhy,
discard_xhz,
discard_vhx,
discard_vhy,
discard_vhz])
info_vec_int = np.vstack([collision_id, discard_body_id])
info_vec_str = np.vstack([particle_type, origin_type, status])
frame_float = info_vec_float.T
frame_int = info_vec_int.T
frame_str = info_vec_str.T
if param['IN_TYPE'] == 'NETCDF_FLOAT':
ftype=np.float32
elif param['IN_TYPE'] == 'NETCDF_DOUBLE' or param['IN_TYPE'] == 'ASCII':
ftype=np.float64
da_float = xr.DataArray(frame_float, dims=infodims, coords={'id': idvals, 'vec': infolab_float}).astype(ftype)
da_int = xr.DataArray(frame_int, dims=infodims, coords={'id': idvals, 'vec': infolab_int})
da_str = xr.DataArray(frame_str, dims=infodims, coords={'id': idvals, 'vec': infolab_str})
ds_float = da_float.to_dataset(dim="vec")
ds_int = da_int.to_dataset(dim="vec")
ds_str = da_str.to_dataset(dim="vec")
info_ds = xr.combine_by_coords([ds_float, ds_int, ds_str])

label_float = tlab.copy()
particle_type[itp] = np.repeat("Test Particle",idvals.size)
frame_float = np.expand_dims(vec_float.T, axis=0)
frame_str = vec_str.T
da_float = xr.DataArray(frame_float, dims=dims, coords={'time': [t], 'id': idvals, 'vec': label_float}).astype(ftype)
da_float = xr.DataArray(frame_float, dims=dims, coords={'time': [t], 'id': idvals, 'vec': label_float})
da_str= xr.DataArray(frame_str, dims=infodims, coords={'id': idvals, 'vec': label_str})
ds_float = da_float.to_dataset(dim="vec")
ds_str = da_str.to_dataset(dim="vec")
ds = xr.combine_by_coords([ds_float, ds_str,info_ds])
ds = xr.combine_by_coords([ds_float, ds_str])

return ds
Loading

0 comments on commit d07c603

Please sign in to comment.