Skip to content
This repository was archived by the owner on Aug 28, 2024. It is now read-only.

Commit

Permalink
Refactored the variable ds to be data throughout the Simulation class…
Browse files Browse the repository at this point in the history
… definition. This makes it more obvious what it is and more intuitive to use (IMHO)
  • Loading branch information
daminton committed Nov 16, 2022
1 parent ed5d09d commit 0aedbe6
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 33 deletions.
9 changes: 1 addition & 8 deletions examples/Basic_Simulation/initial_conditions.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,7 @@
Returns
-------
param.in : ASCII text file
Swiftest parameter input file.
pl.in : ASCII text file
Swiftest massive body input file.
tp.in : ASCII text file
Swiftest test particle input file.
cb.in : ASCII text file
Swiftest central body input file.
Updates sim.data with the simulation data
"""

import swiftest
Expand Down
2 changes: 1 addition & 1 deletion examples/Basic_Simulation/run_simulation.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@
}
],
"source": [
"sim.ds.where(sim.ds['particle_type'] == 'Massive Body',drop=True)['a'].plot(hue=\"name\")"
"sim.data.where(sim.data['particle_type'] == 'Massive Body',drop=True)['a'].plot(hue=\"name\")"
]
},
{
Expand Down
48 changes: 24 additions & 24 deletions python/swiftest/swiftest/simulation_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def __init__(self,read_param: bool = True, **kwargs: Any):
The stopping time for a simulation. `tstop` must be greater than `tstart`.
Parameter input file equivalent: `TSTOP`
dt : float, optional
The step size of the simulation. `dt` must be less than or equal to `tstop-dstart`.
The step size of the simulation. `dt` must be less than or equal to `tstop-tstart`.
Parameter input file equivalent: `DT`
istep_out : int, optional
The number of time steps between outputs to file. *Note*: only `istep_out` or `toutput` can be set.
Expand Down Expand Up @@ -276,7 +276,7 @@ def __init__(self,read_param: bool = True, **kwargs: Any):
self._getter_column_width = '32'

self.param = {}
self.ds = xr.Dataset()
self.data = xr.Dataset()

# Parameters are set in reverse priority order. First the defaults, then values from a pre-existing input file,
# then using the arguments passed via **kwargs.
Expand Down Expand Up @@ -1959,7 +1959,7 @@ def add_solar_system_body(self,
>*Note.* Currently only the JPL Horizons ephemeris is implemented, so this is ignored.
Returns
-------
ds : Xarray dataset with body or bodies added.
data : Xarray dataset with body or bodies added.
"""

