diff --git a/pyemu/utils/os_utils.py b/pyemu/utils/os_utils.py index 4d98ec3a..b5dc1974 100644 --- a/pyemu/utils/os_utils.py +++ b/pyemu/utils/os_utils.py @@ -409,6 +409,7 @@ def start_workers( args = (os.path.join(worker_dir,pst_rel_path),hostname,port) for i in range(num_workers): p = mp.Process(target=ppw_function,args=args,kwargs=ppw_kwargs) + p.daemon = True p.start() procs.append(p) @@ -674,7 +675,7 @@ def __init__(self, pst, host, port, timeout=0.1,verbose=True): self.obs_names = None self.par_values = None - + self.max_reconnect_attempts = 10 self._process_pst() self.connect() self._lock = threading.Lock() @@ -682,6 +683,7 @@ def __init__(self, pst, host, port, timeout=0.1,verbose=True): self._listen_thread = threading.Thread(target=self.listen,args=(self._lock,self._send_lock)) self._listen_thread.start() + def _process_pst(self): if isinstance(self._pst_arg,str): self._pst = pst_handler.Pst(self._pst_arg) @@ -692,7 +694,7 @@ def _process_pst(self): format(type(self._pst_arg))) - def connect(self): + def connect(self,is_reconnect=False): self.message("trying to connect to {0}:{1}...".format(self.host,self.port)) self.s = None c = 0 @@ -703,6 +705,10 @@ def connect(self): c += 1 if c % 75 == 0: print('') + print(c) + if is_reconnect and c > self.max_reconnect_attempts: + print("max reconnect attempts reached...") + return False self.s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) self.s.connect((self.host, self.port)) self.message("connected to {0}:{1}".format(self.host,self.port)) @@ -712,6 +718,10 @@ def connect(self): continue except Exception as e: continue + + self.net_pack = NetPack(timeout=self.timeout,verbose=self.verbose) + return True + def message(self,msg): if self.verbose: @@ -726,18 +736,31 @@ def recv(self,dtype=None): def send(self,mtype,group,runid,desc="",data=0): - self.net_pack.send(self.s,mtype,group,runid,desc,data) + try: + self.net_pack.send(self.s,mtype,group,runid,desc,data) + except Exception as e: + print("WARNING: error sending message:{0}".format(str(e))) + return False self.message("sent message type:{0}".format(NetPack.netpack_type[mtype])) + return True def listen(self,lock=None,send_lock=None): self.s.settimeout(self.timeout) + failed_reconnect = False while True: time.sleep(self.timeout) try: n = self.recv() except Exception as e: print("WARNING: recv exception:"+str(e)+"...trying to reconnect...") - self.connect() + success = self.connect(is_reconnect=True) + if not success: + print("...exiting") + time.sleep(self.timeout) + return + else: + print("...reconnect successfully...") + continue if n > 0: # need to sync here @@ -776,20 +799,40 @@ def listen(self,lock=None,send_lock=None): elif self.net_pack.mtype == 6: if self._send_lock is not None: self._send_lock.acquire() - self.send(7, self.net_pack.group, + success = self.send(7, self.net_pack.group, self.net_pack.runid, "fake linpack result", data=1) if self._send_lock is not None: self._send_lock.release() + if not success: + print("...trying to reconnect...") + success = self.connect(is_reconnect=True) + if not success: + print("...exiting") + time.sleep(self.timeout) + return + else: + print("reconnect successfully...") + continue elif self.net_pack.mtype == 15: if self._send_lock is not None: self._send_lock.acquire() - self.send(15, self.net_pack.group, + sucess = self.send(15, self.net_pack.group, self.net_pack.runid, "ping back") if self._send_lock is not None: self._send_lock.release() + if not success: + print("...trying to reconnect...") + success = self.connect(is_reconnect=True) + if not success: + print("...exiting") + time.sleep(self.timeout) + return + else: + print("reconnect successfully...") + continue elif self.net_pack.mtype == 14: #print("recv'd terminate signal") self.message("recv'd terminate signal") @@ -819,6 +862,7 @@ def get_parameters(self): raise Exception("len(par vals) {0} != len(par names)".format(len(pars),len(self.par_names))) return pd.Series(data=pars,index=self.par_names) + def send_observations(self,obsvals,parvals=None,request_more_pars=True): if len(obsvals) != len(self.obs_names): raise Exception("len(obs vals) {0} != len(obs names)".format(len(obsvals), len(self.obs_names))) @@ -862,11 +906,13 @@ def send_observations(self,obsvals,parvals=None,request_more_pars=True): self.send(3,0,0,"ready for next run",data=0) self._send_lock.release() + def request_more_pars(self): self._send_lock.acquire() self.send(3, 0, 0, "ready for next run", data=0.0) self._send_lock.release() + def send_failed_run(self,group=None,runid=None,desc="failed"): if group is None: group = self.net_pack.group @@ -876,6 +922,7 @@ def send_failed_run(self,group=None,runid=None,desc="failed"): self.send(12, int(group), int(runid), desc, data=0.0) self._send_lock.release() + def send_killed_run(self,group=None,runid=None,desc="killed"): if group is None: group = self.net_pack.group