diff scripts/process.py @ 1015:cf9d29de3dbf

merge With Pr Saunier's code
author Wendlasida
date Mon, 04 Jun 2018 11:25:49 -0400
parents 0d29b75f74ea
children 16932cefabc1
line wrap: on
line diff
--- a/scripts/process.py	Fri Jun 01 17:32:52 2018 -0400
+++ b/scripts/process.py	Mon Jun 04 11:25:49 2018 -0400
@@ -2,18 +2,20 @@
 
 import sys, argparse
 from pathlib import Path
+from multiprocessing.pool import Pool
 
 import matplotlib
 matplotlib.use('Agg')
 import matplotlib.pyplot as plt
 from numpy import percentile
 
-import storage, events, prediction
+import storage, events, prediction, cvutils
 from metadata import *
 
 parser = argparse.ArgumentParser(description='This program manages the processing of several files based on a description of the sites and video data in an SQLite database following the metadata module.')
 parser.add_argument('--db', dest = 'metadataFilename', help = 'name of the metadata file', required = True)
 parser.add_argument('--videos', dest = 'videoIds', help = 'indices of the video sequences', nargs = '*', type = int)
+parser.add_argument('--sites', dest = 'siteIds', help = 'indices of the video sequences', nargs = '*', type = int)
 parser.add_argument('--cfg', dest = 'configFilename', help = 'name of the configuration file')
 parser.add_argument('-n', dest = 'nObjects', help = 'number of objects/interactions to process', type = int)
 parser.add_argument('--prediction-method', dest = 'predictionMethod', help = 'prediction method (constant velocity (cvd: vector computation (approximate); cve: equation solving; cv: discrete time (approximate)), normal adaptation, point set prediction)', choices = ['cvd', 'cve', 'cv', 'na', 'ps', 'mp'])
@@ -23,46 +25,85 @@
 parser.add_argument('--process', dest = 'process', help = 'data to process', choices = ['feature', 'object', 'classification', 'interaction'])
 parser.add_argument('--display', dest = 'display', help = 'data to display (replay over video)', choices = ['feature', 'object', 'classification', 'interaction'])
 parser.add_argument('--analyze', dest = 'analyze', help = 'data to analyze (results)', choices = ['feature', 'object', 'classification', 'interaction'])
+parser.add_argument('--dry', dest = 'dryRun', help = 'dry run of processing', action = 'store_true')
+parser.add_argument('--nthreads', dest = 'nProcesses', help = 'number of processes to run in parallel', type = int, default = 1)
 
 # need way of selecting sites as similar as possible to sql alchemy syntax
 # override tracking.cfg from db
 # manage cfg files, overwrite them (or a subset of parameters)
 # delete sqlite files
-
 # info of metadata
 
-parser.add_argument('--nthreads', dest = 'nProcesses', help = 'number of processes to run in parallel', type = int, default = 1)
-
 args = parser.parse_args()
-# files are relative to metadata location
-
-session = connectDatabase(args.metadataFilename)
-parentDir = Path(args.metadataFilename).parent
 
+#################################
+# Data preparation
+#################################
+session = connectDatabase(args.metadataFilename)
+parentDir = Path(args.metadataFilename).parent # files are relative to metadata location
+videoSequences = []
+if args.videoIds is not None:
+    videoSequences = [session.query(VideoSequence).get(videoId) for videoId in args.videoIds]
+elif args.siteIds is not None:
+    for siteId in args.siteIds:
+        for site in getSite(session, siteId):
+            for cv in site.cameraViews:
+                videoSequences += cv.videoSequences
+else:
+    print('No video/site to process')
+
+#################################
+# Delete
+#################################
 if args.delete is not None:
-    if args.delete in ['object', 'interaction']:
+    if args.delete == 'feature':
+        pass
+    elif args.delete in ['object', 'interaction']:
         #parser.add_argument('-t', dest = 'dataType', help = 'type of the data to remove', required = True, choices = ['object','interaction', 'bb', 'pois', 'prototype'])
-        for videoId in args.videoIds:
-            vs = session.query(VideoSequence).get(videoId)
+        for vs in videoSequences:
             storage.deleteFromSqlite(str(parentDir/vs.getDatabaseFilename()), args.delete)
 
+#################################
+# Process
+#################################
 if args.process in ['feature', 'object']: # tracking
