Mercurial > hg > nsaunier > traffic-intelligence
comparison scripts/dltrack.py @ 1236:100fe098abe9
progress on classification
| author | Nicolas Saunier <nicolas.saunier@polymtl.ca> |
|---|---|
| date | Tue, 19 Sep 2023 17:04:30 -0400 |
| parents | 855abc69fa99 |
| children | 31a441efca6c |
comparison
equal
deleted
inserted
replaced
| 1235:855abc69fa99 | 1236:100fe098abe9 |
|---|---|
| 1 #! /usr/bin/env python3 | 1 #! /usr/bin/env python3 |
| 2 # from https://docs.ultralytics.com/modes/track/ | 2 # from https://docs.ultralytics.com/modes/track/ |
| 3 import sys, argparse | 3 import sys, argparse |
| 4 from copy import copy | 4 from copy import copy |
| 5 from collections import Counter | |
| 5 from ultralytics import YOLO | 6 from ultralytics import YOLO |
| 7 from torch import cat | |
| 8 from torchvision import ops | |
| 6 import cv2 | 9 import cv2 |
| 7 | 10 |
| 8 from trafficintelligence import cvutils, moving, storage, utils | 11 from trafficintelligence import cvutils, moving, storage, utils |
| 9 | 12 |
| 10 parser = argparse.ArgumentParser(description='The program tracks objects following the ultralytics yolo executable.')#, epilog = 'Either the configuration filename or the other parameters (at least video and database filenames) need to be provided.') | 13 parser = argparse.ArgumentParser(description='The program tracks objects following the ultralytics yolo executable.')#, epilog = 'Either the configuration filename or the other parameters (at least video and database filenames) need to be provided.') |
| 13 parser.add_argument('-m', dest = 'detectorFilename', help = 'name of the detection model file', required = True) | 16 parser.add_argument('-m', dest = 'detectorFilename', help = 'name of the detection model file', required = True) |
| 14 parser.add_argument('-t', dest = 'trackerFilename', help = 'name of the tracker file', required = True) | 17 parser.add_argument('-t', dest = 'trackerFilename', help = 'name of the tracker file', required = True) |
| 15 parser.add_argument('--display', dest = 'display', help = 'show the results (careful with long videos, risk of running out of memory)', action = 'store_true') | 18 parser.add_argument('--display', dest = 'display', help = 'show the results (careful with long videos, risk of running out of memory)', action = 'store_true') |
| 16 parser.add_argument('-f', dest = 'firstFrameNum', help = 'number of first frame number to process', type = int, default = 0) | 19 parser.add_argument('-f', dest = 'firstFrameNum', help = 'number of first frame number to process', type = int, default = 0) |
| 17 parser.add_argument('-l', dest = 'lastFrameNum', help = 'number of last frame number to process', type = int, default = float('Inf')) | 20 parser.add_argument('-l', dest = 'lastFrameNum', help = 'number of last frame number to process', type = int, default = float('Inf')) |
| 21 parser.add_argument('--bike-pct', dest = 'bikeProportion', help = 'percent of time a person classified as bike or motorbike to be classified as cyclist', type = float, default = 0.2) | |
| 18 args = parser.parse_args() | 22 args = parser.parse_args() |
| 19 | 23 |
| 20 # required functionality? | 24 # required functionality? |
| 21 # # filename of the video to process (can be images, eg image%04d.png) | 25 # # filename of the video to process (can be images, eg image%04d.png) |
| 22 # video-filename = laurier.avi | 26 # video-filename = laurier.avi |
| 64 model = YOLO(args.detectorFilename, ) # seg yolov8x-seg.pt | 68 model = YOLO(args.detectorFilename, ) # seg yolov8x-seg.pt |
| 65 # seg could be used on cropped image... if can be loaded and kept in memory | 69 # seg could be used on cropped image... if can be loaded and kept in memory |
| 66 # model = YOLO('/home/nicolas/Research/Data/classification-models/yolo_nas_l.pt ') # AttributeError: 'YoloNAS_L' object has no attribute 'get' | 70 # model = YOLO('/home/nicolas/Research/Data/classification-models/yolo_nas_l.pt ') # AttributeError: 'YoloNAS_L' object has no attribute 'get' |
| 67 | 71 |
| 68 # Track with the model | 72 # Track with the model |
| 69 #results = model.track(source=args.videoFilename, tracker="/home/nicolas/Research/Data/classification-models/bytetrack.yaml", classes=list(moving.cocoTypeNames.keys()), show=True) # , save_txt=True | |
| 70 if args.display: | 73 if args.display: |
| 71 windowName = 'frame' | 74 windowName = 'frame' |
| 72 cv2.namedWindow(windowName, cv2.WINDOW_NORMAL) | 75 cv2.namedWindow(windowName, cv2.WINDOW_NORMAL) |
| 73 | 76 |
| 74 capture = cv2.VideoCapture(args.videoFilename) | 77 capture = cv2.VideoCapture(args.videoFilename) |
| 85 results = model.track(frame, tracker=args.trackerFilename, classes=list(moving.cocoTypeNames.keys()), persist=True) | 88 results = model.track(frame, tracker=args.trackerFilename, classes=list(moving.cocoTypeNames.keys()), persist=True) |
| 86 # create object with user type and list of 3 features (bottom ones and middle) + projection | 89 # create object with user type and list of 3 features (bottom ones and middle) + projection |
| 87 while capture.isOpened() and success and frameNum <= lastFrameNum: | 90 while capture.isOpened() and success and frameNum <= lastFrameNum: |
| 88 #for frameNum, result in enumerate(results): | 91 #for frameNum, result in enumerate(results): |
| 89 result = results[0] | 92 result = results[0] |
| 90 print(frameNum, len(result.boxes)) | 93 print(frameNum, len(result.boxes), 'objects') |
| 91 for box in result.boxes: | 94 for box in result.boxes: |
| 92 #print(box.cls, box.id, box.xyxy) | 95 #print(box.cls, box.id, box.xyxy) |
| 93 if box.id is not None: # None are objects with low confidence | 96 if box.id is not None: # None are objects with low confidence |
| 94 num = int(box.id) | 97 num = int(box.id.item()) |
| 95 xyxy = box.xyxy[0].tolist() | 98 #xyxy = box.xyxy[0].tolist() |
| 96 if num in currentObjects: | 99 if num in currentObjects: |
| 97 currentObjects[num].timeInterval.last = frameNum | 100 currentObjects[num].timeInterval.last = frameNum |
| 98 currentObjects[num].userTypes.append(moving.coco2Types[int(box.cls)]) | 101 currentObjects[num].bboxes[frameNum] = copy(box.xyxy) |
| 99 currentObjects[num].features[0].tmpPositions[frameNum] = moving.Point(xyxy[0],xyxy[1]) | 102 currentObjects[num].userTypes.append(moving.coco2Types[int(box.cls.item())]) |
| 100 currentObjects[num].features[1].tmpPositions[frameNum] = moving.Point(xyxy[2],xyxy[3]) | 103 currentObjects[num].features[0].tmpPositions[frameNum] = moving.Point(box.xyxy[0,0].item(), box.xyxy[0,1].item()) |
| 101 #features[0].getPositions().addPositionXY(xyxy[0],xyxy[1]) | 104 currentObjects[num].features[1].tmpPositions[frameNum] = moving.Point(box.xyxy[0,2].item(), box.xyxy[0,3].item()) |
| 102 #features[1].getPositions().addPositionXY(xyxy[2],xyxy[3]) | |
| 103 else: | 105 else: |
| 104 inter = moving.TimeInterval(frameNum,frameNum) | 106 inter = moving.TimeInterval(frameNum,frameNum) |
| 105 currentObjects[num] = moving.MovingObject(num, inter) | 107 currentObjects[num] = moving.MovingObject(num, inter) |
| 106 currentObjects[num].userTypes = [moving.coco2Types[int(box.cls)]] | 108 currentObjects[num].bboxes = {frameNum: copy(box.xyxy)} |
| 109 currentObjects[num].userTypes = [moving.coco2Types[int(box.cls.item())]] | |
| 107 currentObjects[num].features = [moving.MovingObject(featureNum), moving.MovingObject(featureNum+1)] | 110 currentObjects[num].features = [moving.MovingObject(featureNum), moving.MovingObject(featureNum+1)] |
| 108 currentObjects[num].featureNumbers = [featureNum, featureNum+1] | 111 currentObjects[num].featureNumbers = [featureNum, featureNum+1] |
| 109 currentObjects[num].features[0].tmpPositions = {frameNum: moving.Point(xyxy[0],xyxy[1])} | 112 currentObjects[num].features[0].tmpPositions = {frameNum: moving.Point(box.xyxy[0,0].item(), box.xyxy[0,1].item())} |
| 110 currentObjects[num].features[1].tmpPositions = {frameNum: moving.Point(xyxy[2],xyxy[3])} | 113 currentObjects[num].features[1].tmpPositions = {frameNum: moving.Point(box.xyxy[0,2].item(), box.xyxy[0,3].item())} |
| 111 featureNum += 2 | 114 featureNum += 2 |
| 112 if args.display: | 115 if args.display: |
| 113 cvutils.cvImshow(windowName, result.plot()) # original image in orig_img | 116 cvutils.cvImshow(windowName, result.plot()) # original image in orig_img |
| 114 key = cv2.waitKey() | 117 key = cv2.waitKey() |
| 115 if cvutils.quitKey(key): | 118 if cvutils.quitKey(key): |
| 116 break | 119 break |
| 117 frameNum += 1 | 120 frameNum += 1 |
| 118 success, frame = capture.read() | 121 success, frame = capture.read() |
| 119 results = model.track(frame, persist=True) | 122 results = model.track(frame, persist=True) |
| 120 | 123 |
| 121 # interpolate and generate velocity before saving | 124 # classification |
| 122 for num, obj in currentObjects.items(): | 125 for num, obj in currentObjects.items(): |
| 123 obj.setUserType(utils.mostCommon(obj.userTypes)) | 126 #obj.setUserType(utils.mostCommon(obj.userTypes)) # improve? mix with speed? |
| 127 userTypeStats = Counter(obj.userTypes) | |
| 128 if (4 in userTypeStats or (3 in userTypeStats and 4 in userTypeStats and userTypeStats[3]<=userTypeStats[4])) and userTypeStats[3]+userTypeStats[4] > args.bikeProportion*userTypeStats.total(): # 3 is motorcycle and 4 is cyclist (verif if not turning all motorbike into cyclists) | |
| 129 obj.setUserType(4) | |
| 130 else: | |
| 131 obj.setUserType(userTypeStats.most_common()[0][0]) | |
| 132 | |
| 133 # merge bikes and people | |
| 134 #Construire graphe bipartite vélo/moto personne | |
| 135 #Lien = somme des iou / longueur track vélo | |
| 136 #Algo Hongrois | |
| 137 #Verif overlap piéton vélo : si long, changement mode (trouver exemples) | |
| 138 | |
| 139 # for all cyclists and motorbikes | |
| 140 | |
| 141 # interpolate and generate velocity (?) before saving | |
| 142 for num, obj in currentObjects.items(): | |
| 124 obj.features[0].timeInterval = copy(obj.getTimeInterval()) | 143 obj.features[0].timeInterval = copy(obj.getTimeInterval()) |
| 125 obj.features[1].timeInterval = copy(obj.getTimeInterval()) | 144 obj.features[1].timeInterval = copy(obj.getTimeInterval()) |
| 126 if obj.length() != len(obj.features[0].tmpPositions): # interpolate | 145 if obj.length() != len(obj.features[0].tmpPositions): # interpolate |
| 127 obj.features[0].positions = moving.Trajectory.fromPointDict(obj.features[0].tmpPositions) | 146 obj.features[0].positions = moving.Trajectory.fromPointDict(obj.features[0].tmpPositions) |
| 128 obj.features[1].positions = moving.Trajectory.fromPointDict(obj.features[1].tmpPositions) | 147 obj.features[1].positions = moving.Trajectory.fromPointDict(obj.features[1].tmpPositions) |
