diff --git a/python/swiftest/swiftest/init_cond.py b/python/swiftest/swiftest/init_cond.py index a5bf89c2b..ea28bcb8c 100644 --- a/python/swiftest/swiftest/init_cond.py +++ b/python/swiftest/swiftest/init_cond.py @@ -369,7 +369,7 @@ def vec2xr(param: Dict, vec_cb = np.expand_dims(vec_cb.T,axis=0) # Make way for the time dimension! ds_cb = xr.DataArray(vec_cb, dims=dims, coords={'time': [t], 'id': idvals[icb], 'vec': lab_cb}).to_dataset(dim='vec') else: - ds_cb = xr.Dataset() + ds_cb = None if ispl: lab_pl = plab.copy() vec_pl = np.vstack([vec[:,ipl], GMpl[ipl]]) @@ -383,16 +383,21 @@ def vec2xr(param: Dict, vec_pl = np.expand_dims(vec_pl.T,axis=0) # Make way for the time dimension! ds_pl = xr.DataArray(vec_pl, dims=dims, coords={'time': [t], 'id': idvals[ipl], 'vec': lab_pl}).to_dataset(dim='vec') else: - ds_pl = xr.Dataset() + ds_pl = None if istp: lab_tp = tlab.copy() vec_tp = np.expand_dims(vec[:,itp].T,axis=0) # Make way for the time dimension! ds_tp = xr.DataArray(vec_tp, dims=dims, coords={'time': [t], 'id': idvals[itp], 'vec': lab_tp}).to_dataset(dim='vec') particle_type[itp] = np.repeat("Test Particle",idvals[itp].size) else: - ds_tp = xr.Dataset() + ds_tp = None ds_info = xr.DataArray(np.vstack([namevals,particle_type]).T, dims=infodims, coords={'id': idvals, 'vec' : ["name", "particle_type"]}).to_dataset(dim='vec') - ds = xr.combine_by_coords([ds_cb, ds_pl, ds_tp, ds_info]) + ds = [d for d in [ds_cb, ds_pl, ds_tp] if d is not None] + if len(ds) > 1: + ds = xr.combine_by_coords(ds) + else: + ds = ds[0] + ds = xr.merge([ds_info,ds]) return ds \ No newline at end of file diff --git a/python/swiftest/swiftest/simulation_class.py b/python/swiftest/swiftest/simulation_class.py index 9ecbc1e51..8e1ac24e0 100644 --- a/python/swiftest/swiftest/simulation_class.py +++ b/python/swiftest/swiftest/simulation_class.py @@ -1622,10 +1622,15 @@ def add_solar_system_body(self, if all(np.isnan(J4)): J4 = None + t = self.param['TSTART'] - dsnew = init_cond.vec2xr(self.param,name,v1,v2,v3,v4,v5,v6,ephemeris_id,GMpl,Rpl,rhill,Ip1,Ip2,Ip3,rotx,roty,rotz,J2,J4) + dsnew = init_cond.vec2xr(self.param,name,v1,v2,v3,v4,v5,v6,ephemeris_id, + GMpl=GMpl, Rpl=Rpl, rhill=rhill, + Ip1=Ip1, Ip2=Ip2, Ip3=Ip3, + rotx=rotx, roty=roty, rotz=rotz, + J2=J2, J4=J4, t=t) - self.ds = xr.combine_by_coords([self.ds,dsnew]) + dsnew = self._combine_and_fix_dsnew(dsnew) return dsnew @@ -1831,19 +1836,50 @@ def input_to_array(val,t,n=None): t = self.param['TSTART'] - dsnew = init_cond.vec2xr(self.param, idvals, name, v1, v2, v3, v4, v5, v6, + dsnew = init_cond.vec2xr(self.param, name, v1, v2, v3, v4, v5, v6, idvals, GMpl=GMpl, Rpl=Rpl, rhill=rhill, Ip1=Ip1, Ip2=Ip2, Ip3=Ip3, rotx=rotx, roty=roty, rotz=rotz, - J2=J2, J4=J4, t=t) - if dsnew is not None: - self.ds = xr.combine_by_coords([self.ds, dsnew]) - self.ds['ntp'] = self.ds['id'].where(np.isnan(self.ds['Gmass'])).count(dim="id") - self.ds['npl'] = self.ds['id'].where(np.invert(np.isnan(self.ds['Gmass']))).count(dim="id") - 1 + J2=J2, J4=J4,t=t) + + dsnew = self._combine_and_fix_dsnew(dsnew) + + return dsnew + + def _combine_and_fix_dsnew(self,dsnew): + """ + Combines the new Dataset with the old one. Also computes the values of ntp and npl and sets the proper types. + Parameters + ---------- + dsnew : xarray Dataset + Dataset with new bodies + + Returns + ------- + dsnew : xarray Dataset + Updated Dataset with ntp, npl values and types fixed. + + """ + + self.ds = xr.combine_by_coords([self.ds, dsnew]) + + def get_nvals(ds): + if "Gmass" in dsnew: + ds['ntp'] = ds['id'].where(np.isnan(ds['Gmass'])).count(dim="id") + ds['npl'] = ds['id'].where(np.invert(np.isnan(ds['Gmass']))).count(dim="id") - 1 + else: + ds['ntp'] = ds['id'].count(dim="id") + ds['npl'] = xr.full_like(ds['ntp'],0) + return ds + + dsnew = get_nvals(dsnew) + self.ds = get_nvals(self.ds) if self.param['OUT_TYPE'] == "NETCDF_DOUBLE": + dsnew = io.fix_types(dsnew, ftype=np.float64) self.ds = io.fix_types(self.ds, ftype=np.float64) elif self.param['OUT_TYPE'] == "NETCDF_FLOAT": + dsnew = io.fix_types(dsnew, ftype=np.float32) self.ds = io.fix_types(self.ds, ftype=np.float32) return dsnew