Mercurial > hg > nsaunier > traffic-intelligence
comparison python/storage.py @ 588:c5406edbcf12
added loading ground truth annotations (ground truth) from polytrack format
| author | Nicolas Saunier <nicolas.saunier@polymtl.ca> |
|---|---|
| date | Fri, 05 Dec 2014 00:54:38 -0500 |
| parents | cf578ba866da |
| children | 5800a87f11ae 0954aaf28231 |
comparison
equal
deleted
inserted
replaced
| 587:cf578ba866da | 588:c5406edbcf12 |
|---|---|
| 32 for tableName in tableNames: | 32 for tableName in tableNames: |
| 33 cursor.execute('DROP TABLE IF EXISTS '+tableName) | 33 cursor.execute('DROP TABLE IF EXISTS '+tableName) |
| 34 except sqlite3.OperationalError as error: | 34 except sqlite3.OperationalError as error: |
| 35 printDBError(error) | 35 printDBError(error) |
| 36 | 36 |
| 37 # TODO: add test if database connection is open | |
| 37 # IO to sqlite | 38 # IO to sqlite |
| 38 def writeTrajectoriesToSqlite(objects, outputFilename, trajectoryType, objectNumbers = -1): | 39 def writeTrajectoriesToSqlite(objects, outputFilename, trajectoryType, objectNumbers = -1): |
| 39 """ | 40 """ |
| 40 This function writers trajectories to a specified sqlite file | 41 This function writers trajectories to a specified sqlite file |
| 41 @param[in] objects -> a list of trajectories | 42 @param[in] objects -> a list of trajectories |
| 269 | 270 |
| 270 returns a moving object''' | 271 returns a moving object''' |
| 271 cursor = connection.cursor() | 272 cursor = connection.cursor() |
| 272 | 273 |
| 273 try: | 274 try: |
| 274 trajectoryIdQuery = getTrajectoryIdQuery(objectNumbers, trajectoryType) | 275 idQuery = getTrajectoryIdQuery(objectNumbers, trajectoryType) |
| 275 if trajectoryType == 'feature': | 276 if trajectoryType == 'feature': |
| 276 queryStatement = 'SELECT * from '+tableName+' '+trajectoryIdQuery+'ORDER BY trajectory_id, frame_number' | 277 queryStatement = 'SELECT * from '+tableName+' '+idQuery+'ORDER BY trajectory_id, frame_number' |
| 277 cursor.execute(queryStatement) | 278 cursor.execute(queryStatement) |
| 278 logging.debug(queryStatement) | 279 logging.debug(queryStatement) |
| 279 elif trajectoryType == 'object': | 280 elif trajectoryType == 'object': |
| 280 queryStatement = 'SELECT OF.object_id, P.frame_number, avg(P.x_coordinate), avg(P.y_coordinate) from '+tableName+' P, objects_features OF where P.trajectory_id = OF.trajectory_id '+objectIdQuery+'group by OF.object_id, P.frame_number ORDER BY OF.object_id, P.frame_number' | 281 queryStatement = 'SELECT OF.object_id, P.frame_number, avg(P.x_coordinate), avg(P.y_coordinate) from '+tableName+' P, objects_features OF where P.trajectory_id = OF.trajectory_id '+idQuery+'group by OF.object_id, P.frame_number ORDER BY OF.object_id, P.frame_number' |
| 281 cursor.execute(queryStatement) | 282 cursor.execute(queryStatement) |
| 282 logging.debug(queryStatement) | 283 logging.debug(queryStatement) |
| 283 elif trajectoryType == 'bbtop' or trajectoryType == 'bbbottom': | 284 elif trajectoryType in ['bbtop', 'bbbottom']: |
| 284 if trajectoryType == 'bbtop': | 285 if trajectoryType == 'bbtop': |
| 285 corner = 'top_left' | 286 corner = 'top_left' |
| 286 elif trajectoryType == 'bbbottom': | 287 elif trajectoryType == 'bbbottom': |
| 287 corner = 'bottom_right' | 288 corner = 'bottom_right' |
| 288 queryStatement = 'SELECT object_id, frame_number, x_'+corner+', y_'+corner+' FROM '+tableName+' '+trajectoryIdQuery+'ORDER BY object_id, frame_number' | 289 queryStatement = 'SELECT object_id, frame_number, x_'+corner+', y_'+corner+' FROM '+tableName+' '+trajectoryIdQuery+'ORDER BY object_id, frame_number' |
| 298 obj = None | 299 obj = None |
| 299 objects = [] | 300 objects = [] |
| 300 for row in cursor: | 301 for row in cursor: |
| 301 if row[0] != objId: | 302 if row[0] != objId: |
| 302 objId = row[0] | 303 objId = row[0] |
| 303 if obj != None: | 304 if obj != None and obj.length() == obj.positions.length(): |
| 304 objects.append(obj) | 305 objects.append(obj) |
| 306 elif obj != None: | |
| 307 print('Object {} is missing {} positions'.format(obj.getNum(), int(obj.length())-obj.positions.length())) | |
| 305 obj = moving.MovingObject(row[0], timeInterval = moving.TimeInterval(row[1], row[1]), positions = moving.Trajectory([[row[2]],[row[3]]])) | 308 obj = moving.MovingObject(row[0], timeInterval = moving.TimeInterval(row[1], row[1]), positions = moving.Trajectory([[row[2]],[row[3]]])) |
| 306 else: | 309 else: |
| 307 obj.timeInterval.last = row[1] | 310 obj.timeInterval.last = row[1] |
| 308 obj.positions.addPositionXY(row[2],row[3]) | 311 obj.positions.addPositionXY(row[2],row[3]) |
| 309 | 312 |
| 310 if obj: | 313 if obj != None and obj.length() == obj.positions.length(): |
| 311 objects.append(obj) | 314 objects.append(obj) |
| 315 elif obj != None: | |
| 316 print('Object {} is missing {} positions'.format(obj.getNum(), int(obj.length())-obj.positions.length())) | |
| 312 | 317 |
| 313 return objects | 318 return objects |
| 319 | |
| 320 def loadUserTypesFromTable(cursor, trajectoryType, objectNumbers): | |
| 321 objectIdQuery = getTrajectoryIdQuery(objectNumbers, trajectoryType) | |
| 322 if objectIdQuery == '': | |
| 323 cursor.execute('SELECT object_id, road_user_type from objects') | |
| 324 else: | |
| 325 cursor.execute('SELECT object_id, road_user_type from objects where '+objectIdQuery[7:]) | |
| 326 userTypes = {} | |
| 327 for row in cursor: | |
| 328 userTypes[row[0]] = row[1] | |
| 329 return userTypes | |
| 314 | 330 |
| 315 def loadTrajectoriesFromSqlite(filename, trajectoryType, objectNumbers = None): | 331 def loadTrajectoriesFromSqlite(filename, trajectoryType, objectNumbers = None): |
| 316 '''Loads the first objectNumbers objects or the indices in objectNumbers from the database''' | 332 '''Loads the first objectNumbers objects or the indices in objectNumbers from the database''' |
| 317 connection = sqlite3.connect(filename) # add test if it open | 333 connection = sqlite3.connect(filename) |
| 318 | 334 |
| 319 objects = loadTrajectoriesFromTable(connection, 'positions', trajectoryType, objectNumbers) | 335 objects = loadTrajectoriesFromTable(connection, 'positions', trajectoryType, objectNumbers) |
| 320 objectVelocities = loadTrajectoriesFromTable(connection, 'velocities', trajectoryType, objectNumbers) | 336 objectVelocities = loadTrajectoriesFromTable(connection, 'velocities', trajectoryType, objectNumbers) |
| 321 | 337 |
| 322 if len(objectVelocities) > 0: | 338 if len(objectVelocities) > 0: |
| 346 | 362 |
| 347 for obj in objects: | 363 for obj in objects: |
| 348 obj.featureNumbers = featureNumbers[obj.getNum()] | 364 obj.featureNumbers = featureNumbers[obj.getNum()] |
| 349 | 365 |
| 350 # load userType | 366 # load userType |
| 351 if objectIdQuery == '': | 367 userTypes = loadUserTypesFromTable(cursor, trajectoryType, objectNumbers) |
| 352 cursor.execute('SELECT object_id, road_user_type from objects') | |
| 353 else: | |
| 354 cursor.execute('SELECT object_id, road_user_type from objects where '+objectIdQuery[7:]) | |
| 355 userTypes = {} | |
| 356 for row in cursor: | |
| 357 userTypes[row[0]] = row[1] | |
| 358 | |
| 359 for obj in objects: | 368 for obj in objects: |
| 360 obj.userType = userTypes[obj.getNum()] | 369 obj.userType = userTypes[obj.getNum()] |
| 361 | 370 |
| 362 except sqlite3.OperationalError as error: | 371 except sqlite3.OperationalError as error: |
| 363 printDBError(error) | 372 printDBError(error) |
| 364 return [] | 373 objects = [] |
| 365 | 374 |
| 366 connection.close() | 375 connection.close() |
| 367 return objects | 376 return objects |
| 377 | |
| 378 def loadGroundTruthFromSqlite(filename, gtType, gtNumbers = None): | |
| 379 'Loads bounding box annotations (ground truth) from an SQLite ' | |
| 380 connection = sqlite3.connect(filename) | |
| 381 gt = [] | |
| 382 | |
| 383 if gtType == 'bb': | |
| 384 topCorners = loadTrajectoriesFromTable(connection, 'bounding_boxes', 'bbtop', gtNumbers) | |
| 385 bottomCorners = loadTrajectoriesFromTable(connection, 'bounding_boxes', 'bbbottom', gtNumbers) | |
| 386 userTypes = loadUserTypesFromTable(connection.cursor(), 'object', gtNumbers) # string format is same as object | |
| 387 | |
| 388 for t, b in zip(topCorners, bottomCorners): | |
| 389 num = t.getNum() | |
| 390 if t.getNum() == b.getNum(): | |
| 391 annotation = moving.BBAnnotation(num, t.getTimeInterval(), t, b, userTypes[num]) | |
| 392 gt.append(annotation) | |
| 393 else: | |
| 394 print ('Unknown type of annotation {}'.format(gtType)) | |
| 395 | |
| 396 connection.close() | |
| 397 return gt | |
| 368 | 398 |
| 369 def deleteFromSqlite(filename, dataType): | 399 def deleteFromSqlite(filename, dataType): |
| 370 'Deletes (drops) some tables in the filename depending on type of data' | 400 'Deletes (drops) some tables in the filename depending on type of data' |
| 371 import os | 401 import os |
| 372 if os.path.isfile(filename): | 402 if os.path.isfile(filename): |
