Skip to content
This repository was archived by the owner on Aug 28, 2024. It is now read-only.

Commit

Permalink
Collisions are now processed in the Python side
Browse files Browse the repository at this point in the history
  • Loading branch information
daminton committed Dec 12, 2022
1 parent 1b82c22 commit c12ec19
Show file tree
Hide file tree
Showing 4 changed files with 42 additions and 12 deletions.
6 changes: 3 additions & 3 deletions python/swiftest/swiftest/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
"TSTART",
"DUMP_CADENCE",
"ENCOUNTER_SAVE",
"FRAGMENTATION_SAVE")
"COLLISION_SAVE")



Expand All @@ -55,14 +55,14 @@
float_param = ["T0", "TSTART", "TSTOP", "DT", "CHK_RMIN", "CHK_RMAX", "CHK_EJECT", "CHK_QMIN", "DU2M", "MU2KG",
"TU2S", "MIN_GMFRAG", "GMTINY"]

upper_str_param = ["OUT_TYPE","OUT_FORM","OUT_STAT","IN_TYPE","IN_FORM","ENCOUNTER_SAVE","FRAGMENTATION_SAVE", "CHK_QMIN_COORD"]
upper_str_param = ["OUT_TYPE","OUT_FORM","OUT_STAT","IN_TYPE","IN_FORM","ENCOUNTER_SAVE","COLLISION_SAVE", "CHK_QMIN_COORD"]
lower_str_param = ["NC_IN", "PL_IN", "TP_IN", "CB_IN", "CHK_QMIN_RANGE"]

param_keys = ['! VERSION'] + int_param + float_param + upper_str_param + lower_str_param+ bool_param

# This defines Xarray Dataset variables that are strings, which must be processed due to quirks in how NetCDF-Fortran
# handles strings differently than Python's Xarray.
string_varnames = ["name", "particle_type", "status", "origin_type"]
string_varnames = ["name", "particle_type", "status", "origin_type", "stage", "regime"]
char_varnames = ["space"]
int_varnames = ["id", "ntp", "npl", "nplm", "discard_body_id", "collision_id", "loopnum"]

