Skip to content
This repository was archived by the owner on Aug 28, 2024. It is now read-only.

Commit

Permalink
Browse files Browse the repository at this point in the history
Made improvements to type handling by the Swiftest xarray methods
  • Loading branch information
daminton committed Nov 9, 2022
1 parent 3af0e10 commit d592f0e
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 48 deletions.
1 change: 1 addition & 0 deletions python/swiftest/swiftest/init_cond.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,4 +417,5 @@ def vec2xr(param, idvals, namevals, v1, v2, v3, v4, v5, v6, GMpl=None, Rpl=None,
ds_float = da_float.to_dataset(dim="vec")
ds_str = da_str.to_dataset(dim="vec")
ds = xr.combine_by_coords([ds_float, ds_str,info_ds])

return ds
80 changes: 32 additions & 48 deletions python/swiftest/swiftest/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
# This defines Xarray Dataset variables that are strings, which must be processed due to quirks in how NetCDF-Fortran
# handles strings differently than Python's Xarray.
string_varnames = ["name", "particle_type", "status", "origin_type"]
int_varnames = ["id", "ntp", "npl", "nplm", "discard_body_id", "collision_id"]

def bool2yesno(boolval):
"""
Expand Down Expand Up @@ -847,54 +848,14 @@ def swiftest2xr(param, verbose=True):
-------
xarray dataset
"""
if ((param['OUT_TYPE'] == 'REAL8') or (param['OUT_TYPE'] == 'REAL4')):
dims = ['time', 'id', 'vec']
cb = []
pl = []
tp = []
cbn = None
try:
with FortranFile(param['BIN_OUT'], 'r') as f:
for t, cbid, cbnames, cvec, clab, \
npl, plid, plnames, pvec, plab, \
ntp, tpid, tpnames, tvec, tlab in swiftest_stream(f, param):
# Prepare frames by adding an extra axis for the time coordinate
cbframe = np.expand_dims(cvec, axis=0)
plframe = np.expand_dims(pvec, axis=0)
tpframe = np.expand_dims(tvec, axis=0)


# Create xarray DataArrays out of each body type
cbxr = xr.DataArray(cbframe, dims=dims, coords={'time': t, 'id': cbid, 'vec': clab})
cbxr = cbxr.assign_coords(name=("id", cbnames))
plxr = xr.DataArray(plframe, dims=dims, coords={'time': t, 'id': plid, 'vec': plab})
plxr = plxr.assign_coords(name=("id", plnames))
tpxr = xr.DataArray(tpframe, dims=dims, coords={'time': t, 'id': tpid, 'vec': tlab})
tpxr = tpxr.assign_coords(name=("id", tpnames))

cb.append(cbxr)
pl.append(plxr)
tp.append(tpxr)

sys.stdout.write('\r' + f"Reading in time {t[0]:.3e}")
sys.stdout.flush()
except IOError:
print(f"Error encountered reading in {param['BIN_OUT']}")

cbda = xr.concat(cb, dim='time')
plda = xr.concat(pl, dim='time')
tpda = xr.concat(tp, dim='time')

cbds = cbda.to_dataset(dim='vec')
plds = plda.to_dataset(dim='vec')
tpds = tpda.to_dataset(dim='vec')
if verbose: print('\nCreating Dataset')
ds = xr.combine_by_coords([cbds, plds, tpds])

elif ((param['OUT_TYPE'] == 'NETCDF_DOUBLE') or (param['OUT_TYPE'] == 'NETCDF_FLOAT')):
if ((param['OUT_TYPE'] == 'NETCDF_DOUBLE') or (param['OUT_TYPE'] == 'NETCDF_FLOAT')):
if verbose: print('\nCreating Dataset from NetCDF file')
ds = xr.open_dataset(param['BIN_OUT'], mask_and_scale=False)
ds = clean_string_values(ds)
if param['OUT_TYPE'] == "NETCDF_DOUBLE":
ds = fix_types(ds,ftype=np.float64)
elif param['OUT_TYPE'] == "NETCDF_FLOAT":
ds = fix_types(ds,ftype=np.float32)
else:
print(f"Error encountered. OUT_TYPE {param['OUT_TYPE']} not recognized.")
return None
Expand Down Expand Up @@ -931,7 +892,7 @@ def string_converter(da):
"""
if da.dtype == np.dtype(object):
da = da.astype('<U32')
elif da.dtype != np.dtype('<U32'):
elif type(da.values[0]) != np.str_:
da = xstrip(da)
return da

Expand Down Expand Up @@ -968,8 +929,28 @@ def unclean_string_values(ds):
"""

for c in string_varnames:
n = string_converter(ds[c])
ds[c] = n.str.ljust(32).str.encode('utf-8')
if c in ds:
n = string_converter(ds[c])
ds[c] = n.str.ljust(32).str.encode('utf-8')
return ds

def fix_types(ds,itype=np.int64,ftype=np.float64):

ds = clean_string_values(ds)
for intvar in int_varnames:
if intvar in ds:
ds[intvar] = ds[intvar].astype(itype)

float_varnames = [x for x in list(ds.keys()) if x not in string_varnames and x not in int_varnames]

for floatvar in float_varnames:
ds[floatvar] = ds[floatvar].astype(ftype)

float_coordnames = [x for x in list(ds.coords) if x not in string_varnames and x not in int_varnames]
for floatcoord in float_coordnames:
ds[floatcoord] = ds[floatcoord].astype(np.float64)


return ds


Expand Down Expand Up @@ -1102,6 +1083,9 @@ def swiftest_xr2infile(ds, param, in_type="NETCDF_DOUBLE", infile_name=None,fram
# Convert strings back to byte form and save the NetCDF file
# Note: xarray will call the character array dimension string32. The Fortran code
# will rename this after reading

if infile_name is None:
infile_name = param['NC_IN']
frame = unclean_string_values(frame)
print(f"Writing initial conditions to file {infile_name}")
frame.to_netcdf(path=infile_name)
Expand Down
5 changes: 5 additions & 0 deletions python/swiftest/swiftest/simulation_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -1464,6 +1464,11 @@ def addp(self, idvals, namevals, v1, v2, v3, v4, v5, v6, GMpl=None, Rpl=None, rh
self.ds['ntp'] = self.ds['id'].where(np.isnan(self.ds['Gmass'])).count(dim="id")
self.ds['npl'] = self.ds['id'].where(np.invert(np.isnan(self.ds['Gmass']))).count(dim="id") - 1

if self.param['OUT_TYPE'] == "NETCDF_DOUBLE":
self.ds = io.fix_types(self.ds,ftype=np.float64)
elif self.param['OUT_TYPE'] == "NETCDF_FLOAT":
self.ds = io.fix_types(self.ds,ftype=np.float32)

return


Expand Down

0 comments on commit d592f0e

Please sign in to comment.