IT博客汇
  • 首页
  • 精华
  • 技术
  • 设计
  • 资讯
  • 扯淡
  • 权利声明
  • 登录 注册

    ps-lite_part2_Postoffice讲解

    admin发表于 2023-03-09 14:46:23
    love 0

    晚上也在思考这篇文章该以怎样的结构来撰写,ps-lite虽然是轻量版,但是很多东西都是在复用,东西都混在一起,所以在解释理清思路的时候有点绕。思来想去还是按照ps程序启动的逻辑来讲,涉及到本文需要讲的点在重点描写。

    ps启动

    启动一个完整的 ps 服务,需要启动scheduler、sever和worker,启动的时候都会产生一个实际的物理进程,这个进程里都会包含PostOffice,负责管理全局信息。

    ps-lite给出一个简单的demo,启动的顺序是先启动schedule,然后是server 、worker

    #!/bin/bash
    # set -x
    if [ # -lt 3 ]; then
        echo "usage:0 num_servers num_workers bin [args..]"
        exit -1;
    fi
    
    export DMLC_NUM_SERVER=1
    shift
    export DMLC_NUM_WORKER=1
    shift
    bin=1
    shift
    arg="@"
    
    # start the scheduler
    export DMLC_PS_ROOT_URI='127.0.0.1'
    export DMLC_PS_ROOT_PORT=8000
    export DMLC_ROLE='scheduler'
    {bin}{arg} &
    
    
    # start servers
    export DMLC_ROLE='server'
    for ((i=0; i<{DMLC_NUM_SERVER}; ++i)); do
        export HEAPPROFILE=./S{i}
        {bin}{arg} &
    done
    
    # start workers
    export DMLC_ROLE='worker'
    for ((i=0; i<{DMLC_NUM_WORKER}; ++i)); do
        export HEAPPROFILE=./W{i}
        {bin}{arg} &
    done
    

    从上面的启动程序可以看到,ps-lite启动任务一些参数是来自环境变量的,你会看到shell脚本里充斥着export。

    那我们现在看看它是怎么启动程序的?看一个demo

    #include <cmath>
    #include "ps/ps.h"
    
    using namespace ps;
    
    void StartServer() {
      if (!IsServer()) {
        return;
      }
      auto server = new KVServer<float>(0);
      server->set_request_handle(KVServerDefaultHandle<float>());
      RegisterExitCallback([server](){ delete server; });
    }
    
    void RunWorker() {
      if (!IsWorker()) return;
      KVWorker<float> kv(0, 0);
    
      // init
      int num = 10000;
      std::vector<Key> keys(num);
      std::vector<float> vals(num);
    
      int rank = MyRank();
      srand(rank + 7);
      for (int i = 0; i < num; ++i) {
        keys[i] = kMaxKey / num * i + rank;
        vals[i] = (rand() % 1000);
      }
    
      // push
      int repeat = 50;
      std::vector<int> ts;
      for (int i = 0; i < repeat; ++i) {
        ts.push_back(kv.Push(keys, vals));
    
        // to avoid too frequency push, which leads huge memory usage
        if (i > 10) kv.Wait(ts[ts.size()-10]);
      }
      for (int t : ts) kv.Wait(t);
    
      // pull
      std::vector<float> rets;
      kv.Wait(kv.Pull(keys, &rets));
    
      // pushpull
      std::vector<float> outs;
      for (int i = 0; i < repeat; ++i) {
        // PushPull on the same keys should be called serially
        kv.Wait(kv.PushPull(keys, vals, &outs));
      }
    
      float res = 0;
      float res2 = 0;
      for (int i = 0; i < num; ++i) {
        res += std::fabs(rets[i] - vals[i] * repeat);
        res2 += std::fabs(outs[i] - vals[i] * 2 * repeat);
      }
      CHECK_LT(res / repeat, 1e-5);
      CHECK_LT(res2 / (2 * repeat), 1e-5);
      LL << "error: " << res / repeat << ", " << res2 / (2 * repeat);
    }
    
    int main(int argc, char *argv[]) {
      // start system ,这里会根据实际的角色名执行相应的逻辑,比如一开始执行是scheduler的启动任务,
      // 会以scheduler的角色启动,后面的StartServer也是不会执行的,当然这里会在是server角色的时候还会在执行一次
      Start(0);
      // setup server nodes
      StartServer();
      // run worker nodes
      RunWorker();
      // stop system
      Finalize(0, true);
      return 0;
    }
    

    ps-lite在这里其实共用了一套代码,也就是说你启动scheduler、sever和worker 这些都会走一套启动代码,根据不同的角色名称去执行相应的代码逻辑,比如只有在角色scheduler的时候才会触发scheduler相关的代码逻辑。

    PostOffice启动

    接下来就先以scheduler 启动来介绍

      Start(0);
    //实际调用的方法
    inline void Start(int customer_id, const char* argv0 = nullptr) {
      Postoffice::Get()->Start(customer_id, argv0, true);
    }
    

    这里会有一个

    Postoffice::Get()
    

    调用Get方法是去获取PostOffice全局单例对象,这里想要强调的一点就是PostOffice是单例,即一个进程内只有这一个对象,全局变量。

    无论scheduler还是worker 都会调用,那么你可以理解这里PostOffice单例是相对而言的,scheduler进程下有一个,sever进程下也有一个。

    接下来再来看看 Start 函数做了哪些事情?

    void Postoffice::Start(int customer_id, const char* argv0, const bool do_barrier) {
      start_mu_.lock();
      if (init_stage_ == 0) {
        // 初始化环境变量,主要是从 shell 执行脚本中获取相关参数比如 role 角色变量、server 数量和worker 数量
        InitEnvironment();
        // init glog
        if (argv0) {
          dmlc::InitLogging(argv0);
        } else {
          dmlc::InitLogging("ps-lite\0");
        }
    
        // init node info.
        for (int i = 0; i < num_workers_; ++i) {
          int id = WorkerRankToID(i);
          for (int g : {id, kWorkerGroup, kWorkerGroup + kServerGroup,
                        kWorkerGroup + kScheduler,
                        kWorkerGroup + kServerGroup + kScheduler}) {
            node_ids_[g].push_back(id);
          }
        }
    
        for (int i = 0; i < num_servers_; ++i) {
          int id = ServerRankToID(i);
          for (int g : {id, kServerGroup, kWorkerGroup + kServerGroup,
                        kServerGroup + kScheduler,
                        kWorkerGroup + kServerGroup + kScheduler}) {
            node_ids_[g].push_back(id);
          }
        }
    
        for (int g : {kScheduler, kScheduler + kServerGroup + kWorkerGroup,
                      kScheduler + kWorkerGroup, kScheduler + kServerGroup}) {
          node_ids_[g].push_back(kScheduler);
        }
        init_stage_++;
      }
      start_mu_.unlock();
    
      // start van
      van_->Start(customer_id);
    
      start_mu_.lock();
      if (init_stage_ == 1) {
        // record start time
        start_time_ = time(NULL);
        init_stage_++;
      }
      start_mu_.unlock();
      // do a barrier here
      if (do_barrier) Barrier(customer_id, kWorkerGroup + kServerGroup + kScheduler);
    }
    

    现在来看看这个初始化环境的函数做了哪些事情?

    void Postoffice::InitEnvironment() {
      const char* val = NULL;
      std::string van_type = GetEnv("DMLC_PS_VAN_TYPE", "zmq");
      // 核心的一个点就是创建了 Van,至于Van 是什么后续也会做相应的详细介绍
      van_ = Van::Create(van_type);
      //接下来都是在解析环境变量,对于我们而言就是判断这次启动的是哪个?
      val = CHECK_NOTNULL(Environment::Get()->find("DMLC_NUM_WORKER"));
      num_workers_ = atoi(val);
      val =  CHECK_NOTNULL(Environment::Get()->find("DMLC_NUM_SERVER"));
      num_servers_ = atoi(val);
      val = CHECK_NOTNULL(Environment::Get()->find("DMLC_ROLE"));
      std::string role(val);
      is_worker_ = role == "worker";
      is_server_ = role == "server";
      is_scheduler_ = role == "scheduler";
      verbose_ = GetEnv("PS_VERBOSE", 0);
    }
    

    ok,让我们再回到PostOffice的start函数里,假设我们启动的是scheduler角色下的程序

        for (int g : {kScheduler, kScheduler + kServerGroup + kWorkerGroup,
                      kScheduler + kWorkerGroup, kScheduler + kServerGroup}) {
          node_ids_[g].push_back(kScheduler);
        }
        init_stage_++;
    

    那这个node_ids_ 是干啥的?这里就需要讨论下 node_id 的概念

    Node管理

    其实是可以分为两个部分:node group 和 single node_id

    首先我们介绍下 node id 映射功能,就是如何在逻辑节点和物理节点之间做映射,如何把物理节点划分成各个逻辑组,如何用简便的方法做到给组内物理节点统一发消息。

    • 1,2,4分别标识Scheduler, ServerGroup, WorkerGroup。
    • SingleWorker:rank * 2 + 9;SingleServer:rank * 2 + 8。
    • 任意一组节点都可以用单个id标识,等于所有id之和。

    概念

    • Rank 是一个逻辑概念,是每一个节点(scheduler,work,server)内部的唯一逻辑标示。
    • Node id 是物理节点的唯一标识,可以和一个 host + port 的二元组唯一对应。
    • Node Group 是一个逻辑概念,每一个 group 可以包含多个 node id。ps-lite 一共有三组 group : scheduler 组,server 组,worker 组。
    • Node group id 是 是节点组的唯一标示。
      • ps-lite 使用 1,2,4 这三个数字分别标识 Scheduler,ServerGroup,WorkerGroup。每一个数字都代表着一组节点,等于所有该类型节点 id 之和。比如 2 就代表server 组,就是所有 server node 的组合。
      • 为什么选择这三个数字?因为在二进制下这三个数值分别是 “001, 010, 100″,这样如果想给多个 group 发消息,直接把 几个 node group id 做 或操作 就行。
      • 即 1-7 内任意一个数字都代表的是Scheduler / ServerGroup / WorkerGroup的某一种组合。
      • 如果想把某一个请求发送给所有的 worker node,把请求目标节点 id 设置为 4 即可。
      • 假设某一个 worker 希望向所有的 server 节点 和 scheduler 节点同时发送请求,则只要把请求目标节点的 id 设置为 3 即可,因为 3 = 2 + 1 = kServerGroup + kScheduler。
      • 如果想给所有节点发送消息,则设置为 7 即可。

    逻辑组的实现

    三个逻辑组的定义如下:

    /** \brief node ID for the scheduler */
    static const int kScheduler = 1;
    /**
     * \brief the server node group ID
     *
     * group id can be combined:
     * - kServerGroup + kScheduler means all server nodes and the scheuduler
     * - kServerGroup + kWorkerGroup means all server and worker nodes
     */
    static const int kServerGroup = 2;
    /** \brief the worker node group ID */
    static const int kWorkerGroup = 4;
    for (int i = 0; i < num_workers_; ++i) {
          int id = WorkerRankToID(i);
          for (int g : {id, kWorkerGroup, kWorkerGroup + kServerGroup,
                        kWorkerGroup + kScheduler,
                        kWorkerGroup + kServerGroup + kScheduler}) {
            node_ids_[g].push_back(id);
          }
    
    

    如下面代码所示,如果配置了 3 个worker,则 worker 的 rank 从 0 ~ 2,那么这几个 worker 实际对应的 物理 node ID 就会使用 WorkerRankToID 来计算出来。

    node id 是物理节点的唯一标示,rank 是每一个逻辑概念(scheduler,work,server)内部的唯一标示。这两个标示由一个算法来确定。

    Rank vs node id

    node id 是物理节点的唯一标示,rank 是每一个逻辑概念(scheduler,work,server)内部的唯一标示。这两个标示由一个算法来确定。

    如下面代码所示,如果配置了 3 个worker,则 worker 的 rank 从 0 ~ 2,那么这几个 worker 实际对应的 物理 node ID 就会使用 WorkerRankToID 来计算出来。

        for (int i = 0; i < num_workers_; ++i) {
          int id = WorkerRankToID(i);
          for (int g : {id, kWorkerGroup, kWorkerGroup + kServerGroup,
                        kWorkerGroup + kScheduler,
                        kWorkerGroup + kServerGroup + kScheduler}) {
            node_ids_[g].push_back(id);
          }
        }
    

    具体计算规则如下:

      /**
       * \brief convert from a worker rank into a node id
       * \param rank the worker rank
       */
      static inline int WorkerRankToID(int rank) {
        return rank * 2 + 9;
      }
      /**
       * \brief convert from a server rank into a node id
       * \param rank the server rank
       */
      static inline int ServerRankToID(int rank) {
        return rank * 2 + 8;
      }
      /**
       * \brief convert from a node id into a server or worker rank
       * \param id the node id
       */
      static inline int IDtoRank(int id) {
    #ifdef _MSC_VER
    #undef max
    #endif
        return std::max((id - 8) / 2, 0);
      }
    
    
    
    • SingleWorker:rank * 2 + 9;
    • SingleServer:rank * 2 + 8;

    而且这个算法保证server id为偶数,node id为奇数。

    这样我们可以知道,1-7 的id表示的是node group,单个节点的id 就从 8 开始。

    具体计算规则如下:

    Group vs node

    因为有时请求要发送给多个节点,所以ps-lite用了一个 map 来存储每个 node group / single node 对应的实际的node节点集合,即 确定每个id值对应的节点id集。

    std::unordered_map<int, std::vector<int>> node_ids_ 
    
        for (int i = 0; i < num_workers_; ++i) {
          int id = WorkerRankToID(i);
          for (int g : {id, kWorkerGroup, kWorkerGroup + kServerGroup,
                        kWorkerGroup + kScheduler,
                        kWorkerGroup + kServerGroup + kScheduler}) {
            node_ids_[g].push_back(id);
          }
        }
    

    这 5 个id 相对应,即需要在 node_ids_ 这个映射表中对应的 4, 4 + 1, 4 + 2, 4 +1 + 2, 12 这五个 item 之中添加。就是上面代码中的内部 for 循环条件。即,node_ids_ [4], node_ids_ [5],node_ids_ [6],node_ids_ [7] ,node_ids_ [12] 之中,都需要把 12 添加到 vector 最后。

    • 12(本身)
    • 4(kWorkerGroup)
    • 4+1(kWorkerGroup + kScheduler)
    • 4+2(kWorkerGroup + kServerGroup)
    • 4+1+2,(kWorkerGroup + kServerGroup + kScheduler )

    所以,为了实现 “设置 1-7 内任意一个数字 可以发送给其对应的 所有node” 这个功能,对于每一个新节点,需要将其对应多个id(node,node group)上,这些id组就是本节点可以与之通讯的节点。例如对于 worker 2 来说,其 node id 是 2 * 2 + 8 = 12,所以需要将它与

    • 1 ~ 7 的 id 表示的是 node group;
    • 后续的 id(8,9,10,11 …)表示单个的 node。其中双数 8,10,12… 表示 worker 0, worker 1, worker 2,… 即(2n + 8),9,11,13,…,表示 server 0, server 1,server 2,…,即(2n + 9);

    还是花了不少的功夫在讲解node,那么这个node 的标记是用来干啥的?

    这些node的标记实际上与我们的worker还有server都是对应的关心,所以通过这些node标记就可以快速找打,这样通信同步一些数据就方便。

    在记录完node_id之后,开始调用Van的启动程序。Van其实是一个通信模块。Van的东西还是蛮多的,打算放在下一篇文章里讲了。

    在继续就是讲到 Barrier

      if (do_barrier) Barrier(customer_id, kWorkerGroup + kServerGroup + kScheduler);
    

    Barrier

    同步

    总的来讲,schedular节点通过计数的方式实现各个节点的同步。具体来说就是:

    • 每个节点在自己指定的命令运行完后会向schedular节点发送一个Control::BARRIER命令的请求并自己阻塞直到收到schedular对应的返回后才解除阻塞;
    • schedular节点收到请求后则会在本地计数,看收到的请求数是否和barrier_group的数量是否相等,相等则表示每个机器都运行完指定的命令了,此时schedular节点会向barrier_group的每个机器发送一个返回的信息,并解除其阻塞。

    初始化

    ps-lite 使用 Barrier 来控制系统的初始化,就是大家都准备好了再一起前进。这是一个可选项。具体如下:

    • Scheduler等待所有的worker和server发送BARRIER信息;
    • 在完成ADD_NODE后,各个节点会进入指定 group 的Barrier阻塞同步机制(发送 BARRIER 给 Scheduler),以保证上述过程每个节点都已经完成;
    • 所有节点(worker和server,包括scheduler) 等待scheduler收到所有节点 BARRIER 信息后的应答;
    • 最终所有节点收到scheduler 应答的Barrier message后退出阻塞状态;
    等待 BARRIER 消息

    Node会调用 Barrier 函数 告知Scheduler,随即自己进入等待状态。

    注意,调用时候是

    if (do_barrier) Barrier(customer_id, kWorkerGroup + kServerGroup + kScheduler);  
    
    复制代码
    void Postoffice::Barrier(int customer_id, int node_group) {
      if (GetNodeIDs(node_group).size() <= 1) return;
      auto role = van_->my_node().role;
      if (role == Node::SCHEDULER) {
        CHECK(node_group & kScheduler);
      } else if (role == Node::WORKER) {
        CHECK(node_group & kWorkerGroup);
      } else if (role == Node::SERVER) {
        CHECK(node_group & kServerGroup);
      }
    
      std::unique_lock<std::mutex> ulk(barrier_mu_);
      barrier_done_[0][customer_id] = false;
      Message req;
      req.meta.recver = kScheduler;
      req.meta.request = true;
      req.meta.control.cmd = Control::BARRIER;
      req.meta.app_id = 0;
      req.meta.customer_id = customer_id;
      req.meta.control.barrier_group = node_group; // 记录了等待哪些
      req.meta.timestamp = van_->GetTimestamp();
      van_->Send(req); // 给 scheduler 发给 BARRIER
      barrier_cond_.wait(ulk, [this, customer_id] { // 然后等待
          return barrier_done_[0][customer_id];
        });
    }
    
    

    这就是说,等待所有的 group,即 scheduler 节点也要给自己发送消息。

    处理 BARRIER 消息

    处理等待的动作在 Van 类之中,我们提前放出来。

    具体ProcessBarrierCommand逻辑如下:

    • 如果 msg->meta.request 为true,说明是 scheduler 收到消息进行处理。
      • Scheduler会对Barrier请求进行增加计数。
      • 当 Scheduler 收到最后一个请求时(计数等于此group节点总数),则将计数清零,发送结束Barrier的命令。这时候 meta.request 设置为 false;
      • 向此group所有节点发送request==false的BARRIER消息。
    • 如果 msg->meta.request 为 false,说明是收到消息这个 respones,可以解除barrier了,于是进行处理,调用 Manage 函数 。
      • Manage 函数 将app_id对应的所有costomer的barrier_done_置为true,然后通知所有等待条件变量barrier_cond_.notify_all()。
    void Van::ProcessBarrierCommand(Message* msg) {
      auto& ctrl = msg->meta.control;
      if (msg->meta.request) {  // scheduler收到了消息,因为 Postoffice::Barrier函数 会在发送时候做设置为true。
        if (barrier_count_.empty()) {
          barrier_count_.resize(8, 0);
        }
        int group = ctrl.barrier_group;
        ++barrier_count_[group]; // Scheduler会对Barrier请求进行计数
        if (barrier_count_[group] ==
            static_cast<int>(Postoffice::Get()->GetNodeIDs(group).size())) { // 如果相等,说明已经收到了最后一个请求,所以发送解除 barrier 消息。
          barrier_count_[group] = 0;
          Message res;
          res.meta.request = false; // 回复时候,这里就是false
          res.meta.app_id = msg->meta.app_id;
          res.meta.customer_id = msg->meta.customer_id;
          res.meta.control.cmd = Control::BARRIER;
          for (int r : Postoffice::Get()->GetNodeIDs(group)) {
            int recver_id = r;
            if (shared_node_mapping_.find(r) == shared_node_mapping_.end()) {
              res.meta.recver = recver_id;
              res.meta.timestamp = timestamp_++;
              Send(res);
            }
          }
        }
      } else { // 说明这里收到了 barrier respones,可以解除 barrier了。具体见上面的设置为false处。
        Postoffice::Get()->Manage(*msg);
      }
    }
    
    
    

    Manage 函数就是解除了 barrier。

    void Postoffice::Manage(const Message& recv) {
      CHECK(!recv.meta.control.empty());
      const auto& ctrl = recv.meta.control;
      if (ctrl.cmd == Control::BARRIER && !recv.meta.request) {
        barrier_mu_.lock();
        auto size = barrier_done_[recv.meta.app_id].size();
        for (size_t customer_id = 0; customer_id < size; customer_id++) {
          barrier_done_[recv.meta.app_id][customer_id] = true;
        }
        barrier_mu_.unlock();
        barrier_cond_.notify_all(); // 这里解除了barrier
      }
    }
    
    

    在上面的启动程序中可能没见到下面两个函数的调用,但是这也是 Postoffice 重要的成员组成

    数据key分布式存储

    到现在为止,邮车和customer都有了,信件本身无非就是embedding这些参数,但是这些参数的存放也是有讲究的,这也是在上一篇文章中提到的分布式存储,这个分布式是如何体现的?

    const std::vector<Range>& Postoffice::GetServerKeyRanges() {
      server_key_ranges_mu_.lock();
      //循环遍历所有的server,配置server key 的范围
      //本质上就是根据server的数量均匀划分而已,就是这么简单
      if (server_key_ranges_.empty()) {
        for (int i = 0; i < num_servers_; ++i) {
          server_key_ranges_.push_back(Range(
              kMaxKey / num_servers_ * i,
              kMaxKey / num_servers_ * (i+1)));
        }
      }
      server_key_ranges_mu_.unlock();
      return server_key_ranges_;
    }
    

    通过以上的操作的确解决了数据分布式存储,而且可以明确在worker向server端拉取数据的时候要去哪个server拉数据的问题。

    用户管理

    现在大概知道了邮车,那么怎么知道要给哪些customer送信件呢?邮局需要管理一份用户的名单。

    Customer* Postoffice::GetCustomer(int app_id, int customer_id, int timeout) const {
      Customer* obj = nullptr;
      for (int i = 0; i < timeout * 1000 + 1; ++i) {
        {
          std::lock_guard<std::mutex> lk(mu_);
          // app_id 是对应 kv存储的id,举个例子FM 里存在一阶weight app_id=0
          // 通过app_id 去寻找customer,一般 worker 会有多个thread 对应不同的customer
            //但是消费的都是同一个 kv,所以根据app_id可以找到对应的 customer
          const auto it = customers_.find(app_id);
          if (it != customers_.end()) {
            std::unordered_map<int, Customer*> customers_in_app = it->second;
            obj = customers_in_app[customer_id];
            break;
          }
        }
        std::this_thread::sleep_for(std::chrono::milliseconds(1));
      }
      return obj;
    }
    

    这个 GetCustomer 的操作主要是在Van中的 ProcessDataMsg 调用,这里就是Van要把传递的信件交给customer,然后通过 GetCustomer 这个方式来获取相应的customer。

    上面的函数列的是读取,还有 AddCustomer 和 RemoveCustomer 负责添加和删除。



沪ICP备19023445号-2号
友情链接