Mercurial > hg > nsaunier > traffic-intelligence
comparison python/storage.py @ 927:c030f735c594
added assignment of trajectories to prototypes and cleanup of insert queries
| author | Nicolas Saunier <nicolas.saunier@polymtl.ca> |
|---|---|
| date | Tue, 11 Jul 2017 17:56:23 -0400 |
| parents | acb5379c5fd7 |
| children | 0e63a918a1ca |
comparison
equal
deleted
inserted
replaced
| 926:dbd81710d515 | 927:c030f735c594 |
|---|---|
| 52 elif dataType == 'bb': | 52 elif dataType == 'bb': |
| 53 dropTables(connection, ['bounding_boxes']) | 53 dropTables(connection, ['bounding_boxes']) |
| 54 elif dataType == 'pois': | 54 elif dataType == 'pois': |
| 55 dropTables(connection, ['gaussians2d', 'objects_pois']) | 55 dropTables(connection, ['gaussians2d', 'objects_pois']) |
| 56 elif dataType == 'prototype': | 56 elif dataType == 'prototype': |
| 57 dropTables(connection, ['prototypes']) | 57 dropTables(connection, ['prototypes', 'objects_prototypes']) |
| 58 else: | 58 else: |
| 59 print('Unknown data type {} to delete from database'.format(dataType)) | 59 print('Unknown data type {} to delete from database'.format(dataType)) |
| 60 connection.close() | 60 connection.close() |
| 61 else: | 61 else: |
| 62 print('{} does not exist'.format(filename)) | 62 print('{} does not exist'.format(filename)) |
| 98 | 98 |
| 99 def createIndicatorTable(cursor): | 99 def createIndicatorTable(cursor): |
| 100 cursor.execute('CREATE TABLE IF NOT EXISTS indicators (interaction_id INTEGER, indicator_type INTEGER, frame_number INTEGER, value REAL, FOREIGN KEY(interaction_id) REFERENCES interactions(id), PRIMARY KEY(interaction_id, indicator_type, frame_number))') | 100 cursor.execute('CREATE TABLE IF NOT EXISTS indicators (interaction_id INTEGER, indicator_type INTEGER, frame_number INTEGER, value REAL, FOREIGN KEY(interaction_id) REFERENCES interactions(id), PRIMARY KEY(interaction_id, indicator_type, frame_number))') |
| 101 | 101 |
| 102 def insertTrajectoryQuery(tableName): | 102 def insertTrajectoryQuery(tableName): |
| 103 return "INSERT INTO "+tableName+" (trajectory_id, frame_number, x_coordinate, y_coordinate) VALUES (?,?,?,?)" | 103 return "INSERT INTO "+tableName+" VALUES (?,?,?,?)" |
| 104 | 104 |
| 105 def insertObjectQuery(): | 105 def insertObjectQuery(): |
| 106 return "INSERT INTO objects (object_id, road_user_type, n_objects) VALUES (?,?,?)" | 106 return "INSERT INTO objects VALUES (?,?,?)" |
| 107 | 107 |
| 108 def insertObjectFeatureQuery(): | 108 def insertObjectFeatureQuery(): |
| 109 return "INSERT INTO objects_features (object_id, trajectory_id) VALUES (?,?)" | 109 return "INSERT INTO objects_features VALUES (?,?)" |
| 110 | 110 |
| 111 def createIndex(connection, tableName, columnName, unique = False): | 111 def createIndex(connection, tableName, columnName, unique = False): |
| 112 '''Creates an index for the column in the table | 112 '''Creates an index for the column in the table |
| 113 I will make querying with a condition on this column faster''' | 113 I will make querying with a condition on this column faster''' |
| 114 try: | 114 try: |
| 149 else: | 149 else: |
| 150 print("Argument minmax unknown: {}".format(minmax)) | 150 print("Argument minmax unknown: {}".format(minmax)) |
| 151 return cursor.fetchone()[0] | 151 return cursor.fetchone()[0] |
| 152 except sqlite3.OperationalError as error: | 152 except sqlite3.OperationalError as error: |
| 153 printDBError(error) | 153 printDBError(error) |
| 154 | |
| 155 def loadPrototypeMatchIndexesFromSqlite(filename): | |
| 156 """ | |
| 157 This function loads the prototypes table in the database of name <filename>. | |
| 158 It returns a list of tuples representing matching ids : [(prototype_id, matched_trajectory_id),...] | |
| 159 """ | |
| 160 matched_indexes = [] | |
| 161 | |
| 162 connection = sqlite3.connect(filename) | |
| 163 cursor = connection.cursor() | |
| 164 | |
| 165 try: | |
| 166 cursor.execute('SELECT * from prototypes order by prototype_id, trajectory_id_matched') | |
| 167 except sqlite3.OperationalError as error: | |
| 168 printDBError(error) | |
| 169 return [] | |
| 170 | |
| 171 for row in cursor: | |
| 172 matched_indexes.append((row[0],row[1])) | |
| 173 | |
| 174 connection.close() | |
| 175 return matched_indexes | |
| 176 | 154 |
| 177 def getObjectCriteria(objectNumbers): | 155 def getObjectCriteria(objectNumbers): |
| 178 if objectNumbers is None: | 156 if objectNumbers is None: |
| 179 query = '' | 157 query = '' |
| 180 elif type(objectNumbers) == int: | 158 elif type(objectNumbers) == int: |
| 430 cursor.execute(objectFeatureQuery, (obj.getNum(), featureNum)) | 408 cursor.execute(objectFeatureQuery, (obj.getNum(), featureNum)) |
| 431 cursor.execute(objectQuery, (obj.getNum(), obj.getUserType(), 1)) | 409 cursor.execute(objectQuery, (obj.getNum(), obj.getUserType(), 1)) |
| 432 # Parse curvilinear position structure | 410 # Parse curvilinear position structure |
| 433 elif(trajectoryType == 'curvilinear'): | 411 elif(trajectoryType == 'curvilinear'): |
| 434 createCurvilinearTrajectoryTable(cursor) | 412 createCurvilinearTrajectoryTable(cursor) |
| 435 curvilinearQuery = "insert into curvilinear_positions (trajectory_id, frame_number, s_coordinate, y_coordinate, lane) values (?,?,?,?,?)" | 413 curvilinearQuery = "INSERT INTO curvilinear_positions VALUES (?,?,?,?,?)" |
| 436 for obj in objects: | 414 for obj in objects: |
| 437 num = obj.getNum() | 415 num = obj.getNum() |
| 438 frameNum = obj.getFirstInstant() | 416 frameNum = obj.getFirstInstant() |
| 439 for p in obj.getCurvilinearPositions(): | 417 for p in obj.getCurvilinearPositions(): |
| 440 cursor.execute(curvilinearQuery, (num, frameNum, p[0], p[1], p[2])) | 418 cursor.execute(curvilinearQuery, (num, frameNum, p[0], p[1], p[2])) |
| 483 | 461 |
| 484 def saveInteraction(cursor, interaction): | 462 def saveInteraction(cursor, interaction): |
| 485 roadUserNumbers = list(interaction.getRoadUserNumbers()) | 463 roadUserNumbers = list(interaction.getRoadUserNumbers()) |
| 486 cursor.execute('INSERT INTO interactions VALUES({}, {}, {}, {}, {})'.format(interaction.getNum(), roadUserNumbers[0], roadUserNumbers[1], interaction.getFirstInstant(), interaction.getLastInstant())) | 464 cursor.execute('INSERT INTO interactions VALUES({}, {}, {}, {}, {})'.format(interaction.getNum(), roadUserNumbers[0], roadUserNumbers[1], interaction.getFirstInstant(), interaction.getLastInstant())) |
| 487 | 465 |
| 488 def saveInteractions(filename, interactions): | 466 def saveInteractionsToSqlite(filename, interactions): |
| 489 'Saves the interactions in the table' | 467 'Saves the interactions in the table' |
| 490 connection = sqlite3.connect(filename) | 468 connection = sqlite3.connect(filename) |
| 491 cursor = connection.cursor() | 469 cursor = connection.cursor() |
| 492 try: | 470 try: |
| 493 createInteractionTable(cursor) | 471 createInteractionTable(cursor) |
| 501 def saveIndicator(cursor, interactionNum, indicator): | 479 def saveIndicator(cursor, interactionNum, indicator): |
| 502 for instant in indicator.getTimeInterval(): | 480 for instant in indicator.getTimeInterval(): |
| 503 if indicator[instant]: | 481 if indicator[instant]: |
| 504 cursor.execute('INSERT INTO indicators VALUES({}, {}, {}, {})'.format(interactionNum, events.Interaction.indicatorNameToIndices[indicator.getName()], instant, indicator[instant])) | 482 cursor.execute('INSERT INTO indicators VALUES({}, {}, {}, {})'.format(interactionNum, events.Interaction.indicatorNameToIndices[indicator.getName()], instant, indicator[instant])) |
| 505 | 483 |
| 506 def saveIndicators(filename, interactions, indicatorNames = events.Interaction.indicatorNames): | 484 def saveIndicatorsToSqlite(filename, interactions, indicatorNames = events.Interaction.indicatorNames): |
| 507 'Saves the indicator values in the table' | 485 'Saves the indicator values in the table' |
| 508 connection = sqlite3.connect(filename) | 486 connection = sqlite3.connect(filename) |
| 509 cursor = connection.cursor() | 487 cursor = connection.cursor() |
| 510 try: | 488 try: |
| 511 createInteractionTable(cursor) | 489 createInteractionTable(cursor) |
| 519 except sqlite3.OperationalError as error: | 497 except sqlite3.OperationalError as error: |
| 520 printDBError(error) | 498 printDBError(error) |
| 521 connection.commit() | 499 connection.commit() |
| 522 connection.close() | 500 connection.close() |
| 523 | 501 |
| 524 def loadInteractions(filename): | 502 def loadInteractionsFromSqlite(filename): |
| 525 '''Loads interaction and their indicators | 503 '''Loads interaction and their indicators |
| 526 | 504 |
| 527 TODO choose the interactions to load''' | 505 TODO choose the interactions to load''' |
| 528 interactions = [] | 506 interactions = [] |
| 529 connection = sqlite3.connect(filename) | 507 connection = sqlite3.connect(filename) |
| 530 cursor = connection.cursor() | 508 cursor = connection.cursor() |
| 531 try: | 509 try: |
| 532 cursor.execute('select INT.id, INT.object_id1, INT.object_id2, INT.first_frame_number, INT.last_frame_number, IND.indicator_type, IND.frame_number, IND.value from interactions INT, indicators IND WHERE INT.id = IND.interaction_id ORDER BY INT.id, IND.indicator_type, IND.frame_number') | 510 cursor.execute('SELECT INT.id, INT.object_id1, INT.object_id2, INT.first_frame_number, INT.last_frame_number, IND.indicator_type, IND.frame_number, IND.value from interactions INT, indicators IND WHERE INT.id = IND.interaction_id ORDER BY INT.id, IND.indicator_type, IND.frame_number') |
| 533 interactionNum = -1 | 511 interactionNum = -1 |
| 534 indicatorTypeNum = -1 | 512 indicatorTypeNum = -1 |
| 535 tmpIndicators = {} | 513 tmpIndicators = {} |
| 536 for row in cursor: | 514 for row in cursor: |
| 537 if row[0] != interactionNum: | 515 if row[0] != interactionNum: |
| 595 def savePrototypesToSqlite(filename, prototypes): | 573 def savePrototypesToSqlite(filename, prototypes): |
| 596 '''save the prototypes (a prototype is defined by a filename, a number and type''' | 574 '''save the prototypes (a prototype is defined by a filename, a number and type''' |
| 597 connection = sqlite3.connect(filename) | 575 connection = sqlite3.connect(filename) |
| 598 cursor = connection.cursor() | 576 cursor = connection.cursor() |
| 599 try: | 577 try: |
| 600 cursor.execute('CREATE TABLE IF NOT EXISTS prototypes (filename VARCHAR, id INTEGER, trajectory_type VARCHAR CHECK (trajectory_type IN (\"feature\", \"object\")), nmatchings INTEGER, PRIMARY KEY (filename, id))') | 578 cursor.execute('CREATE TABLE IF NOT EXISTS prototypes (prototype_filename VARCHAR, prototype_id INTEGER, trajectory_type VARCHAR CHECK (trajectory_type IN (\"feature\", \"object\")), nmatchings INTEGER, PRIMARY KEY (prototype_filename, prototype_id, trajectory_type))') |
| 601 for p in prototypes: | 579 for p in prototypes: |
| 602 cursor.execute('INSERT INTO prototypes (filename, id, trajectory_type, nmatchings) VALUES (?,?,?,?)', (p.getFilename(), p.getNum(), p.getTrajectoryType(), p.getNMatchings())) | 580 cursor.execute('INSERT INTO prototypes VALUES(?,?,?,?)', (p.getFilename(), p.getNum(), p.getTrajectoryType(), p.getNMatchings())) |
| 603 except sqlite3.OperationalError as error: | 581 except sqlite3.OperationalError as error: |
| 604 printDBError(error) | 582 printDBError(error) |
| 605 connection.commit() | 583 connection.commit() |
| 606 connection.close() | 584 connection.close() |
| 607 | 585 |
| 608 def savePrototypeAssignments(filename, objects): | 586 def savePrototypeAssignmentsToSqlite(filename, objects, labels, prototypes): |
| 609 pass | 587 connection = sqlite3.connect(filename) |
| 588 cursor = connection.cursor() | |
| 589 try: | |
| 590 cursor.execute('CREATE TABLE IF NOT EXISTS objects_prototypes (object_id INTEGER, prototype_filename VARCHAR, prototype_id INTEGER, trajectory_type VARCHAR CHECK (trajectory_type IN (\"feature\", \"object\")), PRIMARY KEY(object_id, prototype_filename, prototype_id, trajectory_type))') | |
| 591 for obj, label in zip(objects, labels): | |
| 592 proto = prototypes[label] | |
| 593 cursor.execute('INSERT INTO objects_prototypes VALUES(?,?,?,?)', (obj.getNum(), proto.getFilename(), proto.getNum(), proto.getTrajectoryType())) | |
| 594 except sqlite3.OperationalError as error: | |
| 595 printDBError(error) | |
| 596 connection.commit() | |
| 597 connection.close() | |
| 610 | 598 |
| 611 def loadPrototypesFromSqlite(filename, withTrajectories = True): | 599 def loadPrototypesFromSqlite(filename, withTrajectories = True): |
| 612 'Loads prototype ids and matchings (if stored)' | 600 'Loads prototype ids and matchings (if stored)' |
| 613 connection = sqlite3.connect(filename) | 601 connection = sqlite3.connect(filename) |
| 614 cursor = connection.cursor() | 602 cursor = connection.cursor() |
| 636 connection.close() | 624 connection.close() |
| 637 if len(set([p.getTrajectoryType() for p in prototypes])) > 1: | 625 if len(set([p.getTrajectoryType() for p in prototypes])) > 1: |
| 638 print('Different types of prototypes in database ({}).'.format(set([p.getTrajectoryType() for p in prototypes]))) | 626 print('Different types of prototypes in database ({}).'.format(set([p.getTrajectoryType() for p in prototypes]))) |
| 639 return prototypes | 627 return prototypes |
| 640 | 628 |
| 641 def savePOIs(filename, gmm, gmmType, gmmId): | 629 def savePOIsToSqlite(filename, gmm, gmmType, gmmId): |
| 642 '''Saves a Gaussian mixture model (of class sklearn.mixture.GaussianMixture) | 630 '''Saves a Gaussian mixture model (of class sklearn.mixture.GaussianMixture) |
| 643 gmmType is a type of GaussianMixture, learnt either from beginnings or ends of trajectories''' | 631 gmmType is a type of GaussianMixture, learnt either from beginnings or ends of trajectories''' |
| 644 connection = sqlite3.connect(filename) | 632 connection = sqlite3.connect(filename) |
| 645 cursor = connection.cursor() | 633 cursor = connection.cursor() |
| 646 if gmmType not in ['beginning', 'end']: | 634 if gmmType not in ['beginning', 'end']: |
| 648 import sys | 636 import sys |
| 649 sys.exit() | 637 sys.exit() |
| 650 try: | 638 try: |
| 651 cursor.execute('CREATE TABLE IF NOT EXISTS gaussians2d (poi_id INTEGER, id INTEGER, type VARCHAR, x_center REAL, y_center REAL, covariance VARCHAR, covariance_type VARCHAR, weight, precisions_cholesky VARCHAR, PRIMARY KEY(poi_id, id))') | 639 cursor.execute('CREATE TABLE IF NOT EXISTS gaussians2d (poi_id INTEGER, id INTEGER, type VARCHAR, x_center REAL, y_center REAL, covariance VARCHAR, covariance_type VARCHAR, weight, precisions_cholesky VARCHAR, PRIMARY KEY(poi_id, id))') |
| 652 for i in xrange(gmm.n_components): | 640 for i in xrange(gmm.n_components): |
| 653 cursor.execute('INSERT INTO gaussians2d VALUES({}, {}, \'{}\', {}, {}, \'{}\', \'{}\', {}, \'{}\')'.format(gmmId, i, gmmType, gmm.means_[i][0], gmm.means_[i][1], str(gmm.covariances_[i].tolist()), gmm.covariance_type, gmm.weights_[i], str(gmm.precisions_cholesky_[i].tolist()))) | 641 cursor.execute('INSERT INTO gaussians2d VALUES(?,?,?,?,?,?,?,?,?)', (gmmId, i, gmmType, gmm.means_[i][0], gmm.means_[i][1], str(gmm.covariances_[i].tolist()), gmm.covariance_type, gmm.weights_[i], str(gmm.precisions_cholesky_[i].tolist()))) |
| 654 connection.commit() | 642 connection.commit() |
| 655 except sqlite3.OperationalError as error: | 643 except sqlite3.OperationalError as error: |
| 656 printDBError(error) | 644 printDBError(error) |
| 657 connection.close() | 645 connection.close() |
| 658 | 646 |
| 659 def savePOIAssignments(filename, objects): | 647 def savePOIAssignmentsToSqlite(filename, objects): |
| 660 'save the od fields of objects' | 648 'save the od fields of objects' |
| 661 connection = sqlite3.connect(filename) | 649 connection = sqlite3.connect(filename) |
| 662 cursor = connection.cursor() | 650 cursor = connection.cursor() |
| 663 try: | 651 try: |
| 664 cursor.execute('CREATE TABLE IF NOT EXISTS objects_pois (object_id INTEGER, origin_poi_id INTEGER, destination_poi_id INTEGER, PRIMARY KEY(object_id))') | 652 cursor.execute('CREATE TABLE IF NOT EXISTS objects_pois (object_id INTEGER, origin_poi_id INTEGER, destination_poi_id INTEGER, PRIMARY KEY(object_id))') |
| 665 for o in objects: | 653 for o in objects: |
| 666 cursor.execute('INSERT INTO objects_pois VALUES({},{},{})'.format(o.getNum(), o.od[0], o.od[1])) | 654 cursor.execute('INSERT INTO objects_pois VALUES(?,?,?)', (o.getNum(), o.od[0], o.od[1])) |
| 667 connection.commit() | 655 connection.commit() |
| 668 except sqlite3.OperationalError as error: | 656 except sqlite3.OperationalError as error: |
| 669 printDBError(error) | 657 printDBError(error) |
| 670 connection.close() | 658 connection.close() |
| 671 | 659 |
| 672 def loadPOIs(filename): | 660 def loadPOIsFromSqlite(filename): |
| 673 'Loads all 2D Gaussians in the database' | 661 'Loads all 2D Gaussians in the database' |
| 674 from sklearn import mixture # todo if not avalaible, load data in duck-typed class with same fields | 662 from sklearn import mixture # todo if not avalaible, load data in duck-typed class with same fields |
| 675 from ast import literal_eval | 663 from ast import literal_eval |
| 676 connection = sqlite3.connect(filename) | 664 connection = sqlite3.connect(filename) |
| 677 cursor = connection.cursor() | 665 cursor = connection.cursor() |
