Skip to content
This repository was archived by the owner on Aug 28, 2024. It is now read-only.

Commit

Permalink
Merge branch 'debug' into Simulation_API_improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
daminton committed Nov 10, 2022
2 parents 1cc9745 + b2614ce commit 63cec3c
Show file tree
Hide file tree
Showing 5 changed files with 108 additions and 132 deletions.
46 changes: 0 additions & 46 deletions examples/Basic_Simulation/param.in

This file was deleted.

1 change: 1 addition & 0 deletions python/swiftest/swiftest/init_cond.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,4 +417,5 @@ def vec2xr(param, idvals, namevals, v1, v2, v3, v4, v5, v6, GMpl=None, Rpl=None,
ds_float = da_float.to_dataset(dim="vec")
ds_str = da_str.to_dataset(dim="vec")
ds = xr.combine_by_coords([ds_float, ds_str,info_ds])

return ds
132 changes: 54 additions & 78 deletions python/swiftest/swiftest/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,16 @@
"YARKOVSKY",
"YORP"]

int_param = ["ISTEP_OUT", "ISTEP_DUMP"]
float_param = ["T0", "TSTART", "TSTOP", "DT", "CHK_RMIN", "CHK_RMAX", "CHK_EJECT", "CHK_QMIN", "DU2M", "MU2KG",
"TU2S", "MIN_GMFRAG", "GMTINY"]

upper_str_param = ["OUT_TYPE","OUT_FORM","OUT_STAT","IN_TYPE","IN_FORM"]

# This defines Xarray Dataset variables that are strings, which must be processed due to quirks in how NetCDF-Fortran
# handles strings differently than Python's Xarray.
string_varnames = ["name", "particle_type", "status", "origin_type"]
int_varnames = ["id", "ntp", "npl", "nplm", "discard_body_id", "collision_id"]

def bool2yesno(boolval):
"""
Expand Down Expand Up @@ -106,10 +113,10 @@ def str2bool(input_str):
valid_false = ["NO", "N", "F", "FALSE", ".FALSE."]
if input_str.upper() in valid_true:
return True
elif input_str.lower() in valid_false:
elif input_str.upper() in valid_false:
return False
else:
raise ValueError(f"{input_str} cannot is not recognized as boolean")
raise ValueError(f"{input_str} is not recognized as boolean")



Expand Down Expand Up @@ -146,7 +153,8 @@ def read_swiftest_param(param_file_name, param, verbose=True):
A dictionary containing the entries in the user parameter file
"""
param['! VERSION'] = f"Swiftest parameter input from file {param_file_name}"



# Read param.in file
if verbose: print(f'Reading Swiftest file {param_file_name}')
try:
Expand All @@ -157,38 +165,23 @@ def read_swiftest_param(param_file_name, param, verbose=True):
if fields[0][0] != '!':
key = fields[0].upper()
param[key] = fields[1]
#for key in param:
# if (key == fields[0].upper()): param[key] = fields[1]
# Special case of CHK_QMIN_RANGE requires a second input
if fields[0].upper() == 'CHK_QMIN_RANGE':
alo = real2float(fields[1])
ahi = real2float(fields[2])
param['CHK_QMIN_RANGE'] = f"{alo} {ahi}"

param['ISTEP_OUT'] = int(param['ISTEP_OUT'])
param['ISTEP_DUMP'] = int(param['ISTEP_DUMP'])
param['OUT_TYPE'] = param['OUT_TYPE'].upper()
param['OUT_FORM'] = param['OUT_FORM'].upper()
param['OUT_STAT'] = param['OUT_STAT'].upper()
param['IN_TYPE'] = param['IN_TYPE'].upper()
param['IN_FORM'] = param['IN_FORM'].upper()
param['T0'] = real2float(param['T0'])
param['TSTART'] = real2float(param['TSTART'])
param['TSTOP'] = real2float(param['TSTOP'])
param['DT'] = real2float(param['DT'])
param['CHK_RMIN'] = real2float(param['CHK_RMIN'])
param['CHK_RMAX'] = real2float(param['CHK_RMAX'])
param['CHK_EJECT'] = real2float(param['CHK_EJECT'])
param['CHK_QMIN'] = real2float(param['CHK_QMIN'])
param['DU2M'] = real2float(param['DU2M'])
param['MU2KG'] = real2float(param['MU2KG'])
param['TU2S'] = real2float(param['TU2S'])
param['INTERACTION_LOOPS'] = param['INTERACTION_LOOPS'].upper()
param['ENCOUNTER_CHECK'] = param['ENCOUNTER_CHECK'].upper()
if 'GMTINY' in param:
param['GMTINY'] = real2float(param['GMTINY'])
if 'MIN_GMFRAG' in param:
param['MIN_GMFRAG'] = real2float(param['MIN_GMFRAG'])

