diff python/moving.py @ 729:dad99b86a104 dev

merge with default
author Nicolas Saunier <nicolas.saunier@polymtl.ca>
date Mon, 10 Aug 2015 17:52:19 -0400
parents c6d4ea05a2d0
children 15ddc8715236
line wrap: on
line diff
--- a/python/moving.py	Mon Aug 10 17:51:49 2015 -0400
+++ b/python/moving.py	Mon Aug 10 17:52:19 2015 -0400
@@ -78,12 +78,13 @@
         else:
             return None
 
-def unionIntervals(intervals):
-    'returns the smallest interval containing all intervals'
-    inter = intervals[0]
-    for i in intervals[1:]:
-        inter = Interval.union(inter, i)
-    return inter
+    @classmethod
+    def unionIntervals(cls, intervals):
+        'returns the smallest interval containing all intervals'
+        inter = cls(intervals[0].first, intervals[0].last)
+        for i in intervals[1:]:
+            inter = cls.union(inter, i)
+        return inter
 
 
 class TimeInterval(Interval):
@@ -1067,6 +1068,42 @@
             print 'The object does not exist at '+str(inter)
             return None
 
+    def getObjectsInMask(self, mask, homography = None, minLength = 1):
+        '''Returns new objects made of the positions in the mask
+        mask is in the destination of the homography space'''
+        if homography is not None:
+            self.projectedPositions = self.positions.project(homography)
+        else:
+            self.projectedPositions = self.positions
+        def inMask(positions, i, mask):
+            p = positions[i]
+            return mask[p.y, p.x] != 0.
+
+        #subTimeIntervals self.getFirstInstant()+i
+        filteredIndices = [inMask(self.projectedPositions, i, mask) for i in range(int(self.length()))]
+        # 'connected components' in subTimeIntervals
+        l = 0
+        intervalLabels = []
+        prev = True
+        for i in filteredIndices:
+            if i:
+                if not prev: # new interval
+                    l += 1
+                intervalLabels.append(l)
+            else:
+                intervalLabels.append(-1)
+            prev = i
+        intervalLabels = array(intervalLabels)
+        subObjects = []
+        for l in set(intervalLabels):
+            if l >= 0:
+                if sum(intervalLabels == l) >= minLength:
+                    times = [self.getFirstInstant()+i for i in range(len(intervalLabels)) if intervalLabels[i] == l]
+                    subTimeInterval = TimeInterval(min(times), max(times))
+                    subObjects.append(self.getObjectInTimeInterval(subTimeInterval))
+
+        return subObjects
+
     def getPositions(self):
         return self.positions
 
@@ -1517,7 +1554,7 @@
         else:
             return matchingDistance + 1
 
-def computeClearMOT(annotations, objects, matchingDistance, firstInstant, lastInstant, debug = False):
+def computeClearMOT(annotations, objects, matchingDistance, firstInstant, lastInstant, returnMatches = False, debug = False):
     '''Computes the CLEAR MOT metrics 
 
     Reference:
@@ -1536,6 +1573,12 @@
     fpt number of false alarm.frames (tracker objects without match in each frame)
     gt number of GT.frames
 
+    if returnMatches is True, return as 2 new arguments the GT and TO matches
+    matches is a dict
+    matches[i] is the list of matches for GT/TO i
+    the list of matches is a dict, indexed by time, for the TO/GT id matched at time t 
+    (an instant t not present in matches[i] at which GT/TO exists means a missed detection or false alarm)
+
     TODO: Should we use the distance as weights or just 1/0 if distance below matchingDistance?
     (add argument useDistanceForWeights = False)'''
     from munkres import Munkres
@@ -1548,6 +1591,9 @@
     fpt = 0 # number of false alarm.frames (tracker objects without match in each frame)
     mme = 0 # number of mismatches
     matches = {} # match[i] is the tracker track associated with GT i (using object references)
+    if returnMatches:
+        gtMatches = {a.getNum():{} for a in annotations}
+        toMatches = {o.getNum():{} for o in objects}
     for t in xrange(firstInstant, lastInstant+1):
         previousMatches = matches.copy()
         # go through currently matched GT-TO and check if they are still matched withing matchingDistance
@@ -1583,6 +1629,10 @@
                     dist += costs[k][v]
         if debug:
             print('{} '.format(t)+', '.join(['{} {}'.format(k.getNum(), v.getNum()) for k,v in matches.iteritems()]))
+        if returnMatches:
+            for a,o in matches.iteritems():
+                gtMatches[a.getNum()][t] = o.getNum()
+                toMatches[o.getNum()][t] = a.getNum()
         
         # compute metrics elements
         ct += len(matches)
@@ -1615,8 +1665,10 @@
         mota = 1.-float(mt+fpt+mme)/gt
     else:
         mota = None
-    return motp, mota, mt, mme, fpt, gt
-
+    if returnMatches:
+        return motp, mota, mt, mme, fpt, gt, gtMatches, toMatches
+    else:
+        return motp, mota, mt, mme, fpt, gt
 
 def plotRoadUsers(objects, colors):
     '''Colors is a PlottingPropertyValues instance'''