diff --git a/swiftest/simulation_class.py b/swiftest/simulation_class.py index 42682396b..688dc07fa 100644 --- a/swiftest/simulation_class.py +++ b/swiftest/simulation_class.py @@ -2265,7 +2265,7 @@ def add_solar_system_body(self, dsnew = init_cond.vec2xr(self.param,**kwargs) - dsnew = self._combine_and_fix_dsnew(dsnew) + dsnew = self._combine_and_fix_dsnew(dsnew,**kwargs) if dsnew['id'].max(dim='name') > 0 and dsnew['name'].size > 0: self.save(verbose=False) @@ -2421,7 +2421,8 @@ def add_body(self, J2: float | List[float] | npt.NDArray[np.float_] | None=None, J4: float | List[float] | npt.NDArray[np.float_] | None=None, c_lm: List[float] | List[npt.NDArray[np.float_]] | npt.NDArray[np.float_] | None = None, - rotphase: float | List[float] | npt.NDArray[np.float_] | None=None + rotphase: float | List[float] | npt.NDArray[np.float_] | None=None, + **kwargs: Any ) -> None: """ Adds a body (test particle or massive body) to the internal DataSet given a set up 6 vectors (orbital elements @@ -2601,13 +2602,14 @@ def input_to_clm_array(val, n): dsnew = init_cond.vec2xr(self.param, name=name, a=a, e=e, inc=inc, capom=capom, omega=omega, capm=capm, id=id, Gmass=Gmass, radius=radius, rhill=rhill, Ip=Ip, rh=rh, vh=vh,rot=rot, j2rp2=J2, j4rp4=J4, c_lm=c_lm, rotphase=rotphase, time=time) - dsnew = self._combine_and_fix_dsnew(dsnew) + dsnew = self._combine_and_fix_dsnew(dsnew,**kwargs) self.save(verbose=False) return def _combine_and_fix_dsnew(self, - dsnew: xr.Dataset + dsnew: xr.Dataset, + **kwargs: Any ) -> xr.Dataset: """ Combines the new Dataset with the old one. Also computes the values of ntp and npl and sets the proper types. @@ -2642,6 +2644,7 @@ def _combine_and_fix_dsnew(self, dsnew = io.fix_types(dsnew, ftype=np.float32) self.data = io.fix_types(self.data, ftype=np.float32) + self.set_central_body(**kwargs) def get_nvals(ds): if "name" in ds.dims: count_dim = "name" @@ -2661,6 +2664,7 @@ def get_nvals(ds): dsnew = get_nvals(dsnew) self.data = get_nvals(self.data) + self.data = self.data.sortby("id") self.data = io.reorder_dims(self.data) @@ -3064,6 +3068,7 @@ def save(self, if not self.simdir.exists(): self.simdir.mkdir(parents=True, exist_ok=True) + self.init_cond = self.data.copy(deep=True) if codename == "Swiftest": @@ -3184,3 +3189,62 @@ def clean(self): os.remove(f) return + def set_central_body(self, + align_to_rotation_pole: bool = False, + **kwargs: Any): + """ + Sets the central body to be the most massive body in the dataset. Cartesian position and velocity Cartesian coordinates are rotated If align_to_rotation_pole is True, the rotation pole is set to the z-axis. + + Parameters + ---------- + align_to_rotation_pole : bool, default False + If True, the rotation pole is set to the z-axis. + + Returns + ------- + None + + """ + + + if "Gmass" not in self.data: + warnings.warn("No bodies with Gmass values found in dataset. Cannot set central body.",stacklevel=2) + return + + cbid = self.data.Gmass.argmax().values[()] + if 'name' in self.data.dims: + cbidx = self.data.id.isel(name=cbid).values[()] + cbname = self.data.name.isel(name=cbid).values[()] + elif 'id' in self.data.dims: + cbidx = self.data.id.isel(id=cbid).values[()] + cbname = self.data.name.isel(id=cbid).values[()] + else: + raise ValueError("No 'name' or 'id' dimensions found in dataset.") + + if cbidx != 0: + if 'name' in self.data.dims: + if 0 in self.data.id.values: + name_0 = self.data.name.where(self.data.id == 0, drop=True).values[()] + self.data['id'].loc[dict(name=name_0)] = cbidx + self.data['id'].loc[dict(name=cbname)] = 0 + else: + if 0 in self.data.id.values: + self.data['id'].loc[dict(id=0)] = cbidx + self.data['id'].loc[dict(id=cbidx)] = 0 + + # Ensure that the central body is at the origin + if 'name' in self.data.dims: + cbda = self.data.sel(name=cbname) + else: + cbda = self.data.sel(id=cbidx) + + pos_skip = ['space','Ip','rot'] + for var in self.data.variables: + if var not in pos_skip: + self.data[var] -= cbda[var] + + if align_to_rotation_pole and 'rot' in cbda: + self.data = tool.rotate_to_vector(self.data,cbda.rot) + + + return \ No newline at end of file diff --git a/swiftest/tool.py b/swiftest/tool.py index 1f4b959b1..c275eff86 100644 --- a/swiftest/tool.py +++ b/swiftest/tool.py @@ -11,6 +11,8 @@ import numpy as np import xarray as xr +from scipy.spatial.transform import Rotation as R + def magnitude(ds,x): """ Computes the magnitude of a vector quantity from a Dataset. @@ -507,4 +509,71 @@ def xv2el_vec(mu, rvec, vvec): """ vecfunc = np.vectorize(xv2el_one, signature='(),(3),(3)->(),(),(),(),(),(),(),(),()') - return vecfunc(mu, rvec, vvec) \ No newline at end of file + return vecfunc(mu, rvec, vvec) + + +def rotate_to_vector(ds, new_pole, skip_vars=['space','Ip']): + """ + Rotates the coordinate system such that the z-axis is aligned with an input pole. The new pole is defined by the input vector. + This will change all variables in the Dataset that have the "space" dimension, except for those passed to the skip_vars parameter. + + Parameters + ---------- + ds : Xarray Dataset + Dataset containing the vector quantity + new_pole : (3) float array + New pole vector + skip_vars : list of str, optional + List of variable names to skip. The default is ['space','Ip']. + + Returns + ------- + ds : Xarray Dataset + Dataset with the new pole vector applied to all variables with the "space" dimension + """ + + if 'space' not in ds.dims: + print("No space dimension in Dataset") + return ds + + # Verify that the new pole is a 3-element array + if len(new_pole) != 3: + print("New pole must be a 3-element array") + return ds + + # Normalize the new pole vector to ensure it is a unit vector + pole_mag = np.linalg.norm(new_pole) + unit_pole = new_pole / pole_mag + + # Define the original and target vectors + target_vector = np.array([0, 0, 1]) # Rotate so that the z-axis is aligned with the new pole + original_vector = unit_pole.reshape(1, 3) + + # Use align_vectors to get the rotation that aligns the z-axis with Mars_rot + rotation, _ = R.align_vectors(target_vector, original_vector.reshape(1, 3)) + + # Define a function to apply the rotation, which will be used with apply_ufunc + def apply_rotation(vector, rotation): + return rotation.apply(vector) + + # Function to apply rotation to a DataArray + def rotate_dataarray(da, rotation): + return xr.apply_ufunc( + apply_rotation, + da, + kwargs={'rotation': rotation}, + input_core_dims=[['space']], + output_core_dims=[['space']], + vectorize=True, + dask='parallelized', + output_dtypes=[da.dtype] + ) + + # Loop through each variable in the dataset and apply the rotation if 'space' dimension is present + for var in ds.variables: + if 'space' in ds[var].dims and var not in skip_vars: + ds[var] = rotate_dataarray(ds[var], rotation) + + return ds + + \ No newline at end of file