Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1

2

3

4

5

6

7

8

9

10

11

12

13

14

15

16

17

18

19

20

21

22

23

24

25

26

27

28

29

30

31

32

33

34

35

36

37

38

39

40

41

42

43

44

45

46

47

48

49

50

51

52

53

54

55

56

57

58

59

60

61

62

63

64

65

66

67

68

69

70

71

72

73

74

75

76

77

78

79

80

81

82

83

84

85

86

87

88

89

90

91

92

93

94

95

96

97

98

99

100

101

102

103

104

105

106

107

108

109

110

111

112

113

114

115

116

117

118

119

120

121

122

123

124

125

126

127

128

129

130

131

132

133

134

135

136

137

138

139

140

141

142

143

144

145

146

147

148

149

150

151

152

153

154

155

156

157

158

159

160

161

162

163

164

165

166

167

168

169

170

171

172

173

174

175

176

177

178

179

180

181

182

183

184

185

186

187

188

189

190

191

192

193

194

195

196

197

198

199

200

201

202

203

204

205

206

207

208

209

210

211

212

213

214

215

216

217

218

219

220

221

222

223

224

225

226

227

228

229

230

231

232

233

234

235

236

237

238

239

240

241

242

243

244

245

246

247

248

249

250

251

252

253

254

255

256

257

258

259

260

261

262

263

264

265

266

267

268

269

270

271

272

273

274

275

276

277

278

279

280

281

282

283

284

285

286

287

288

289

290

291

292

293

294

295

296

297

298

299

300

301

302

303

304

305

306

307

308

309

310

311

312

313

314

315

316

317

318

319

320

321

322

323

324

325

326

327

328

329

330

331

332

333

334

335

336

337

338

339

340

341

342

343

344

345

346

347

348

349

350

351

352

353

354

355

356

357

358

359

360

361

362

363

364

365

366

367

368

369

370

371

372

373

374

375

376

377

378

379

380

381

382

383

384

385

386

387

388

389

390

391

392

393

394

395

396

397

""" 

libEnsemble manager routines 

==================================================== 

""" 

 

from __future__ import division 

from __future__ import absolute_import 

 

import time 

import sys 

import os 

import logging 

import socket 

import pickle 

 

from mpi4py import MPI 

import numpy as np 

 

from libensemble.message_numbers import \ 

EVAL_SIM_TAG, FINISHED_PERSISTENT_SIM_TAG, \ 

EVAL_GEN_TAG, FINISHED_PERSISTENT_GEN_TAG, \ 

STOP_TAG, UNSET_TAG, \ 

WORKER_KILL, WORKER_KILL_ON_ERR, WORKER_KILL_ON_TIMEOUT, \ 

JOB_FAILED, WORKER_DONE, \ 

MAN_SIGNAL_FINISH, MAN_SIGNAL_KILL, \ 

MAN_SIGNAL_REQ_RESEND, MAN_SIGNAL_REQ_PICKLE_DUMP, \ 

ABORT_ENSEMBLE 

 

logger = logging.getLogger(__name__) 

#For debug messages - uncomment 

# logger.setLevel(logging.DEBUG) 

 

class ManagerException(Exception): pass 

 

 

def manager_main(hist, libE_specs, alloc_specs, 

sim_specs, gen_specs, exit_criteria, persis_info): 

"""Manager routine to coordinate the generation and simulation evaluations 

""" 

mgr = Manager(hist, libE_specs, alloc_specs, 

sim_specs, gen_specs, exit_criteria) 

return mgr.run(persis_info) 

 

 

def get_stopwatch(): 

"Return an elapsed time function, starting now" 

start_time = time.time() 

def elapsed(): 

"Return time elapsed since start." 

return time.time()-start_time 

return elapsed 

 

 

def filter_nans(array): 

"Filter out NaNs from a numpy array." 

return array[~np.isnan(array)] 

 

 

class Manager: 

"""Manager class for libensemble.""" 

 

worker_dtype = [('worker_id', int), 

('active', int), 

('persis_state', int), 

('blocked', bool)] 

 

def __init__(self, hist, libE_specs, alloc_specs, 

sim_specs, gen_specs, exit_criteria): 

"""Initialize the manager.""" 

self.hist = hist 

self.libE_specs = libE_specs 

