comparison scripts/process.py @ 1071:58994b08be42

added multithreading for safety
author Nicolas Saunier <nicolas.saunier@polymtl.ca>
date Wed, 18 Jul 2018 02:12:47 -0400
parents 0154133e77df
children c67f8c36ebc7 8ab92ee3cbef
comparison
equal deleted inserted replaced
1070:0154133e77df 1071:58994b08be42
215 else: 215 else:
216 outputPrototypeDatabaseFilename = args.outputPrototypeDatabaseFilename 216 outputPrototypeDatabaseFilename = args.outputPrototypeDatabaseFilename
217 clusterSizes = ml.computeClusterSizes(labels, prototypeIndices, -1) 217 clusterSizes = ml.computeClusterSizes(labels, prototypeIndices, -1)
218 storage.savePrototypesToSqlite(str(parentPath/site.getPath()/outputPrototypeDatabaseFilename), [moving.Prototype(object2VideoSequences[trainingObjects[i]].getDatabaseFilename(False), trainingObjects[i].getNum(), prototypeType, clusterSizes[i]) for i in prototypeIndices]) 218 storage.savePrototypesToSqlite(str(parentPath/site.getPath()/outputPrototypeDatabaseFilename), [moving.Prototype(object2VideoSequences[trainingObjects[i]].getDatabaseFilename(False), trainingObjects[i].getNum(), prototypeType, clusterSizes[i]) for i in prototypeIndices])
219 219
220
221 elif args.process == 'interaction': 220 elif args.process == 'interaction':
222 # safety analysis TODO make function in safety analysis script 221 # safety analysis TODO make function in safety analysis script
223 if args.predictionMethod == 'cvd': 222 if args.predictionMethod == 'cvd':
224 predictionParameters = prediction.CVDirectPredictionParameters() 223 predictionParameters = prediction.CVDirectPredictionParameters()
225 if args.predictionMethod == 'cve': 224 elif args.predictionMethod == 'cve':
226 predictionParameters = prediction.CVExactPredictionParameters() 225 predictionParameters = prediction.CVExactPredictionParameters()
227 for vs in videoSequences: 226 for vs in videoSequences:
228 print('Processing '+vs.getDatabaseFilename()) 227 print('Processing '+vs.getDatabaseFilename())
228 if args.configFilename is None:
229 params = storage.ProcessParameters(str(parentPath/vs.cameraView.getTrackingConfigurationFilename()))
230 else:
231 params = storage.ProcessParameters(args.configFilename)
229 objects = storage.loadTrajectoriesFromSqlite(str(parentPath/vs.getDatabaseFilename()), 'object')#, args.nObjects, withFeatures = (params.useFeaturesForPrediction or predictionMethod == 'ps' or predictionMethod == 'mp')) 232 objects = storage.loadTrajectoriesFromSqlite(str(parentPath/vs.getDatabaseFilename()), 'object')#, args.nObjects, withFeatures = (params.useFeaturesForPrediction or predictionMethod == 'ps' or predictionMethod == 'mp'))
230 interactions = events.createInteractions(objects) 233 interactions = events.createInteractions(objects)
231 #if args.nProcesses == 1: 234 if args.nProcesses == 1:
232 #print(str(parentPath/vs.cameraView.getTrackingConfigurationFilename())) 235 #print(len(interactions), args.computePET, predictionParameters, params.collisionDistance, params.predictionTimeHorizon, params.crossingZones)
233 params = storage.ProcessParameters(str(parentPath/vs.cameraView.getTrackingConfigurationFilename())) 236 processed = events.computeIndicators(interactions, True, args.computePET, predictionParameters, params.collisionDistance, params.predictionTimeHorizon, False, False, None) # params.crossingZones
234 #print(len(interactions), args.computePET, predictionParameters, params.collisionDistance, params.predictionTimeHorizon, params.crossingZones) 237 else:
235 processed = events.computeIndicators(interactions, True, args.computePET, predictionParameters, params.collisionDistance, params.predictionTimeHorizon, params.crossingZones, False, None) 238 #pool = Pool(processes = args.nProcesses)
239 nInteractionPerProcess = int(np.ceil(len(interactions)/float(args.nProcesses)))
240 jobs = [pool.apply_async(events.computeIndicators, args = (interactions[i*nInteractionPerProcess:(i+1)*nInteractionPerProcess], True, args.computePET, predictionParameters, params.collisionDistance, params.predictionTimeHorizon, False, False, None)) for i in range(args.nProcesses)] # params.crossingZones
241 processed = []
242 for job in jobs:
243 processed += job.get()
244 #pool.close()
236 storage.saveIndicatorsToSqlite(str(parentPath/vs.getDatabaseFilename()), processed) 245 storage.saveIndicatorsToSqlite(str(parentPath/vs.getDatabaseFilename()), processed)
237 # else: 246
238 # pool = Pool(processes = args.nProcesses)
239 # nInteractionPerProcess = int(np.ceil(len(interactions)/float(args.nProcesses)))
240 # jobs = [pool.apply_async(events.computeIndicators, args = (interactions[i*nInteractionPerProcess:(i+1)*nInteractionPerProcess], not args.noMotionPrediction, args.computePET, predictionParameters, params.collisionDistance, params.predictionTimeHorizon, params.crossingZones, False, None)) for i in range(args.nProcesses)]
241 # processed = []
242 # for job in jobs:
243 # processed += job.get()
244 # pool.close()
245
246 ################################# 247 #################################
247 # Analyze 248 # Analyze
248 ################################# 249 #################################
249 if args.analyze == 'object': 250 if args.analyze == 'object':
250 # user speeds, accelerations 251 # user speeds, accelerations