Mercurial > hg > nsaunier > traffic-intelligence
comparison python/storage.py @ 871:6db83beb5350
work in progress to update gaussian mixtures
| author | Nicolas Saunier <nicolas.saunier@polymtl.ca> |
|---|---|
| date | Fri, 03 Feb 2017 16:26:18 -0500 |
| parents | 2d6249fe905a |
| children | c70adaeeddf5 |
comparison
equal
deleted
inserted
replaced
| 870:1535251a1f40 | 871:6db83beb5350 |
|---|---|
| 565 ######################### | 565 ######################### |
| 566 # saving and loading for scene interpretation | 566 # saving and loading for scene interpretation |
| 567 ######################### | 567 ######################### |
| 568 | 568 |
| 569 def savePOIs(filename, gmm, gmmType, gmmId): | 569 def savePOIs(filename, gmm, gmmType, gmmId): |
| 570 '''Saves a Gaussian mixture model (of class sklearn.mixture.GMM) | 570 '''Saves a Gaussian mixture model (of class sklearn.mixture.GaussianMixture) |
| 571 gmmType is a type of GMM, learnt either from beginnings or ends of trajectories''' | 571 gmmType is a type of GaussianMixture, learnt either from beginnings or ends of trajectories''' |
| 572 connection = sqlite3.connect(filename) | 572 connection = sqlite3.connect(filename) |
| 573 cursor = connection.cursor() | 573 cursor = connection.cursor() |
| 574 if gmmType not in ['beginning', 'end']: | 574 if gmmType not in ['beginning', 'end']: |
| 575 print('Unknown POI type {}. Exiting'.format(gmmType)) | 575 print('Unknown POI type {}. Exiting'.format(gmmType)) |
| 576 import sys | 576 import sys |
| 577 sys.exit() | 577 sys.exit() |
| 578 try: | 578 try: |
| 579 cursor.execute('CREATE TABLE IF NOT EXISTS gaussians2d (id INTEGER, type VARCHAR, x_center REAL, y_center REAL, covar00 REAL, covar01 REAL, covar10 REAL, covar11 REAL, covariance_type VARCHAR, weight, mixture_id INTEGER, PRIMARY KEY(id, mixture_id))') | 579 cursor.execute('CREATE TABLE IF NOT EXISTS gaussians2d (id INTEGER, type VARCHAR, x_center REAL, y_center REAL, covar00 REAL, covar01 REAL, covar10 REAL, covar11 REAL, covariance_type VARCHAR, weight, mixture_id INTEGER, PRIMARY KEY(id, mixture_id))') |
| 580 for i in xrange(gmm.n_components): | 580 for i in xrange(gmm.n_components): |
| 581 cursor.execute('INSERT INTO gaussians2d VALUES({}, \'{}\', {}, {}, {}, {}, {}, {}, \'{}\', {}, {})'.format(i, gmmType, gmm.means_[i][0], gmm.means_[i][1], gmm.covars_[i][0,0], gmm.covars_[i][0,1], gmm.covars_[i][1,0], gmm.covars_[i][1,1], gmm.covariance_type, gmm.weights_[i], gmmId)) | 581 cursor.execute('INSERT INTO gaussians2d VALUES({}, \'{}\', {}, {}, {}, {}, {}, {}, \'{}\', {}, {})'.format(i, gmmType, gmm.means_[i][0], gmm.means_[i][1], gmm.covariances_[i][0,0], gmm.covariances_[i][0,1], gmm.covariances_[i][1,0], gmm.covariances_[i][1,1], gmm.covariance_type, gmm.weights_[i], gmmId)) |
| 582 connection.commit() | 582 connection.commit() |
| 583 except sqlite3.OperationalError as error: | 583 except sqlite3.OperationalError as error: |
| 584 printDBError(error) | 584 printDBError(error) |
| 585 connection.close() | 585 connection.close() |
| 586 | 586 |
| 595 gmmId = None | 595 gmmId = None |
| 596 gmm = [] | 596 gmm = [] |
| 597 for row in cursor: | 597 for row in cursor: |
| 598 if gmmId is None or row[10] != gmmId: | 598 if gmmId is None or row[10] != gmmId: |
| 599 if len(gmm) > 0: | 599 if len(gmm) > 0: |
| 600 tmp = mixture.GMM(len(gmm), covarianceType) | 600 tmp = mixture.GaussianMixture(len(gmm), covarianceType) |
| 601 tmp.means_ = array([gaussian['mean'] for gaussian in gmm]) | 601 tmp.means_ = array([gaussian['mean'] for gaussian in gmm]) |
| 602 tmp.covars_ = array([gaussian['covar'] for gaussian in gmm]) | 602 tmp.covariances_ = array([gaussian['covar'] for gaussian in gmm]) |
| 603 tmp.weights_ = array([gaussian['weight'] for gaussian in gmm]) | 603 tmp.weights_ = array([gaussian['weight'] for gaussian in gmm]) |
| 604 tmp.gmmTypes = [gaussian['type'] for gaussian in gmm] | 604 tmp.gmmTypes = [gaussian['type'] for gaussian in gmm] |
| 605 pois.append(tmp) | 605 pois.append(tmp) |
| 606 gaussian = {'type': row[1], | 606 gaussian = {'type': row[1], |
| 607 'mean': row[2:4], | 607 'mean': row[2:4], |
| 614 gmm.append({'type': row[1], | 614 gmm.append({'type': row[1], |
| 615 'mean': row[2:4], | 615 'mean': row[2:4], |
| 616 'covar': array(row[4:8]).reshape(2,2), | 616 'covar': array(row[4:8]).reshape(2,2), |
| 617 'weight': row[9]}) | 617 'weight': row[9]}) |
| 618 if len(gmm) > 0: | 618 if len(gmm) > 0: |
| 619 tmp = mixture.GMM(len(gmm), covarianceType) | 619 tmp = mixture.GaussianMixture(len(gmm), covarianceType) |
| 620 tmp.means_ = array([gaussian['mean'] for gaussian in gmm]) | 620 tmp.means_ = array([gaussian['mean'] for gaussian in gmm]) |
| 621 tmp.covars_ = array([gaussian['covar'] for gaussian in gmm]) | 621 tmp.covariances_ = array([gaussian['covar'] for gaussian in gmm]) |
| 622 tmp.weights_ = array([gaussian['weight'] for gaussian in gmm]) | 622 tmp.weights_ = array([gaussian['weight'] for gaussian in gmm]) |
| 623 tmp.gmmTypes = [gaussian['type'] for gaussian in gmm] | 623 tmp.gmmTypes = [gaussian['type'] for gaussian in gmm] |
| 624 pois.append(tmp) | 624 pois.append(tmp) |
| 625 except sqlite3.OperationalError as error: | 625 except sqlite3.OperationalError as error: |
| 626 printDBError(error) | 626 printDBError(error) |