for uc in upper_str_param:
if uc in param:
param[uc] = param[uc].upper()

for i in int_param:
if i in param and type(i) != int:
param[i] = int(param[i])

for f in float_param:
if f in param and type(f) is str:
param[f] = real2float(param[f])
for b in bool_param:
if b in param:
param[b] = str2bool(param[b])
Expand Down Expand Up @@ -847,54 +840,14 @@ def swiftest2xr(param, verbose=True):
-------
xarray dataset
"""
if ((param['OUT_TYPE'] == 'REAL8') or (param['OUT_TYPE'] == 'REAL4')):
dims = ['time', 'id', 'vec']
cb = []
pl = []
tp = []
cbn = None
try:
with FortranFile(param['BIN_OUT'], 'r') as f:
for t, cbid, cbnames, cvec, clab, \
npl, plid, plnames, pvec, plab, \
ntp, tpid, tpnames, tvec, tlab in swiftest_stream(f, param):
# Prepare frames by adding an extra axis for the time coordinate
cbframe = np.expand_dims(cvec, axis=0)
plframe = np.expand_dims(pvec, axis=0)
tpframe = np.expand_dims(tvec, axis=0)


# Create xarray DataArrays out of each body type
cbxr = xr.DataArray(cbframe, dims=dims, coords={'time': t, 'id': cbid, 'vec': clab})
cbxr = cbxr.assign_coords(name=("id", cbnames))
plxr = xr.DataArray(plframe, dims=dims, coords={'time': t, 'id': plid, 'vec': plab})
plxr = plxr.assign_coords(name=("id", plnames))
tpxr = xr.DataArray(tpframe, dims=dims, coords={'time': t, 'id': tpid, 'vec': tlab})
tpxr = tpxr.assign_coords(name=("id", tpnames))

cb.append(cbxr)
pl.append(plxr)
tp.append(tpxr)

sys.stdout.write('\r' + f"Reading in time {t[0]:.3e}")
sys.stdout.flush()
except IOError:
print(f"Error encountered reading in {param['BIN_OUT']}")

cbda = xr.concat(cb, dim='time')
plda = xr.concat(pl, dim='time')
tpda = xr.concat(tp, dim='time')

cbds = cbda.to_dataset(dim='vec')
plds = plda.to_dataset(dim='vec')
tpds = tpda.to_dataset(dim='vec')
if verbose: print('\nCreating Dataset')
ds = xr.combine_by_coords([cbds, plds, tpds])

elif ((param['OUT_TYPE'] == 'NETCDF_DOUBLE') or (param['OUT_TYPE'] == 'NETCDF_FLOAT')):
if ((param['OUT_TYPE'] == 'NETCDF_DOUBLE') or (param['OUT_TYPE'] == 'NETCDF_FLOAT')):
if verbose: print('\nCreating Dataset from NetCDF file')
ds = xr.open_dataset(param['BIN_OUT'], mask_and_scale=False)
ds = clean_string_values(ds)
if param['OUT_TYPE'] == "NETCDF_DOUBLE":
ds = fix_types(ds,ftype=np.float64)
elif param['OUT_TYPE'] == "NETCDF_FLOAT":
ds = fix_types(ds,ftype=np.float32)
else:
print(f"Error encountered. OUT_TYPE {param['OUT_TYPE']} not recognized.")
return None
Expand Down Expand Up @@ -931,7 +884,7 @@ def string_converter(da):
"""
if da.dtype == np.dtype(object):
da = da.astype('<U32')
elif da.dtype != np.dtype('<U32'):
elif type(da.values[0]) != np.str_:
da = xstrip(da)
return da

