diff --git a/python/swiftest/swiftest/init_cond.py b/python/swiftest/swiftest/init_cond.py index 7d7c20e14..78ef51be9 100644 --- a/python/swiftest/swiftest/init_cond.py +++ b/python/swiftest/swiftest/init_cond.py @@ -400,7 +400,7 @@ def vec2xr(param, idvals, namevals, v1, v2, v3, v4, v5, v6, GMpl=None, Rpl=None, frame_str = info_vec_str.T if param['IN_TYPE'] == 'NETCDF_FLOAT': ftype=np.float32 - elif param['IN_TYPE'] == 'NETCDF_DOUBLE': + elif param['IN_TYPE'] == 'NETCDF_DOUBLE' or param['IN_TYPE'] == 'ASCII': ftype=np.float64 da_float = xr.DataArray(frame_float, dims=infodims, coords={'id': idvals, 'vec': infolab_float}).astype(ftype) da_int = xr.DataArray(frame_int, dims=infodims, coords={'id': idvals, 'vec': infolab_int}) diff --git a/python/swiftest/swiftest/io.py b/python/swiftest/swiftest/io.py index 176237b0a..4678b4784 100644 --- a/python/swiftest/swiftest/io.py +++ b/python/swiftest/swiftest/io.py @@ -999,7 +999,9 @@ def swiftest_xr2infile(ds, param, in_type="NETCDF_DOUBLE", infile_name=None,fram ------- A set of input files for a new Swiftest run """ - frame = select_active_from_frame(ds, param, framenum) + param_tmp = param.copy() + param_tmp['OUT_FORM'] = param['IN_FORM'] + frame = select_active_from_frame(ds, param_tmp, framenum) if in_type == "NETCDF_DOUBLE" or in_type == "NETCDF_FLOAT": # Convert strings back to byte form and save the NetCDF file @@ -1051,22 +1053,22 @@ def swiftest_xr2infile(ds, param, in_type="NETCDF_DOUBLE", infile_name=None,fram for i in pl.id: pli = pl.sel(id=i) if param['RHILL_PRESENT'] == 'YES': - print(pli['name'].values, pli['Gmass'].values, pli['rhill'].values, file=plfile) + print(pli['name'].values[0], pli['Gmass'].values[0], pli['rhill'].values[0], file=plfile) else: - print(pli['name'].values, pli['Gmass'].values, file=plfile) + print(pli['name'].values[0], pli['Gmass'].values[0], file=plfile) if param['CHK_CLOSE'] == 'YES': - print(pli['radius'].values, file=plfile) + print(pli['radius'].values[0], file=plfile) if param['IN_FORM'] == 'XV': - print(pli['xhx'].values, pli['xhy'].values, pli['xhz'].values, file=plfile) - print(pli['vhx'].values, pli['vhy'].values, pli['vhz'].values, file=plfile) + print(pli['xhx'].values[0], pli['xhy'].values[0], pli['xhz'].values[0], file=plfile) + print(pli['vhx'].values[0], pli['vhy'].values[0], pli['vhz'].values[0], file=plfile) elif param['IN_FORM'] == 'EL': - print(pli['a'].values, pli['e'].values, pli['inc'].values, file=plfile) - print(pli['capom'].values, pli['omega'].values, pli['capm'].values, file=plfile) + print(pli['a'].values[0], pli['e'].values[0], pli['inc'].values[0], file=plfile) + print(pli['capom'].values[0], pli['omega'].values[0], pli['capm'].values[0], file=plfile) else: print(f"{param['IN_FORM']} is not a valid input format type.") if param['ROTATION'] == 'YES': - print(pli['Ip1'].values, pli['Ip2'].values, pli['Ip3'].values, file=plfile) - print(pli['rotx'].values, pli['roty'].values, pli['rotz'].values, file=plfile) + print(pli['Ip1'].values[0], pli['Ip2'].values[0], pli['Ip3'].values[0], file=plfile) + print(pli['rotx'].values[0], pli['roty'].values[0], pli['rotz'].values[0], file=plfile) plfile.close() # TP file @@ -1074,105 +1076,16 @@ def swiftest_xr2infile(ds, param, in_type="NETCDF_DOUBLE", infile_name=None,fram print(tp.id.count().values, file=tpfile) for i in tp.id: tpi = tp.sel(id=i) - print(tpi['name'].values, file=tpfile) + print(tpi['name'].values[0], file=tpfile) if param['IN_FORM'] == 'XV': - print(tpi['xhx'].values, tpi['xhy'].values, tpi['xhz'].values, file=tpfile) - print(tpi['vhx'].values, tpi['vhy'].values, tpi['vhz'].values, file=tpfile) + print(tpi['xhx'].values[0], tpi['xhy'].values[0], tpi['xhz'].values[0], file=tpfile) + print(tpi['vhx'].values[0], tpi['vhy'].values[0], tpi['vhz'].values[0], file=tpfile) elif param['IN_FORM'] == 'EL': - print(tpi['a'].values, tpi['e'].values, tpi['inc'].values, file=tpfile) - print(tpi['capom'].values, tpi['omega'].values, tpi['capm'].values, file=tpfile) + print(tpi['a'].values[0], tpi['e'].values[0], tpi['inc'].values[0], file=tpfile) + print(tpi['capom'].values[0], tpi['omega'].values[0], tpi['capm'].values[0], file=tpfile) else: print(f"{param['IN_FORM']} is not a valid input format type.") tpfile.close() - elif in_type == 'REAL8': - # Now make Swiftest files - cbfile = FortranFile(param['CB_IN'], 'w') - cbfile.write_record(cbid) - cbfile.write_record(np.double(GMSun)) - cbfile.write_record(np.double(RSun)) - cbfile.write_record(np.double(J2)) - cbfile.write_record(np.double(J4)) - if param['ROTATION'] == 'YES': - cbfile.write_record(np.double(Ip1cb)) - cbfile.write_record(np.double(Ip2cb)) - cbfile.write_record(np.double(Ip3cb)) - cbfile.write_record(np.double(rotxcb)) - cbfile.write_record(np.double(rotycb)) - cbfile.write_record(np.double(rotzcb)) - - cbfile.close() - - plfile = FortranFile(param['PL_IN'], 'w') - npl = pl.id.count().values - plid = pl.id.values - if param['IN_FORM'] == 'XV': - v1 = pl['xhx'].values - v2 = pl['xhy'].values - v3 = pl['xhz'].values - v4 = pl['vhx'].values - v5 = pl['vhy'].values - v6 = pl['vhz'].values - elif param['IN_FORM'] == 'EL': - v1 = pl['a'].values - v2 = pl['e'].values - v3 = pl['inc'].values - v4 = pl['capom'].values - v5 = pl['omega'].values - v6 = pl['capm'].values - else: - print(f"{param['IN_FORM']} is not a valid input format type.") - Gmass = pl['Gmass'].values - if param['CHK_CLOSE'] == 'YES': - radius = pl['radius'].values - - plfile.write_record(npl) - plfile.write_record(plid) - plfile.write_record(v1) - plfile.write_record(v2) - plfile.write_record(v3) - plfile.write_record(v4) - plfile.write_record(v5) - plfile.write_record(v6) - plfile.write_record(Gmass) - if param['RHILL_PRESENT'] == 'YES': - plfile.write_record(pl['rhill'].values) - if param['CHK_CLOSE'] == 'YES': - plfile.write_record(radius) - if param['ROTATION'] == 'YES': - plfile.write_record(pl['Ip1'].values) - plfile.write_record(pl['Ip2'].values) - plfile.write_record(pl['Ip3'].values) - plfile.write_record(pl['rotx'].values) - plfile.write_record(pl['roty'].values) - plfile.write_record(pl['rotz'].values) - plfile.close() - tpfile = FortranFile(param['TP_IN'], 'w') - ntp = tp.id.count().values - tpid = tp.id.values - if param['IN_FORM'] == 'XV': - v1 = tp['xhx'].values - v2 = tp['xhy'].values - v3 = tp['xhz'].values - v4 = tp['vhx'].values - v5 = tp['vhy'].values - v6 = tp['vhz'].values - elif param['IN_FORM'] == 'EL': - v1 = tp['a'].values - v2 = tp['e'].values - v3 = tp['inc'].values - v4 = tp['capom'].values - v5 = tp['omega'].values - v6 = tp['capm'].values - else: - print(f"{param['IN_FORM']} is not a valid input format type.") - tpfile.write_record(ntp) - tpfile.write_record(tpid) - tpfile.write_record(v1) - tpfile.write_record(v2) - tpfile.write_record(v3) - tpfile.write_record(v4) - tpfile.write_record(v5) - tpfile.write_record(v6) else: print(f"{in_type} is an unknown file type") diff --git a/python/swiftest/swiftest/simulation_class.py b/python/swiftest/swiftest/simulation_class.py index 153e92225..d1fa28ede 100644 --- a/python/swiftest/swiftest/simulation_class.py +++ b/python/swiftest/swiftest/simulation_class.py @@ -33,6 +33,9 @@ def __init__(self, codename="Swiftest", param_file="param.in", readbin=True, ver 'IN_FORM': "XV", 'IN_TYPE': "NETCDF_DOUBLE", 'NC_IN' : "init_cond.nc", + 'CB_IN' : "cb.in", + 'PL_IN' : "pl.in", + 'TP_IN' : "tp.in", 'ISTEP_OUT': "1", 'ISTEP_DUMP': "1", 'BIN_OUT': "bin.nc", @@ -341,7 +344,7 @@ def save(self, param_file, framenum=-1, codename="Swiftest"): """ if codename == "Swiftest": - io.swiftest_xr2infile(ds=self.ds, param=self.param, framenum=framenum,infile_name=self.param['NC_IN']) + io.swiftest_xr2infile(ds=self.ds, param=self.param, in_type=self.param['IN_TYPE'], framenum=framenum,infile_name=self.param['NC_IN']) self.write_param(param_file) elif codename == "Swifter": if self.codename == "Swiftest":