Expand Down
44 changes: 37 additions & 7 deletions python/swiftest/swiftest/simulation_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,7 @@ def __init__(self,read_param: bool = False, read_old_output_file: bool = False,
self.data = xr.Dataset()
self.ic = xr.Dataset()
self.encounters = xr.Dataset()
self.collision = xr.Dataset()
self.collisions = xr.Dataset()

self.simdir = Path(simdir)
if self.simdir.exists():
Expand Down Expand Up @@ -1234,10 +1234,10 @@ def set_feature(self,
msg = f"{collision_save} is not a valid option for collision_save."
msg += f"\nMust be one of {valid_vals}"
warnings.warn(msg,stacklevel=2)
if "FRAGMENTATION_SAVE" not in self.param:
self.param["FRAGMENTATION_SAVE"] = valid_vals[0]
if "COLLISION_SAVE" not in self.param:
self.param["COLLISION_SAVE"] = valid_vals[0]
else:
self.param["FRAGMENTATION_SAVE"] = collision_save
self.param["COLLISION_SAVE"] = collision_save
update_list.append("collision_save")

self.param["TIDES"] = False
Expand Down Expand Up @@ -1272,7 +1272,7 @@ def get_feature(self, arg_list: str | List[str] | None = None, verbose: bool | N
valid_var = {"close_encounter_check": "CHK_CLOSE",
"fragmentation": "FRAGMENTATION",
"encounter_save": "ENCOUNTER_SAVE",
"collision_save": "FRAGMENTATION_SAVE",
"collision_save": "COLLISION_SAVE",
"minimum_fragment_gmass": "MIN_GMFRAG",
"rotation": "ROTATION",
"general_relativity": "GR",
Expand Down Expand Up @@ -2735,10 +2735,16 @@ def read_output_file(self,read_init_cond : bool = True):
# This is done to handle cases where the method is called from a different working directory than the simulation
# results

if "ENCOUNTER_SAVE" in self.param or "FRAGMENTATION_SAVE" in self.param:
read_encounters = self.param["ENCOUNTER_SAVE"] != "NONE" or self.param["FRAGMENTATION_SAVE"] != "NONE"
if "ENCOUNTER_SAVE" in self.param:
read_encounters = self.param["ENCOUNTER_SAVE"] != "NONE"
else:
read_encounters = False

if "COLLISION_SAVE" in self.param:
read_collisions = self.param["COLLISION_SAVE"] != "NONE"
else:
read_collisions = False

param_tmp = self.param.copy()
param_tmp['BIN_OUT'] = os.path.join(self.simdir, self.param['BIN_OUT'])
if self.codename == "Swiftest":
Expand All @@ -2755,6 +2761,8 @@ def read_output_file(self,read_init_cond : bool = True):
self.ic = self.data.isel(time=0)
if read_encounters:
self.read_encounters()
if read_collisions:
self.read_collisions()

elif self.codename == "Swifter":
self.data = io.swifter2xr(param_tmp, verbose=self.verbose)
Expand Down Expand Up @@ -2790,6 +2798,28 @@ def _preprocess(ds, param):
return


def read_collisions(self):
if self.verbose:
print("Reading collision history file as .collisions")
col_files = glob(f"{self.simdir}{os.path.sep}collision_*.nc")
col_files.sort()

# This is needed in order to pass the param argument down to the io.process_netcdf_input function
def _preprocess(ds, param):
return io.process_netcdf_input(ds,param)
partial_func = partial(_preprocess, param=self.param)

self.collisions = xr.open_mfdataset(col_files,parallel=True,combine="nested",concat_dim="collision",join="left",preprocess=partial_func,mask_and_scale=True)
self.collisions = io.process_netcdf_input(self.collisions, self.param)

# # Reduce the dimensionality of variables that got expanded in the combine process
# self.encounters['loopnum'] = self.encounters['loopnum'].max(dim="name")
# self.encounters['id'] = self.encounters['id'].max(dim="time")
# self.encounters['particle_type'] = self.encounters['particle_type'].max(dim="time")

return


def follow(self, codestyle="Swifter"):
"""
An implementation of the Swift tool_follow algorithm. Under development. Currently only for Swift simulations.
Expand Down
2 changes: 1 addition & 1 deletion src/io/io.f90
Original file line number Diff line number Diff line change
Expand Up @@ -682,7 +682,7 @@ module subroutine io_param_reader(self, unit, iotype, v_list, iostat, iomsg)
param%lrestart = .true.
end if
! Ignore SyMBA-specific, not-yet-implemented, or obsolete input parameters
case ("NPLMAX", "NTPMAX", "GMTINY", "MIN_GMFRAG", "FRAGMENTATION", "SEED", "YARKOVSKY", "YORP", "ENCOUNTER_SAVE", "FRAGMENTATION_SAVE")
case ("NPLMAX", "NTPMAX", "GMTINY", "MIN_GMFRAG", "FRAGMENTATION", "SEED", "YARKOVSKY", "YORP", "ENCOUNTER_SAVE", "COLLISION_SAVE")
case default
write(*,*) "Ignoring unknown parameter -> ",param_name
end select
Expand Down
2 changes: 1 addition & 1 deletion src/symba/symba_io.f90
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ module subroutine symba_io_param_reader(self, unit, iotype, v_list, iostat, ioms
case ("ENCOUNTER_SAVE")
call io_toupper(param_value)
read(param_value, *) param%encounter_save
case ("FRAGMENTATION_SAVE")
case ("COLLISION_SAVE")
call io_toupper(param_value)
read(param_value, *) param%collision_save
case("SEED")
Expand Down

0 comments on commit c12ec19

Please sign in to comment.