Expand Down Expand Up @@ -968,8 +921,28 @@ def unclean_string_values(ds):
"""

for c in string_varnames:
n = string_converter(ds[c])
ds[c] = n.str.ljust(32).str.encode('utf-8')
if c in ds:
n = string_converter(ds[c])
ds[c] = n.str.ljust(32).str.encode('utf-8')
return ds

def fix_types(ds,itype=np.int64,ftype=np.float64):

ds = clean_string_values(ds)
for intvar in int_varnames:
if intvar in ds:
ds[intvar] = ds[intvar].astype(itype)

float_varnames = [x for x in list(ds.keys()) if x not in string_varnames and x not in int_varnames]

for floatvar in float_varnames:
ds[floatvar] = ds[floatvar].astype(ftype)

float_coordnames = [x for x in list(ds.coords) if x not in string_varnames and x not in int_varnames]
for floatcoord in float_coordnames:
ds[floatcoord] = ds[floatcoord].astype(np.float64)


return ds


Expand Down Expand Up @@ -1102,6 +1075,9 @@ def swiftest_xr2infile(ds, param, in_type="NETCDF_DOUBLE", infile_name=None,fram
# Convert strings back to byte form and save the NetCDF file
# Note: xarray will call the character array dimension string32. The Fortran code
# will rename this after reading

if infile_name is None:
infile_name = param['NC_IN']
frame = unclean_string_values(frame)
print(f"Writing initial conditions to file {infile_name}")
frame.to_netcdf(path=infile_name)
Expand Down
8 changes: 7 additions & 1 deletion python/swiftest/swiftest/simulation_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -1560,6 +1560,11 @@ def addp(self, idvals, namevals, v1, v2, v3, v4, v5, v6, GMpl=None, Rpl=None, rh
self.ds['ntp'] = self.ds['id'].where(np.isnan(self.ds['Gmass'])).count(dim="id")
self.ds['npl'] = self.ds['id'].where(np.invert(np.isnan(self.ds['Gmass']))).count(dim="id") - 1

if self.param['OUT_TYPE'] == "NETCDF_DOUBLE":
self.ds = io.fix_types(self.ds,ftype=np.float64)
elif self.param['OUT_TYPE'] == "NETCDF_FLOAT":
self.ds = io.fix_types(self.ds,ftype=np.float32)

return


Expand All @@ -1578,7 +1583,8 @@ def read_param(self, param_file, codename="Swiftest", verbose=True):
self.ds : xarray dataset
"""
if codename == "Swiftest":
self.param = io.read_swiftest_param(param_file, self.param, verbose=verbose)
param_old = self.param.copy()
self.param = io.read_swiftest_param(param_file, param_old, verbose=verbose)
self.codename = "Swiftest"
elif codename == "Swifter":
self.param = io.read_swifter_param(param_file, verbose=verbose)
Expand Down
53 changes: 46 additions & 7 deletions src/netcdf/netcdf.f90
Original file line number Diff line number Diff line change
Expand Up @@ -351,11 +351,14 @@ module subroutine netcdf_open(self, param, readonly)
! Internals
integer(I4B) :: mode, status
character(len=NF90_MAX_NAME) :: str_dim_name
integer(I4B) :: idmax
real(DP), dimension(:), allocatable :: gmtemp
logical, dimension(:), allocatable :: tpmask, plmask, plmmask

mode = NF90_WRITE
if (present(readonly)) then
if (readonly) mode = NF90_NOWRITE
end if
!if (present(readonly)) then
! if (readonly) mode = NF90_NOWRITE
!end if

call check( nf90_open(param%outfile, mode, self%ncid), "netcdf_open nf90_open" )

Expand All @@ -366,12 +369,49 @@ module subroutine netcdf_open(self, param, readonly)

call check( nf90_inq_varid(self%ncid, TIME_DIMNAME, self%time_varid), "netcdf_open nf90_inq_varid time_varid" )
call check( nf90_inq_varid(self%ncid, ID_DIMNAME, self%id_varid), "netcdf_open nf90_inq_varid id_varid" )
call check( nf90_inq_varid(self%ncid, NPL_VARNAME, self%npl_varid), "netcdf_open nf90_inq_varid npl_varid" )
call check( nf90_inq_varid(self%ncid, NTP_VARNAME, self%ntp_varid), "netcdf_open nf90_inq_varid ntp_varid" )
if (param%integrator == SYMBA) call check( nf90_inq_varid(self%ncid, NPLM_VARNAME, self%nplm_varid), "netcdf_open nf90_inq_varid nplm_varid" )
call check( nf90_inq_varid(self%ncid, NAME_VARNAME, self%name_varid), "netcdf_open nf90_inq_varid name_varid" )
call check( nf90_inq_varid(self%ncid, PTYPE_VARNAME, self%ptype_varid), "netcdf_open nf90_inq_varid ptype_varid" )
call check( nf90_inq_varid(self%ncid, STATUS_VARNAME, self%status_varid), "netcdf_open nf90_inq_varid status_varid" )
call check( nf90_inq_varid(self%ncid, GMASS_VARNAME, self%Gmass_varid), "netcdf_open nf90_inq_varid Gmass_varid" )

