diff scripts/learn-poi.py @ 915:13434f5017dd

work to save trajectory assignment to origin and destinations
author Nicolas Saunier <nicolas.saunier@polymtl.ca>
date Tue, 04 Jul 2017 17:03:29 -0400
parents f228fd649644
children 7345f0d51faa
line wrap: on
line diff
--- a/scripts/learn-poi.py	Wed Jun 28 23:43:52 2017 -0400
+++ b/scripts/learn-poi.py	Tue Jul 04 17:03:29 2017 -0400
@@ -20,6 +20,8 @@
 parser.add_argument('--display', dest = 'display', help = 'display points of interests', action = 'store_true') # default is manhattan distance
 parser.add_argument('--assign', dest = 'assign', help = 'display points of interests', action = 'store_true')
 
+# TODO test Variational Bayesian Gaussian Mixture BayesianGaussianMixture
+
 args = parser.parse_args()
 
 objects = storage.loadTrajectoriesFromSqlite(args.databaseFilename, args.trajectoryType, args.nObjects)
@@ -29,6 +31,8 @@
 for o in objects:
     beginnings.append(o.getPositionAt(0).aslist())
     ends.append(o.getPositionAt(int(o.length())-1).aslist())
+    if args.assign:
+        o.od = [-1, -1]
 
 beginnings = np.array(beginnings)
 ends = np.array(ends)
@@ -38,42 +42,41 @@
     nDestinationClusters = args.nOriginClusters
 
 gmmId=0
+models = {}
 for nClusters, points, gmmType in zip([args.nOriginClusters, nDestinationClusters],
                                       [beginnings, ends],
                                       ['beginning', 'end']):
     # estimation
     gmm = mixture.GaussianMixture(n_components=nClusters, covariance_type = args.covarianceType)
-    model=gmm.fit(points)
-    if not model.converged_:
+    models[gmmType]=gmm.fit(points)
+    if not models[gmmType].converged_:
         print('Warning: model for '+gmmType+' points did not converge')
     if args.display or args.assign:
-        labels = model.predict(points)
+        labels = models[gmmType].predict(points)
     # plot
     if args.display:
         fig = plt.figure()
         if args.worldImageFilename is not None and args.unitsPerPixel is not None:
             img = plt.imread(args.worldImageFilename)
             plt.imshow(img)
-        ml.plotGMMClusters(model, labels, points, fig, nUnitsPerPixel = args.unitsPerPixel)
+        ml.plotGMMClusters(models[gmmType], labels, points, fig, nUnitsPerPixel = args.unitsPerPixel)
         plt.axis('image')
         plt.title(gmmType)
-        print(gmmType+' Clusters:\n{}'.format(ml.computeClusterSizes(labels, range(model.n_components))))
+        print(gmmType+' Clusters:\n{}'.format(ml.computeClusterSizes(labels, range(models[gmmType].n_components))))
     # save
-    storage.savePOIs(args.databaseFilename, model, gmmType, gmmId)
+    storage.savePOIs(args.databaseFilename, models[gmmType], gmmType, gmmId)
     # save assignments
     if args.assign:
-        pass # savePOIAssignments(
+        for o, l in zip(objects, labels):
+            if gmmType == 'beginning':
+                o.od[0] = l
+            elif gmmType == 'end':
+                o.od[1] = l
     gmmId += 1
 
+if args.assign:
+    storage.savePOIAssignments(args.databaseFilename, objects)
+
 if args.display:
     plt.axis('equal')
     plt.show()
-
-# fig = plt.figure()
-# if args.worldImageFilename is not None and args.pixelsPerUnit is not None:
-#     img = plt.imread(args.worldImageFilename)
-#     plt.imshow(img)
-# ml.plotGMMClusters(, , fig, nPixelsPerUnit = args.pixelsPerUnit)
-# plt.axis('equal')
-# plt.title()
-# print('Destination Clusters:\n{}'.format(ml.computeClusterSizes(endModel.predict(ends), range(args.nClusters))))