comparison scripts/learn-poi.py @ 915:13434f5017dd

work to save trajectory assignment to origin and destinations
author Nicolas Saunier <nicolas.saunier@polymtl.ca>
date Tue, 04 Jul 2017 17:03:29 -0400
parents f228fd649644
children 7345f0d51faa
comparison
equal deleted inserted replaced
914:f228fd649644 915:13434f5017dd
18 parser.add_argument('-w', dest = 'worldImageFilename', help = 'filename of the world image') 18 parser.add_argument('-w', dest = 'worldImageFilename', help = 'filename of the world image')
19 parser.add_argument('-u', dest = 'unitsPerPixel', help = 'number of units of distance per pixel', type = float, default = 1.) 19 parser.add_argument('-u', dest = 'unitsPerPixel', help = 'number of units of distance per pixel', type = float, default = 1.)
20 parser.add_argument('--display', dest = 'display', help = 'display points of interests', action = 'store_true') # default is manhattan distance 20 parser.add_argument('--display', dest = 'display', help = 'display points of interests', action = 'store_true') # default is manhattan distance
21 parser.add_argument('--assign', dest = 'assign', help = 'display points of interests', action = 'store_true') 21 parser.add_argument('--assign', dest = 'assign', help = 'display points of interests', action = 'store_true')
22 22
23 # TODO test Variational Bayesian Gaussian Mixture BayesianGaussianMixture
24
23 args = parser.parse_args() 25 args = parser.parse_args()
24 26
25 objects = storage.loadTrajectoriesFromSqlite(args.databaseFilename, args.trajectoryType, args.nObjects) 27 objects = storage.loadTrajectoriesFromSqlite(args.databaseFilename, args.trajectoryType, args.nObjects)
26 28
27 beginnings = [] 29 beginnings = []
28 ends = [] 30 ends = []
29 for o in objects: 31 for o in objects:
30 beginnings.append(o.getPositionAt(0).aslist()) 32 beginnings.append(o.getPositionAt(0).aslist())
31 ends.append(o.getPositionAt(int(o.length())-1).aslist()) 33 ends.append(o.getPositionAt(int(o.length())-1).aslist())
34 if args.assign:
35 o.od = [-1, -1]
32 36
33 beginnings = np.array(beginnings) 37 beginnings = np.array(beginnings)
34 ends = np.array(ends) 38 ends = np.array(ends)
35 39
36 nDestinationClusters = args.nDestinationClusters 40 nDestinationClusters = args.nDestinationClusters
37 if args.nDestinationClusters is None: 41 if args.nDestinationClusters is None:
38 nDestinationClusters = args.nOriginClusters 42 nDestinationClusters = args.nOriginClusters
39 43
40 gmmId=0 44 gmmId=0
45 models = {}
41 for nClusters, points, gmmType in zip([args.nOriginClusters, nDestinationClusters], 46 for nClusters, points, gmmType in zip([args.nOriginClusters, nDestinationClusters],
42 [beginnings, ends], 47 [beginnings, ends],
43 ['beginning', 'end']): 48 ['beginning', 'end']):
44 # estimation 49 # estimation
45 gmm = mixture.GaussianMixture(n_components=nClusters, covariance_type = args.covarianceType) 50 gmm = mixture.GaussianMixture(n_components=nClusters, covariance_type = args.covarianceType)
46 model=gmm.fit(points) 51 models[gmmType]=gmm.fit(points)
47 if not model.converged_: 52 if not models[gmmType].converged_:
48 print('Warning: model for '+gmmType+' points did not converge') 53 print('Warning: model for '+gmmType+' points did not converge')
49 if args.display or args.assign: 54 if args.display or args.assign:
50 labels = model.predict(points) 55 labels = models[gmmType].predict(points)
51 # plot 56 # plot
52 if args.display: 57 if args.display:
53 fig = plt.figure() 58 fig = plt.figure()
54 if args.worldImageFilename is not None and args.unitsPerPixel is not None: 59 if args.worldImageFilename is not None and args.unitsPerPixel is not None:
55 img = plt.imread(args.worldImageFilename) 60 img = plt.imread(args.worldImageFilename)
56 plt.imshow(img) 61 plt.imshow(img)
57 ml.plotGMMClusters(model, labels, points, fig, nUnitsPerPixel = args.unitsPerPixel) 62 ml.plotGMMClusters(models[gmmType], labels, points, fig, nUnitsPerPixel = args.unitsPerPixel)
58 plt.axis('image') 63 plt.axis('image')
59 plt.title(gmmType) 64 plt.title(gmmType)
60 print(gmmType+' Clusters:\n{}'.format(ml.computeClusterSizes(labels, range(model.n_components)))) 65 print(gmmType+' Clusters:\n{}'.format(ml.computeClusterSizes(labels, range(models[gmmType].n_components))))
61 # save 66 # save
62 storage.savePOIs(args.databaseFilename, model, gmmType, gmmId) 67 storage.savePOIs(args.databaseFilename, models[gmmType], gmmType, gmmId)
63 # save assignments 68 # save assignments
64 if args.assign: 69 if args.assign:
65 pass # savePOIAssignments( 70 for o, l in zip(objects, labels):
71 if gmmType == 'beginning':
72 o.od[0] = l
73 elif gmmType == 'end':
74 o.od[1] = l
66 gmmId += 1 75 gmmId += 1
76
77 if args.assign:
78 storage.savePOIAssignments(args.databaseFilename, objects)
67 79
68 if args.display: 80 if args.display:
69 plt.axis('equal') 81 plt.axis('equal')
70 plt.show() 82 plt.show()
71
72 # fig = plt.figure()
73 # if args.worldImageFilename is not None and args.pixelsPerUnit is not None:
74 # img = plt.imread(args.worldImageFilename)
75 # plt.imshow(img)
76 # ml.plotGMMClusters(, , fig, nPixelsPerUnit = args.pixelsPerUnit)
77 # plt.axis('equal')
78 # plt.title()
79 # print('Destination Clusters:\n{}'.format(ml.computeClusterSizes(endModel.predict(ends), range(args.nClusters))))