Mercurial > hg > nsaunier > traffic-intelligence
comparison scripts/learn-poi.py @ 913:1cd878812529
work in progress
| author | Nicolas Saunier <nicolas.saunier@polymtl.ca> |
|---|---|
| date | Wed, 28 Jun 2017 17:57:06 -0400 |
| parents | 6db83beb5350 |
| children | f228fd649644 |
comparison
equal
deleted
inserted
replaced
| 912:fd057a6b04db | 913:1cd878812529 |
|---|---|
| 15 parser.add_argument('-ndestinations', dest = 'nDestinationClusters', help = 'number of clusters for trajectory destinations (=norigins if not provided)', type = int) | 15 parser.add_argument('-ndestinations', dest = 'nDestinationClusters', help = 'number of clusters for trajectory destinations (=norigins if not provided)', type = int) |
| 16 parser.add_argument('--covariance-type', dest = 'covarianceType', help = 'type of covariance of Gaussian model', default = "full") | 16 parser.add_argument('--covariance-type', dest = 'covarianceType', help = 'type of covariance of Gaussian model', default = "full") |
| 17 parser.add_argument('-w', dest = 'worldImageFilename', help = 'filename of the world image') | 17 parser.add_argument('-w', dest = 'worldImageFilename', help = 'filename of the world image') |
| 18 parser.add_argument('-u', dest = 'unitsPerPixel', help = 'number of units of distance per pixel', type = float, default = 1.) | 18 parser.add_argument('-u', dest = 'unitsPerPixel', help = 'number of units of distance per pixel', type = float, default = 1.) |
| 19 parser.add_argument('--display', dest = 'display', help = 'display points of interests', action = 'store_true') # default is manhattan distance | 19 parser.add_argument('--display', dest = 'display', help = 'display points of interests', action = 'store_true') # default is manhattan distance |
| 20 parser.add_argument('--assign', dest = 'display', help = 'display points of interests', action = 'store_true') # default is manhattan distance | |
| 20 | 21 |
| 21 args = parser.parse_args() | 22 args = parser.parse_args() |
| 22 | 23 |
| 23 objects = storage.loadTrajectoriesFromSqlite(args.databaseFilename, args.trajectoryType) | 24 objects = storage.loadTrajectoriesFromSqlite(args.databaseFilename, args.trajectoryType) |
| 24 | 25 |
| 35 if args.nDestinationClusters is None: | 36 if args.nDestinationClusters is None: |
| 36 nDestinationClusters = args.nOriginClusters | 37 nDestinationClusters = args.nOriginClusters |
| 37 | 38 |
| 38 gmmId=0 | 39 gmmId=0 |
| 39 for nClusters, points, gmmType in zip([args.nOriginClusters, nDestinationClusters], | 40 for nClusters, points, gmmType in zip([args.nOriginClusters, nDestinationClusters], |
| 40 [beginnings, ends], | 41 [beginnings, ends], |
| 41 ['beginning', 'end']): | 42 ['beginning', 'end']): |
| 42 # estimation | 43 # estimation |
| 43 gmm = mixture.GaussianMixture(n_components=nClusters, covariance_type = args.covarianceType) | 44 gmm = mixture.GaussianMixture(n_components=nClusters, covariance_type = args.covarianceType) |
| 44 model=gmm.fit(beginnings) | 45 model=gmm.fit(points) |
| 45 if not model.converged_: | 46 if not model.converged_: |
| 46 print('Warning: model for '+gmmType+' points did not converge') | 47 print('Warning: model for '+gmmType+' points did not converge') |
| 47 # plot | 48 # plot |
| 48 if args.display: | 49 if args.display: |
| 49 fig = plt.figure() | 50 fig = plt.figure() |
| 50 if args.worldImageFilename is not None and args.unitsPerPixel is not None: | 51 if args.worldImageFilename is not None and args.unitsPerPixel is not None: |
| 51 img = plt.imread(args.worldImageFilename) | 52 img = plt.imread(args.worldImageFilename) |
| 52 plt.imshow(img) | 53 plt.imshow(img) |
| 53 labels = ml.plotGMMClusters(model, points, fig, nUnitsPerPixel = args.unitsPerPixel) | 54 labels = model.predict(points) |
| 55 labels = ml.plotGMMClusters(model, labels, points, fig, nUnitsPerPixel = args.unitsPerPixel) | |
| 54 plt.axis('image') | 56 plt.axis('image') |
| 55 plt.title(gmmType) | 57 plt.title(gmmType) |
| 56 print(gmmType+' Clusters:\n{}'.format(ml.computeClusterSizes(labels, range(model.n_components)))) | 58 print(gmmType+' Clusters:\n{}'.format(ml.computeClusterSizes(labels, range(model.n_components)))) |
| 57 # save | 59 # save |
| 58 storage.savePOIs(args.databaseFilename, model, gmmType, gmmId) | 60 storage.savePOIs(args.databaseFilename, model, gmmType, gmmId) |
| 61 # save assignments | |
| 62 if args.assign: | |
| 63 pass | |
| 59 gmmId += 1 | 64 gmmId += 1 |
| 60 | 65 |
| 61 if args.display: | 66 if args.display: |
| 62 plt.axis('equal') | 67 plt.axis('equal') |
| 63 plt.show() | 68 plt.show() |
