trainer.h 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437
  1. //##########################################################################
  2. //# #
  3. //# CLOUDCOMPARE PLUGIN: qCANUPO #
  4. //# #
  5. //# This program is free software; you can redistribute it and/or modify #
  6. //# it under the terms of the GNU General Public License as published by #
  7. //# the Free Software Foundation; version 2 or later of the License. #
  8. //# #
  9. //# This program is distributed in the hope that it will be useful, #
  10. //# but WITHOUT ANY WARRANTY; without even the implied warranty of #
  11. //# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the #
  12. //# GNU General Public License for more details. #
  13. //# #
  14. //# COPYRIGHT: UEB (UNIVERSITE EUROPEENNE DE BRETAGNE) / CNRS #
  15. //# #
  16. //##########################################################################
  17. /** This file is directly inspired of the suggest_classifier_lda.cpp file in the
  18. original CANUPO project, by N. Brodu and D. Lague.
  19. **/
  20. #ifndef QCANUPO_TRAINER_HEADER
  21. #define QCANUPO_TRAINER_HEADER
  22. //system
  23. #include <vector>
  24. //dlib
  25. #include <dlib/matrix.h>
  26. #include <dlib/svm.h>
  27. class LDATrainer
  28. {
  29. public:
  30. typedef dlib::matrix<float, 0, 1> sample_type;
  31. typedef dlib::linear_kernel<sample_type> kernel_type;
  32. typedef dlib::decision_function<kernel_type> trained_function_type;
  33. //typedef trained_function_type::mem_manager_type mem_manager_type;
  34. trained_function_type train(const std::vector<sample_type>& samplesvec, const std::vector<float>& labels) const
  35. {
  36. size_t fdim = samplesvec[0].size();
  37. size_t nsamples = samplesvec.size();
  38. long ndata_class1 = 0, ndata_class2 = 0;
  39. for (size_t i = 0; i < nsamples; ++i)
  40. {
  41. if (labels[i] > 0)
  42. ++ndata_class1;
  43. else
  44. ++ndata_class2;
  45. }
  46. dlib::matrix<sample_type, 0, 1> samples1, samples2;
  47. samples1.set_size(ndata_class1);
  48. samples2.set_size(ndata_class2);
  49. sample_type mu1; mu1.set_size(fdim);
  50. sample_type mu2; mu2.set_size(fdim);
  51. for (size_t i = 0; i < fdim; ++i)
  52. {
  53. mu1(i) = 0;
  54. mu2(i) = 0;
  55. }
  56. ndata_class1 = 0; ndata_class2 = 0;
  57. for (size_t i = 0; i < nsamples; ++i)
  58. {
  59. if (labels[i] > 0)
  60. {
  61. samples1(ndata_class1) = samplesvec[i];
  62. ++ndata_class1;
  63. mu1 += samplesvec[i];
  64. }
  65. else
  66. {
  67. samples2(ndata_class2) = samplesvec[i];
  68. ++ndata_class2;
  69. mu2 += samplesvec[i];
  70. }
  71. }
  72. if (ndata_class1 != 0)
  73. {
  74. mu1 /= ndata_class1;
  75. }
  76. if (ndata_class2 != 0)
  77. {
  78. mu2 /= ndata_class2;
  79. }
  80. // if you get a compilation error coming from here (with templates
  81. // and a 'visual_studio_sucks_cov_helper' structure involved) then
  82. // you may have to patch the dlib's file 'matrix_utilities.h":
  83. //
  84. // line 1611, replace
  85. // const matrix<double,EXP::type::NR,EXP::type::NC, typename EXP::mem_manager_type> avg = mean(m);
  86. // by
  87. // const typename EXP::type avg = mean(m);
  88. //
  89. dlib::matrix<float> sigma1 = covariance(samples1);
  90. dlib::matrix<float> sigma2 = covariance(samples2);
  91. sample_type w_vect = pinv(sigma1 + sigma2) * (mu2 - mu1);
  92. trained_function_type ret;
  93. //ret.alpha.set_size(fdim);
  94. //for (int i=0; i<fdim; ++i) ret.alpha(i) = w_vect(i);
  95. ret.alpha = w_vect;
  96. ret.b = dot(w_vect, (mu1 + mu2)*0.5);
  97. // linear kernel idiocy
  98. ret.basis_vectors.set_size(fdim);
  99. for (size_t i = 0; i < fdim; ++i)
  100. {
  101. ret.basis_vectors(i).set_size(fdim);
  102. for (size_t j = 0; j < fdim; ++j)
  103. ret.basis_vectors(i)(j) = 0;
  104. ret.basis_vectors(i)(i) = 1;
  105. }
  106. return ret;
  107. }
  108. #if 0
  109. trained_function_type train(const trained_function_type::sample_vector_type& samplesvec, const trained_function_type::scalar_vector_type& labels) const
  110. {
  111. int fdim = samplesvec(0).size();
  112. int nsamples = samplesvec.size();
  113. int ndata_class1 = 0, ndata_class2 = 0;
  114. for (int i=0; i<nsamples; ++i)
  115. {
  116. if (labels(i)>0)
  117. ++ndata_class1;
  118. else
  119. ++ndata_class2;
  120. }
  121. dlib::matrix<sample_type,0,1> samples1, samples2;
  122. samples1.set_size(ndata_class1);
  123. samples2.set_size(ndata_class2);
  124. sample_type mu1; mu1.set_size(fdim);
  125. sample_type mu2; mu2.set_size(fdim);
  126. for (int i=0; i<fdim; ++i)
  127. {
  128. mu1(i)=0;
  129. mu2(i)=0;
  130. }
  131. ndata_class1 = 0; ndata_class2 = 0;
  132. for (int i=0; i<nsamples; ++i)
  133. {
  134. if (labels(i)>0)
  135. {
  136. samples1(ndata_class1) = samplesvec(i);
  137. ++ndata_class1;
  138. mu1 += samplesvec(i);
  139. }
  140. else
  141. {
  142. samples2(ndata_class2) = samplesvec(i);
  143. ++ndata_class2;
  144. mu2 += samplesvec(i);
  145. }
  146. }
  147. mu1 /= ndata_class1;
  148. mu2 /= ndata_class2;
  149. // if you get a compilation error coming from here (with templates
  150. // and a 'visual_studio_sucks_cov_helper' structure involved) then
  151. // you may have to patch the dlib's file 'matrix_utilities.h":
  152. //
  153. // line 1611, replace
  154. // const matrix<double,EXP::type::NR,EXP::type::NC, typename EXP::mem_manager_type> avg = mean(m);
  155. // by
  156. // const typename EXP::type avg = mean(m);
  157. //
  158. dlib::matrix<float> sigma1 = covariance(samples1);
  159. dlib::matrix<float> sigma2 = covariance(samples2);
  160. sample_type w_vect = pinv(sigma1+sigma2) * (mu2 - mu1);
  161. trained_function_type ret;
  162. //ret.alpha.set_size(fdim);
  163. //for (int i=0; i<fdim; ++i) ret.alpha(i) = w_vect(i);
  164. ret.alpha = w_vect;
  165. ret.b = dot(w_vect,(mu1+mu2)*0.5);
  166. // linear kernel idiocy
  167. ret.basis_vectors.set_size(fdim);
  168. for (int i=0; i<fdim; ++i)
  169. {
  170. ret.basis_vectors(i).set_size(fdim);
  171. for (int j=0; j<fdim; ++j)
  172. ret.basis_vectors(i)(j)=0;
  173. ret.basis_vectors(i)(i) = 1;
  174. }
  175. return ret;
  176. /*LinearPredictor classifier;
  177. classifier.weights.resize(fdim+1);
  178. for (int i=0; i<fdim; ++i) classifier.weights[i] = w_vect(i);
  179. classifier.weights[fdim] = -dot(w_vect,(mu1+mu2)*0.5);
  180. return classifier;*/
  181. }
  182. #endif
  183. void train(int nfolds, const std::vector<sample_type>& samples, const std::vector<float>& labels)
  184. {
  185. dlib::probabilistic_decision_function<kernel_type> pdecfun = dlib::train_probabilistic_decision_function(*this, samples, labels, nfolds);
  186. dlib::decision_function<kernel_type>& decfun = pdecfun.decision_funct;
  187. int dim = samples.back().size();
  188. // see comments in linearSVM.hpp
  189. m_weights.clear();
  190. m_weights.resize(dim + 1, 0);
  191. dlib::matrix<float> w(dim, 1);
  192. w = 0;
  193. for (int i = 0; i < decfun.alpha.nr(); ++i)
  194. {
  195. w += decfun.alpha(i) * decfun.basis_vectors(i);
  196. }
  197. for (int i = 0; i < dim; ++i)
  198. m_weights[i] = w(i);
  199. m_weights[dim] = -decfun.b;
  200. for (int i = 0; i <= dim; ++i)
  201. m_weights[i] *= pdecfun.alpha;
  202. m_weights[dim] += pdecfun.beta;
  203. // TODO: check if necessary here
  204. for (int i = 0; i <= dim; ++i)
  205. m_weights[i] = -m_weights[i];
  206. }
  207. double predict(const sample_type& data) const
  208. {
  209. assert(!m_weights.empty());
  210. double ret = m_weights.back();
  211. for (size_t d = 0; d < m_weights.size() - 1; ++d)
  212. ret += static_cast<double>(m_weights[d]) * data(d);
  213. return ret;
  214. }
  215. //! Classifier weights
  216. std::vector<float> m_weights;
  217. };
  218. //! Gram-Schmidt process to re-orthonormalise the basis
  219. static void GramSchmidt(dlib::matrix<LDATrainer::sample_type,0,1>& basis, LDATrainer::sample_type& newX)
  220. {
  221. // goal: find a basis so that the given vector is the new X
  222. // principle: at least one basis vector is not orthogonal with newX (except if newX is null but we suppose this is not the case)
  223. // => use the max dot product vector, and replace it by newX. this forms a set of
  224. // linearly independent vectors.
  225. // then apply the Gram-Schmidt process
  226. long dim = basis.size();
  227. double maxabsdp = -1.0;
  228. long selectedCoord = 0;
  229. for (long i = 0; i < dim; ++i)
  230. {
  231. double absdp = std::abs(dot(basis(i), newX));
  232. if (absdp > maxabsdp)
  233. {
  234. absdp = maxabsdp;
  235. selectedCoord = i;
  236. }
  237. }
  238. // swap basis vectors to use the selected coord as the X vector, then replaced by newX
  239. basis(selectedCoord) = basis(0);
  240. basis(0) = newX;
  241. // Gram-Schmidt process to re-orthonormalise the basis.
  242. // Thanks Wikipedia for the stabilized version
  243. for (long j = 0; j < dim; ++j)
  244. {
  245. for (long i = 0; i < j; ++i)
  246. basis(j) -= (dot(basis(j), basis(i)) / dot(basis(i), basis(i))) * basis(i);
  247. basis(j) /= sqrt(dot(basis(j), basis(j)));
  248. }
  249. }
  250. //! Compute pos. and neg. reference points
  251. static void ComputeReferencePoints( Classifier::Point2D& refpt_pos,
  252. Classifier::Point2D& refpt_neg,
  253. const std::vector<float>& proj1,
  254. const std::vector<float>& proj2,
  255. const std::vector<float>& labels,
  256. unsigned* _npos = 0,
  257. unsigned* _nneg = 0)
  258. {
  259. assert(proj1.size() == proj2.size() && proj1.size() == labels.size());
  260. refpt_neg = refpt_pos = Classifier::Point2D(0, 0);
  261. size_t npos = 0;
  262. size_t nneg = 0;
  263. for (size_t i = 0; i < labels.size(); ++i)
  264. {
  265. if (labels[i] < 0)
  266. {
  267. refpt_neg += Classifier::Point2D(proj1[i], proj2[i]);
  268. ++nneg;
  269. }
  270. else
  271. {
  272. refpt_pos += Classifier::Point2D(proj1[i], proj2[i]);
  273. ++npos;
  274. }
  275. }
  276. if (npos)
  277. refpt_pos /= static_cast<PointCoordinateType>(npos);
  278. if (nneg)
  279. refpt_neg /= static_cast<PointCoordinateType>(nneg);
  280. if (_npos)
  281. *_npos = npos;
  282. if (_nneg)
  283. *_nneg = nneg;
  284. }
  285. //! Experimental (same as Brodu's code): dilatation to highlight the internal data structure
  286. static bool DilateClassifier( Classifier& classifier,
  287. std::vector<float>& proj1,
  288. std::vector<float>& proj2,
  289. const std::vector<float>& labels,
  290. const std::vector<LDATrainer::sample_type>& samples,
  291. LDATrainer& trainer,
  292. LDATrainer& orthoTrainer)
  293. {
  294. //m_app->dispToConsole("[Cloud dilatation]");
  295. Classifier::Point2D e1 = classifier.refPointPos - classifier.refPointNeg;
  296. e1.normalize();
  297. Classifier::Point2D e2(-e1.y, e1.x);
  298. Classifier::Point2D ori = (classifier.refPointPos + classifier.refPointNeg) / 2;
  299. float m11 = 0, m21 = 0, m12 = 0, m22 = 0; // m12, m22 null by construction
  300. float v11 = 0, v12 = 0, v21 = 0, v22 = 0;
  301. size_t nsamples1 = 0;
  302. size_t nsamples2 = 0;
  303. size_t nsamples = proj1.size();
  304. assert(proj1.size() == proj2.size());
  305. for (size_t i = 0; i < nsamples; ++i)
  306. {
  307. Classifier::Point2D p(proj1[i], proj2[i]);
  308. p -= ori;
  309. float p1 = p.dot(e1);
  310. float p2 = p.dot(e2);
  311. if (labels[i] < 0)
  312. {
  313. m11 += p1; v11 += p1 * p1;
  314. m12 += p2; v12 += p2 * p2;
  315. ++nsamples1;
  316. }
  317. else
  318. {
  319. m21 += p1; v21 += p1 * p1;
  320. m22 += p2; v22 += p2 * p2;
  321. ++nsamples2;
  322. }
  323. }
  324. if (nsamples1 < 2 || nsamples2 < 2)
  325. {
  326. assert(false);
  327. return false;
  328. }
  329. m11 /= nsamples1;
  330. v11 = (v11 - m11 * m11*nsamples1) / (nsamples1 - 1);
  331. m21 /= nsamples2;
  332. v21 = (v21 - m21 * m21*nsamples2) / (nsamples2 - 1);
  333. m12 /= nsamples1;
  334. v12 = (v12 - m12 * m12*nsamples1) / (nsamples1 - 1);
  335. m22 /= nsamples2;
  336. v22 = (v22 - m22 * m22*nsamples2) / (nsamples2 - 1);
  337. float d1 = sqrt(v11 / v12);
  338. float d2 = sqrt(v21 / v22);
  339. classifier.axisScaleRatio = sqrt(d1*d2);
  340. float bdValues[4] { e1.x, e1.y, e2.x / classifier.axisScaleRatio, e2.y / classifier.axisScaleRatio };
  341. dlib::matrix<float, 2, 2> bd(bdValues);
  342. float biValues[4] { e1.x, e2.x, e1.y, e2.y };
  343. dlib::matrix<float, 2, 2> bi(biValues);
  344. dlib::matrix<float, 2, 2> c = inv(trans(bd)) /* bi * bd */;
  345. std::vector<float>& w1 = trainer.m_weights;
  346. std::vector<float>& w2 = orthoTrainer.m_weights;
  347. assert(w1.size() == w2.size());
  348. std::vector<float> wn1, wn2;
  349. try
  350. {
  351. wn1.resize(w1.size());
  352. wn2.resize(w2.size());
  353. }
  354. catch (const std::bad_alloc&)
  355. {
  356. //not enough memory
  357. return false;
  358. }
  359. // first shift so the center of the figure is at the midpoint
  360. w1.back() -= ori.x;
  361. w2.back() -= ori.y;
  362. // now transform / scale along e2
  363. {
  364. for (size_t i = 0; i < w1.size(); ++i)
  365. {
  366. wn1[i] = c(0, 0) * w1[i] + c(0, 1) * w2[i];
  367. wn2[i] = c(1, 0) * w1[i] + c(1, 1) * w2[i];
  368. }
  369. }
  370. trainer.m_weights = wn1;
  371. orthoTrainer.m_weights = wn2;
  372. // reset projections
  373. {
  374. for (size_t i = 0; i < nsamples; ++i)
  375. {
  376. proj1[i] = trainer.predict(samples[i]);
  377. proj2[i] = orthoTrainer.predict(samples[i]);
  378. }
  379. }
  380. classifier.weightsAxis1 = wn1;
  381. classifier.weightsAxis2 = wn2;
  382. //update reference points
  383. ComputeReferencePoints( classifier.refPointPos,
  384. classifier.refPointNeg,
  385. proj1,
  386. proj2,
  387. labels);
  388. return true;
  389. }
  390. #endif //QCANUPO_CLASSIFIER_HEADER