Skip to content
This repository was archived by the owner on Aug 28, 2024. It is now read-only.

Commit

Permalink
Browse files Browse the repository at this point in the history
Improved the getters and setters for the initial conditions files
  • Loading branch information
daminton committed Nov 9, 2022
1 parent b4c6604 commit 99c7619
Showing 1 changed file with 161 additions and 75 deletions.
236 changes: 161 additions & 75 deletions python/swiftest/swiftest/simulation_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,12 +211,6 @@ def __init__(self,
self.ds = xr.Dataset()
self.param = {
'! VERSION': f"Swiftest parameter input",
'T0': 0.0,
'TSTART' : 0.0,
'TSTOP': 0.0,
'DT': 0.0,
'ISTEP_OUT': 1,
'ISTEP_DUMP': 1,
'CHK_QMIN_COORD': "HELIO",
'INTERACTION_LOOPS': interaction_loops,
'ENCOUNTER_CHECK': encounter_check_loops
Expand Down Expand Up @@ -436,18 +430,10 @@ def get_simulation_time(self, arg_list: str | List[str] | None = None):
else:
tstep_out = None

if arg_list is None:
arg_list = valid_var.keys()
elif type(arg_list) is str:
arg_list = [arg_list]
else:
arg_list = [k for k in arg_list if k in set(valid_var.keys())]

time_dict = {valid_var[k]: self.param[valid_var[k]] for k in arg_list}
valid_arg, time_dict = self._get_valid_arg_list(arg_list, valid_var)

if self.verbose:
print("\nSimulation time parameters:")
for arg in arg_list:
for arg in valid_arg:
key = valid_var[arg]
print(f"{arg:<32} {time_dict[key]} {units[arg]}")
if tstep_out is not None:
Expand Down Expand Up @@ -597,20 +583,10 @@ def get_feature(self, arg_list: str | List[str] | None = None):
"restart" : "RESTART"
}

if arg_list is None:
arg_list = valid_var.keys()
elif type(arg_list) is str:
arg_list = [arg_list]
else:
# Only allow arg_lists to be checked if they are valid. Otherwise ignore.
arg_list = [k for k in arg_list if k in set(valid_var.keys())]

# Extract the arg_list dictionary
feature_dict = {valid_var[feat]:self.param[valid_var[feat]] for feat in arg_list}
valid_arg, feature_dict = self._get_valid_arg_list(arg_list, valid_var)

if self.verbose:
print("\nSimulation feature parameters:")
for arg in arg_list:
for arg in valid_arg:
key = valid_var[arg]
print(f"{arg:<32} {feature_dict[key]}")

Expand Down Expand Up @@ -662,7 +638,9 @@ def set_init_cond_files(self,
"""

update_list = ["init_cond_file_name"]
update_list = []
if init_cond_file_name is not None:
update_list.append("init_cond_file_name")
if init_cond_file_type is not None:
update_list.append("init_cond_file_type")
if init_cond_format is not None:
Expand All @@ -678,6 +656,18 @@ def ascii_file_input_error_msg(codename):
print('}')
return

if init_cond_format is None:
if "IN_FORM" in self.param:
init_cond_format = self.param['IN_FORM']
else:
init_cond_format = "EL"

if init_cond_file_type is None:
if "IN_TYPE" in self.param:
init_cond_file_type = self.param['IN_TYPE']
else:
init_cond_file_type = "NETCDF_DOUBLE"

if self.codename == "Swiftest":
init_cond_keys = ["CB", "PL", "TP"]
else:
Expand All @@ -689,8 +679,18 @@ def ascii_file_input_error_msg(codename):
print(f"{init_cond_format} is not supported by {self.codename}. Using XV instead")
init_cond_format = "XV"

self.param["IN_TYPE"] = init_cond_file_type
self.param["IN_FORM"] = init_cond_format

valid_formats={"EL", "XV"}
if init_cond_format not in valid_formats:
print(f"{init_cond_format} is not a valid input format")
else:
self.param['IN_FORM'] = init_cond_format

valid_types = {"NETCDF_DOUBLE", "NETCDF_FLOAT", "ASCII"}
if init_cond_file_type not in valid_types:
print(f"{init_cond_file_type} is not a valid input type")
else:
self.param['IN_TYPE'] = init_cond_file_type

if init_cond_file_type == "ASCII":
if init_cond_file_name is None:
Expand Down Expand Up @@ -752,49 +752,55 @@ def get_init_cond_files(self, arg_list: str | List[str] | None = None):
valid_var = {"init_cond_file_type": "IN_TYPE",
"init_cond_format": "IN_FORM",
"init_cond_file_name" : "NC_IN",
"init_cond_file_name (cb)" : "CB_IN",
"init_cond_file_name (pl)" : "PL_IN",
"init_cond_file_name (tp)" : "TP_IN",
"init_cond_file_name['CB']" : "CB_IN",
"init_cond_file_name['PL']" : "PL_IN",
"init_cond_file_name['TP']" : "TP_IN",
}

three_file_args = ["init_cond_file_name['CB']",
"init_cond_file_name['PL']",
"init_cond_file_name['TP']"]

if self.codename == "Swifter":
three_file_args.remove("init_cond_file_name['CB']")

# We have to figure out which initial conditions file model we are using (1 vs. 3 files)
if arg_list is None:
arg_list = list(valid_var.keys())
valid_arg = None
else:
valid_arg = arg_list.copy()

if valid_arg is None:
valid_arg = list(valid_var.keys())
elif type(arg_list) is str:
arg_list = [arg_list]
valid_arg = [arg_list]
else:
arg_list = [k for k in arg_list if k in set(valid_var.keys())]
# Only allow arg_lists to be checked if they are valid. Otherwise ignore.
valid_arg = [k for k in arg_list if k in list(valid_var.keys())]

if "init_cond_file_name" in arg_list:
# Figure out which input file model we need to use
if "init_cond_file_name" in valid_arg:
if self.param["IN_TYPE"] == "ASCII":
arg_list.remove("init_cond_file_name")
if "init_cond_file_name (cb)" not in arg_list:
arg_list.append("init_cond_file_name (cb)")
if "init_cond_file_name (pl)" not in arg_list:
arg_list.append("init_cond_file_name (pl)")
if "init_cond_file_name (tp)" not in arg_list:
arg_list.append("init_cond_file_name (tp)")
valid_arg.remove("init_cond_file_name")
for key in three_file_args:
if key not in valid_arg:
valid_arg.append(key)
else:
if "init_cond_file_name (cb)" in arg_list:
arg_list.remove("init_cond_file_name (cb)")
if "init_cond_file_name (pl)" in arg_list:
arg_list.remove("init_cond_file_name (pl)")
if "init_cond_file_name (tp)" in arg_list:
arg_list.remove("init_cond_file_name (tp)")
for key in three_file_args:
if key in valid_arg:
valid_arg.remove(key)

init_cond_file_dict = {valid_var[k]: self.param[valid_var[k]] for k in arg_list}
valid_arg, init_cond_file_dict = self._get_valid_arg_list(valid_arg, valid_var)

if self.verbose:
print("\nInitial condition file parameters:")
for arg in arg_list:
for arg in valid_arg:
key = valid_var[arg]
print(f"{arg:<32} {init_cond_file_dict[key]}")

return init_cond_file_dict




def set_output_files(self,
output_file_type: Literal[ "NETCDF_DOUBLE", "NETCDF_FLOAT", "REAL4", "REAL8", "XDR4", "XDR8"] | None = None,
output_file_name: os.PathLike | str | None = None,
Expand Down Expand Up @@ -916,18 +922,10 @@ def get_output_files(self, arg_list: str | List[str] | None = None):
"output_format": "OUT_FORM",
}

if arg_list is None:
arg_list = valid_var.keys()
elif type(arg_list) is str:
arg_list = [arg_list]
else:
arg_list = [k for k in arg_list if k in set(valid_var.keys())]

output_file_dict = {valid_var[k]: self.param[valid_var[k]] for k in arg_list}
valid_arg, output_file_dict = self._get_valid_arg_list(arg_list, valid_var)

if self.verbose:
print("\nOutput file parameters:")
for arg in arg_list:
for arg in valid_arg:
key = valid_var[arg]
print(f"{arg:<32} {output_file_dict[key]}")

Expand Down Expand Up @@ -1159,18 +1157,10 @@ def get_unit_system(self, arg_list: str | List[str] | None = None):
"TU" : f"s / {TU_name}"
}


if arg_list is None:
arg_list = valid_var.keys()
elif type(arg_list) is str:
arg_list = [arg_list]
else:
arg_list = [k for k in arg_list if k in set(valid_var.keys())]

unit_dict = {valid_var[k]: self.param[valid_var[k]] for k in arg_list}
valid_arg, unit_dict = self._get_valid_arg_list(arg_list, valid_var)

if self.verbose:
for arg in arg_list:
for arg in valid_arg:
key = valid_var[arg]
print(f"{arg:<32} {unit_dict[key]} {units[arg]}")

Expand Down Expand Up @@ -1262,6 +1252,102 @@ def set_distance_range(self,

return

def _get_valid_arg_list(self, arg_list: str | List[str] | None = None, valid_var: Dict | None = None):
"""
Internal function for getters that extracts subset of arguments that is contained in the dictionary of valid
argument/parameter variable pairs.
Parameters
----------
arg_list : str | List[str], optional
A single string or list of strings containing the Simulation argument. If none are supplied,
then it will create the arg_list out of all keys in the valid_var dictionary.
valid_var : valid_var: Dict
A dictionary where the key is the argument name and the value is the equivalent key in the Simulation
parameter dictionary (i.e. the left-hand column of a param.in file)
Returns
-------
valid_arg : [str]
The list of valid arguments that match the keys in valid_var
param : dict
A dictionary that is the subset of the Simulation parameter dictionary corresponding to the arguments listed
in arg_list.
"""


if arg_list is None:
valid_arg = None
else:
valid_arg = arg_list.copy()

if valid_arg is None:
valid_arg = list(valid_var.keys())
elif type(arg_list) is str:
valid_arg = [arg_list]
else:
# Only allow arg_lists to be checked if they are valid. Otherwise ignore.
valid_arg = [k for k in arg_list if k in list(valid_var.keys())]

# Extract the arg_list dictionary
param = {valid_var[feat]:self.param[valid_var[feat]] for feat in valid_arg}

return valid_arg, param

def get_distance_range(self, arg_list: str | List[str] | None = None):
"""
Returns a subset of the parameter dictionary containing the current values of the distance range parameters.
If the verbose option is set in the Simulation object, then it will also print the values.
Parameters
----------
arg_list: str | List[str], optional
A single string or list of strings containing the names of the features to extract. Default is all of:
["rmin", "rmax"]
Returns
-------
range_dict : dict
A dictionary containing the requested parameters.
"""

valid_var = {"rmin": "CHK_RMIN",
"rmax": "CHK_RMAX",
"qmin": "CHK_QMIN",
"qminR" : "CHK_QMIN_RANGE"
}

units = {"rmin" : self.DU_name,
"rmax" : self.DU_name,
"qmin" : self.DU_name,
"qminR" : self.DU_name,
}

if "rmin" in arg_list:
arg_list.append("qmin")
if "rmax" in arg_list or "rmin" in arg_list:
arg_list.append("qminR")

valid_arg, range_dict = self._get_valid_arg_list(arg_list, valid_var)

if self.verbose:
if "rmin" in valid_arg:
key = valid_arg["rmin"]
print(f"{'rmin':<32} {range_dict[key]} {units['rmin']}")
key = valid_arg["qmin"]
print(f"{'':<32} {range_dict[key]} {units['qmin']}")
if "rmax" in valid_arg:
key = valid_arg["rmax"]
print(f"{'rmax':<32} {range_dict[key]} {units['rmax']}")
if "rmax" in valid_arg or "rmin" in valid_arg:
key = valid_arg["qminR"]
print(f"{'':<32} {range_dict[key]} {units['qminR']}")


return range_dict


def add(self, plname, date=date.today().isoformat(), idval=None):
Expand Down

0 comments on commit 99c7619

Please sign in to comment.