comparison scripts/safety-analysis.py @ 949:d6c1c05d11f5

modified multithreading at the interaction level for safety computations
author Nicolas Saunier <nicolas.saunier@polymtl.ca>
date Fri, 21 Jul 2017 17:52:56 -0400
parents 584b9405e494
children eb42f2f51490
comparison
equal deleted inserted replaced
948:584b9405e494 949:d6c1c05d11f5
1 #! /usr/bin/env python 1 #! /usr/bin/env python
2 2
3 import storage, prediction, events, moving 3 import storage, prediction, events, moving
4 4
5 import sys, argparse, random 5 import sys, argparse, random
6 from multiprocessing import Pool
6 7
7 import matplotlib.pyplot as plt 8 import matplotlib.pyplot as plt
8 import numpy as np 9 import numpy as np
9 10
10 # todo: very slow if too many predicted trajectories 11 # todo: very slow if too many predicted trajectories
68 # params.useFeaturesForPrediction) 69 # params.useFeaturesForPrediction)
69 70
70 objects = storage.loadTrajectoriesFromSqlite(params.databaseFilename, 'object', args.nObjects, withFeatures = (params.useFeaturesForPrediction or predictionMethod == 'ps' or predictionMethod == 'mp')) 71 objects = storage.loadTrajectoriesFromSqlite(params.databaseFilename, 'object', args.nObjects, withFeatures = (params.useFeaturesForPrediction or predictionMethod == 'ps' or predictionMethod == 'mp'))
71 72
72 interactions = events.createInteractions(objects) 73 interactions = events.createInteractions(objects)
73 for inter in interactions: 74 if args.nProcesses == 1:
74 print('processing interaction {}'.format(inter.getNum()) 75 processed = events.computeIndicators(interactions, not args.noMotionPrediction, args.computePET, predictionParameters, params.collisionDistance, params.predictionTimeHorizon, params.crossingZones, False, None)
75 inter.computeIndicators() 76 else:
76 if not args.noMotionPrediction: 77 pool = Pool(processes = args.nProcesses)
77 inter.computeCrossingsCollisions(predictionParameters, params.collisionDistance, params.predictionTimeHorizon, params.crossingZones, nProcesses = args.nProcesses) 78 nInteractionPerProcess = int(np.ceil(len(interactions)/float(args.nProcesses)))
78 79 jobs = [pool.apply_async(events.computeIndicators, args = (interactions[i*nInteractionPerProcess:(i+1)*nInteractionPerProcess], not args.noMotionPrediction, args.computePET, predictionParameters, params.collisionDistance, params.predictionTimeHorizon, params.crossingZones, False, None)) for i in range(args.nProcesses)]
79 if args.computePET: 80 processed = []
80 for inter in interactions: 81 for job in jobs:
81 inter.computePET(params.collisionDistance) 82 processed += job.get()
82 83 pool.close()
83 storage.saveIndicatorsToSqlite(params.databaseFilename, interactions) 84 storage.saveIndicatorsToSqlite(params.databaseFilename, processed)
84 85
85 if args.displayCollisionPoints: 86 if args.displayCollisionPoints:
86 plt.figure() 87 plt.figure()
87 allCollisionPoints = [] 88 allCollisionPoints = []
88 for inter in interactions: 89 for inter in processed:
89 for collisionPoints in inter.collisionPoints.values(): 90 for collisionPoints in inter.collisionPoints.values():
90 allCollisionPoints += collisionPoints 91 allCollisionPoints += collisionPoints
91 moving.Point.plotAll(allCollisionPoints) 92 moving.Point.plotAll(allCollisionPoints)
92 plt.axis('equal') 93 plt.axis('equal')
93 94