Source code for EXOSIMS.MissionSim

from EXOSIMS.util.vprint import vprint
from EXOSIMS.util.get_module import get_module
from EXOSIMS.util.waypoint import waypoint
from EXOSIMS.util.CheckScript import CheckScript
from EXOSIMS.util.keyword_fun import get_all_mod_kws, check_opticalsystem_kws
import logging
import json
import os.path
import tempfile
import numpy as np
import astropy.units as u
import copy
import inspect
import warnings
from typing import Dict, Optional, Any


[docs] class MissionSim(object): """Mission Simulation (backbone) class This class is responsible for instantiating all objects required to carry out a mission simulation. Args: scriptfile (string): Full path to JSON script file. If not set, assumes that dictionary has been passed through specs. nopar (bool): Ignore any provided ensemble module in the script or specs and force the prototype :py:class:`~EXOSIMS.Prototypes.SurveyEnsemble`. Defaults True verbose (bool): Input to :py:meth:`~EXOSIMS.util.vprint.vprint`, toggling verbosity of print statements. Defaults True. logfile (str of None): Path to the log file. If None, logging is turned off. If supplied but empty string (''), a temporary file is generated. loglevel (str): The level of log, defaults to 'INFO'. Valid levels are: CRITICAL, ERROR, WARNING, INFO, DEBUG (case sensitive). checkInputs (bool): Validate inputs against selected modules. Defaults True. **specs (dict): :ref:`sec:inputspec` Attributes: StarCatalog (StarCatalog module): StarCatalog class object PlanetPopulation (PlanetPopulation module): PlanetPopulation class object PlanetPhysicalModel (PlanetPhysicalModel module): PlanetPhysicalModel class object OpticalSystem (OpticalSystem module): OpticalSystem class object ZodiacalLight (ZodiacalLight module): ZodiacalLight class object BackgroundSources (BackgroundSources module): Background Source class object PostProcessing (PostProcessing module): PostProcessing class object Completeness (Completeness module): Completeness class object TargetList (TargetList module): TargetList class object SimulatedUniverse (SimulatedUniverse module): SimulatedUniverse class object Observatory (Observatory module): Observatory class object TimeKeeping (TimeKeeping module): TimeKeeping class object SurveySimulation (SurveySimulation module): SurveySimulation class object SurveyEnsemble (SurveyEnsemble module): SurveyEnsemble class object modules (dict): Dictionary of all modules, except StarCatalog verbose (bool): Boolean used to create the vprint function, equivalent to the python print function with an extra verbose toggle parameter (True by default). The vprint function can be accessed by all modules from EXOSIMS.util.vprint. seed (int): Number used to seed the NumPy generator. Generated randomly by default. logfile (str): Path to the log file. If None, logging is turned off. If supplied but empty string (''), a temporary file is generated. loglevel (str): The level of log, defaults to 'INFO'. Valid levels are: CRITICAL, ERROR, WARNING, INFO, DEBUG (case sensitive). """ _modtype = "MissionSim" _outspec = {} def __init__( self, scriptfile=None, nopar=False, verbose=True, logfile=None, loglevel="INFO", checkInputs=True, **specs, ): """Initializes all modules from a given script file or specs dictionary.""" # extend given specs with (JSON) script file if scriptfile is not None: assert os.path.isfile(scriptfile), "%s is not a file." % scriptfile try: with open(scriptfile, "r") as ff: script = ff.read() specs_from_file = json.loads(script) specs_from_file.update(specs) except ValueError as err: print( "Error: %s: Input file `%s' improperly formatted." % (self._modtype, scriptfile) ) print("Error: JSON error was: %s" % err) # re-raise here to suppress the rest of the backtrace. # it is only confusing details about the bowels of json.loads() raise ValueError(err) else: specs_from_file = {} specs.update(specs_from_file) if "modules" not in specs: raise ValueError("No modules field found in script.") # push all inputs into combined spec dict and save a copy before it gets # modified through module instantiations specs["verbose"] = bool(verbose) specs["logfile"] = logfile specs["loglevel"] = loglevel specs["nopar"] = bool(nopar) specs["checkInputs"] = bool(checkInputs) specs0 = copy.deepcopy(specs) # load the vprint function (same line in all prototype module constructors) self.verbose = specs["verbose"] self.vprint = vprint(self.verbose) # overwrite any ensemble setting if nopar is set self.nopar = specs["nopar"] if self.nopar: self.vprint("No-parallel: resetting SurveyEnsemble to Prototype") specs["modules"]["SurveyEnsemble"] = " " # start logging, with log file and logging level (default: INFO) self.logfile = specs.get("logfile", None) self.loglevel = specs.get("loglevel", "INFO").upper() specs["logger"] = self.get_logger(self.logfile, self.loglevel) specs["logger"].info( "Start Logging: loglevel = %s" % specs["logger"].level + " (%s)" % self.loglevel ) # populate outspec self.checkInputs = specs["checkInputs"] for att in self.__dict__: if att not in ["vprint"]: self._outspec[att] = self.__dict__[att] # create a surveysimulation object (triggering init of everything else) self.SurveySimulation = get_module( specs["modules"]["SurveySimulation"], "SurveySimulation" )(**specs) # collect sub-initializations SS = self.SurveySimulation self.StarCatalog = SS.StarCatalog self.PlanetPopulation = SS.PlanetPopulation self.PlanetPhysicalModel = SS.PlanetPhysicalModel self.OpticalSystem = SS.OpticalSystem self.ZodiacalLight = SS.ZodiacalLight self.BackgroundSources = SS.BackgroundSources self.PostProcessing = SS.PostProcessing self.Completeness = SS.Completeness self.TargetList = SS.TargetList self.SimulatedUniverse = SS.SimulatedUniverse self.Observatory = SS.Observatory self.TimeKeeping = SS.TimeKeeping # now that everything has successfully built, you can create the ensemble self.SurveyEnsemble = get_module( specs["modules"]["SurveyEnsemble"], "SurveyEnsemble" )(**copy.deepcopy(specs0)) # create a dictionary of all modules, except StarCatalog self.modules = SS.modules self.modules["SurveyEnsemble"] = self.SurveyEnsemble # alias SurveySimulation random seed to attribute for easier access self.seed = self.SurveySimulation.seed self.specs0 = specs0 # run keywords check if requested if self.checkInputs: self.check_ioscripts()
[docs] def check_ioscripts(self) -> None: """Collect all input and output scripts against selected module inits and report and discrepancies. """ # get a list of all modules in use mods = {} for modname in self.modules: mods[modname] = self.modules[modname].__class__ mods["MissionSim"] = self.__class__ mods["StarCatalog"] = self.TargetList.StarCatalog.__class__ # collect keywords allkws, allkwmods, ukws, ukwcounts = get_all_mod_kws(mods) self.vprint( ( "\nThe following keywords are used in multiple inits (this is ok):" "\n\t{}" ).format("\n\t".join(ukws[ukwcounts > 1])) ) # now let's compare against specs0 unused = list(set(self.specs0.keys()) - set(ukws)) if "modules" in unused: unused.remove("modules") if "seed" in unused: unused.remove("seed") if len(unused) > 0: warnstr = ( "\nThe following input keywords were not used in any " "module init:\n\t{}".format("\n\t".join(unused)) ) warnings.warn(warnstr) self.vprint( "\n{} keywords were set to their default values.".format( len(list(set(ukws) - set(self.specs0.keys()))) ) ) # check the optical system out = check_opticalsystem_kws(self.specs0, self.OpticalSystem) if out != "": warnings.warn(f"\n{out}") # and finally, let's look at the outspec outspec = self.genOutSpec(modnames=True) # these are extraneous things allowed to be in outspec: whitelist = ["modules", "Version", "seed", "nStars"] for w in whitelist: _ = outspec.pop(w, None) extraouts = list(set(outspec.keys()) - set(ukws)) if len(extraouts) > 0: warnstr = ( "\nThe following outspec keywords were not used in any " "module init:\n" ) for e in extraouts: warnstr += "\t{:>20} ({})\n".format(e, outspec[e]) warnings.warn(warnstr) missingouts = list(set(ukws) - set(outspec.keys())) if len(missingouts) > 0: allkws = np.array(allkws) allkwmods = np.array(allkwmods) warnstr = "\nThe following init keywords were not found in any outspec:\n" for m in missingouts: warnstr += "\t{:>20} ({})\n".format( m, ", ".join(allkwmods[allkws == m]) ) warnings.warn(warnstr)
[docs] def get_logger(self, logfile, loglevel): r"""Set up logging object so other modules can use logging.info(), logging.warning, etc. Args: logfile (string): Path to the log file. If None, logging is turned off. If supplied but empty string (''), a temporary file is generated. loglevel (string): The level of log, defaults to 'INFO'. Valid levels are: CRITICAL, ERROR, WARNING, INFO, DEBUG (case sensitive). Returns: logger (logging object): Mission Simulation logger. """ # this leaves the default logger in place, so logger.warn will appear on stderr if logfile is None: logger = logging.getLogger(__name__) return logger # if empty string, a temporary file is generated if logfile == "": dummy, logfile = tempfile.mkstemp( prefix="EXOSIMS.", suffix=".log", dir="/tmp", text=True ) else: # ensure we can write it try: with open(logfile, "w") as ff: # noqa: F841 pass except (IOError, OSError): print('%s: Failed to open logfile "%s"' % (__file__, logfile)) return None self.vprint("Logging to '%s' at level '%s'" % (logfile, loglevel.upper())) # convert string to a logging.* level numeric_level = getattr(logging, loglevel.upper()) if not isinstance(numeric_level, int): raise ValueError("Invalid log level: %s" % loglevel.upper()) # set up the top-level logger logger = logging.getLogger(__name__.split(".")[0]) logger.setLevel(numeric_level) # do not propagate EXOSIMS messages to higher loggers in this case logger.propagate = False # create a handler that outputs to the named file handler = logging.FileHandler(logfile, mode="w") handler.setLevel(numeric_level) # logging format formatter = logging.Formatter( "%(levelname)s: %(filename)s(%(lineno)s): " + "%(funcName)s: %(message)s" ) handler.setFormatter(formatter) # add the handler to the logger logger.addHandler(handler) return logger
[docs] def run_sim(self): """Convenience method that simply calls the SurveySimulation run_sim method.""" res = self.SurveySimulation.run_sim() return res
[docs] def reset_sim(self, genNewPlanets=True, rewindPlanets=True, seed=None): """ Convenience method that simply calls the SurveySimulation reset_sim method. """ res = self.SurveySimulation.reset_sim( genNewPlanets=genNewPlanets, rewindPlanets=rewindPlanets, seed=seed ) self.modules = self.SurveySimulation.modules self.modules["SurveyEnsemble"] = self.SurveyEnsemble # replace SurveyEnsemble return res
[docs] def run_ensemble( self, nb_run_sim, run_one=None, genNewPlanets=True, rewindPlanets=True, kwargs={}, ): """ Convenience method that simply calls the SurveyEnsemble run_ensemble method. """ res = self.SurveyEnsemble.run_ensemble( self, nb_run_sim, run_one=run_one, genNewPlanets=genNewPlanets, rewindPlanets=rewindPlanets, kwargs=kwargs, ) return res
[docs] def genOutSpec( self, tofile: Optional[str] = None, modnames: bool = False, ) -> Dict[str, Any]: """Join all _outspec dicts from all modules into one output dict and optionally write out to JSON file on disk. Args: tofile (str): Name of the file containing all output specifications (outspecs). Defaults to None. modnames (bool): If True, populate outspec dictionary with the module it originated from, instead of the actual value of the keyword. Defaults False. Returns: dict: Dictionary containing the full :ref:`sec:inputspec`, including all filled-in default values. Combination of all individual module _outspec attributes. """ starting_outspec = copy.copy(self._outspec) if modnames: for k in starting_outspec: starting_outspec[k] = "MissionSim" out = self.SurveySimulation.genOutSpec( starting_outspec=starting_outspec, tofile=tofile, modnames=modnames ) return out
[docs] def genWaypoint(self, targetlist=None, duration=365, tofile=None, charmode=False): """generates a ballpark estimate of the expected number of star visits and the total completeness of these visits for a given mission duration Args: targetlist (list, optional): List of target indices duration (int): The length of time allowed for the waypoint calculation, defaults to 365 tofile (str): Name of the file containing a plot of total completeness over mission time, by default genWaypoint does not create this plot charmode (bool): Run the waypoint calculation using either the char mode instead of the det mode Returns: dict: Output dictionary containing the number of stars visited, the total completeness achieved, and the amount of time spent integrating. """ SS = self.SurveySimulation OS = SS.OpticalSystem ZL = SS.ZodiacalLight Comp = SS.Completeness TL = SS.TargetList Obs = SS.Observatory TK = SS.TimeKeeping # Only considering detections allModes = OS.observingModes if charmode: int_mode = list( filter(lambda mode: "spec" in mode["inst"]["name"], allModes) )[0] else: int_mode = list(filter(lambda mode: mode["detectionMode"], allModes))[0] mpath = os.path.split(inspect.getfile(self.__class__))[0] if targetlist is not None: num_stars = len(targetlist) sInds = np.array(targetlist) else: num_stars = TL.nStars sInds = np.arange(TL.nStars) startTimes = TK.currentTimeAbs + np.zeros(num_stars) * u.d fZ = ZL.fZ(Obs, TL, sInds, startTimes, int_mode) JEZ = TL.JEZ0[int_mode["hex"]] dMag = TL.int_dMag[sInds] WA = TL.int_WA[sInds] # sort star indices by completeness diveded by integration time intTimes = OS.calc_intTime(TL, sInds, fZ, JEZ, dMag, WA, int_mode) comps = Comp.comp_per_intTime(intTimes, TL, sInds, fZ, JEZ, WA[0], int_mode) wp = waypoint(comps, intTimes, duration, mpath, tofile) return wp
[docs] def checkScript(self, scriptfile, prettyprint=False, tofile=None): """Calls CheckScript and checks the script file against the mission outspec. Args: scriptfile (str): The path to the scriptfile being used by the sim prettyprint (bool): Outputs the results of Checkscript in a readable format. tofile (str): Name of the file containing all output specifications (outspecs). Default to None. Returns: str: Output string containing the results of the check. """ if scriptfile is not None: cs = CheckScript(scriptfile, self.genOutSpec()) out = cs.recurse(cs.specs_from_file, cs.outspec, pretty_print=prettyprint) if tofile is not None: mpath = os.path.split(inspect.getfile(self.__class__))[0] cs.write_file(os.path.join(mpath, tofile)) else: out = None return out
[docs] def DRM2array(self, key, DRM=None): """Creates an array corresponding to one element of the DRM dictionary. Args: key (str): Name of an element of the DRM dictionary DRM (list(dict)): Design Reference Mission, contains the results of a survey simulation Returns: ~numpy.ndarray or ~astropy.units.Quantity(~numpy.ndarray): Array containing all the DRM values of the selected element """ # if the DRM was not specified, get it from the current SurveySimulation if DRM is None: DRM = self.SurveySimulation.DRM assert DRM != [], "DRM is empty. Use MissionSim.run_sim() to start simulation." # lists of relevant DRM elements keysStar = [ "star_ind", "star_name", "arrival_time", "OB_nb", "det_time", "det_fZ", "char_time", "char_fZ", ] keysPlans = ["plan_inds", "det_status", "det_SNR", "char_status", "char_SNR"] keysParams = [ "det_JEZ", "det_dMag", "det_WA", "det_d", "char_JEZ", "char_dMag", "char_WA", "char_d", ] keysFA = [ "FA_det_status", "FA_char_status", "FA_char_SNR", "FA_char_JEZ", "FA_char_dMag", "FA_char_WA", ] keysOcculter = [ "slew_time", "slew_dV", "det_dF_lateral", "scMass", "slewMass", "skMass", "char_dF_axial", "det_mass_used", "slew_mass_used", "det_dF_axial", "det_dV", "slew_angle", "char_dF_lateral", ] assert key in ( keysStar + keysPlans + keysParams + keysFA + keysOcculter ), "'%s' is not a relevant DRM keyword." # extract arrays for each relevant keyword in the DRM if key in keysParams: if "det_" in key: elem = [DRM[x]["det_params"][key[4:]] for x in range(len(DRM))] elif "char_" in key: elem = [DRM[x]["char_params"][key[5:]] for x in range(len(DRM))] elif isinstance(DRM[0][key], u.Quantity): elem = ([DRM[x][key].value for x in range(len(DRM))]) * DRM[0][key].unit else: elem = [DRM[x][key] for x in range(len(DRM))] try: elem = np.array(elem) except ValueError: elem = np.array(elem, dtype=object) return elem
[docs] def filter_status(self, key, status, DRM=None, obsMode=None): """Finds the values of one DRM element, corresponding to a status value, for detection or characterization. Args: key (string): Name of an element of the DRM dictionary status (integer): Status value for detection or characterization DRM (list of dicts): Design Reference Mission, contains the results of a survey simulation obsMode (string): Observing mode type ('det' or 'char') Returns: elemStat (ndarray / astropy Quantity array): Array containing all the DRM values of the selected element, and filtered by the value of the corresponding status array """ # get DRM detection status array det = ( self.DRM2array("FA_det_status", DRM=DRM) if "FA_" in key else self.DRM2array("det_status", DRM=DRM) ) # get DRM characterization status array char = ( self.DRM2array("FA_char_status", DRM=DRM) if "FA_" in key else self.DRM2array("char_status", DRM=DRM) ) # get DRM key element array elem = self.DRM2array(key, DRM=DRM) # reshape elem array, for keys with 1 value per observation if elem[0].shape == () and "FA_" not in key: if isinstance(elem[0], u.Quantity): elem = np.array( [ np.array([elem[x].value] * len(det[x])) * elem[0].unit for x in range(len(elem)) ] ) else: elem = np.array( [np.array([elem[x]] * len(det[x])) for x in range(len(elem))] ) # assign a default observing mode type ('det' or 'char') if obsMode is None: obsMode = "char" if "char_" in key else "det" assert obsMode in ( "det", "char", ), "Observing mode type must be 'det' or 'char'." # now, find the values of elem corresponding to the specified status value if obsMode == "det": if isinstance(elem[0], u.Quantity): elemStat = ( np.concatenate( [elem[x][det[x] == status].value for x in range(len(elem))] ) * elem[0].unit ) else: elemStat = np.concatenate( [elem[x][det[x] == status] for x in range(len(elem))] ) else: # if obsMode is 'char' if isinstance(elem[0], u.Quantity): elemDet = ( np.concatenate( [elem[x][det[x] == 1].value for x in range(len(elem))] ) * elem[0].unit ) else: elemDet = np.concatenate( [elem[x][det[x] == 1] for x in range(len(elem))] ) charDet = np.concatenate([char[x][det[x] == 1] for x in range(len(elem))]) elemStat = elemDet[charDet == status] return elemStat