# HG changeset patch # User Nicolas Saunier # Date 1486157178 18000 # Node ID 6db83beb53501eeba3d0dbaa49372e91c95c6efb # Parent 1535251a1f40254e13b6cd37980d576c29d7899b work in progress to update gaussian mixtures diff -r 1535251a1f40 -r 6db83beb5350 python/storage.py --- a/python/storage.py Fri Feb 03 16:15:06 2017 -0500 +++ b/python/storage.py Fri Feb 03 16:26:18 2017 -0500 @@ -567,8 +567,8 @@ ######################### def savePOIs(filename, gmm, gmmType, gmmId): - '''Saves a Gaussian mixture model (of class sklearn.mixture.GMM) - gmmType is a type of GMM, learnt either from beginnings or ends of trajectories''' + '''Saves a Gaussian mixture model (of class sklearn.mixture.GaussianMixture) + gmmType is a type of GaussianMixture, learnt either from beginnings or ends of trajectories''' connection = sqlite3.connect(filename) cursor = connection.cursor() if gmmType not in ['beginning', 'end']: @@ -578,7 +578,7 @@ try: cursor.execute('CREATE TABLE IF NOT EXISTS gaussians2d (id INTEGER, type VARCHAR, x_center REAL, y_center REAL, covar00 REAL, covar01 REAL, covar10 REAL, covar11 REAL, covariance_type VARCHAR, weight, mixture_id INTEGER, PRIMARY KEY(id, mixture_id))') for i in xrange(gmm.n_components): - cursor.execute('INSERT INTO gaussians2d VALUES({}, \'{}\', {}, {}, {}, {}, {}, {}, \'{}\', {}, {})'.format(i, gmmType, gmm.means_[i][0], gmm.means_[i][1], gmm.covars_[i][0,0], gmm.covars_[i][0,1], gmm.covars_[i][1,0], gmm.covars_[i][1,1], gmm.covariance_type, gmm.weights_[i], gmmId)) + cursor.execute('INSERT INTO gaussians2d VALUES({}, \'{}\', {}, {}, {}, {}, {}, {}, \'{}\', {}, {})'.format(i, gmmType, gmm.means_[i][0], gmm.means_[i][1], gmm.covariances_[i][0,0], gmm.covariances_[i][0,1], gmm.covariances_[i][1,0], gmm.covariances_[i][1,1], gmm.covariance_type, gmm.weights_[i], gmmId)) connection.commit() except sqlite3.OperationalError as error: printDBError(error) @@ -597,9 +597,9 @@ for row in cursor: if gmmId is None or row[10] != gmmId: if len(gmm) > 0: - tmp = mixture.GMM(len(gmm), covarianceType) + tmp = mixture.GaussianMixture(len(gmm), covarianceType) tmp.means_ = array([gaussian['mean'] for gaussian in gmm]) - tmp.covars_ = array([gaussian['covar'] for gaussian in gmm]) + tmp.covariances_ = array([gaussian['covar'] for gaussian in gmm]) tmp.weights_ = array([gaussian['weight'] for gaussian in gmm]) tmp.gmmTypes = [gaussian['type'] for gaussian in gmm] pois.append(tmp) @@ -616,9 +616,9 @@ 'covar': array(row[4:8]).reshape(2,2), 'weight': row[9]}) if len(gmm) > 0: - tmp = mixture.GMM(len(gmm), covarianceType) + tmp = mixture.GaussianMixture(len(gmm), covarianceType) tmp.means_ = array([gaussian['mean'] for gaussian in gmm]) - tmp.covars_ = array([gaussian['covar'] for gaussian in gmm]) + tmp.covariances_ = array([gaussian['covar'] for gaussian in gmm]) tmp.weights_ = array([gaussian['weight'] for gaussian in gmm]) tmp.gmmTypes = [gaussian['type'] for gaussian in gmm] pois.append(tmp) diff -r 1535251a1f40 -r 6db83beb5350 python/tests/storage.txt --- a/python/tests/storage.txt Fri Feb 03 16:15:06 2017 -0500 +++ b/python/tests/storage.txt Fri Feb 03 16:26:18 2017 -0500 @@ -86,11 +86,11 @@ >>> readline(strio, '%#') 'sadlkfjsdlakjf' ->>> from sklearn.mixture import GMM +>>> from sklearn.mixture import GaussianMixture >>> from numpy.random import random_sample >>> nPoints = 50 >>> points = random_sample(nPoints*2).reshape(nPoints,2) ->>> gmm = GMM(4, covariance_type = 'full') +>>> gmm = GaussianMixture(4, covariance_type = 'full') >>> tmp = gmm.fit(points) >>> id = 0 >>> savePOIs('pois-tmp.sqlite', gmm, 'end', id) diff -r 1535251a1f40 -r 6db83beb5350 scripts/learn-poi.py --- a/scripts/learn-poi.py Fri Feb 03 16:15:06 2017 -0500 +++ b/scripts/learn-poi.py Fri Feb 03 16:26:18 2017 -0500 @@ -40,7 +40,7 @@ [beginnings, ends], ['beginning', 'end']): # estimation - gmm = mixture.GMM(n_components=nClusters, covariance_type = args.covarianceType) + gmm = mixture.GaussianMixture(n_components=nClusters, covariance_type = args.covarianceType) model=gmm.fit(beginnings) if not model.converged_: print('Warning: model for '+gmmType+' points did not converge')