-    for videoId in args.videoIds:
-        vs = session.query(VideoSequence).get(videoId)
-        if args.configFilename is None:
-            configFilename = vs.cameraView.getTrackingConfigurationFilename()
-        else:
-            configFilename = args.configFilename
-        #todo cvutils.tracking(configFilename, args.process == 'object', str(parentDir/vs.getVideoSequenceFilename(), str(parentDir/vs.getDatabaseFilename(), configFilename = vs.cameraView.getHomographyFilename())
-    
+    if args.nProcesses == 1:
+        for vs in videoSequences:
+            if not (parentDir/vs.getDatabaseFilename()).exists() or args.process == 'object':
+                if args.configFilename is None:
+                    configFilename = str(parentDir/vs.cameraView.getTrackingConfigurationFilename())
+                else:
+                    configFilename = args.configFilename
+                if vs.cameraView.cameraType is None:
+                    cvutils.tracking(configFilename, args.process == 'object', str(parentDir.absolute()/vs.getVideoSequenceFilename()), str(parentDir.absolute()/vs.getDatabaseFilename()), str(parentDir.absolute()/vs.cameraView.getHomographyFilename()), str(parentDir.absolute()/vs.cameraView.getMaskFilename()), False, None, None, args.dryRun)
+                else:
+                    cvutils.tracking(configFilename, args.process == 'object', str(parentDir.absolute()/vs.getVideoSequenceFilename()), str(parentDir.absolute()/vs.getDatabaseFilename()), str(parentDir.absolute()/vs.cameraView.getHomographyFilename()), str(parentDir.absolute()/vs.cameraView.getMaskFilename()), True, vs.cameraView.cameraType.intrinsicCameraMatrix, vs.cameraView.cameraType.distortionCoefficients, args.dryRun)
+            else:
+                print('SQLite already exists: {}'.format(parentDir/vs.getDatabaseFilename()))
+    else:
+        pool = Pool(args.nProcesses)
+        for vs in videoSequences:
+            if not (parentDir/vs.getDatabaseFilename()).exists() or args.process == 'object':
+                if args.configFilename is None:
+                    configFilename = str(parentDir/vs.cameraView.getTrackingConfigurationFilename())
+                else:
+                    configFilename = args.configFilename
+                if vs.cameraView.cameraType is None:
+                    pool.apply_async(cvutils.tracking, args = (configFilename, args.process == 'object', str(parentDir.absolute()/vs.getVideoSequenceFilename()), str(parentDir.absolute()/vs.getDatabaseFilename()), str(parentDir.absolute()/vs.cameraView.getHomographyFilename()), str(parentDir.absolute()/vs.cameraView.getMaskFilename()), False, None, None, args.dryRun))
+                else:
+                    pool.apply_async(cvutils.tracking, args = (configFilename, args.process == 'object', str(parentDir.absolute()/vs.getVideoSequenceFilename()), str(parentDir.absolute()/vs.getDatabaseFilename()), str(parentDir.absolute()/vs.cameraView.getHomographyFilename()), str(parentDir.absolute()/vs.cameraView.getMaskFilename()), True, vs.cameraView.cameraType.intrinsicCameraMatrix, vs.cameraView.cameraType.distortionCoefficients, args.dryRun))
+            else:
+                print('SQLite already exists: {}'.format(parentDir/vs.getDatabaseFilename()))
+        pool.close()
+        pool.join()
+
 elif args.process == 'interaction':
     # safety analysis TODO make function in safety analysis script
     if args.predictionMethod == 'cvd':
         predictionParameters = prediction.CVDirectPredictionParameters()
     if args.predictionMethod == 'cve':
         predictionParameters = prediction.CVExactPredictionParameters()
-    for videoId in args.videoIds:
-        vs = session.query(VideoSequence).get(videoId)
+    for vs in videoSequences:
         print('Processing '+vs.getDatabaseFilename())
         objects = storage.loadTrajectoriesFromSqlite(str(parentDir/vs.getDatabaseFilename()), 'object')#, args.nObjects, withFeatures = (params.useFeaturesForPrediction or predictionMethod == 'ps' or predictionMethod == 'mp'))
         interactions = events.createInteractions(objects)
@@ -81,12 +122,14 @@
     #         processed += job.get()
     #     pool.close()
 
+#################################
+# Analyze
+#################################
 if args.analyze == 'object': # user speed for now
     medianSpeeds = {}
     speeds85 = {}
     minLength = 2*30
-    for videoId in args.videoIds:
-        vs = session.query(VideoSequence).get(videoId)
+    for vs in videoSequences:
         if not vs.cameraView.siteIdx in medianSpeeds:
             medianSpeeds[vs.cameraView.siteIdx] = []
             speeds85[vs.cameraView.siteIdx] = []
@@ -111,8 +154,7 @@
     maxIndicatorValue = {2: float('inf'), 5: float('inf'), 7:10., 10:10.}
     indicators = {}
     interactions = {}
-    for videoId in args.videoIds:
-        vs = session.query(VideoSequence).get(videoId)
+    for vs in videoSequences:
         if not vs.cameraView.siteIdx in interactions:
             interactions[vs.cameraView.siteIdx] = []
             indicators[vs.cameraView.siteIdx] = {}