Mercurial > hg > nsaunier > traffic-intelligence
comparison scripts/train-object-classification.py @ 680:da1352b89d02 dev
classification is working
| author | Nicolas Saunier <nicolas.saunier@polymtl.ca> |
|---|---|
| date | Fri, 05 Jun 2015 02:25:30 +0200 |
| parents | ce40a89bd6ae |
| children | 5b970a5bc233 |
comparison
equal
deleted
inserted
replaced
| 678:97c305108460 | 680:da1352b89d02 |
|---|---|
| 34 trainingSamplesPV = {} | 34 trainingSamplesPV = {} |
| 35 trainingLabelsPV = {} | 35 trainingLabelsPV = {} |
| 36 | 36 |
| 37 for k, v in imageDirectories.iteritems(): | 37 for k, v in imageDirectories.iteritems(): |
| 38 print('Loading {} samples'.format(k)) | 38 print('Loading {} samples'.format(k)) |
| 39 trainingSamplesPBV[k], trainingLabelsPBV[k] = cvutils.createHOGTrainingSet(v, moving.userType2Num[k], rescaleSize, args.nOrientations, nPixelsPerCell, nCellsPerBlock) | 39 trainingSamples, trainingLabels = cvutils.createHOGTrainingSet(v, moving.userType2Num[k], rescaleSize, args.nOrientations, nPixelsPerCell, nCellsPerBlock) |
| 40 trainingSamplesPBV[k], trainingLabelsPBV[k] = trainingSamples, trainingLabels | |
| 40 if k != 'pedestrian': | 41 if k != 'pedestrian': |
| 41 trainingSamplesBV[k], trainingLabelsBV[k] = cvutils.createHOGTrainingSet(v, moving.userType2Num[k], rescaleSize, args.nOrientations, nPixelsPerCell, nCellsPerBlock) | 42 trainingSamplesBV[k], trainingLabelsBV[k] = trainingSamples, trainingLabels |
| 42 if k != 'car': | 43 if k != 'car': |
| 43 trainingSamplesPB[k], trainingLabelsPB[k] = cvutils.createHOGTrainingSet(v, moving.userType2Num[k], rescaleSize, args.nOrientations, nPixelsPerCell, nCellsPerBlock) | 44 trainingSamplesPB[k], trainingLabelsPB[k] = trainingSamples, trainingLabels |
| 44 if k != 'bicycle': | 45 if k != 'bicycle': |
| 45 trainingSamplesPV[k], trainingLabelsPV[k] = cvutils.createHOGTrainingSet(v, moving.userType2Num[k], rescaleSize, args.nOrientations, nPixelsPerCell, nCellsPerBlock) | 46 trainingSamplesPV[k], trainingLabelsPV[k] = trainingSamples, trainingLabels |
| 46 | 47 |
| 47 # Training the Support Vector Machine | 48 # Training the Support Vector Machine |
| 48 print "Training Pedestrian-Cyclist-Vehicle Model" | 49 print "Training Pedestrian-Cyclist-Vehicle Model" |
| 49 model = ml.SVM(args.svmType, args.kernelType) | 50 model = ml.SVM() |
| 50 model.train(np.concatenate(trainingSamplesPBV.values()), np.concatenate(trainingLabelsPBV.values())) | 51 model.train(np.concatenate(trainingSamplesPBV.values()), np.concatenate(trainingLabelsPBV.values()), args.svmType, args.kernelType) |
| 51 model.save(args.directoryName + "/modelPBV.xml") | 52 model.save(args.directoryName + "/modelPBV.xml") |
| 52 | 53 |
| 53 print "Training Cyclist-Vehicle Model" | 54 print "Training Cyclist-Vehicle Model" |
| 54 model = ml.SVM(args.svmType, args.kernelType) | 55 model = ml.SVM() |
| 55 model.train(np.concatenate(trainingSamplesBV.values()), np.concatenate(trainingLabelsBV.values())) | 56 model.train(np.concatenate(trainingSamplesBV.values()), np.concatenate(trainingLabelsBV.values()), args.svmType, args.kernelType) |
| 56 model.save(args.directoryName + "/modelBV.xml") | 57 model.save(args.directoryName + "/modelBV.xml") |
| 57 | 58 |
| 58 print "Training Pedestrian-Cyclist Model" | 59 print "Training Pedestrian-Cyclist Model" |
| 59 model = ml.SVM(args.svmType, args.kernelType) | 60 model = ml.SVM() |
| 60 model.train(np.concatenate(trainingSamplesPB.values()), np.concatenate(trainingLabelsPB.values())) | 61 model.train(np.concatenate(trainingSamplesPB.values()), np.concatenate(trainingLabelsPB.values()), args.svmType, args.kernelType) |
| 61 model.save(args.directoryName + "/modelPB.xml") | 62 model.save(args.directoryName + "/modelPB.xml") |
| 62 | 63 |
| 63 print "Training Pedestrian-Vehicle Model" | 64 print "Training Pedestrian-Vehicle Model" |
| 64 model = ml.SVM(args.svmType, args.kernelType) | 65 model = ml.SVM() |
| 65 model.train(np.concatenate(trainingSamplesPV.values()), np.concatenate(trainingLabelsPV.values())) | 66 model.train(np.concatenate(trainingSamplesPV.values()), np.concatenate(trainingLabelsPV.values()), args.svmType, args.kernelType) |
| 66 model.save(args.directoryName + "/modelPV.xml") | 67 model.save(args.directoryName + "/modelPV.xml") |
