Mercurial > hg > nsaunier > traffic-intelligence
comparison python/ml.py @ 788:5b970a5bc233 dev
updated classifying code to OpenCV 3.x (bug in function to load classification models)
| author | Nicolas Saunier <nicolas.saunier@polymtl.ca> |
|---|---|
| date | Thu, 24 Mar 2016 16:37:37 -0400 |
| parents | 0a428b449b80 |
| children | 1158a6e2d28e |
comparison
equal
deleted
inserted
replaced
| 787:0a428b449b80 | 788:5b970a5bc233 |
|---|---|
| 9 from matplotlib.pylab import text | 9 from matplotlib.pylab import text |
| 10 import matplotlib as mpl | 10 import matplotlib as mpl |
| 11 import matplotlib.pyplot as plt | 11 import matplotlib.pyplot as plt |
| 12 from scipy.cluster.vq import kmeans, whiten, vq | 12 from scipy.cluster.vq import kmeans, whiten, vq |
| 13 from sklearn import mixture | 13 from sklearn import mixture |
| 14 import cv2 | |
| 14 | 15 |
| 15 import utils | 16 import utils |
| 16 | 17 |
| 17 ##################### | 18 ##################### |
| 18 # OpenCV ML models | 19 # OpenCV ML models |
| 19 ##################### | 20 ##################### |
| 20 | 21 |
| 21 class Model(object): | 22 class StatModel(object): |
| 22 '''Abstract class for loading/saving model''' | 23 '''Abstract class for loading/saving model''' |
| 23 def load(self, filename): | 24 def load(self, filename): |
| 24 if path.exists(filename): | 25 if path.exists(filename): |
| 25 self.model.load(filename) | 26 self.model.load(filename) |
| 26 else: | 27 else: |
| 27 print('Provided filename {} does not exist: model not loaded!'.format(filename)) | 28 print('Provided filename {} does not exist: model not loaded!'.format(filename)) |
| 28 | 29 |
| 29 def save(self, filename): | 30 def save(self, filename): |
| 30 self.model.save(filename) | 31 self.model.save(filename) |
| 31 | 32 |
| 32 class SVM(Model): | 33 class SVM(StatModel): |
| 33 '''wrapper for OpenCV SimpleVectorMachine algorithm''' | 34 '''wrapper for OpenCV SimpleVectorMachine algorithm''' |
| 34 | 35 def __init__(self, svmType = cv2.ml.SVM_C_SVC, kernelType = cv2.ml.SVM_RBF, degree = 0, gamma = 1, coef0 = 0, Cvalue = 1, nu = 0, p = 0): |
| 35 def __init__(self): | 36 self.model = cv2.ml.SVM_create() |
| 36 import cv2 | 37 self.model.setType(svmType) |
| 37 self.model = cv2.SVM() | 38 self.model.setKernel(kernelType) |
| 38 | 39 self.model.setDegree(degree) |
| 39 def train(self, samples, responses, svm_type, kernel_type, degree = 0, gamma = 1, coef0 = 0, Cvalue = 1, nu = 0, p = 0): | 40 self.model.setGamma(gamma) |
| 40 self.params = dict(svm_type = svm_type, kernel_type = kernel_type, degree = degree, gamma = gamma, coef0 = coef0, Cvalue = Cvalue, nu = nu, p = p) | 41 self.model.setCoef0(coef0) |
| 41 self.model.train(samples, responses, params = self.params) | 42 self.model.setC(Cvalue) |
| 43 self.model.setNu(nu) | |
| 44 self.model.setP(p) | |
| 45 | |
| 46 def train(self, samples, layout, responses): | |
| 47 self.model.train(samples, layout, responses) | |
| 42 | 48 |
| 43 def predict(self, hog): | 49 def predict(self, hog): |
| 44 return self.model.predict(hog) | 50 return self.model.predict(hog) |
| 45 | 51 |
| 46 | 52 |
