Mercurial > hg > nsaunier > traffic-intelligence
comparison scripts/process.py @ 1071:58994b08be42
added multithreading for safety
| author | Nicolas Saunier <nicolas.saunier@polymtl.ca> |
|---|---|
| date | Wed, 18 Jul 2018 02:12:47 -0400 |
| parents | 0154133e77df |
| children | c67f8c36ebc7 8ab92ee3cbef |
comparison
equal
deleted
inserted
replaced
| 1070:0154133e77df | 1071:58994b08be42 |
|---|---|
| 215 else: | 215 else: |
| 216 outputPrototypeDatabaseFilename = args.outputPrototypeDatabaseFilename | 216 outputPrototypeDatabaseFilename = args.outputPrototypeDatabaseFilename |
| 217 clusterSizes = ml.computeClusterSizes(labels, prototypeIndices, -1) | 217 clusterSizes = ml.computeClusterSizes(labels, prototypeIndices, -1) |
| 218 storage.savePrototypesToSqlite(str(parentPath/site.getPath()/outputPrototypeDatabaseFilename), [moving.Prototype(object2VideoSequences[trainingObjects[i]].getDatabaseFilename(False), trainingObjects[i].getNum(), prototypeType, clusterSizes[i]) for i in prototypeIndices]) | 218 storage.savePrototypesToSqlite(str(parentPath/site.getPath()/outputPrototypeDatabaseFilename), [moving.Prototype(object2VideoSequences[trainingObjects[i]].getDatabaseFilename(False), trainingObjects[i].getNum(), prototypeType, clusterSizes[i]) for i in prototypeIndices]) |
| 219 | 219 |
| 220 | |
| 221 elif args.process == 'interaction': | 220 elif args.process == 'interaction': |
| 222 # safety analysis TODO make function in safety analysis script | 221 # safety analysis TODO make function in safety analysis script |
| 223 if args.predictionMethod == 'cvd': | 222 if args.predictionMethod == 'cvd': |
| 224 predictionParameters = prediction.CVDirectPredictionParameters() | 223 predictionParameters = prediction.CVDirectPredictionParameters() |
| 225 if args.predictionMethod == 'cve': | 224 elif args.predictionMethod == 'cve': |
| 226 predictionParameters = prediction.CVExactPredictionParameters() | 225 predictionParameters = prediction.CVExactPredictionParameters() |
| 227 for vs in videoSequences: | 226 for vs in videoSequences: |
| 228 print('Processing '+vs.getDatabaseFilename()) | 227 print('Processing '+vs.getDatabaseFilename()) |
| 228 if args.configFilename is None: | |
| 229 params = storage.ProcessParameters(str(parentPath/vs.cameraView.getTrackingConfigurationFilename())) | |
| 230 else: | |
| 231 params = storage.ProcessParameters(args.configFilename) | |
| 229 objects = storage.loadTrajectoriesFromSqlite(str(parentPath/vs.getDatabaseFilename()), 'object')#, args.nObjects, withFeatures = (params.useFeaturesForPrediction or predictionMethod == 'ps' or predictionMethod == 'mp')) | 232 objects = storage.loadTrajectoriesFromSqlite(str(parentPath/vs.getDatabaseFilename()), 'object')#, args.nObjects, withFeatures = (params.useFeaturesForPrediction or predictionMethod == 'ps' or predictionMethod == 'mp')) |
| 230 interactions = events.createInteractions(objects) | 233 interactions = events.createInteractions(objects) |
| 231 #if args.nProcesses == 1: | 234 if args.nProcesses == 1: |
| 232 #print(str(parentPath/vs.cameraView.getTrackingConfigurationFilename())) | 235 #print(len(interactions), args.computePET, predictionParameters, params.collisionDistance, params.predictionTimeHorizon, params.crossingZones) |
| 233 params = storage.ProcessParameters(str(parentPath/vs.cameraView.getTrackingConfigurationFilename())) | 236 processed = events.computeIndicators(interactions, True, args.computePET, predictionParameters, params.collisionDistance, params.predictionTimeHorizon, False, False, None) # params.crossingZones |
| 234 #print(len(interactions), args.computePET, predictionParameters, params.collisionDistance, params.predictionTimeHorizon, params.crossingZones) | 237 else: |
| 235 processed = events.computeIndicators(interactions, True, args.computePET, predictionParameters, params.collisionDistance, params.predictionTimeHorizon, params.crossingZones, False, None) | 238 #pool = Pool(processes = args.nProcesses) |
| 239 nInteractionPerProcess = int(np.ceil(len(interactions)/float(args.nProcesses))) | |
| 240 jobs = [pool.apply_async(events.computeIndicators, args = (interactions[i*nInteractionPerProcess:(i+1)*nInteractionPerProcess], True, args.computePET, predictionParameters, params.collisionDistance, params.predictionTimeHorizon, False, False, None)) for i in range(args.nProcesses)] # params.crossingZones | |
| 241 processed = [] | |
| 242 for job in jobs: | |
| 243 processed += job.get() | |
| 244 #pool.close() | |
| 236 storage.saveIndicatorsToSqlite(str(parentPath/vs.getDatabaseFilename()), processed) | 245 storage.saveIndicatorsToSqlite(str(parentPath/vs.getDatabaseFilename()), processed) |
| 237 # else: | 246 |
| 238 # pool = Pool(processes = args.nProcesses) | |
| 239 # nInteractionPerProcess = int(np.ceil(len(interactions)/float(args.nProcesses))) | |
| 240 # jobs = [pool.apply_async(events.computeIndicators, args = (interactions[i*nInteractionPerProcess:(i+1)*nInteractionPerProcess], not args.noMotionPrediction, args.computePET, predictionParameters, params.collisionDistance, params.predictionTimeHorizon, params.crossingZones, False, None)) for i in range(args.nProcesses)] | |
| 241 # processed = [] | |
| 242 # for job in jobs: | |
| 243 # processed += job.get() | |
| 244 # pool.close() | |
| 245 | |
| 246 ################################# | 247 ################################# |
| 247 # Analyze | 248 # Analyze |
| 248 ################################# | 249 ################################# |
| 249 if args.analyze == 'object': | 250 if args.analyze == 'object': |
| 250 # user speeds, accelerations | 251 # user speeds, accelerations |
