IT博客汇
  • 首页
  • 精华
  • 技术
  • 设计
  • 资讯
  • 扯淡
  • 权利声明
  • 登录 注册

    k-means 之 C++ 的实现

    月影云帆发表于 2016-07-22 12:06:07
    love 0

    物以类聚,人以群分

    所谓k-means,即k均值聚类.聚类过程好比中国历史上的“春秋五霸,战国七雄”,它们同属与中国大地,同时被周王室分封。分封的过程就相当于K类的指定过程,每一个诸侯国都对应于一个聚类。五霸即五类,七雄即七类,从五霸到七雄,即相当于一个聚类生长的过程。

    用数学的语言来说就是,假设N个样点构成集合A,根据欧式距离需要将A划分为K个子集,则划分子集的过程就是k均值聚类实现的过程。

    简而言之,物以类聚,人以群分,在数学中亦是如此。

    K均值是怎么实现的

    就像周王室分封诸侯,k均值聚类也需要被告知“到底要分多少诸侯”。有鉴于诸侯王们都不傻,都想要土地肥沃+物产丰绕+风调雨顺+。。。,所以周王室干脆一刀切“那就随机指定吧!”。于是,诸侯们到达封地后,为了得到更适合他们居住的地方,不断变换他们的国都,不断蚕食周围的群落,直到有一天,他们各自发现已经达到了自己理想国度--他们有无尽的子民,无数子民围绕在他们周边,他们有广阔的土地,他们就位于着土地中央! 最终,每个诸侯王不再迁都,定居过程也随之结束。

    拿k均值来类比,总结以下几点:

    1. 有多少诸侯要分封 -- k值

    2. 一开始怎么分 -- 随机

    3. 诸侯国迁徙 -- 距离

    4. 还要迁徙吗 -- 聚类最优

    5. 定居 -- 聚类结束

    结构设计

    当然,要实现一个算法,其数据结构的设计是必不可少的!因为主要是针对三维数据的K均值计算,所以每一个样点需声明为一个结构体类型:

    typedef struct st_pointxyz
    {
            float x;
            float y;
            float z;
    }st_pointxyz;
    

    为了便于后续计算, 还需再设计一个结构,用于存贮某点和该点的索引号:

    typedef struct st_point
    {
            st_pointxyz pnt;
            int groupID;
            st_point()  
                    {
                    } 
            st_point(st_pointxyz &p,int id)
                    {
                            pnt =p;
                            groupID= id;
                    }
    }st_point;
    

    既然是实现k均值算法,那就先定义一个class KMeans吧!

    既然定义了class,就应该考虑其应该包含的具体实现函数了. 首先,聚类簇数K自不必说吧,定义SetK()。其次我想到的是应该包含输入输出,那就再构造一个成员输入函数:SetInputCloud() ,一个输出函数:SaveFile()。包含了输入输出,自然必须包含聚类过程的实现函数,就先定义为Cluster()吧!

    接下来思考以下聚类过程是怎么实现的?哦,诸侯是被随机分封的,那我们就给它一个初始化随机函数InitKCenter(),接着,诸侯的不断迁移,就是聚类中心不断变化的过程,似乎也应该包含一个聚类中心更新的函数,那就定义为UpdateGroupCenter(),想起来了,他们聚类的过程是通过两点的欧式距离实现的,似乎DisBetweenPoints()也少不了,到这里似乎聚类过程还没有结束,我们必须再给定一个结束聚类计算的“终止函数”,就像诸侯王定居,国都不再改变,k均值聚类的中心不再变化即可认为聚类过程的结束,那就再定义一个判断中心点是否移动的函数ExistCenterShift()。

    KMeans类的成员函数似乎都找齐了,但是成员变量还没说明。int m_k自不必说,接着再定义一个命令别名以便后用typedef vector<st_point> VecPoint_t(打算用vector存储数据),然后定义需要计算的输入点云VecPoint_t mv_pntcloud,还需要定义一个保存聚类结果的结构,定义为vector<VecPoint_t>m_grp_pclcloud,最后我们还要知道每类的聚类中心vector<st_pointxyz> m_center。

    到现在,k均值聚类整体结构已经有了,接下来就是将他们组合到一起(这里借助了pcl库,因为目前为止pcl中还没有K-means算法功能,ps:如果有谁能在pcl中找到k-means算法,请一定留言通知,不胜感激. 借助pcl只是为了省去三维点云读取与存贮的麻烦)

    class KMeans
    {
    public:
            int m_k;
            typedef vector<st_point> VecPoint_t;  //定义命令别名
    
            VecPoint_t mv_pntcloud; //要聚类的点云
            vector<VecPoint_t>m_grp_pntcloud;  //k类,每一类存储若干点
            vector<st_pointxyz>mv_center; //每个类的中心
    
            KMeans() 
            {
                    m_k =0;
            }
            inline void SetK(int k_) //设置聚类簇数
            {
                    m_k = k_;
                    m_grp_pntcloud.resize(m_k); 
            }
            //设置输入点云
            bool SetInputCloud(pcl::PointCloud<pcl::PointXYZ>::Ptr pPntCloud);  
    
            //初始化最初的k个类的中心
            bool InitKCenter();  
    
            //聚类
            bool Cluster();
    
            //更新k类的中心(参数为类和中心点)
           vector<st_pointxyz>  UpdateGroupCenter(vector<VecPoint_t> &grp_pntcloud,vector<st_pointxyz> cer);
    
            //计算两点欧式距离
            double DistBetweenPoints(st_pointxyz &p1,st_pointxyz &p2);
    
            //是否存在中心点转移动
            bool ExistCenterShift(vector<st_pointxyz> &prev_center,vector<st_pointxyz> &cur_center);
    
            //将聚类分别存储到各自的pcd文件中
            bool SaveFile(const char *fname);
    
    };
        
    
    

    具体实现

    首先设置一个判断聚类中心是否移动的阀值cosnt float DIST_NRAR = 0.001,也就是说当两次聚类中心的差值小于此值时,聚类则停止。

    上代码:

    
        bool KMeans::InitKCenter( )
        {
                mv_center.resize(m_k);
                int size = mv_pntcloud.size();
                srand(unsigned(time(NULL)));  
                for (int i =0; i< m_k;i++)
                {
                        int seed = random()%(size+1);
                        mv_center[i].x = mv_pntcloud[seed].pnt.x;
                        mv_center[i].y = mv_pntcloud[seed].pnt.y;
                        mv_center[i].z = mv_pntcloud[seed].pnt.z;   
                }
                return true;
        }
        bool KMeans::SetInputCloud(pcl::PointCloud<pcl::PointXYZ>::Ptr pPntCloud)
        {
                size_t pntCount = (size_t) pPntCloud->points.size();
                for (size_t i = 0; i< pntCount;++i)
                {
                        st_point point;
                        point.pnt.x = pPntCloud->points[i].x;
                        point.pnt.y = pPntCloud->points[i].y;
                        point.pnt.z = pPntCloud->points[i].z;
                        point.groupID = 0;
        
                        mv_pntcloud.push_back(point);
                }
                
                return true;
        }
        bool KMeans::Cluster()
        {
                InitKCenter();
                vector<st_pointxyz>v_center(mv_center.size());
                size_t pntCount = mv_pntcloud.size();
        
                do
                {
                        for (size_t i = 0;i < pntCount;++i)  
                        {
                                double min_dist = DBL_MAX;  
                                int pnt_grp = 0;   //聚类群组索引号
                                for (size_t j =0;j <m_k;++j) 
                                {
                                         double dist = DistBetweenPoints(mv_pntcloud[i].pnt, mv_center[j]);  
                                         if (min_dist - dist > 0.000001)  
                                         {  
                                                 min_dist = dist;  
                                                 pnt_grp = j;
                                         }
                                }
                                m_grp_pntcloud[pnt_grp].push_back(st_point(mv_pntcloud[i].pnt,pnt_grp)); //将该点和该点群组的索引存入聚类中
                        }
        
                        //保存上一次迭代的中心点
                        for (size_t i = 0; i<mv_center.size();++i)
                        {
                                v_center[i] = mv_center[i];
                        }
        
                        mv_center=UpdateGroupCenter(m_grp_pntcloud,mv_center);
                        if ( !ExistCenterShift(v_center, mv_center))  
                        {  
                                break;   
                        }  
                        for (size_t i = 0; i < m_k; ++i){  
                                m_grp_pntcloud[i].clear();  
                        }  
                        
                }while(true);
                
                return true;
    }
    double KMeans::DistBetweenPoints(st_pointxyz &p1, st_pointxyz &p2)  
    {  
            double dist = 0;  
            double x_diff = 0, y_diff = 0, z_diff = 0;  
      
            x_diff = p1.x - p2.x;  
            y_diff = p1.y - p2.y;  
            z_diff = p1.z - p2.z;  
            dist = sqrt(x_diff * x_diff + y_diff * y_diff + z_diff * z_diff);  
          
            return dist;  
    }  
    vector<st_pointxyz> KMeans::UpdateGroupCenter(std::vector<VecPoint_t> &grp_pntcloud, std::vector<st_pointxyz> center) 
    {
        for (size_t i = 0; i < m_k; ++i)  
        {  
            float x = 0, y = 0, z = 0;  
            size_t pnt_num_in_grp = grp_pntcloud[i].size();  
      
            for (size_t j = 0; j < pnt_num_in_grp; ++j)  
            {             
                    x += grp_pntcloud[i][j].pnt.x;  
                    y += grp_pntcloud[i][j].pnt.y;  
                    z += grp_pntcloud[i][j].pnt.z;  
            }  
            x /= pnt_num_in_grp;  
            y /= pnt_num_in_grp;  
            z /= pnt_num_in_grp;  
            center[i].x = x;   
            center[i].y = y;  
            center[i].z = z;  
        }  
        return center;
        
    }
    //是否存在中心点移动  
    bool KMeans::ExistCenterShift(std::vector<st_pointxyz> &prev_center, std::vector<st_pointxyz> &cur_center)  
    {  
        for (size_t i = 0; i < m_k; ++i)  
        {  
            double dist = DistBetweenPoints(prev_center[i], cur_center[i]);  
            if (dist > DIST_NEAR_ZERO)  
            {  
                return true;  
            }  
        }  
      
        return false;  
    }
    //将聚类的点分别存到各自的pcd文件中  
    bool KMeans::SaveFile(const char *prex_name)  
    {  
        for (int i = 0; i < m_k; ++i)  
        {  
            pcl::PointCloud<pcl::PointXYZ>::Ptr p_pnt_cloud(new pcl::PointCloud<pcl::PointXYZ> ());  
      
            for (size_t j = 0, grp_pnt_count = m_grp_pntcloud[i].size(); j < grp_pnt_count; ++j)  
            {  
                pcl::PointXYZ pt;  
                pt.x = m_grp_pntcloud[i][j].pnt.x;  
                pt.y = m_grp_pntcloud[i][j].pnt.y;  
                pt.z = m_grp_pntcloud[i][j].pnt.z;  
      
                p_pnt_cloud->points.push_back(pt);  
            }  
      
            p_pnt_cloud->width = (int)m_grp_pntcloud[i].size();
            p_pnt_cloud->height = 1;  
      
            char newFileName[256] = {0};  
            char indexStr[16] = {0};  
      
            strcat(newFileName, szFileName);  
            strcat(newFileName, "-");  
            strcat(newFileName, prex_name);  
            strcat(newFileName, "-");  
            sprintf(indexStr, "%d", i + 1);  
            strcat(newFileName, indexStr);  
            strcat(newFileName, ".pcd");  
            pcl::io::savePCDFileASCII(newFileName, *p_pnt_cloud);  
        }  
          
        return true;  
    } 
    

    实例检测

    k = 2

    k = 5



沪ICP备19023445号-2号
友情链接