diff --git a/python/swiftest/swiftest/init_cond.py b/python/swiftest/swiftest/init_cond.py index 641fa2b68..c90f1d59b 100644 --- a/python/swiftest/swiftest/init_cond.py +++ b/python/swiftest/swiftest/init_cond.py @@ -8,17 +8,26 @@ You should have received a copy of the GNU General Public License along with Swiftest. If not, see: https://www.gnu.org/licenses. """ +from __future__ import annotations import swiftest import numpy as np +import numpy.typing as npt from astroquery.jplhorizons import Horizons import astropy.units as u from astropy.coordinates import SkyCoord import datetime -from datetime import date import xarray as xr - -def solar_system_horizons(plname, idval, param, ephemerides_start_date): +from typing import ( + Literal, + Dict, + List, + Any +) +def solar_system_horizons(plname: str, + param: Dict, + ephemerides_start_date: str, + idval: int | None = None): """ Initializes a Swiftest dataset containing the major planets of the Solar System at a particular data from JPL/Horizons @@ -118,10 +127,10 @@ def solar_system_horizons(plname, idval, param, ephemerides_start_date): THIRDLONG = np.longdouble(1.0) / np.longdouble(3.0) # Central body value vectors - GMcb = np.array([swiftest.GMSun * param['TU2S'] ** 2 / param['DU2M'] ** 3]) - Rcb = np.array([swiftest.RSun / param['DU2M']]) - J2RP2 = np.array([swiftest.J2Sun * (swiftest.RSun / param['DU2M']) ** 2]) - J4RP4 = np.array([swiftest.J4Sun * (swiftest.RSun / param['DU2M']) ** 4]) + GMcb = swiftest.GMSun * param['TU2S'] ** 2 / param['DU2M'] ** 3 + Rcb = swiftest.RSun / param['DU2M'] + J2RP2 = swiftest.J2Sun * (swiftest.RSun / param['DU2M']) ** 2 + J4RP4 = swiftest.J4Sun * (swiftest.RSun / param['DU2M']) ** 4 solarpole = SkyCoord(ra=286.13 * u.degree, dec=63.87 * u.degree) solarrot = planetrot['Sun'] * param['TU2S'] @@ -145,12 +154,12 @@ def solar_system_horizons(plname, idval, param, ephemerides_start_date): J2 = J2RP2 J4 = J4RP4 if param['ROTATION']: - Ip1 = [Ipsun[0]] - Ip2 = [Ipsun[1]] - Ip3 = [Ipsun[2]] - rotx = [rotcb.x] - roty = [rotcb.y] - rotz = [rotcb.z] + Ip1 = Ipsun[0] + Ip2 = Ipsun[1] + Ip3 = Ipsun[2] + rotx = rotcb.x.value + roty = rotcb.y.value + rotz = rotcb.z.value else: Ip1 = None Ip2 = None @@ -168,12 +177,6 @@ def solar_system_horizons(plname, idval, param, ephemerides_start_date): ephemerides_end_date = tend.isoformat() ephemerides_step = '1d' - v1 = [] - v2 = [] - v3 = [] - v4 = [] - v5 = [] - v6 = [] J2 = None J4 = None @@ -183,42 +186,33 @@ def solar_system_horizons(plname, idval, param, ephemerides_start_date): 'step': ephemerides_step}) if param['IN_FORM'] == 'XV': - v1.append(pldata[plname].vectors()['x'][0] * DCONV) - v2.append(pldata[plname].vectors()['y'][0] * DCONV) - v3.append(pldata[plname].vectors()['z'][0] * DCONV) - v4.append(pldata[plname].vectors()['vx'][0] * VCONV) - v5.append(pldata[plname].vectors()['vy'][0] * VCONV) - v6.append(pldata[plname].vectors()['vz'][0] * VCONV) + v1 = pldata[plname].vectors()['x'][0] * DCONV + v2 = pldata[plname].vectors()['y'][0] * DCONV + v3 = pldata[plname].vectors()['z'][0] * DCONV + v4 = pldata[plname].vectors()['vx'][0] * VCONV + v5 = pldata[plname].vectors()['vy'][0] * VCONV + v6 = pldata[plname].vectors()['vz'][0] * VCONV elif param['IN_FORM'] == 'EL': - v1.append(pldata[plname].elements()['a'][0] * DCONV) - v2.append(pldata[plname].elements()['e'][0]) - v3.append(pldata[plname].elements()['incl'][0]) - v4.append(pldata[plname].elements()['Omega'][0]) - v5.append(pldata[plname].elements()['w'][0]) - v6.append(pldata[plname].elements()['M'][0]) + v1 = pldata[plname].elements()['a'][0] * DCONV + v2 = pldata[plname].elements()['e'][0] + v3 = pldata[plname].elements()['incl'][0] + v4 = pldata[plname].elements()['Omega'][0] + v5 = pldata[plname].elements()['w'][0] + v6 = pldata[plname].elements()['M'][0] if ispl: - GMpl = [] - GMpl.append(GMcb[0] / MSun_over_Mpl[plname]) + GMpl = GMcb / MSun_over_Mpl[plname] if param['CHK_CLOSE']: - Rpl = [] - Rpl.append(planetradius[plname] * DCONV) + Rpl = planetradius[plname] * DCONV else: Rpl = None # Generate planet value vectors if (param['RHILL_PRESENT']): - rhill = [] - rhill.append(pldata[plname].elements()['a'][0] * DCONV * (3 * MSun_over_Mpl[plname]) ** (-THIRDLONG)) + rhill = pldata[plname].elements()['a'][0] * DCONV * (3 * MSun_over_Mpl[plname]) ** (-THIRDLONG) else: rhill = None if (param['ROTATION']): - Ip1 = [] - Ip2 = [] - Ip3 = [] - rotx = [] - roty = [] - rotz = [] RA = pldata[plname].ephemerides()['NPole_RA'][0] DEC = pldata[plname].ephemerides()['NPole_DEC'][0] @@ -226,12 +220,12 @@ def solar_system_horizons(plname, idval, param, ephemerides_start_date): rotrate = planetrot[plname] * param['TU2S'] rot = rotpole.cartesian * rotrate Ip = np.array([0.0, 0.0, planetIpz[plname]]) - Ip1.append(Ip[0]) - Ip2.append(Ip[1]) - Ip3.append(Ip[2]) - rotx.append(rot.x) - roty.append(rot.y) - rotz.append(rot.z) + Ip1 = Ip[0] + Ip2 = Ip[1] + Ip3 = Ip[2] + rotx = rot.x.value + roty = rot.y.value + rotz = rot.z.value else: Ip1 = None Ip2 = None @@ -243,13 +237,33 @@ def solar_system_horizons(plname, idval, param, ephemerides_start_date): GMpl = None if idval is None: - plid = np.array([planetid[plname]], dtype=int) + plid = planetid[plname] else: - plid = np.array([idval], dtype=int) + plid = idval - return plid,[plname],v1,v2,v3,v4,v5,v6,GMpl,Rpl,rhill,Ip1,Ip2,Ip3,rotx,roty,rotz,J2,J4 + return plname,v1,v2,v3,v4,v5,v6,idval,GMpl,Rpl,rhill,Ip1,Ip2,Ip3,rotx,roty,rotz,J2,J4 -def vec2xr(param, idvals, namevals, v1, v2, v3, v4, v5, v6, GMpl=None, Rpl=None, rhill=None, Ip1=None, Ip2=None, Ip3=None, rotx=None, roty=None, rotz=None, J2=None, J4=None,t=0.0): +def vec2xr(param: Dict, + namevals: npt.NDArray[np.str_], + v1: npt.NDArray[np.float_], + v2: npt.NDArray[np.float_], + v3: npt.NDArray[np.float_], + v4: npt.NDArray[np.float_], + v5: npt.NDArray[np.float_], + v6: npt.NDArray[np.float_], + idvals: npt.NDArray[np.int_], + GMpl: npt.NDArray[np.float_] | None=None, + Rpl: npt.NDArray[np.float_] | None=None, + rhill: npt.NDArray[np.float_] | None=None, + Ip1: npt.NDArray[np.float_] | None=None, + Ip2: npt.NDArray[np.float_] | None=None, + Ip3: npt.NDArray[np.float_] | None=None, + rotx: npt.NDArray[np.float_] | None=None, + roty: npt.NDArray[np.float_] | None=None, + rotz: npt.NDArray[np.float_] | None=None, + J2: npt.NDArray[np.float_] | None=None, + J4: npt.NDArray[np.float_] | None=None, + t: float=0.0): """ Converts and stores the variables of all bodies in an xarray dataset. @@ -297,11 +311,6 @@ def vec2xr(param, idvals, namevals, v1, v2, v3, v4, v5, v6, GMpl=None, Rpl=None, ------- ds : xarray dataset """ - if v1 is None: # This is the central body - iscb = True - else: - iscb = False - if param['ROTATION']: if Ip1 is None: Ip1 = np.full_like(v1, 0.4) @@ -318,11 +327,18 @@ def vec2xr(param, idvals, namevals, v1, v2, v3, v4, v5, v6, GMpl=None, Rpl=None, dims = ['time', 'id', 'vec'] infodims = ['id', 'vec'] - if not iscb and GMpl is not None: + + # The central body is always given id 0 + icb = idvals == 0 + iscb = any(icb) + + if GMpl is not None: ispl = True + ipl = ~np.isnan(GMpl) + itp = np.isnan(GMpl) else: ispl = False - + if ispl and param['CHK_CLOSE'] and Rpl is None: print("Massive bodies need a radius value.") return None @@ -337,83 +353,33 @@ def vec2xr(param, idvals, namevals, v1, v2, v3, v4, v5, v6, GMpl=None, Rpl=None, param['OUT_FORM'] = old_out_form vec_str = np.vstack([namevals]) label_str = ["name"] + particle_type = np.empty_like(namevals) if iscb: label_float = clab.copy() - vec_float = np.vstack([GMpl,Rpl,J2,J4]) + vec_float = np.vstack([GMpl[icb],Rpl[icb],J2[icb],J4[icb]]) + if param['ROTATION']: + vec_float = np.vstack([vec_float, Ip1[icb], Ip2[icb], Ip3[icb], rotx[icb], roty[icb], rotz[icb]]) + particle_type[icb] = "Central Body" + # vec_float = np.vstack([v1, v2, v3, v4, v5, v6]) + if ispl: + label_float = plab.copy() + vec_float = np.vstack([vec_float, GMpl]) + if param['CHK_CLOSE']: + vec_float = np.vstack([vec_float, Rpl]) + if param['RHILL_PRESENT']: + vec_float = np.vstack([vec_float, rhill]) if param['ROTATION']: vec_float = np.vstack([vec_float, Ip1, Ip2, Ip3, rotx, roty, rotz]) - particle_type = "Central Body" + particle_type[ipl] = np.repeat("Massive Body",idvals.size) else: - vec_float = np.vstack([v1, v2, v3, v4, v5, v6]) - if ispl: - label_float = plab.copy() - vec_float = np.vstack([vec_float, GMpl]) - if param['CHK_CLOSE']: - vec_float = np.vstack([vec_float, Rpl]) - if param['RHILL_PRESENT']: - vec_float = np.vstack([vec_float, rhill]) - if param['ROTATION']: - vec_float = np.vstack([vec_float, Ip1, Ip2, Ip3, rotx, roty, rotz]) - particle_type = np.repeat("Massive Body",idvals.size) - else: - label_float = tlab.copy() - particle_type = np.repeat("Test Particle",idvals.size) - origin_type = np.repeat("User Added Body",idvals.size) - origin_time = np.full_like(v1,t) - collision_id = np.full_like(idvals,0) - origin_xhx = v1 - origin_xhy = v2 - origin_xhz = v3 - origin_vhx = v4 - origin_vhy = v5 - origin_vhz = v6 - discard_time = np.full_like(v1,-1.0) - status = np.repeat("ACTIVE",idvals.size) - discard_xhx = np.zeros_like(v1) - discard_xhy = np.zeros_like(v1) - discard_xhz = np.zeros_like(v1) - discard_vhx = np.zeros_like(v1) - discard_vhy = np.zeros_like(v1) - discard_vhz = np.zeros_like(v1) - discard_body_id = np.full_like(idvals,-1) - info_vec_float = np.vstack([ - origin_time, - origin_xhx, - origin_xhy, - origin_xhz, - origin_vhx, - origin_vhy, - origin_vhz, - discard_time, - discard_xhx, - discard_xhy, - discard_xhz, - discard_vhx, - discard_vhy, - discard_vhz]) - info_vec_int = np.vstack([collision_id, discard_body_id]) - info_vec_str = np.vstack([particle_type, origin_type, status]) - frame_float = info_vec_float.T - frame_int = info_vec_int.T - frame_str = info_vec_str.T - if param['IN_TYPE'] == 'NETCDF_FLOAT': - ftype=np.float32 - elif param['IN_TYPE'] == 'NETCDF_DOUBLE' or param['IN_TYPE'] == 'ASCII': - ftype=np.float64 - da_float = xr.DataArray(frame_float, dims=infodims, coords={'id': idvals, 'vec': infolab_float}).astype(ftype) - da_int = xr.DataArray(frame_int, dims=infodims, coords={'id': idvals, 'vec': infolab_int}) - da_str = xr.DataArray(frame_str, dims=infodims, coords={'id': idvals, 'vec': infolab_str}) - ds_float = da_float.to_dataset(dim="vec") - ds_int = da_int.to_dataset(dim="vec") - ds_str = da_str.to_dataset(dim="vec") - info_ds = xr.combine_by_coords([ds_float, ds_int, ds_str]) - + label_float = tlab.copy() + particle_type[itp] = np.repeat("Test Particle",idvals.size) frame_float = np.expand_dims(vec_float.T, axis=0) frame_str = vec_str.T - da_float = xr.DataArray(frame_float, dims=dims, coords={'time': [t], 'id': idvals, 'vec': label_float}).astype(ftype) + da_float = xr.DataArray(frame_float, dims=dims, coords={'time': [t], 'id': idvals, 'vec': label_float}) da_str= xr.DataArray(frame_str, dims=infodims, coords={'id': idvals, 'vec': label_str}) ds_float = da_float.to_dataset(dim="vec") ds_str = da_str.to_dataset(dim="vec") - ds = xr.combine_by_coords([ds_float, ds_str,info_ds]) + ds = xr.combine_by_coords([ds_float, ds_str]) return ds \ No newline at end of file diff --git a/python/swiftest/swiftest/simulation_class.py b/python/swiftest/swiftest/simulation_class.py index 0d26ebdf0..01c3b8414 100644 --- a/python/swiftest/swiftest/simulation_class.py +++ b/python/swiftest/swiftest/simulation_class.py @@ -14,10 +14,11 @@ from swiftest import init_cond from swiftest import tool from swiftest import constants +import os import datetime import xarray as xr import numpy as np -import os +import numpy.typing as npt import shutil from typing import ( Literal, @@ -1506,39 +1507,58 @@ def get_distance_range(self, arg_list: str | List[str] | None = None, verbose: b return range_dict def add_solar_system_body(self, - name: str | List[str] | None = None, - id: int | List[int] | None = None, + name: str | List[str], + ephemeris_id: int | List[int] | None = None, date: str | None = None, - origin_type: str = "initial_conditions", source: str = "HORIZONS"): """ Adds a solar system body to an existing simulation Dataset from the JPL Horizons ephemeris service. - + + The following are name/ephemeris_id pairs that are currently known to Swiftest, and therefore have + physical properties that can be used to make massive bodies. + + Sun : 0 + Mercury : 1 + Venus : 2 + Earth : 3 + Mars : 4 + Jupiter : 5 + Saturn : 6 + Uranus : 7 + Neptune : 8 + Pluto : 9 + Parameters ---------- - name : str | List[str], optional - Add solar system body by name. - Currently bodies from the following list will result in fully-massive bodies (they include mass, radius, - and rotation parameters). "Sun" (added as a central body), "Mercury", "Venus", "Earth", "Mars", - "Jupiter", "Saturn", "Uranus", "Neptune", "Pluto" - - Bodies not on this list will be added as test particles, but additional properties can be added later if - desired. - id : int | List[int], optional - Add solar system body by id number. - date : str, optional - ISO-formatted date sto use when obtaining the ephemerides in the format YYYY-MM-DD. Defaults to value - set by `set_ephemeris_date`. - origin_type : str, default "initial_conditions" - The string that will be added to the `origin_type` variable for all bodies added to the list - source : str, default "Horizons" - The source of the ephemerides. - >*Note.* Currently only the JPL Horizons ephemeris is implemented, so this is ignored. + name : str | List[str] + Add solar system body by name. + Bodies not on this list will be added as test particles, but additional properties can be added later if + desired. + ephemeris_id : int | List[int], optional but must be the same length as `name` if passed. + Use id if the body you wish to add is recognized by Swiftest. In that case, the id is passed to the + ephemeris service and the name is used. The body specified by `id` supercedes that given by `name`. + date : str, optional + ISO-formatted date sto use when obtaining the ephemerides in the format YYYY-MM-DD. Defaults to value + set by `set_ephemeris_date`. + source : str, default "Horizons" + The source of the ephemerides. + >*Note.* Currently only the JPL Horizons ephemeris is implemented, so this is ignored. Returns ------- ds : Xarray dataset with body or bodies added. """ + if type(name) is str: + name = [name] + if ephemeris_id is not None: + if type(ephemeris_id) is int: + ephemeris_id = [ephemeris_id] + if len(ephemeris_id) != len(name): + print(f"Error! The length of ephemeris_id ({len(ephemeris_id)}) does not match the length of name ({len(name)})") + return None + else: + ephemeris_id = [None] * len(name) + if self.ephemeris_date is None: self.set_ephemeris_date() @@ -1553,32 +1573,59 @@ def add_solar_system_body(self, if source.upper() != "HORIZONS": print("Currently only the JPL Horizons ephemeris service is supported") - if id is not None and name is not None: - print("Warning! Requesting both id and name could lead to duplicate bodies.") - dsnew = [] - if name is not None: - if type(name) is str: - name = [name] - - if origin_type is None: - origin_type = ['initial_conditions'] * len(name) - - for n in name: - dsnew.append(self.addp(*init_cond.solar_system_horizons(n, self.param, date))) - - - if id is not None: - if type(id) is str: - id = [id] - - if origin_type is None: - origin_type = ['initial_conditions'] * len(id) - - for i in id: - dsnew.append(self.addp(*init_cond.solar_system_horizons(i, self.param, date))) - - - return + body_list = [] + for i,n in enumerate(name): + body_list.append(init_cond.solar_system_horizons(n, self.param, date, idval=ephemeris_id[i])) + + #Convert the list receieved from the solar_system_horizons output and turn it into arguments to vec2xr + name,v1,v2,v3,v4,v5,v6,ephemeris_id,GMpl,Rpl,rhill,Ip1,Ip2,Ip3,rotx,roty,rotz,J2,J4 = tuple(np.squeeze(np.hsplit(np.array(body_list),19))) + + v1 = v1.astype(np.float64) + v2 = v2.astype(np.float64) + v3 = v3.astype(np.float64) + v4 = v4.astype(np.float64) + v5 = v5.astype(np.float64) + v6 = v6.astype(np.float64) + ephemeris_id = ephemeris_id.astype(int) + GMpl = GMpl.astype(np.float64) + Rpl = Rpl.astype(np.float64) + rhill = rhill.astype(np.float64) + Ip1 = Ip1.astype(np.float64) + Ip2 = Ip2.astype(np.float64) + Ip3 = Ip3.astype(np.float64) + rotx = rotx.astype(np.float64) + roty = roty.astype(np.float64) + rotz = rotz.astype(np.float64) + J2 = J2.astype(np.float64) + J4 = J4.astype(np.float64) + + if all(np.isnan(GMpl)): + GMpl = None + if all(np.isnan(Rpl)): + Rpl = None + if all(np.isnan(rhill)): + rhill = None + if all(np.isnan(Ip1)): + Ip1 = None + if all(np.isnan(Ip2)): + Ip2 = None + if all(np.isnan(Ip3)): + Ip3 = None + if all(np.isnan(rotx)): + rotx = None + if all(np.isnan(roty)): + roty = None + if all(np.isnan(rotz)): + rotz = None + if all(np.isnan(J2)): + J2 = None + if all(np.isnan(J4)): + J4 = None + + + dsnew = init_cond.vec2xr(self.param,name,v1,v2,v3,v4,v5,v6,ephemeris_id,GMpl,Rpl,rhill,Ip1,Ip2,Ip3,rotx,roty,rotz,J2,J4) + + return body_list def set_ephemeris_date(self, @@ -1662,48 +1709,131 @@ def get_ephemeris_date(self, verbose: bool | None = None, **kwargs: Any): return self.ephemeris_date - def add_body(self, namevals, v1, v2, v3, v4, v5, v6, GMpl=None, Rpl=None, rhill=None, Ip1=None, Ip2=None, - Ip3=None, rotx=None, roty=None, rotz=None, J2=None, J4=None, t=None): + def add_body(self, + name: str | List[str] | npt.NDArray[np.str_], + v1: float | List[float] | npt.NDArray[np.float_], + v2: float | List[float] | npt.NDArray[np.float_], + v3: float | List[float] | npt.NDArray[np.float_], + v4: float | List[float] | npt.NDArray[np.float_], + v5: float | List[float] | npt.NDArray[np.float_], + v6: float | List[float] | npt.NDArray[np.float_], + idvals: int | list[int] | npt.NDArray[np.int_] | None=None, + GMpl: float | List[float] | npt.NDArray[np.float_] | None=None, + Rpl: float | List[float] | npt.NDArray[np.float_] | None=None, + rhill: float | List[float] | npt.NDArray[np.float_] | None=None, + Ip1: float | List[float] | npt.NDArray[np.float_] | None=None, + Ip2: float | List[float] | npt.NDArray[np.float_] | None=None, + Ip3: float | List[float] | npt.NDArray[np.float_] | None=None, + rotx: float | List[float] | npt.NDArray[np.float_] | None=None, + roty: float | List[float] | npt.NDArray[np.float_] | None=None, + rotz: float | List[float] | npt.NDArray[np.float_] | None=None, + J2: float | List[float] | npt.NDArray[np.float_] | None=None, + J4: float | List[float] | npt.NDArray[np.float_] | None=None): """ Adds a body (test particle or massive body) to the internal DataSet given a set up 6 vectors (orbital elements - or cartesian state vectors, depending on the value of self.param). Input all angles in degress + or cartesian state vectors, depending on the value of self.param). Input all angles in degress. + + This method will update self.ds with the new body or bodies added to the existing Dataset. Parameters ---------- - v1 : float - xh for param['IN_FORM'] == "XV"; a for param['IN_FORM'] == "EL" - v2 : float - yh for param['IN_FORM'] == "XV"; e for param['IN_FORM'] == "EL" - v3 : float - zh for param['IN_FORM'] == "XV"; inc for param['IN_FORM'] == "EL" - v4 : float - vhxh for param['IN_FORM'] == "XV"; capom for param['IN_FORM'] == "EL" - v5 : float - vhyh for param['IN_FORM'] == "XV"; omega for param['IN_FORM'] == "EL" - v6 : float - vhzh for param['IN_FORM'] == "XV"; capm for param['IN_FORM'] == "EL" - Gmass : float - Optional: Array of G*mass values if these are massive bodies - radius : float - Optional: Array radius values if these are massive bodies - rhill : float - Optional: Array rhill values if these are massive bodies - Ip1,y,z : float - Optional: Principal axes moments of inertia - rotx,y,z: float - Optional: Rotation rate vector components - t : float - Optional: Time at start of simulation + name : str or array-like of str + Name or names of + v1 : float or array-like of float + xhx for param['IN_FORM'] == "XV"; a for param['IN_FORM'] == "EL" + v2 : float or array-like of float + xhy for param['IN_FORM'] == "XV"; e for param['IN_FORM'] == "EL" + v3 : float or array-like of float + xhz for param['IN_FORM'] == "XV"; inc for param['IN_FORM'] == "EL" + v4 : float or array-like of float + vhx for param['IN_FORM'] == "XV"; capom for param['IN_FORM'] == "EL" + v5 : float or array-like of float + vhy for param['IN_FORM'] == "XV"; omega for param['IN_FORM'] == "EL" + v6 : float or array-like of float + vhz for param['IN_FORM'] == "XV"; capm for param['IN_FORM'] == "EL" + idvals : int or array-like of int, optional + Unique id values. If not passed, this will be computed based on the pre-existing Dataset ids. + Gmass : float or array-like of float, optional + G*mass values if these are massive bodies + radius : float or array-like of float, optional + Radius values if these are massive bodies + rhill : float, optional + Hill's radius values if these are massive bodies + Ip1,y,z : float, optional + Principal axes moments of inertia these are massive bodies with rotation enabled + rotx,y,z: float, optional + Rotation rate vector components if these are massive bodies with rotation enabled Returns ------- - self.ds : xarray dataset + ds : Xarray Dataset + Dasaset containing the body or bodies that were added + """ - if t is None: - t = self.param['T0'] - dsnew = init_cond.vec2xr(self.param, idvals, namevals, v1, v2, v3, v4, v5, v6, GMpl, Rpl, rhill, Ip1, Ip2, Ip3, - rotx, roty, rotz, J2, J4, t) + #convert all inputs to numpy arrays + def input_to_array(val,t,n=None): + if t == "f": + t = np.float64 + elif t == "i": + t = np.int64 + elif t == "s": + t = np.str + if val is None: + return None + elif np.isscalar(val): + val = np.array([val],dtype=t) + elif type(val) is list: + val = np.array(val,dtype=t) + + if n is None: + return val, len(val) + else: + if n != len(val): + raise ValueError(f"Error! Mismatched array lengths in add_body. Got {len(val)} when expecting {n}") + return val + + + name,nbodies = input_to_array(name,"s") + v1 = input_to_array(v1,"f",nbodies) + v2 = input_to_array(v2,"f",nbodies) + v3 = input_to_array(v3,"f",nbodies) + v4 = input_to_array(v4,"f",nbodies) + v5 = input_to_array(v5,"f",nbodies) + v6 = input_to_array(v6,"f",nbodies) + idvals = input_to_array(idvals,"i",nbodies) + GMpl = input_to_array(GMpl,"f",nbodies) + rhill = input_to_array(rhill,"f",nbodies) + Rpl = input_to_array(Rpl,"f",nbodies) + Ip1 = input_to_array(Ip1,"f",nbodies) + Ip2 = input_to_array(Ip2,"f",nbodies) + Ip3 = input_to_array(Ip3,"f",nbodies) + rotx = input_to_array(rotx,"f",nbodies) + roty = input_to_array(roty,"f",nbodies) + rotz = input_to_array(rotz,"f",nbodies) + J2 = input_to_array(J2,"f",nbodies) + J4 = input_to_array(J4,"f",nbodies) + + if len(self.ds) == 0: + maxid = -1 + else: + maxid = self.ds.id.max().values[()] + + if idvals is None: + idvals = np.arange(start=maxid+1,stop=maxid+1+nbodies,dtype=int) + + if len(self.ds) > 0: + dup_id = np.in1d(idvals,self.ds.id) + if any(dup_id): + raise ValueError(f"Duplicate ids detected: ", *idvals[dup_id]) + + t = self.param['TSTART'] + + dsnew = init_cond.vec2xr(self.param, idvals, name, v1, v2, v3, v4, v5, v6, + GMpl=GMpl, Rpl=Rpl, rhill=rhill, + Ip1=Ip1, Ip2=Ip2, Ip3=Ip3, + rotx=rotx, roty=roty, rotz=rotz, + J2=J2, J4=J4, t=t) if dsnew is not None: self.ds = xr.combine_by_coords([self.ds, dsnew]) self.ds['ntp'] = self.ds['id'].where(np.isnan(self.ds['Gmass'])).count(dim="id") @@ -1714,7 +1844,7 @@ def add_body(self, namevals, v1, v2, v3, v4, v5, v6, GMpl=None, Rpl=None, rhill= elif self.param['OUT_TYPE'] == "NETCDF_FLOAT": self.ds = io.fix_types(self.ds, ftype=np.float32) - return + return dsnew def read_param(self, param_file, codename="Swiftest", verbose=True): """