self.alloc_specs = alloc_specs 

self.sim_specs = sim_specs 

self.gen_specs = gen_specs 

self.exit_criteria = exit_criteria 

self.elapsed = get_stopwatch() 

self.comm = libE_specs['comm'] 

self.W = self._make_worker_pool(self.comm) 

self.term_tests = \ 

[(2, 'elapsed_wallclock_time', self.term_test_wallclock), 

(1, 'sim_max', self.term_test_sim_max), 

(1, 'gen_max', self.term_test_gen_max), 

(1, 'stop_val', self.term_test_stop_val)] 

 

@staticmethod 

def _make_worker_pool(comm): 

"""Set up an array of worker states.""" 

num_workers = comm.Get_size()-1 

W = np.zeros(num_workers, dtype=Manager.worker_dtype) 

W['worker_id'] = np.arange(num_workers) + 1 

return W 

 

# --- Termination logic routines 

 

def term_test_wallclock(self, max_elapsed): 

"""Check against wallclock timeout""" 

return self.elapsed() >= max_elapsed 

 

def term_test_sim_max(self, sim_max): 

"""Check against max simulations""" 

return self.hist.given_count >= sim_max + self.hist.offset 

 

def term_test_gen_max(self, gen_max): 

"""Check against max generator calls.""" 

return self.hist.index >= gen_max + self.hist.offset 

 

def term_test_stop_val(self, stop_val): 

"""Check against stop value criterion.""" 

key, val = stop_val 

H = self.hist.H 

idx = self.hist.index 

return np.any(filter_nans(H[key][:idx]) <= val) 

 

def term_test(self, logged=True): 

"""Check termination criteria""" 

for retval, key, testf in self.term_tests: 

if key in self.exit_criteria: 

if testf(self.exit_criteria[key]): 

if logged: 

logger.info("Term test tripped: {}".format(key)) 

return retval 

return 0 

 

# --- Low-level communication routines (use MPI directly) 

 

def Iprobe(self, w, status=None): 

"Check whether there is a message from a worker." 

return self.comm.Iprobe(source=w, tag=MPI.ANY_TAG, status=status) 

 

def recv(self, w, status=None): 

"Receive from a worker." 

return self.comm.recv(source=w, tag=MPI.ANY_TAG, status=status) 

 

def send(self, obj, w, tag=0): 

"Send to a worker." 

return self.comm.send(obj=obj, dest=w, tag=tag) 

 

def _send_dtypes_to_workers(self): 

"Broadcast sim_spec/gen_spec input dtypes to workers." 

self.comm.bcast(obj=self.hist.H[self.sim_specs['in']].dtype) 

self.comm.bcast(obj=self.hist.H[self.gen_specs['in']].dtype) 

 

def _kill_workers(self): 

"""Kill the workers""" 

for w in self.W['worker_id']: 

self.send(MAN_SIGNAL_FINISH, w, tag=STOP_TAG) 

 

def _man_request_resend_on_error(self, w): 

"Request the worker resend data on error." 

self.send(MAN_SIGNAL_REQ_RESEND, w, tag=STOP_TAG) 

return self.recv(w) 

 

def _man_request_pkl_dump_on_error(self, w): 

"Request the worker dump a pickle on error." 

self.send(MAN_SIGNAL_REQ_PICKLE_DUMP, w, tag=STOP_TAG) 

pkl_recv = self.recv(w) 

D_recv = pickle.load(open(pkl_recv, "rb")) 

os.remove(pkl_recv) #If want to delete file 

return D_recv 

 

# --- Checkpointing logic 

 

def _save_every_k(self, fname, count, k): 

"Save history every kth step." 

