diff --git a/python/swiftest/swiftest/io.py b/python/swiftest/swiftest/io.py index 5331e50fd..beb5ebab9 100644 --- a/python/swiftest/swiftest/io.py +++ b/python/swiftest/swiftest/io.py @@ -33,7 +33,7 @@ "TSTART", "DUMP_CADENCE", "ENCOUNTER_SAVE", - "FRAGMENTATION_SAVE") + "COLLISION_SAVE") @@ -55,14 +55,14 @@ float_param = ["T0", "TSTART", "TSTOP", "DT", "CHK_RMIN", "CHK_RMAX", "CHK_EJECT", "CHK_QMIN", "DU2M", "MU2KG", "TU2S", "MIN_GMFRAG", "GMTINY"] -upper_str_param = ["OUT_TYPE","OUT_FORM","OUT_STAT","IN_TYPE","IN_FORM","ENCOUNTER_SAVE","FRAGMENTATION_SAVE", "CHK_QMIN_COORD"] +upper_str_param = ["OUT_TYPE","OUT_FORM","OUT_STAT","IN_TYPE","IN_FORM","ENCOUNTER_SAVE","COLLISION_SAVE", "CHK_QMIN_COORD"] lower_str_param = ["NC_IN", "PL_IN", "TP_IN", "CB_IN", "CHK_QMIN_RANGE"] param_keys = ['! VERSION'] + int_param + float_param + upper_str_param + lower_str_param+ bool_param # This defines Xarray Dataset variables that are strings, which must be processed due to quirks in how NetCDF-Fortran # handles strings differently than Python's Xarray. -string_varnames = ["name", "particle_type", "status", "origin_type"] +string_varnames = ["name", "particle_type", "status", "origin_type", "stage", "regime"] char_varnames = ["space"] int_varnames = ["id", "ntp", "npl", "nplm", "discard_body_id", "collision_id", "loopnum"] diff --git a/python/swiftest/swiftest/simulation_class.py b/python/swiftest/swiftest/simulation_class.py index ca11f121e..58bc3b713 100644 --- a/python/swiftest/swiftest/simulation_class.py +++ b/python/swiftest/swiftest/simulation_class.py @@ -320,7 +320,7 @@ def __init__(self,read_param: bool = False, read_old_output_file: bool = False, self.data = xr.Dataset() self.ic = xr.Dataset() self.encounters = xr.Dataset() - self.collision = xr.Dataset() + self.collisions = xr.Dataset() self.simdir = Path(simdir) if self.simdir.exists(): @@ -1234,10 +1234,10 @@ def set_feature(self, msg = f"{collision_save} is not a valid option for collision_save." msg += f"\nMust be one of {valid_vals}" warnings.warn(msg,stacklevel=2) - if "FRAGMENTATION_SAVE" not in self.param: - self.param["FRAGMENTATION_SAVE"] = valid_vals[0] + if "COLLISION_SAVE" not in self.param: + self.param["COLLISION_SAVE"] = valid_vals[0] else: - self.param["FRAGMENTATION_SAVE"] = collision_save + self.param["COLLISION_SAVE"] = collision_save update_list.append("collision_save") self.param["TIDES"] = False @@ -1272,7 +1272,7 @@ def get_feature(self, arg_list: str | List[str] | None = None, verbose: bool | N valid_var = {"close_encounter_check": "CHK_CLOSE", "fragmentation": "FRAGMENTATION", "encounter_save": "ENCOUNTER_SAVE", - "collision_save": "FRAGMENTATION_SAVE", + "collision_save": "COLLISION_SAVE", "minimum_fragment_gmass": "MIN_GMFRAG", "rotation": "ROTATION", "general_relativity": "GR", @@ -2735,10 +2735,16 @@ def read_output_file(self,read_init_cond : bool = True): # This is done to handle cases where the method is called from a different working directory than the simulation # results - if "ENCOUNTER_SAVE" in self.param or "FRAGMENTATION_SAVE" in self.param: - read_encounters = self.param["ENCOUNTER_SAVE"] != "NONE" or self.param["FRAGMENTATION_SAVE"] != "NONE" + if "ENCOUNTER_SAVE" in self.param: + read_encounters = self.param["ENCOUNTER_SAVE"] != "NONE" else: read_encounters = False + + if "COLLISION_SAVE" in self.param: + read_collisions = self.param["COLLISION_SAVE"] != "NONE" + else: + read_collisions = False + param_tmp = self.param.copy() param_tmp['BIN_OUT'] = os.path.join(self.simdir, self.param['BIN_OUT']) if self.codename == "Swiftest": @@ -2755,6 +2761,8 @@ def read_output_file(self,read_init_cond : bool = True): self.ic = self.data.isel(time=0) if read_encounters: self.read_encounters() + if read_collisions: + self.read_collisions() elif self.codename == "Swifter": self.data = io.swifter2xr(param_tmp, verbose=self.verbose) @@ -2790,6 +2798,28 @@ def _preprocess(ds, param): return + def read_collisions(self): + if self.verbose: + print("Reading collision history file as .collisions") + col_files = glob(f"{self.simdir}{os.path.sep}collision_*.nc") + col_files.sort() + + # This is needed in order to pass the param argument down to the io.process_netcdf_input function + def _preprocess(ds, param): + return io.process_netcdf_input(ds,param) + partial_func = partial(_preprocess, param=self.param) + + self.collisions = xr.open_mfdataset(col_files,parallel=True,combine="nested",concat_dim="collision",join="left",preprocess=partial_func,mask_and_scale=True) + self.collisions = io.process_netcdf_input(self.collisions, self.param) + + # # Reduce the dimensionality of variables that got expanded in the combine process + # self.encounters['loopnum'] = self.encounters['loopnum'].max(dim="name") + # self.encounters['id'] = self.encounters['id'].max(dim="time") + # self.encounters['particle_type'] = self.encounters['particle_type'].max(dim="time") + + return + + def follow(self, codestyle="Swifter"): """ An implementation of the Swift tool_follow algorithm. Under development. Currently only for Swift simulations. diff --git a/src/io/io.f90 b/src/io/io.f90 index 062d0f70a..71815b4be 100644 --- a/src/io/io.f90 +++ b/src/io/io.f90 @@ -682,7 +682,7 @@ module subroutine io_param_reader(self, unit, iotype, v_list, iostat, iomsg) param%lrestart = .true. end if ! Ignore SyMBA-specific, not-yet-implemented, or obsolete input parameters - case ("NPLMAX", "NTPMAX", "GMTINY", "MIN_GMFRAG", "FRAGMENTATION", "SEED", "YARKOVSKY", "YORP", "ENCOUNTER_SAVE", "FRAGMENTATION_SAVE") + case ("NPLMAX", "NTPMAX", "GMTINY", "MIN_GMFRAG", "FRAGMENTATION", "SEED", "YARKOVSKY", "YORP", "ENCOUNTER_SAVE", "COLLISION_SAVE") case default write(*,*) "Ignoring unknown parameter -> ",param_name end select diff --git a/src/symba/symba_io.f90 b/src/symba/symba_io.f90 index 18b56767e..4f19bfd30 100644 --- a/src/symba/symba_io.f90 +++ b/src/symba/symba_io.f90 @@ -68,7 +68,7 @@ module subroutine symba_io_param_reader(self, unit, iotype, v_list, iostat, ioms case ("ENCOUNTER_SAVE") call io_toupper(param_value) read(param_value, *) param%encounter_save - case ("FRAGMENTATION_SAVE") + case ("COLLISION_SAVE") call io_toupper(param_value) read(param_value, *) param%collision_save case("SEED")