Skip to content
This repository was archived by the owner on Aug 28, 2024. It is now read-only.

Commit

Permalink
Browse files Browse the repository at this point in the history
Improvements to API handling of integrators and codes.
  • Loading branch information
daminton committed Nov 11, 2022
1 parent f36cb5a commit e8e3daf
Showing 1 changed file with 44 additions and 22 deletions.
66 changes: 44 additions & 22 deletions python/swiftest/swiftest/simulation_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,6 @@ def __init__(self,
"""
self.param = {}
self.ds = xr.Dataset()
self.codename = codename
self.verbose = verbose
self.restart = restart

Expand All @@ -257,7 +256,7 @@ def __init__(self,
self.sim_dir = os.path.dirname(os.path.realpath(param_file))
if read_param:
if os.path.exists(param_file):
self.read_param(param_file, codename=codename, verbose=self.verbose)
self.read_param(param_file, codename=codename.title(), verbose=self.verbose)
else:
print(f"{param_file} not found.")

Expand Down Expand Up @@ -543,7 +542,9 @@ def set_parameter(self, **kwargs):
"""

# Setters returning parameter dictionary values
param_dict = self.set_unit_system(**kwargs)
param_dict = {}
param_dict.update(self.set_integrator(**kwargs))
param_dict.update(self.set_unit_system(**kwargs))
param_dict.update(self.set_distance_range(**kwargs))
param_dict.update(self.set_feature(**kwargs))
param_dict.update(self.set_init_cond_files(**kwargs))
Expand All @@ -552,7 +553,6 @@ def set_parameter(self, **kwargs):

# Non-returning setters
self.set_ephemeris_date(**kwargs)
self.set_integrator(**kwargs)

return param_dict

Expand All @@ -569,10 +569,11 @@ def get_parameter(self, **kwargs):
"""

self.get_integrator(**kwargs)

# Getters returning parameter dictionary values
param_dict = self.get_simulation_time(**kwargs)
param_dict = {}
param_dict.update(self.get_integrator(**kwargs))
param_dict.update(self.get_simulation_time(**kwargs))
param_dict.update(self.get_init_cond_files(**kwargs))
param_dict.update(self.get_output_files(**kwargs))
param_dict.update(self.get_distance_range(**kwargs))
Expand Down Expand Up @@ -611,18 +612,21 @@ def set_integrator(self,
if integrator is None and codename is None:
return

update_list = []

if codename is not None:
valid_codename = ["swiftest", "swifter", "swift"]
if codename.lower() not in valid_codename:
valid_codename = ["Swiftest", "Swifter", "Swift"]
if codename.title() not in valid_codename:
print(f"{codename} is not a valid codename. Valid options are ",",".join(valid_codename))
try:
self.codename
except:
self.codename = valid_codename[0]
else:
self.codename = codename.lower()
self.codename = codename.title()

self.param['! VERSION'] = f"{self.codename.title()} parameter input"
self.param['! VERSION'] = f"{self.codename} parameter input"
update_list.append("codename")

if integrator is not None:
valid_integrator = ["symba","rmvs","whm","helio"]
Expand All @@ -634,10 +638,11 @@ def set_integrator(self,
self.integrator = valid_integrator[0]
else:
self.integrator = integrator.lower()
update_list.append("integrator")

self.get_integrator("integrator", verbose)
integrator_dict = self.get_integrator(update_list, verbose)

return
return integrator_dict

def get_integrator(self,arg_list: str | List[str] | None = None, verbose: bool | None = None, **kwargs: Any):
"""
Expand All @@ -658,29 +663,45 @@ def get_integrator(self,arg_list: str | List[str] | None = None, verbose: bool |
Returns
-------
integrator: str,
The integrator name.
integrator_dict : dict
The subset of the dictionary containing the code name if codename is selected
"""

valid_var = {"codename": "! VERSION"}

valid_instance_vars = {"integrator": self.integrator,
"codename": self.codename}

try:
self.integrator
except:
print(f"integrator is not set")
return
return {}

try:
self.codename
except:
print(f"codename is not set")
return
return {}

valid_arg = {"integrator": self.integrator,
"codename": self.codename}

if not bool(kwargs) and arg_list is None:
arg_list = list(valid_arg.keys())
ephemeris_date = self._get_instance_var(arg_list, valid_arg, verbose, **kwargs)
return
arg_list = list(valid_instance_vars.keys())
integrator = self._get_instance_var(arg_list, valid_instance_vars, verbose, **kwargs)

valid_arg, integrator_dict = self._get_valid_arg_list(arg_list, valid_var)

if verbose is None:
verbose = self.verbose

if verbose:
for arg in arg_list:
if arg in valid_arg:
key = valid_var[arg]
print(f"{arg:<{self._getter_column_width}} {integrator_dict[key]}")
elif arg in valid_instance_vars:
print(f"{arg:<{self._getter_column_width}} {valid_instance_vars[arg]}")
return integrator_dict

def set_feature(self,
close_encounter_check: bool | None = None,
Expand Down Expand Up @@ -963,7 +984,7 @@ def ascii_file_input_error_msg(codename):
else:
init_cond_file_type = "NETCDF_DOUBLE"

if self.codename == "Swiftest":
if self.codename.title() == "Swiftest":
init_cond_keys = ["CB", "PL", "TP"]
else:
init_cond_keys = ["PL", "TP"]
Expand Down Expand Up @@ -1057,6 +1078,7 @@ def get_init_cond_files(self, arg_list: str | List[str] | None = None, verbose:

if self.codename == "Swifter":
three_file_args.remove("init_cond_file_name['CB']")
valid_var.pop("init_cond_file_name['CB']",None)

# We have to figure out which initial conditions file model we are using (1 vs. 3 files)
if arg_list is None:
Expand Down

0 comments on commit e8e3daf

Please sign in to comment.