count = k*(count//k) 

filename = fname.format(count) 

if not os.path.isfile(filename) and count > 0: 

np.save(filename, self.hist.H) 

 

def _save_every_k_sims(self): 

"Save history every kth sim step." 

self._save_every_k('libE_history_after_sim_{}.npy', 

self.hist.sim_count, 

self.sim_specs['save_every_k']) 

 

def _save_every_k_gens(self): 

"Save history every kth gen step." 

self._save_every_k('libE_history_after_gen_{}.npy', 

self.hist.index, 

self.gen_specs['save_every_k']) 

 

# --- Handle outgoing messages to workers (work orders from alloc) 

 

def _check_work_order(self, Work, w): 

"""Check validity of an allocation function order. 

""" 

assert w != 0, "Can't send to worker 0; this is the manager. Aborting" 

assert self.W[w-1]['active'] == 0, \ 

"Allocation function requested work to an already active worker. Aborting" 

work_rows = Work['libE_info']['H_rows'] 

if len(work_rows): 

work_fields = set(Work['H_fields']) 

hist_fields = self.hist.H.dtype.names 

diff_fields = list(work_fields.difference(hist_fields)) 

assert not diff_fields, \ 

"Allocation function requested invalid fields {}" \ 

"be sent to worker={}.".format(diff_fields, w) 

 

def _send_work_order(self, Work, w): 

"""Send an allocation function order to a worker. 

""" 

logger.debug("Manager sending work unit to worker {}".format(w)) 

self.send(Work, w, tag=Work['tag']) 

work_rows = Work['libE_info']['H_rows'] 

if len(work_rows): 

self.send(self.hist.H[Work['H_fields']][work_rows], w) 

 

def _update_state_on_alloc(self, Work, w): 

"""Update worker active/idle status following an allocation order.""" 

 

self.W[w-1]['active'] = Work['tag'] 

if 'libE_info' in Work and 'persistent' in Work['libE_info']: 

self.W[w-1]['persis_state'] = Work['tag'] 

 

if 'blocking' in Work['libE_info']: 

for w_i in Work['libE_info']['blocking']: 

assert self.W[w_i-1]['active'] == 0, \ 

"Active worker being blocked; aborting" 

self.W[w_i-1]['blocked'] = 1 

self.W[w_i-1]['active'] = 1 

 

if Work['tag'] == EVAL_SIM_TAG: 

work_rows = Work['libE_info']['H_rows'] 

self.hist.update_history_x_out(work_rows, w) 

 

# --- Handle incoming messages from workers 

 

@staticmethod 

def _check_received_calc(D_recv): 

"Check the type and status fields on a receive calculation." 

calc_type = D_recv['calc_type'] 

calc_status = D_recv['calc_status'] 

assert calc_type in [EVAL_SIM_TAG, EVAL_GEN_TAG], \ 

'Aborting, Unknown calculation type received. Received type: ' + str(calc_type) 

assert calc_status in [FINISHED_PERSISTENT_SIM_TAG, 

FINISHED_PERSISTENT_GEN_TAG, 

UNSET_TAG, 

MAN_SIGNAL_FINISH, 

MAN_SIGNAL_KILL, 

WORKER_KILL_ON_ERR, 

WORKER_KILL_ON_TIMEOUT, 

WORKER_KILL, 

JOB_FAILED, 

WORKER_DONE], \ 

'Aborting: Unknown calculation status received. Received status: ' + str(calc_status) 

 

def _receive_from_workers(self, persis_info): 

"""Receive calculation output from workers. Loops over all 

active workers and probes to see if worker is ready to 

communticate. If any output is received, all other workers are 

looped back over. 

""" 

status = MPI.Status() 

 

new_stuff = True 

while new_stuff and any(self.W['active']): 

new_stuff = False 

for w in self.W['worker_id'][self.W['active'] > 0]: 

if self.Iprobe(w, status): 

new_stuff = True 

self._handle_msg_from_worker(persis_info, w, status) 

 

if 'save_every_k' in self.sim_specs: 

self._save_every_k_sims() 

if 'save_every_k' in self.gen_specs: 

self._save_every_k_gens() 

return persis_info 

 

def _update_state_on_worker_msg(self, persis_info, D_recv, w): 

"""Update history and worker info on worker message. 

""" 

calc_type = D_recv['calc_type'] 

calc_status = D_recv['calc_status'] 

Manager._check_received_calc(D_recv) 

 

self.W[w-1]['active'] = 0 

if calc_status in [FINISHED_PERSISTENT_SIM_TAG, 

FINISHED_PERSISTENT_GEN_TAG]: 

self.W[w-1]['persis_state'] = 0 

else: 

if calc_type == EVAL_SIM_TAG: 

self.hist.update_history_f(D_recv) 

if calc_type == EVAL_GEN_TAG: 

self.hist.update_history_x_in(w, D_recv['calc_out']) 

if 'libE_info' in D_recv and 'persistent' in D_recv['libE_info']: 

# Now a waiting, persistent worker 

self.W[w-1]['persis_state'] = calc_type 

 

if 'libE_info' in D_recv and 'blocking' in D_recv['libE_info']: 

# Now done blocking these workers 

for w_i in D_recv['libE_info']['blocking']: 

self.W[w_i-1]['blocked'] = 0 

self.W[w_i-1]['active'] = 0 

 

if 'persis_info' in D_recv: 

persis_info[w].update(D_recv['persis_info']) 

 

def _handle_msg_from_worker(self, persis_info, w, status): 

"""Handle a message from worker w. 

""" 

logger.debug("Manager receiving from Worker: {}".format(w)) 

try: 

D_recv = self.recv(w) 

logger.debug("Message size {}".format(status.Get_count())) 

except Exception as e: 

logger.error("Exception caught on Manager receive: {}".format(e)) 

logger.error("From worker: {}".format(w)) 

logger.error("Message size of errored message {}". \ 

format(status.Get_count())) 

logger.error("Message status error code {}". \ 

format(status.Get_error())) 

 

# Check on working with peristent data - curently only use one 

#D_recv = _man_request_resend_on_error(w) 

D_recv = self._man_request_pkl_dump_on_error(w) 

 

if status.Get_tag() == ABORT_ENSEMBLE: 

raise ManagerException('Received abort signal from worker') 

 

self._update_state_on_worker_msg(persis_info, D_recv, w) 

 

# --- Handle termination 

 

def _read_final_messages(self): 

"""Read final messages from any active workers""" 

for w in self.W['worker_id'][self.W['active'] > 0]: 

if self.Iprobe(w): 

self.recv(w) 

 

def _final_receive_and_kill(self, persis_info): 

""" 

Tries to receive from any active workers. 

 

If time expires before all active workers have been received from, a 

nonblocking receive is posted (though the manager will not receive this 

data) and a kill signal is sent. 

""" 

exit_flag = 0 

while any(self.W['active']) and exit_flag == 0: 

persis_info = self._receive_from_workers(persis_info) 

if self.term_test(logged=False) == 2 and any(self.W['active']): 

self._print_wallclock_term() 

self._read_final_messages() 

exit_flag = 2 

 

self._kill_workers() 

print("\nlibEnsemble manager total time:", self.elapsed()) 

return persis_info, exit_flag 

 

@staticmethod 

def _print_wallclock_term(): 

"""Print termination message for wall clock elapsed.""" 

print("Termination due to elapsed_wallclock_time has occurred.\n"\ 

"A last attempt has been made to receive any completed work.\n"\ 

"Posting nonblocking receives and kill messages for all active workers\n") 

sys.stdout.flush() 

sys.stderr.flush() 

 

# --- Main loop 

 

def _queue_update(self, H, persis_info): 

"Call queue update function from libE_specs (if defined)" 

if 'queue_update_function' not in self.libE_specs or not len(H): 

return persis_info 

qfun = self.libE_specs['queue_update_function'] 

return qfun(H, self.gen_specs, persis_info) 

 

def _alloc_work(self, H, persis_info): 

"Call work allocation function from alloc_specs" 

alloc_f = self.alloc_specs['alloc_f'] 

return alloc_f(self.W, H, self.sim_specs, self.gen_specs, persis_info) 

 

def run(self, persis_info): 

"Run the manager." 

logger.info("Manager initiated on MPI rank {} on node {}". \ 

format(self.comm.Get_rank(), socket.gethostname())) 

logger.info("Manager exit_criteria: {}".format(self.exit_criteria)) 

 

# Send initial info to workers 

self._send_dtypes_to_workers() 

 

### Continue receiving and giving until termination test is satisfied 

while not self.term_test(): 

persis_info = self._receive_from_workers(persis_info) 

persis_info = self._queue_update(self.hist.trim_H(), persis_info) 

if any(self.W['active'] == 0): 

Work, persis_info = self._alloc_work(self.hist.trim_H(), 

persis_info) 

for w in Work: 

if self.term_test(): 

break 

self._check_work_order(Work[w], w) 

self._send_work_order(Work[w], w) 

self._update_state_on_alloc(Work[w], w) 

 

# Return persis_info, exit_flag 

return self._final_receive_and_kill(persis_info)