Skip to content
This repository was archived by the owner on Aug 28, 2024. It is now read-only.

Commit

Permalink
More clean up of encounter data
Browse files Browse the repository at this point in the history
  • Loading branch information
daminton committed Dec 6, 2022
1 parent c25423d commit d34d36e
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 14 deletions.
12 changes: 6 additions & 6 deletions examples/Fragmentation/Fragmentation_Movie.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,8 @@ class AnimatedScatter(object):
"""An animated scatter plot using matplotlib.animations.FuncAnimation."""

def __init__(self, sim, animfile, title, nskip=1):
nframes = int(sim.enc['time'].size)
self.ds = sim.enc.mean(dim="encounter") # Reduce out the encounter dimension
nframes = int(self.ds['time'].size)
self.sim = sim
self.title = title
self.body_color_list = {'Initial conditions': 'xkcd:windows blue',
Expand All @@ -106,11 +107,10 @@ def setup_plot(self):
fig = plt.figure(figsize=figsize, dpi=300)
plt.tight_layout(pad=0)


# Calculate the distance along the y-axis between the colliding bodies at the start of the simulation.
# This will be used to scale the axis limits on the movie.
rhy1 = sim.enc['rh'].isel(time=0).sel(name="Body1",space='y').values[()]
rhy2 = sim.enc['rh'].isel(time=0).sel(name="Body2",space='y').values[()]
rhy1 = self.ds['rh'].isel(time=0).sel(name="Body1",space='y').values[()]
rhy2 = self.ds['rh'].isel(time=0).sel(name="Body2",space='y').values[()]

scale_frame = abs(rhy1) + abs(rhy2)
ax = plt.Axes(fig, [0.1, 0.1, 0.8, 0.8])
Expand Down Expand Up @@ -145,7 +145,7 @@ def center(Gmass, x, y):

def data_stream(self, frame=0):
while True:
ds = self.sim.enc.isel(time=frame)
ds = self.ds.isel(time=frame)
ds = ds.where(ds['name'] != "Sun", drop=True)
radius = ds['radius'].values
Gmass = ds['Gmass'].values
Expand Down Expand Up @@ -181,4 +181,4 @@ def data_stream(self, frame=0):
sim.set_parameter(fragmentation=True, fragmentation_save="TRAJECTORY", gmtiny=gmtiny, minimum_fragment_gmass=minimum_fragment_gmass, verbose=False)
sim.run(dt=1e-4, tstop=2.0e-3, istep_out=1, dump_cadence=0)

anim = AnimatedScatter(sim,movie_filename,movie_titles[style],nskip=1)
anim = AnimatedScatter(sim,movie_filename,movie_titles[style],nskip=1)
34 changes: 26 additions & 8 deletions python/swiftest/swiftest/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -823,9 +823,10 @@ def process_netcdf_input(ds, param):
ds = fix_types(ds,ftype=np.float32)
ds = ds.where(ds.id >=0 ,drop=True)
# Check if the name variable contains unique values. If so, make name the dimension instead of id
if "name" not in ds.dims and len(np.unique(ds['name'])) == len(ds['name']):
ds = ds.swap_dims({"id" : "name"})
ds = ds.reset_coords("id")
if "id" in ds.dims:
if len(np.unique(ds['name'])) == len(ds['name']):
ds = ds.swap_dims({"id" : "name"})
ds = ds.reset_coords("id")

return ds

Expand Down Expand Up @@ -855,21 +856,37 @@ def swiftest2xr(param, verbose=True):

return ds

def xstrip(a):
def xstrip_nonstr(a):
"""
Cleans up the string values in the DataSet to remove extra white space
Parameters
----------
a : xarray dataset
Returns
-------
da : xarray dataset with the strings cleaned up
"""
func = lambda x: np.char.strip(x)
return xr.apply_ufunc(func, a.str.decode(encoding='utf-8'),dask='parallelized')

def xstrip_str(a):
"""
Cleans up the string values in the DataSet to remove extra white space
Parameters
----------
a : xarray dataset
Returns
-------
da : xarray dataset with the strings cleaned up
"""
func = lambda x: np.char.strip(x)
return xr.apply_ufunc(func, a,dask='parallelized')


def string_converter(da):
"""
Converts a string to a unicode string
Expand All @@ -883,9 +900,10 @@ def string_converter(da):
da : xarray dataset with the strings cleaned up
"""
if da.dtype == np.dtype(object):
da = da.astype('<U32')
da = da.astype('<U32')
da = xstrip_str(da)
elif type(da.values[0]) != np.str_:
da = xstrip(da)
da = xstrip_nonstr(da)
return da

def char_converter(da):
Expand All @@ -903,7 +921,7 @@ def char_converter(da):
if da.dtype == np.dtype(object):
da = da.astype('<U1')
elif type(da.values[0]) != np.str_:
da = xstrip(da)
da = xstrip_nonstr(da)
return da

def clean_string_values(ds):
Expand Down

0 comments on commit d34d36e

Please sign in to comment.