Mercurial > hg > nsaunier > traffic-intelligence
comparison scripts/safety-analysis.py @ 949:d6c1c05d11f5
modified multithreading at the interaction level for safety computations
| author | Nicolas Saunier <nicolas.saunier@polymtl.ca> |
|---|---|
| date | Fri, 21 Jul 2017 17:52:56 -0400 |
| parents | 584b9405e494 |
| children | eb42f2f51490 |
comparison
equal
deleted
inserted
replaced
| 948:584b9405e494 | 949:d6c1c05d11f5 |
|---|---|
| 1 #! /usr/bin/env python | 1 #! /usr/bin/env python |
| 2 | 2 |
| 3 import storage, prediction, events, moving | 3 import storage, prediction, events, moving |
| 4 | 4 |
| 5 import sys, argparse, random | 5 import sys, argparse, random |
| 6 from multiprocessing import Pool | |
| 6 | 7 |
| 7 import matplotlib.pyplot as plt | 8 import matplotlib.pyplot as plt |
| 8 import numpy as np | 9 import numpy as np |
| 9 | 10 |
| 10 # todo: very slow if too many predicted trajectories | 11 # todo: very slow if too many predicted trajectories |
| 68 # params.useFeaturesForPrediction) | 69 # params.useFeaturesForPrediction) |
| 69 | 70 |
| 70 objects = storage.loadTrajectoriesFromSqlite(params.databaseFilename, 'object', args.nObjects, withFeatures = (params.useFeaturesForPrediction or predictionMethod == 'ps' or predictionMethod == 'mp')) | 71 objects = storage.loadTrajectoriesFromSqlite(params.databaseFilename, 'object', args.nObjects, withFeatures = (params.useFeaturesForPrediction or predictionMethod == 'ps' or predictionMethod == 'mp')) |
| 71 | 72 |
| 72 interactions = events.createInteractions(objects) | 73 interactions = events.createInteractions(objects) |
| 73 for inter in interactions: | 74 if args.nProcesses == 1: |
| 74 print('processing interaction {}'.format(inter.getNum()) | 75 processed = events.computeIndicators(interactions, not args.noMotionPrediction, args.computePET, predictionParameters, params.collisionDistance, params.predictionTimeHorizon, params.crossingZones, False, None) |
| 75 inter.computeIndicators() | 76 else: |
| 76 if not args.noMotionPrediction: | 77 pool = Pool(processes = args.nProcesses) |
| 77 inter.computeCrossingsCollisions(predictionParameters, params.collisionDistance, params.predictionTimeHorizon, params.crossingZones, nProcesses = args.nProcesses) | 78 nInteractionPerProcess = int(np.ceil(len(interactions)/float(args.nProcesses))) |
| 78 | 79 jobs = [pool.apply_async(events.computeIndicators, args = (interactions[i*nInteractionPerProcess:(i+1)*nInteractionPerProcess], not args.noMotionPrediction, args.computePET, predictionParameters, params.collisionDistance, params.predictionTimeHorizon, params.crossingZones, False, None)) for i in range(args.nProcesses)] |
| 79 if args.computePET: | 80 processed = [] |
| 80 for inter in interactions: | 81 for job in jobs: |
| 81 inter.computePET(params.collisionDistance) | 82 processed += job.get() |
| 82 | 83 pool.close() |
| 83 storage.saveIndicatorsToSqlite(params.databaseFilename, interactions) | 84 storage.saveIndicatorsToSqlite(params.databaseFilename, processed) |
| 84 | 85 |
| 85 if args.displayCollisionPoints: | 86 if args.displayCollisionPoints: |
| 86 plt.figure() | 87 plt.figure() |
| 87 allCollisionPoints = [] | 88 allCollisionPoints = [] |
| 88 for inter in interactions: | 89 for inter in processed: |
| 89 for collisionPoints in inter.collisionPoints.values(): | 90 for collisionPoints in inter.collisionPoints.values(): |
| 90 allCollisionPoints += collisionPoints | 91 allCollisionPoints += collisionPoints |
| 91 moving.Point.plotAll(allCollisionPoints) | 92 moving.Point.plotAll(allCollisionPoints) |
| 92 plt.axis('equal') | 93 plt.axis('equal') |
| 93 | 94 |
