Mercurial > hg > nsaunier > traffic-intelligence
comparison python/storage.py @ 919:7b3f2e0a2652
saving and loading prototype trajectories
| author | Nicolas Saunier <nicolas.saunier@polymtl.ca> |
|---|---|
| date | Wed, 05 Jul 2017 13:16:47 -0400 |
| parents | 3a06007a4bb7 |
| children | 499154254f37 |
comparison
equal
deleted
inserted
replaced
| 918:3a06007a4bb7 | 919:7b3f2e0a2652 |
|---|---|
| 4 | 4 |
| 5 import utils, moving, events, indicators, shutil | 5 import utils, moving, events, indicators, shutil |
| 6 from base import VideoFilenameAddable | 6 from base import VideoFilenameAddable |
| 7 | 7 |
| 8 from os import path | 8 from os import path |
| 9 from copy import copy | |
| 9 import sqlite3, logging | 10 import sqlite3, logging |
| 10 from numpy import log, min as npmin, max as npmax, round as npround, array, sum as npsum, loadtxt, floor as npfloor, ceil as npceil, linalg | 11 from numpy import log, min as npmin, max as npmax, round as npround, array, sum as npsum, loadtxt, floor as npfloor, ceil as npceil, linalg |
| 11 from pandas import read_csv, merge | 12 from pandas import read_csv, merge |
| 12 | 13 |
| 13 | 14 |
| 51 elif dataType == 'bb': | 52 elif dataType == 'bb': |
| 52 dropTables(connection, ['bounding_boxes']) | 53 dropTables(connection, ['bounding_boxes']) |
| 53 elif dataType == 'pois': | 54 elif dataType == 'pois': |
| 54 dropTables(connection, ['gaussians2d', 'objects_pois']) | 55 dropTables(connection, ['gaussians2d', 'objects_pois']) |
| 55 elif dataType == 'prototype': | 56 elif dataType == 'prototype': |
| 56 dropTables(connection, ['prototypes']) | 57 dropTables(connection, ['prototypes', 'prototype_positions', 'prototype_velocities']) |
| 57 else: | 58 else: |
| 58 print('Unknown data type {} to delete from database'.format(dataType)) | 59 print('Unknown data type {} to delete from database'.format(dataType)) |
| 59 connection.close() | 60 connection.close() |
| 60 else: | 61 else: |
| 61 print('{} does not exist'.format(filename)) | 62 print('{} does not exist'.format(filename)) |
| 62 | 63 |
| 63 def tableExists(filename, tableName): | 64 def tableExists(connection, tableName): |
| 64 'indicates if the table exists in the database' | 65 'indicates if the table exists in the database' |
| 65 try: | 66 try: |
| 66 connection = sqlite3.connect(filename) | 67 #connection = sqlite3.connect(filename) |
| 67 cursor = connection.cursor() | 68 cursor = connection.cursor() |
| 68 cursor.execute('SELECT COUNT(*) FROM SQLITE_MASTER WHERE type = \'table\' AND name = \''+tableName+'\'') | 69 cursor.execute('SELECT COUNT(*) FROM SQLITE_MASTER WHERE type = \'table\' AND name = \''+tableName+'\'') |
| 69 return cursor.fetchone()[0] == 1 | 70 return cursor.fetchone()[0] == 1 |
| 70 except sqlite3.OperationalError as error: | 71 except sqlite3.OperationalError as error: |
| 71 printDBError(error) | 72 printDBError(error) |
| 72 | 73 |
| 73 def createTrajectoryTable(cursor, tableName): | 74 def createTrajectoryTable(cursor, tableName): |
| 74 if tableName in ['positions', 'velocities']: | 75 if tableName.endswith('positions') or tableName.endswith('velocities'): |
| 75 cursor.execute("CREATE TABLE IF NOT EXISTS "+tableName+" (trajectory_id INTEGER, frame_number INTEGER, x_coordinate REAL, y_coordinate REAL, PRIMARY KEY(trajectory_id, frame_number))") | 76 cursor.execute("CREATE TABLE IF NOT EXISTS "+tableName+" (trajectory_id INTEGER, frame_number INTEGER, x_coordinate REAL, y_coordinate REAL, PRIMARY KEY(trajectory_id, frame_number))") |
| 76 else: | 77 else: |
| 77 print('Unallowed name {} for trajectory table'.format(tableName)) | 78 print('Unallowed name {} for trajectory table'.format(tableName)) |
| 78 | 79 |
| 79 def createObjectsTable(cursor): | 80 def createObjectsTable(cursor): |
| 264 userTypes = {} | 265 userTypes = {} |
| 265 for row in cursor: | 266 for row in cursor: |
| 266 userTypes[row[0]] = row[1] | 267 userTypes[row[0]] = row[1] |
| 267 return userTypes | 268 return userTypes |
| 268 | 269 |
| 269 def loadTrajectoriesFromSqlite(filename, trajectoryType, objectNumbers = None, withFeatures = False, timeStep = None): | 270 def loadTrajectoriesFromSqlite(filename, trajectoryType, objectNumbers = None, withFeatures = False, timeStep = None, tablePrefix = None): |
| 270 '''Loads the trajectories (in the general sense, | 271 '''Loads the trajectories (in the general sense, |
| 271 either features, objects (feature groups) or bounding box series) | 272 either features, objects (feature groups) or bounding box series) |
| 272 The number loaded is either the first objectNumbers objects, | 273 The number loaded is either the first objectNumbers objects, |
| 273 or the indices in objectNumbers from the database''' | 274 or the indices in objectNumbers from the database''' |
| 274 connection = sqlite3.connect(filename) | 275 connection = sqlite3.connect(filename) |
| 275 | 276 |
| 276 objects = loadTrajectoriesFromTable(connection, 'positions', trajectoryType, objectNumbers, timeStep) | 277 if tablePrefix is None: |
| 277 objectVelocities = loadTrajectoriesFromTable(connection, 'velocities', trajectoryType, objectNumbers, timeStep) | 278 prefix = '' |
| 279 else: | |
| 280 prefix = tablePrefix + '_' | |
| 281 objects = loadTrajectoriesFromTable(connection, prefix+'positions', trajectoryType, objectNumbers, timeStep) | |
| 282 objectVelocities = loadTrajectoriesFromTable(connection, prefix+'velocities', trajectoryType, objectNumbers, timeStep) | |
| 278 | 283 |
| 279 if len(objectVelocities) > 0: | 284 if len(objectVelocities) > 0: |
| 280 for o,v in zip(objects, objectVelocities): | 285 for o,v in zip(objects, objectVelocities): |
| 281 if o.getNum() == v.getNum(): | 286 if o.getNum() == v.getNum(): |
| 282 o.velocities = v.positions | 287 o.velocities = v.positions |
| 588 ######################### | 593 ######################### |
| 589 | 594 |
| 590 def savePrototypesToSqlite(filename, prototypeIndices, trajectoryType, objects = None, nMatchings = None, dbFilenames = None): | 595 def savePrototypesToSqlite(filename, prototypeIndices, trajectoryType, objects = None, nMatchings = None, dbFilenames = None): |
| 591 '''save the prototype indices | 596 '''save the prototype indices |
| 592 if objects is not None, the trajectories are also saved in prototype_positions and _velocities | 597 if objects is not None, the trajectories are also saved in prototype_positions and _velocities |
| 593 (prototypeIndices have to be in objects) | 598 (prototypeIndices have to be in objects |
| 599 objects will be saved as features, with the centroid trajectory as if it is a feature) | |
| 594 nMatchings, if not None, is a list of the number of matches | 600 nMatchings, if not None, is a list of the number of matches |
| 595 dbFilenames, if not None, is a list of the DB filenames''' | 601 dbFilenames, if not None, is a list of the DB filenames |
| 602 | |
| 603 The order of prototypeIndices, objects, nMatchings and dbFilenames should be consistent''' | |
| 596 connection = sqlite3.connect(filename) | 604 connection = sqlite3.connect(filename) |
| 597 cursor = connection.cursor() | 605 cursor = connection.cursor() |
| 598 try: | 606 try: |
| 599 cursor.execute('CREATE TABLE IF NOT EXISTS prototypes (id INTEGER, dbfilename VARCHAR, trajectory_type VARCHAR CHECK (trajectory_type IN (\"feature\", \"object\")), nmatchings INTEGER, positions_id INTEGER, PRIMARY KEY (id, dbfilename))') | 607 cursor.execute('CREATE TABLE IF NOT EXISTS prototypes (id INTEGER, dbfilename VARCHAR, trajectory_type VARCHAR CHECK (trajectory_type IN (\"feature\", \"object\")), nmatchings INTEGER, positions_id INTEGER, PRIMARY KEY (id, dbfilename))') |
| 600 for i, protoId in enumerate(prototypeIndices): | 608 for i, protoId in enumerate(prototypeIndices): |
| 605 if dbFilenames is not None: | 613 if dbFilenames is not None: |
| 606 dbfn = dbFilenames[i] | 614 dbfn = dbFilenames[i] |
| 607 else: | 615 else: |
| 608 dbfn = filename | 616 dbfn = filename |
| 609 cursor.execute('INSERT INTO prototypes (id, dbfilename, trajectory_type, nmatchings, positions_id) VALUES ({},\"{}\",\"{}\",{}, {})'.format(protoId, dbfn, trajectoryType, n, i)) | 617 cursor.execute('INSERT INTO prototypes (id, dbfilename, trajectory_type, nmatchings, positions_id) VALUES ({},\"{}\",\"{}\",{}, {})'.format(protoId, dbfn, trajectoryType, n, i)) |
| 610 #cursor.execute('SELECT * from sqlite_master WHERE type = \"table\" and name = \"{}\"'.format(tableNames[trajectoryType])) | |
| 611 if objects is not None: # save positions and velocities | 618 if objects is not None: # save positions and velocities |
| 612 pass | 619 features = [] |
| 620 for i, o in enumerate(objects): | |
| 621 f = copy(o) | |
| 622 f.num = i | |
| 623 features.append(f) | |
| 624 saveTrajectoriesToTable(connection, features, 'feature', 'prototype') | |
| 613 except sqlite3.OperationalError as error: | 625 except sqlite3.OperationalError as error: |
| 614 printDBError(error) | 626 printDBError(error) |
| 615 connection.commit() | 627 connection.commit() |
| 616 connection.close() | 628 connection.close() |
| 617 | 629 |
| 624 cursor = connection.cursor() | 636 cursor = connection.cursor() |
| 625 prototypeIndices = [] | 637 prototypeIndices = [] |
| 626 dbFilenames = [] | 638 dbFilenames = [] |
| 627 trajectoryTypes = [] | 639 trajectoryTypes = [] |
| 628 nMatchings = [] | 640 nMatchings = [] |
| 641 trajectoryNumbers = [] | |
| 629 try: | 642 try: |
| 630 cursor.execute('SELECT * FROM prototypes') | 643 cursor.execute('SELECT * FROM prototypes') |
| 631 for row in cursor: | 644 for row in cursor: |
| 632 prototypeIndices.append(row[0]) | 645 prototypeIndices.append(row[0]) |
| 633 dbFilenames.append(row[1]) | 646 dbFilenames.append(row[1]) |
| 634 trajectoryTypes.append(row[2]) | 647 trajectoryTypes.append(row[2]) |
| 635 if row[3] is not None: | 648 if row[3] is not None: |
| 636 nMatchings.append(row[3]) | 649 nMatchings.append(row[3]) |
| 650 if row[4] is not None: | |
| 651 trajectoryNumbers.append(row[4]) | |
| 652 if tableExists(connection, 'prototype_positions'): # load prototypes trajectories | |
| 653 objects = loadTrajectoriesFromSqlite(filename, 'feature', trajectoryNumbers, tablePrefix = 'prototype') | |
| 654 else: | |
| 655 objects = None | |
| 637 except sqlite3.OperationalError as error: | 656 except sqlite3.OperationalError as error: |
| 638 printDBError(error) | 657 printDBError(error) |
| 639 connection.close() | 658 connection.close() |
| 640 if len(set(trajectoryTypes)) > 1: | 659 if len(set(trajectoryTypes)) > 1: |
| 641 print('Different types of prototypes in database ({}).'.format(set(trajectoryTypes))) | 660 print('Different types of prototypes in database ({}).'.format(set(trajectoryTypes))) |
| 642 return prototypeIndices, dbFilenames, trajectoryTypes, nMatchings | 661 return prototypeIndices, dbFilenames, trajectoryTypes, nMatchings, objects |
| 643 | 662 |
| 644 def savePOIs(filename, gmm, gmmType, gmmId): | 663 def savePOIs(filename, gmm, gmmType, gmmId): |
| 645 '''Saves a Gaussian mixture model (of class sklearn.mixture.GaussianMixture) | 664 '''Saves a Gaussian mixture model (of class sklearn.mixture.GaussianMixture) |
| 646 gmmType is a type of GaussianMixture, learnt either from beginnings or ends of trajectories''' | 665 gmmType is a type of GaussianMixture, learnt either from beginnings or ends of trajectories''' |
| 647 connection = sqlite3.connect(filename) | 666 connection = sqlite3.connect(filename) |