if type(name) is str:
Expand Down Expand Up @@ -2195,7 +2195,7 @@ def add_body(self,
Adds a body (test particle or massive body) to the internal DataSet given a set up 6 vectors (orbital elements
or cartesian state vectors, depending on the value of self.param). Input all angles in degress.
This method will update self.ds with the new body or bodies added to the existing Dataset.
This method will update self.data with the new body or bodies added to the existing Dataset.
Parameters
----------
Expand Down Expand Up @@ -2228,7 +2228,7 @@ def add_body(self,
Returns
-------
ds : Xarray Dataset
data : Xarray Dataset
Dasaset containing the body or bodies that were added
"""
Expand Down Expand Up @@ -2276,16 +2276,16 @@ def input_to_array(val,t,n=None):
J2 = input_to_array(J2,"f",nbodies)
J4 = input_to_array(J4,"f",nbodies)

if len(self.ds) == 0:
if len(self.data) == 0:
maxid = -1
else:
maxid = self.ds.id.max().values[()]
maxid = self.data.id.max().values[()]

if idvals is None:
idvals = np.arange(start=maxid+1,stop=maxid+1+nbodies,dtype=int)

if len(self.ds) > 0:
dup_id = np.in1d(idvals,self.ds.id)
if len(self.data) > 0:
dup_id = np.in1d(idvals, self.data.id)
if any(dup_id):
raise ValueError(f"Duplicate ids detected: ", *idvals[dup_id])

Expand Down Expand Up @@ -2317,7 +2317,7 @@ def _combine_and_fix_dsnew(self,dsnew):
"""

self.ds = xr.combine_by_coords([self.ds, dsnew])
self.data = xr.combine_by_coords([self.data, dsnew])

def get_nvals(ds):
if "Gmass" in ds:
Expand All @@ -2329,14 +2329,14 @@ def get_nvals(ds):
return ds

dsnew = get_nvals(dsnew)
self.ds = get_nvals(self.ds)
self.data = get_nvals(self.data)

if self.param['OUT_TYPE'] == "NETCDF_DOUBLE":
dsnew = io.fix_types(dsnew, ftype=np.float64)
self.ds = io.fix_types(self.ds, ftype=np.float64)
self.data = io.fix_types(self.data, ftype=np.float64)
elif self.param['OUT_TYPE'] == "NETCDF_FLOAT":
dsnew = io.fix_types(dsnew, ftype=np.float32)
self.ds = io.fix_types(self.ds, ftype=np.float32)
self.data = io.fix_types(self.data, ftype=np.float32)

return dsnew

Expand Down Expand Up @@ -2481,7 +2481,7 @@ def bin2xr(self):
Returns
-------
self.ds : xarray dataset
self.data : xarray dataset
"""

# Make a temporary copy of the parameter dictionary so we can supply the absolute path of the binary file
Expand All @@ -2490,11 +2490,11 @@ def bin2xr(self):
param_tmp = self.param.copy()
param_tmp['BIN_OUT'] = os.path.join(self.sim_dir, self.param['BIN_OUT'])
if self.codename == "Swiftest":
self.ds = io.swiftest2xr(param_tmp, verbose=self.verbose)
if self.verbose: print('Swiftest simulation data stored as xarray DataSet .ds')
self.data = io.swiftest2xr(param_tmp, verbose=self.verbose)
if self.verbose: print('Swiftest simulation data stored as xarray DataSet .data')
elif self.codename == "Swifter":
self.ds = io.swifter2xr(param_tmp, verbose=self.verbose)
if self.verbose: print('Swifter simulation data stored as xarray DataSet .ds')
self.data = io.swifter2xr(param_tmp, verbose=self.verbose)
if self.verbose: print('Swifter simulation data stored as xarray DataSet .data')
elif self.codename == "Swift":
warnings.warn("Reading Swift simulation data is not implemented yet")
else:
Expand All @@ -2514,7 +2514,7 @@ def follow(self, codestyle="Swifter"):
-------
fol : xarray dataset
"""
if self.ds is None:
if self.data is None:
self.bin2xr()
if codestyle == "Swift":
try:
Expand All @@ -2532,7 +2532,7 @@ def follow(self, codestyle="Swifter"):
warnings.warn('No follow.in file found')
ifol = None
nskp = None
fol = tool.follow_swift(self.ds, ifol=ifol, nskp=nskp)
fol = tool.follow_swift(self.data, ifol=ifol, nskp=nskp)
else:
fol = None

Expand Down Expand Up @@ -2574,14 +2574,14 @@ def save(self,
param = self.param

if codename == "Swiftest":
io.swiftest_xr2infile(ds=self.ds, param=param, in_type=self.param['IN_TYPE'], framenum=framenum)
io.swiftest_xr2infile(ds=self.data, param=param, in_type=self.param['IN_TYPE'], framenum=framenum)
self.write_param(param_file=param_file)
elif codename == "Swifter":
if codename == "Swiftest":
swifter_param = io.swiftest2swifter_param(param)
else:
swifter_param = param
io.swifter_xr2infile(self.ds, swifter_param, framenum)
io.swifter_xr2infile(self.data, swifter_param, framenum)
self.write_param(param_file, param=swifter_param)
else:
warnings.warn(f'Saving to {codename} not supported')
Expand Down Expand Up @@ -2623,7 +2623,7 @@ def initial_conditions_from_bin(self, framenum=-1, new_param=None, new_param_fil

if codename == "Swiftest":
if restart:
new_param['T0'] = self.ds.time.values[framenum]
new_param['T0'] = self.data.time.values[framenum]
if self.param['OUT_TYPE'] == 'NETCDF_DOUBLE':
new_param['IN_TYPE'] = 'NETCDF_DOUBLE'
elif self.param['OUT_TYPE'] == 'NETCDF_FLOAT':
Expand All @@ -2646,7 +2646,7 @@ def initial_conditions_from_bin(self, framenum=-1, new_param=None, new_param_fil
new_param.pop('TP_IN', None)
new_param.pop('CB_IN', None)
print(f"Extracting data from dataset at time frame number {framenum} and saving it to {new_param['NC_IN']}")
frame = io.swiftest_xr2infile(self.ds, self.param, infile_name=new_param['NC_IN'], framenum=framenum)
frame = io.swiftest_xr2infile(self.data, self.param, infile_name=new_param['NC_IN'], framenum=framenum)
print(f"Saving parameter configuration file to {new_param_file}")
self.write_param(new_param_file, param=new_param)

Expand Down

0 comments on commit 0aedbe6

Please sign in to comment.