if ((nf90_inq_varid(self%ncid, NPL_VARNAME, self%npl_varid) /= nf90_noerr) .or. &
(nf90_inq_varid(self%ncid, NTP_VARNAME, self%ntp_varid) /= nf90_noerr) .or. &
((nf90_inq_varid(self%ncid, NPLM_VARNAME, self%nplm_varid) /= nf90_noerr) .and. (param%integrator == SYMBA))) then
call check( nf90_inquire_dimension(self%ncid, self%id_dimid, len=idmax), "netcdf_open nf90_inquire_dimension id_dimid" )
allocate(gmtemp(idmax))
call check( nf90_get_var(self%ncid, self%Gmass_varid, gmtemp, start=[1,1]), "netcdf_open nf90_getvar Gmass_varid" )
allocate(tpmask(idmax))
allocate(plmask(idmax))
allocate(plmmask(idmax))
plmask(:) = gmtemp(:) == gmtemp(:)
tpmask(:) = .not. plmask(:)
plmask(1) = .false. ! This is the central body
select type (param)
class is (symba_parameters)
plmmask(:) = gmtemp(:) > param%GMTINY .and. plmask(:)
end select
if ((nf90_inq_varid(self%ncid, NPL_VARNAME, self%npl_varid) /= nf90_noerr)) then
call check( nf90_redef(self%ncid), "netcdf_open nf90_redef npl_varid")
call check( nf90_def_var(self%ncid, NPL_VARNAME, NF90_INT, self%time_dimid, self%npl_varid), "netcdf_open nf90_def_var npl_varid" )
call check( nf90_enddef(self%ncid), "netcdf_open nf90_enddef npl_varid")
call check( nf90_put_var(self%ncid, self%npl_varid, count(plmask(:)), start=[1]), "netcdf_open nf90_put_var npl_varid" )
call check( nf90_inq_varid(self%ncid, NPL_VARNAME, self%npl_varid), "netcdf_open nf90_inq_varid npl_varid" )
end if
if (nf90_inq_varid(self%ncid, NTP_VARNAME, self%ntp_varid) /= nf90_noerr) then
call check( nf90_redef(self%ncid), "netcdf_open nf90_redef ntp_varid")
call check( nf90_def_var(self%ncid, NTP_VARNAME, NF90_INT, self%time_dimid, self%ntp_varid), "netcdf_open nf90_def_var ntp_varid" )
call check( nf90_enddef(self%ncid), "netcdf_open nf90_enddef ntp_varid")
call check( nf90_put_var(self%ncid, self%ntp_varid, count(tpmask(:)), start=[1]), "netcdf_open nf90_put_var ntp_varid" )
call check( nf90_inq_varid(self%ncid, NTP_VARNAME, self%ntp_varid), "netcdf_open nf90_inq_varid ntp_varid" )
end if
if ((nf90_inq_varid(self%ncid, NPLM_VARNAME, self%nplm_varid) /= nf90_noerr) .and. (param%integrator == SYMBA)) then
call check( nf90_redef(self%ncid), "netcdf_open nf90_redef nplm_varid")
call check( nf90_def_var(self%ncid, NPLM_VARNAME, NF90_INT, self%time_dimid, self%nplm_varid), "netcdf_open nf90_def_var nplm_varid" )
call check( nf90_enddef(self%ncid), "netcdf_open nf90_enddef nplm_varid")
call check( nf90_put_var(self%ncid, self%nplm_varid, count(plmmask(:)), start=[1]), "netcdf_open nf90_put_var nplm_varid" )
call check( nf90_inq_varid(self%ncid, NPLM_VARNAME, self%nplm_varid), "netcdf_open nf90_inq_varid nplm_varid" )
end if
end if

if ((param%out_form == XV) .or. (param%out_form == XVEL)) then
call check( nf90_inq_varid(self%ncid, XHX_VARNAME, self%xhx_varid), "netcdf_open nf90_inq_varid xhx_varid" )
Expand Down Expand Up @@ -409,7 +449,6 @@ module subroutine netcdf_open(self, param, readonly)
call check( nf90_inq_varid(self%ncid, OMEGA_VARNAME, self%omega_varid), "netcdf_open nf90_inq_varid omega_varid" )
call check( nf90_inq_varid(self%ncid, CAPM_VARNAME, self%capm_varid), "netcdf_open nf90_inq_varid capm_varid" )
end if
call check( nf90_inq_varid(self%ncid, GMASS_VARNAME, self%Gmass_varid), "netcdf_open nf90_inq_varid Gmass_varid" )

if (param%lrhill_present) call check( nf90_inq_varid(self%ncid, RHILL_VARNAME, self%rhill_varid), "netcdf_open nf90_inq_varid rhill_varid" )

Expand Down

0 comments on commit 63cec3c

Please sign in to comment.