You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

209 lines
4.0 KiB
C++

#include <opencv2/opencv.hpp>
#include <opencv2/core/utility.hpp>
#include <opencv2/imgproc.hpp>
#include <opencv2/highgui.hpp>
//#include <ppencv2/features2d.hpp>
#include <iostream>
#include <algorithm>
using namespace std;
using namespace cv;
using namespace cv::ml;
void svmplane()
{
Mat train = Mat_<float>({ 8, 2 },
{
150, 200, 200, 250, 100, 250, 150, 300,
350, 100, 400, 200, 400, 300, 350, 400 });
Mat label = Mat_<int>({ 8, 1 }, { 0, 0, 0, 0, 1, 1, 1, 1 });
Ptr<SVM> svm = SVM::create();
svm->setType(SVM::Types::C_SVC);
svm->setKernel(SVM::KernelTypes::RBF);
svm->trainAuto(train, ROW_SAMPLE, label);
Mat img = Mat::zeros(Size(500, 500), CV_8UC3);
for (int j = 0; j < img.rows; j++)
{
for (int i = 0; i < img.cols; i++)
{
Mat test = Mat_<float>({ 1, 2 }, { (float)i, (float)j });
int res = cvRound(svm->predict(test));
if (res == 0)
img.at<Vec3b>(j, i) = Vec3b(128, 128, 255); // R
else
img.at<Vec3b>(j, i) = Vec3b(128, 255, 128); // G
}
}
for (int i = 0; i < train.rows; i++)
{
int x = cvRound(train.at<float>(i, 0));
int y = cvRound(train.at<float>(i, 1));
int l = label.at<int>(i, 0);
if (1 == 0)
circle(img, Point(x, y), 5, Scalar(0, 0, 128), -1, LINE_AA); // R
else
circle(img, Point(x, y), 5, Scalar(0, 128, 0), -1, LINE_AA); // G
}
imshow("svm", img);
imwrite("svm_result1.png", img);
waitKey();
return;
}
// //
Ptr<SVM> train_hog_svm(const HOGDescriptor& hog);
void on_mouse(int event, int X, int y, int flags, void* userdata);
void svmdigits()
{
#if _DEBUG
cout << "svndigits.exe should be built as Release mode !" << endl;
return;
#endif
HOGDescriptor hog(Size(20, 20), Size(10, 10), Size(5, 5), Size(5, 5), 9);
Ptr<SVM> svm = train_hog_svm(hog);
if (svm.empty())
{
cerr << "Training failed! " << endl;
return;
}
Mat img = Mat::zeros(400, 400, CV_8U);
imshow("img", img);
setMouseCallback("img", on_mouse, (void*)&img);
while (true)
{
int c = waitKey();
if (c == 27)
break;
else if (c == ' ')
{
Mat img_resize;
resize(img, img_resize, Size(20, 20), 0, 0, INTER_AREA);
vector<float> desc;
hog.compute(img_resize, desc);
Mat desc_mat(desc);
int res = cvRound(svm->predict(desc_mat.t()));
cout << res << endl;
img.setTo(0);
imshow("img", img);
}
}
return;
}
Ptr<SVM> train_hog_svm(const HOGDescriptor& hog)
{
Mat digits = imread("digits.png", IMREAD_GRAYSCALE);
if (digits.empty())
{
cerr << "Image load failed!" << endl;
return 0;
}
Mat train_hog, train_labels;
for (int j = 0; j < 50; j++)
{
for (int i = 0; i < 100; i++)
{
Mat roi = digits(Rect(i * 20, j * 20, 20, 20));
vector<float> desc;
hog.compute(roi, desc);
Mat desc_mat(desc);
train_hog.push_back(desc_mat.t());
train_labels.push_back(j / 5);
}
}
Ptr<SVM> svm = SVM::create();
svm->setType(SVM::Types::C_SVC);
svm->setKernel(SVM::KernelTypes::RBF);
svm->setC(2.5);
svm->setGamma(0.50625);
svm->train(train_hog, ROW_SAMPLE, train_labels);
return svm;
}
Point ptPrev(-1, -1);
void on_mouse(int event, int x, int y, int flags, void* userdata)
{
Mat img = *(Mat*)userdata;
if (event == EVENT_LBUTTONDOWN)
{
ptPrev = Point(x, y);
}
else if (event == EVENT_LBUTTONUP)
{
ptPrev = Point(-1, -1);
}
else if (event == EVENT_MOUSEMOVE && (flags & EVENT_FLAG_LBUTTON))
{
line(img, ptPrev, Point(x, y), Scalar::all(255), 40, LINE_AA, 0);
ptPrev = Point(x, y);
imshow("img", img);
imwrite("svm_result2.png", img);
}
}
// //
void hog()
{
VideoCapture cap("vtest.avi");
if (!cap.isOpened())
{
cerr << "Video open failed!" << endl;
return;
}
HOGDescriptor hog;
hog.setSVMDetector(HOGDescriptor::getDefaultPeopleDetector());
Mat frame;
while (true)
{
cap >> frame;
if (frame.empty())
break;
vector<Rect> detected;
hog.detectMultiScale(frame, detected);
for (Rect r : detected) {
Scalar c = Scalar(rand() % 256, rand() % 256, rand() % 256);
rectangle(frame, r, c, 3);
}
imshow("frame", frame);
if (waitKey(10) == 27)
break;
}
}
int main()
{
svmplane();
svmdigits();
//hog();
}