diff --git a/python/swiftest/swiftest/simulation_class.py b/python/swiftest/swiftest/simulation_class.py index 1b6509976..a731603ce 100644 --- a/python/swiftest/swiftest/simulation_class.py +++ b/python/swiftest/swiftest/simulation_class.py @@ -246,7 +246,6 @@ def __init__(self, """ self.param = {} self.ds = xr.Dataset() - self.codename = codename self.verbose = verbose self.restart = restart @@ -257,7 +256,7 @@ def __init__(self, self.sim_dir = os.path.dirname(os.path.realpath(param_file)) if read_param: if os.path.exists(param_file): - self.read_param(param_file, codename=codename, verbose=self.verbose) + self.read_param(param_file, codename=codename.title(), verbose=self.verbose) else: print(f"{param_file} not found.") @@ -543,7 +542,9 @@ def set_parameter(self, **kwargs): """ # Setters returning parameter dictionary values - param_dict = self.set_unit_system(**kwargs) + param_dict = {} + param_dict.update(self.set_integrator(**kwargs)) + param_dict.update(self.set_unit_system(**kwargs)) param_dict.update(self.set_distance_range(**kwargs)) param_dict.update(self.set_feature(**kwargs)) param_dict.update(self.set_init_cond_files(**kwargs)) @@ -552,7 +553,6 @@ def set_parameter(self, **kwargs): # Non-returning setters self.set_ephemeris_date(**kwargs) - self.set_integrator(**kwargs) return param_dict @@ -569,10 +569,11 @@ def get_parameter(self, **kwargs): """ - self.get_integrator(**kwargs) # Getters returning parameter dictionary values - param_dict = self.get_simulation_time(**kwargs) + param_dict = {} + param_dict.update(self.get_integrator(**kwargs)) + param_dict.update(self.get_simulation_time(**kwargs)) param_dict.update(self.get_init_cond_files(**kwargs)) param_dict.update(self.get_output_files(**kwargs)) param_dict.update(self.get_distance_range(**kwargs)) @@ -611,18 +612,21 @@ def set_integrator(self, if integrator is None and codename is None: return + update_list = [] + if codename is not None: - valid_codename = ["swiftest", "swifter", "swift"] - if codename.lower() not in valid_codename: + valid_codename = ["Swiftest", "Swifter", "Swift"] + if codename.title() not in valid_codename: print(f"{codename} is not a valid codename. Valid options are ",",".join(valid_codename)) try: self.codename except: self.codename = valid_codename[0] else: - self.codename = codename.lower() + self.codename = codename.title() - self.param['! VERSION'] = f"{self.codename.title()} parameter input" + self.param['! VERSION'] = f"{self.codename} parameter input" + update_list.append("codename") if integrator is not None: valid_integrator = ["symba","rmvs","whm","helio"] @@ -634,10 +638,11 @@ def set_integrator(self, self.integrator = valid_integrator[0] else: self.integrator = integrator.lower() + update_list.append("integrator") - self.get_integrator("integrator", verbose) + integrator_dict = self.get_integrator(update_list, verbose) - return + return integrator_dict def get_integrator(self,arg_list: str | List[str] | None = None, verbose: bool | None = None, **kwargs: Any): """ @@ -658,29 +663,45 @@ def get_integrator(self,arg_list: str | List[str] | None = None, verbose: bool | Returns ------- - integrator: str, - The integrator name. + integrator_dict : dict + The subset of the dictionary containing the code name if codename is selected """ + valid_var = {"codename": "! VERSION"} + + valid_instance_vars = {"integrator": self.integrator, + "codename": self.codename} + try: self.integrator except: print(f"integrator is not set") - return + return {} try: self.codename except: print(f"codename is not set") - return + return {} - valid_arg = {"integrator": self.integrator, - "codename": self.codename} if not bool(kwargs) and arg_list is None: - arg_list = list(valid_arg.keys()) - ephemeris_date = self._get_instance_var(arg_list, valid_arg, verbose, **kwargs) - return + arg_list = list(valid_instance_vars.keys()) + integrator = self._get_instance_var(arg_list, valid_instance_vars, verbose, **kwargs) + + valid_arg, integrator_dict = self._get_valid_arg_list(arg_list, valid_var) + + if verbose is None: + verbose = self.verbose + + if verbose: + for arg in arg_list: + if arg in valid_arg: + key = valid_var[arg] + print(f"{arg:<{self._getter_column_width}} {integrator_dict[key]}") + elif arg in valid_instance_vars: + print(f"{arg:<{self._getter_column_width}} {valid_instance_vars[arg]}") + return integrator_dict def set_feature(self, close_encounter_check: bool | None = None, @@ -963,7 +984,7 @@ def ascii_file_input_error_msg(codename): else: init_cond_file_type = "NETCDF_DOUBLE" - if self.codename == "Swiftest": + if self.codename.title() == "Swiftest": init_cond_keys = ["CB", "PL", "TP"] else: init_cond_keys = ["PL", "TP"] @@ -1057,6 +1078,7 @@ def get_init_cond_files(self, arg_list: str | List[str] | None = None, verbose: if self.codename == "Swifter": three_file_args.remove("init_cond_file_name['CB']") + valid_var.pop("init_cond_file_name['CB']",None) # We have to figure out which initial conditions file model we are using (1 vs. 3 files) if arg_list is None: