From 0aedbe63339405d15115b2aa77c301cf3101a7d9 Mon Sep 17 00:00:00 2001 From: David A Minton Date: Wed, 16 Nov 2022 16:53:58 -0500 Subject: [PATCH] Refactored the variable ds to be data throughout the Simulation class definition. This makes it more obvious what it is and more intuitive to use (IMHO) --- .../Basic_Simulation/initial_conditions.py | 9 +--- .../Basic_Simulation/run_simulation.ipynb | 2 +- python/swiftest/swiftest/simulation_class.py | 48 +++++++++---------- 3 files changed, 26 insertions(+), 33 deletions(-) diff --git a/examples/Basic_Simulation/initial_conditions.py b/examples/Basic_Simulation/initial_conditions.py index c14cdd931..78900ce1f 100644 --- a/examples/Basic_Simulation/initial_conditions.py +++ b/examples/Basic_Simulation/initial_conditions.py @@ -15,14 +15,7 @@ Returns ------- -param.in : ASCII text file - Swiftest parameter input file. -pl.in : ASCII text file - Swiftest massive body input file. -tp.in : ASCII text file - Swiftest test particle input file. -cb.in : ASCII text file - Swiftest central body input file. +Updates sim.data with the simulation data """ import swiftest diff --git a/examples/Basic_Simulation/run_simulation.ipynb b/examples/Basic_Simulation/run_simulation.ipynb index fa47bcd56..8f2ab51b3 100644 --- a/examples/Basic_Simulation/run_simulation.ipynb +++ b/examples/Basic_Simulation/run_simulation.ipynb @@ -123,7 +123,7 @@ } ], "source": [ - "sim.ds.where(sim.ds['particle_type'] == 'Massive Body',drop=True)['a'].plot(hue=\"name\")" + "sim.data.where(sim.data['particle_type'] == 'Massive Body',drop=True)['a'].plot(hue=\"name\")" ] }, { diff --git a/python/swiftest/swiftest/simulation_class.py b/python/swiftest/swiftest/simulation_class.py index 5df0c3edb..7c17c3d32 100644 --- a/python/swiftest/swiftest/simulation_class.py +++ b/python/swiftest/swiftest/simulation_class.py @@ -81,7 +81,7 @@ def __init__(self,read_param: bool = True, **kwargs: Any): The stopping time for a simulation. `tstop` must be greater than `tstart`. Parameter input file equivalent: `TSTOP` dt : float, optional - The step size of the simulation. `dt` must be less than or equal to `tstop-dstart`. + The step size of the simulation. `dt` must be less than or equal to `tstop-tstart`. Parameter input file equivalent: `DT` istep_out : int, optional The number of time steps between outputs to file. *Note*: only `istep_out` or `toutput` can be set. @@ -276,7 +276,7 @@ def __init__(self,read_param: bool = True, **kwargs: Any): self._getter_column_width = '32' self.param = {} - self.ds = xr.Dataset() + self.data = xr.Dataset() # Parameters are set in reverse priority order. First the defaults, then values from a pre-existing input file, # then using the arguments passed via **kwargs. @@ -1959,7 +1959,7 @@ def add_solar_system_body(self, >*Note.* Currently only the JPL Horizons ephemeris is implemented, so this is ignored. Returns ------- - ds : Xarray dataset with body or bodies added. + data : Xarray dataset with body or bodies added. """ if type(name) is str: @@ -2195,7 +2195,7 @@ def add_body(self, Adds a body (test particle or massive body) to the internal DataSet given a set up 6 vectors (orbital elements or cartesian state vectors, depending on the value of self.param). Input all angles in degress. - This method will update self.ds with the new body or bodies added to the existing Dataset. + This method will update self.data with the new body or bodies added to the existing Dataset. Parameters ---------- @@ -2228,7 +2228,7 @@ def add_body(self, Returns ------- - ds : Xarray Dataset + data : Xarray Dataset Dasaset containing the body or bodies that were added """ @@ -2276,16 +2276,16 @@ def input_to_array(val,t,n=None): J2 = input_to_array(J2,"f",nbodies) J4 = input_to_array(J4,"f",nbodies) - if len(self.ds) == 0: + if len(self.data) == 0: maxid = -1 else: - maxid = self.ds.id.max().values[()] + maxid = self.data.id.max().values[()] if idvals is None: idvals = np.arange(start=maxid+1,stop=maxid+1+nbodies,dtype=int) - if len(self.ds) > 0: - dup_id = np.in1d(idvals,self.ds.id) + if len(self.data) > 0: + dup_id = np.in1d(idvals, self.data.id) if any(dup_id): raise ValueError(f"Duplicate ids detected: ", *idvals[dup_id]) @@ -2317,7 +2317,7 @@ def _combine_and_fix_dsnew(self,dsnew): """ - self.ds = xr.combine_by_coords([self.ds, dsnew]) + self.data = xr.combine_by_coords([self.data, dsnew]) def get_nvals(ds): if "Gmass" in ds: @@ -2329,14 +2329,14 @@ def get_nvals(ds): return ds dsnew = get_nvals(dsnew) - self.ds = get_nvals(self.ds) + self.data = get_nvals(self.data) if self.param['OUT_TYPE'] == "NETCDF_DOUBLE": dsnew = io.fix_types(dsnew, ftype=np.float64) - self.ds = io.fix_types(self.ds, ftype=np.float64) + self.data = io.fix_types(self.data, ftype=np.float64) elif self.param['OUT_TYPE'] == "NETCDF_FLOAT": dsnew = io.fix_types(dsnew, ftype=np.float32) - self.ds = io.fix_types(self.ds, ftype=np.float32) + self.data = io.fix_types(self.data, ftype=np.float32) return dsnew @@ -2481,7 +2481,7 @@ def bin2xr(self): Returns ------- - self.ds : xarray dataset + self.data : xarray dataset """ # Make a temporary copy of the parameter dictionary so we can supply the absolute path of the binary file @@ -2490,11 +2490,11 @@ def bin2xr(self): param_tmp = self.param.copy() param_tmp['BIN_OUT'] = os.path.join(self.sim_dir, self.param['BIN_OUT']) if self.codename == "Swiftest": - self.ds = io.swiftest2xr(param_tmp, verbose=self.verbose) - if self.verbose: print('Swiftest simulation data stored as xarray DataSet .ds') + self.data = io.swiftest2xr(param_tmp, verbose=self.verbose) + if self.verbose: print('Swiftest simulation data stored as xarray DataSet .data') elif self.codename == "Swifter": - self.ds = io.swifter2xr(param_tmp, verbose=self.verbose) - if self.verbose: print('Swifter simulation data stored as xarray DataSet .ds') + self.data = io.swifter2xr(param_tmp, verbose=self.verbose) + if self.verbose: print('Swifter simulation data stored as xarray DataSet .data') elif self.codename == "Swift": warnings.warn("Reading Swift simulation data is not implemented yet") else: @@ -2514,7 +2514,7 @@ def follow(self, codestyle="Swifter"): ------- fol : xarray dataset """ - if self.ds is None: + if self.data is None: self.bin2xr() if codestyle == "Swift": try: @@ -2532,7 +2532,7 @@ def follow(self, codestyle="Swifter"): warnings.warn('No follow.in file found') ifol = None nskp = None - fol = tool.follow_swift(self.ds, ifol=ifol, nskp=nskp) + fol = tool.follow_swift(self.data, ifol=ifol, nskp=nskp) else: fol = None @@ -2574,14 +2574,14 @@ def save(self, param = self.param if codename == "Swiftest": - io.swiftest_xr2infile(ds=self.ds, param=param, in_type=self.param['IN_TYPE'], framenum=framenum) + io.swiftest_xr2infile(ds=self.data, param=param, in_type=self.param['IN_TYPE'], framenum=framenum) self.write_param(param_file=param_file) elif codename == "Swifter": if codename == "Swiftest": swifter_param = io.swiftest2swifter_param(param) else: swifter_param = param - io.swifter_xr2infile(self.ds, swifter_param, framenum) + io.swifter_xr2infile(self.data, swifter_param, framenum) self.write_param(param_file, param=swifter_param) else: warnings.warn(f'Saving to {codename} not supported') @@ -2623,7 +2623,7 @@ def initial_conditions_from_bin(self, framenum=-1, new_param=None, new_param_fil if codename == "Swiftest": if restart: - new_param['T0'] = self.ds.time.values[framenum] + new_param['T0'] = self.data.time.values[framenum] if self.param['OUT_TYPE'] == 'NETCDF_DOUBLE': new_param['IN_TYPE'] = 'NETCDF_DOUBLE' elif self.param['OUT_TYPE'] == 'NETCDF_FLOAT': @@ -2646,7 +2646,7 @@ def initial_conditions_from_bin(self, framenum=-1, new_param=None, new_param_fil new_param.pop('TP_IN', None) new_param.pop('CB_IN', None) print(f"Extracting data from dataset at time frame number {framenum} and saving it to {new_param['NC_IN']}") - frame = io.swiftest_xr2infile(self.ds, self.param, infile_name=new_param['NC_IN'], framenum=framenum) + frame = io.swiftest_xr2infile(self.data, self.param, infile_name=new_param['NC_IN'], framenum=framenum) print(f"Saving parameter configuration file to {new_param_file}") self.write_param(new_param_file, param=new_param)