Skip to content
This repository was archived by the owner on Aug 28, 2024. It is now read-only.

Commit

Permalink
Merge branch 'Simulation_API_improvements' into debug
Browse files Browse the repository at this point in the history
  • Loading branch information
daminton committed Nov 15, 2022
2 parents 4b9d337 + dad3e36 commit d823790
Show file tree
Hide file tree
Showing 4 changed files with 59 additions and 19 deletions.
1 change: 1 addition & 0 deletions python/swiftest/swiftest/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,6 +427,7 @@ def write_labeled_param(param, param_file_name):
'TU2S',
'DU2M',
'GMTINY',
'FRAGMENTATION'
'MIN_GMFRAG',
'RESTART']
ptmp = param.copy()
Expand Down
41 changes: 31 additions & 10 deletions python/swiftest/swiftest/simulation_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import numpy.typing as npt
import shutil
import subprocess
import shlex
from typing import (
Literal,
Dict,
Expand Down Expand Up @@ -302,6 +303,9 @@ def __init__(self,read_param: bool = True, **kwargs: Any):
if os.path.exists(self.param_file):
self.read_param(self.param_file, codename=self.codename, verbose=self.verbose)
param_file_found = True
# We will add the parameter file to the kwarg list. This will keep the set_parameter method from
# overriding everything with defaults when there are no arguments passed to Simulation()
kwargs['param_file'] = self.param_file
else:
param_file_found = False

Expand Down Expand Up @@ -358,9 +362,26 @@ def run(self,**kwargs):

print(f"Running a {self.codename} {self.integrator} run from tstart={self.param['TSTART']} {self.TU_name} to tstop={self.param['TSTOP']} {self.TU_name}")

with subprocess.Popen([self.driver_executable, self.integrator, self.param_file], stdout=subprocess.PIPE, bufsize=1,universal_newlines=True) as p:
# Get current environment variables
env = os.environ.copy()

try:
cmd = f"{self.driver_executable} {self.integrator} {self.param_file}"
p = subprocess.Popen(shlex.split(cmd),
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
env=env,
universal_newlines=True)
for line in p.stdout:
print(line, end='')
res = p.communicate()
if p.returncode != 0:
for line in res[1]:
print(line, end='')
raise Exception ("Failure in swiftest_driver")
except:
print(f"Error executing main swiftest_driver program")

return

def _get_valid_arg_list(self, arg_list: str | List[str] | None = None, valid_var: Dict | None = None):
Expand Down Expand Up @@ -2306,17 +2327,17 @@ def get_nvals(ds):

def read_param(self, param_file, codename="Swiftest", verbose=True):
"""
Reads in a param.in file and determines whether it is a Swift/Swifter/Swiftest parameter file.
Reads in an input parameter file and stores the values in the param dictionary.
Parameters
----------
param_file : string
File name of the input parameter file
codename : string
Type of parameter file, either "Swift", "Swifter", or "Swiftest"
param_file : string
File name of the input parameter file
codename : string
Type of parameter file, either "Swift", "Swifter", or "Swiftest"
Returns
-------
self.ds : xarray dataset
"""
if codename == "Swiftest":
param_old = self.param.copy()
Expand All @@ -2335,7 +2356,7 @@ def read_param(self, param_file, codename="Swiftest", verbose=True):

def write_param(self,
codename: Literal["Swiftest", "Swifter", "Swift"] | None = None,
param_file: str | PathLike | None = None,
param_file: str | os.PathLike | None = None,
param: Dict | None = None,
**kwargs: Any):
"""
Expand All @@ -2345,7 +2366,7 @@ def write_param(self,
----------
codename : {"Swiftest", "Swifter", "Swift"}, optional
Alternative name of the n-body code that the parameter file will be formatted for. Defaults to current instance
variable self.codename
variable codename
param_file : str or path-like, optional
Alternative file name of the input parameter file. Defaults to current instance variable self.param_file
param: Dict, optional
Expand Down Expand Up @@ -2506,7 +2527,7 @@ def follow(self, codestyle="Swifter"):

def save(self,
codename: Literal["Swiftest", "Swifter", "Swift"] | None = None,
param_file: str | PathLike | None = None,
param_file: str | os.PathLike | None = None,
param: Dict | None = None,
framenum: int = -1,
**kwargs: Any):
Expand Down
2 changes: 1 addition & 1 deletion src/modules/swiftest_globals.f90
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ module swiftest_globals
character(*), parameter :: NETCDF_OUTFILE = 'bin.nc' !! Default output file name
character(*), parameter :: TIME_DIMNAME = "time" !! NetCDF name of the time dimension
character(*), parameter :: ID_DIMNAME = "id" !! NetCDF name of the particle id dimension
character(*), parameter :: STR_DIMNAME = "str" !! NetCDF name of the particle id dimension
character(*), parameter :: STR_DIMNAME = "string32" !! NetCDF name of the character string dimension
character(*), parameter :: PTYPE_VARNAME = "particle_type" !! NetCDF name of the particle type variable
character(*), parameter :: NAME_VARNAME = "name" !! NetCDF name of the particle name variable
character(*), parameter :: NPL_VARNAME = "npl" !! NetCDF name of the number of active massive bodies variable
Expand Down
34 changes: 26 additions & 8 deletions src/netcdf/netcdf.f90
Original file line number Diff line number Diff line change
Expand Up @@ -198,8 +198,8 @@ module subroutine netcdf_initialize_output(self, param)

! Define the NetCDF dimensions with particle name as the record dimension
call check( nf90_def_dim(self%ncid, ID_DIMNAME, NF90_UNLIMITED, self%id_dimid), "netcdf_initialize_output nf90_def_dim id_dimid" ) ! 'x' dimension
call check( nf90_def_dim(self%ncid, TIME_DIMNAME, NF90_UNLIMITED, self%time_dimid), "netcdf_initialize_output nf90_def_dim time_dimid" ) ! 'y' dimension
call check( nf90_def_dim(self%ncid, STR_DIMNAME, NAMELEN, self%str_dimid), "netcdf_initialize_output nf90_def_dim str_dimid" ) ! Dimension for string variables (aka character arrays)
call check( nf90_def_dim(self%ncid, TIME_DIMNAME, NF90_UNLIMITED, self%time_dimid), "netcdf_initialize_output nf90_def_dim time_dimid" ) ! 'y' dimension

select case (param%out_type)
case(NETCDF_FLOAT_TYPE)
Expand Down Expand Up @@ -361,7 +361,14 @@ module subroutine netcdf_open(self, param, readonly)

call check( nf90_inq_dimid(self%ncid, TIME_DIMNAME, self%time_dimid), "netcdf_open nf90_inq_dimid time_dimid" )
call check( nf90_inq_dimid(self%ncid, ID_DIMNAME, self%id_dimid), "netcdf_open nf90_inq_dimid id_dimid" )
call check( nf90_inquire_dimension(self%ncid, max(self%time_dimid,self%id_dimid)+1, name=str_dim_name), "netcdf_open nf90_inquire_dimension str_dim_name" )
if (max(self%time_dimid,self%id_dimid) == 2) then
self%str_dimid = 3
else if (min(self%time_dimid,self%id_dimid) == 0) then
self%str_dimid = 1
else
self%str_dimid = 2
end if
call check( nf90_inquire_dimension(self%ncid, self%str_dimid, name=str_dim_name), "netcdf_open nf90_inquire_dimension str_dim_name" )
call check( nf90_inq_dimid(self%ncid, str_dim_name, self%str_dimid), "netcdf_open nf90_inq_dimid str_dimid" )

! Required Variables
Expand All @@ -370,7 +377,6 @@ module subroutine netcdf_open(self, param, readonly)
call check( nf90_inq_varid(self%ncid, ID_DIMNAME, self%id_varid), "netcdf_open nf90_inq_varid id_varid" )
call check( nf90_inq_varid(self%ncid, NAME_VARNAME, self%name_varid), "netcdf_open nf90_inq_varid name_varid" )
call check( nf90_inq_varid(self%ncid, PTYPE_VARNAME, self%ptype_varid), "netcdf_open nf90_inq_varid ptype_varid" )
call check( nf90_inq_varid(self%ncid, STATUS_VARNAME, self%status_varid), "netcdf_open nf90_inq_varid status_varid" )
call check( nf90_inq_varid(self%ncid, GMASS_VARNAME, self%Gmass_varid), "netcdf_open nf90_inq_varid Gmass_varid" )

if ((param%out_form == XV) .or. (param%out_form == XVEL)) then
Expand Down Expand Up @@ -448,6 +454,7 @@ module subroutine netcdf_open(self, param, readonly)

! Variables The User Doesn't Need to Know About

status = nf90_inq_varid(self%ncid, STATUS_VARNAME, self%status_varid)
status = nf90_inq_varid(self%ncid, J2RP2_VARNAME, self%j2rp2_varid)
status = nf90_inq_varid(self%ncid, J4RP4_VARNAME, self%j4rp4_varid)

Expand Down Expand Up @@ -562,7 +569,7 @@ module function netcdf_read_frame_system(self, iu, param) result(ierr)
class is (symba_pl)
select type (param)
class is (symba_parameters)
nplm_check = count(rtemp(:) > param%GMTINY .and. plmask(:))
nplm_check = count(pack(rtemp,plmask) > param%GMTINY )
if (nplm_check /= pl%nplm) then
write(*,*) "Error reading in NetCDF file: The recorded value of nplm does not match the number of active fully interacting massive bodies"
call util_exit(failure)
Expand Down Expand Up @@ -774,17 +781,23 @@ module subroutine netcdf_read_hdr_system(self, iu, param)
tslot = int(param%ioutput, kind=I4B) + 1
call check( nf90_inquire_dimension(iu%ncid, iu%id_dimid, len=idmax), "netcdf_read_frame_system nf90_inquire_dimension id_dimid" )
call check( nf90_get_var(iu%ncid, iu%time_varid, param%t, start=[tslot]), "netcdf_read_hdr_system nf90_getvar time_varid" )
call check( nf90_get_var(iu%ncid, iu%Gmass_varid, gmtemp, start=[1,1]), "netcdf_read_frame_system nf90_getvar Gmass_varid" )

allocate(gmtemp(idmax))
allocate(tpmask(idmax))
allocate(plmask(idmax))
allocate(plmmask(idmax))

call check( nf90_get_var(iu%ncid, iu%Gmass_varid, gmtemp, start=[1,1]), "netcdf_read_frame_system nf90_getvar Gmass_varid" )

plmask(:) = gmtemp(:) == gmtemp(:)
tpmask(:) = .not. plmask(:)
plmask(1) = .false. ! This is the central body
select type (param)
class is (symba_parameters)
plmmask(:) = gmtemp(:) > param%GMTINY .and. plmask(:)
plmmask(:) = plmask(:)
where(plmask(:))
plmmask(:) = gmtemp(:) > param%GMTINY
endwhere
end select

status = nf90_inq_varid(iu%ncid, NPL_VARNAME, iu%npl_varid)
Expand Down Expand Up @@ -920,8 +933,13 @@ module subroutine netcdf_read_particle_info_system(self, iu, param, plmask, tpma
call tp%info(i)%set_value(particle_type=ctemp(tpind(i)))
end do

call check( nf90_get_var(iu%ncid, iu%status_varid, ctemp, count=[NAMELEN, idmax]), "netcdf_read_particle_info_system nf90_getvar status_varid" )
call cb%info%set_value(status=ctemp(1))
status = nf90_inq_varid(iu%ncid, STATUS_VARNAME, iu%status_varid)
if (status == nf90_noerr) then
call check( nf90_get_var(iu%ncid, iu%status_varid, ctemp, count=[NAMELEN, idmax]), "netcdf_read_particle_info_system nf90_getvar status_varid")
call cb%info%set_value(status=ctemp(1))
else
call cb%info%set_value(status="ACTIVE")
end if
do i = 1, npl
call pl%info(i)%set_value(status=ctemp(plind(i)))
end do
Expand Down

0 comments on commit d823790

Please sign in to comment.