IT博客汇
  • 首页
  • 精华
  • 技术
  • 设计
  • 资讯
  • 扯淡
  • 权利声明
  • 登录 注册

    ps-lite_part5_kvworker和kvsever

    admin发表于 2023-03-23 14:21:29
    love 0

    先回顾一下之前写到哪里了

    1. 介绍ps-lite的基本概念 https://www.deeplearn.me/4302.html
    2. 介绍ps-lite核心组成 postOffice https://www.deeplearn.me/4303.html
    3. 介绍ps-lite 通信模块van https://www.deeplearn.me/4306.html
    4. 介绍ps-lite 中介 customer https://www.deeplearn.me/4308.html

    这篇文章主要讲一下server 和woker,在扒拉一下ps架构的一张图

    image-20230323205127231

    一般意义上来说:

    1. server负责梯度和参数的更新
    2. woker端负责前向和后向的计算

    这也是之前有customer出现的缘故,server和worker集中去计算,负责通信的任务就交给customer。在上一节讲customer在哪里被创建的时候就提到kvworker和kvserver,这里在着重讲一下吧!

    在这之前还是要补充一点kvworker 和kvserver都继承 SimpleApp,那么SimpleApp 又是啥?

    SimpleApp:KVServer和KVWorker的父类,它提供了简单的Request, Wait, Response,Process功能;KVServer和KVWorker分别根据自己的使命重写了这些功能;

    kvwoker

    构造函数

     explicit KVWorker(int app_id, int customer_id) : SimpleApp() {
        using namespace std::placeholders;
        slicer_ = std::bind(&KVWorker<Val>::DefaultSlicer, this, _1, _2, _3);
        obj_ = new Customer(app_id, customer_id, std::bind(&KVWorker<Val>::Process, this, _1));
      }
    

    这里关于构造函数的定义也在上一节提到了,此处略过哈!

    PULL函数

    从开始的图你也看到worker需要从server拉取参数数据,那么肯定需要pull。

     int Pull(const std::vector<Key>& keys,
               std::vector<Val>* vals,
               std::vector<int>* lens = nullptr,
               int cmd = 0,
               const Callback& cb = nullptr,
               int priority = 0) {
        SArray<Key> skeys(keys);
        int ts = AddPullCB(skeys, vals, lens, cmd, cb);
        KVPairs<Val> kvs;
        kvs.keys = skeys;
        kvs.priority = priority;
        Send(ts, false, true, cmd, kvs);
        return ts;
      }
    
    

    这里面有两个需要关注的调用,AddPullCB 和 Send,依次来看下这两个函数的定义和功能

    AddPullCB 是添加一个callback,这个callback等所有server返回结果之后在执行,可以认为是一个阻塞等操作。

    int KVWorker<Val>::AddPullCB(
    // C* vals和D* lens指向由调用者指定的结构体。
    // 等所有server都返回后,从所有server拉来的数据
        const SArray<Key>& keys, C* vals, D* lens, int cmd,
    // Callback& cb代表在所有server回复后要执行的额外的回调
    // 一般我们都是在pull后就立刻阻塞等待,所以cb一般为空
        const Callback& cb) {
    // ************** 创建request,返回的ts是该request_id
      int ts = obj_->NewRequest(kServerGroup);
    
    // ************** 添加callback,等所有server都回复后再执行
      AddCallback(ts, [this, ts, keys, vals, lens, cb]() mutable {
          ......
          // 容纳ts(即request_id)所接受数据的缓冲区
          auto& kvs = recv_kvs_[ts];
          ......
    
          // total_keys是根据kvs统计出来的接收到的key的总数
          // keys是当初请求的所有keys,检查二者是否相等
          ......
          CHECK_EQ(total_key, keys.size()) << "lost some servers?";
    
    // ************** 将所有server返回的数据,合并,填充到用户指定的输出位置
          // vals和lens都指向调用者传入的结构体
          // p_vals和p_lens都是指向输出区的指针
          Val* p_vals = vals->data();
          ......
            p_lens = lens->data();
          ......
          // 遍历从各台server接收到的内容,填充到输出区p_vals和p_lens
          for (const auto& s : kvs) {
            memcpy(p_vals, s.vals.data(), s.vals.size() * sizeof(Val));
            p_vals += s.vals.size();
            if (p_lens) {
              memcpy(p_lens, s.lens.data(), s.lens.size() * sizeof(int));
              p_lens += s.lens.size();
            }
          }
          ......
          recv_kvs_.erase(ts);//清空本次请求的接收缓冲区
          ......
          if (cb) cb();// 如果有额外的callback,执行之
        });
    
      return ts;
    }
    

    send的操作才是真正的去请求server,下面看下send的定义

    void KVWorker<Val>::Send(int timestamp, bool push, bool pull, int cmd, const KVPairs<Val>& kvs) {
      // ****************** 决定要向哪些server发送请求
      SlicedKVs sliced;// 存储分配结果
      slicer_(kvs, Postoffice::Get()->GetServerKeyRanges(), &sliced);
    
      // ****************** 有些server不包含本次请求要求的keys,提前处理
      int skipped = 0;// 本次请求不涉及的servers的总数
      //这里调用first参数,需要去追溯一下SlicedKVs 的定义
      // using SlicedKVs = std::vector<std::pair<bool, KVPairs<Val>>>;
      // bool 参数决定是否需要去这个server节点拉取数据,不需要直接跳过
      for (size_t i = 0; i < sliced.size(); ++i) {
        if (!sliced[i].first) ++skipped;
      }
      // 内部不过是tracker_[timestamp].second += skipped
      // 假设这些不涉及的servers已经返回了
      obj_->AddResponse(timestamp, skipped);
    
      ......
    
      // ****************** 向所有涉及到的server发送请求
      for (size_t i = 0; i < sliced.size(); ++i) {
        const auto& s = sliced[i];
        if (!s.first) continue;//本次请求不需要访问的server节点直接跳过
    
        Message msg;
        msg.meta.app_id = obj_->app_id();
        msg.meta.customer_id = obj_->customer_id();
        msg.meta.request     = true;
        msg.meta.push        = push;
        msg.meta.pull        = pull;
        msg.meta.head        = cmd;
        msg.meta.timestamp   = timestamp;
        msg.meta.recver      = Postoffice::Get()->ServerRankToID(i);
        msg.meta.priority    = kvs.priority;
    
        const auto& kvs = s.second;//分配到当前节点上的key-value pairs
        if (kvs.keys.size()) {
          msg.AddData(kvs.keys);
          msg.AddData(kvs.vals);
          if (kvs.lens.size()) {
            msg.AddData(kvs.lens);
          }
        }
        //通过van通信模块发送请求
        Postoffice::Get()->van()->Send(msg);
      }
    }
    

    至此再回去看pull 应该就差不多了,除了pull之外还有一个zpull ,全称是zero pull,说是实现了零拷贝,起到一个加速的作用,这里就不细看了。

    PUSH

    说完pull 就是push了,woker的push 就是要把梯度传给server,让server 去更新参数。

      int ZPush(const SArray<Key>& keys,
                const SArray<Val>& vals,
                const SArray<int>& lens = {},
                int cmd = 0,
                const Callback& cb = nullptr,
                int priority = 0) {
        int ts = obj_->NewRequest(kServerGroup);
        AddCallback(ts, cb);
        KVPairs<Val> kvs;
        kvs.keys = keys;
        kvs.vals = vals;
        kvs.lens = lens;
        kvs.priority = priority;
        // send 将这些梯度传递到指定的server上
        Send(ts, true, false, cmd, kvs);
        return ts;
      }
    

    同时也还有一个zpush,本质上实现的功能是一致的。

    差不多 woker 就这些事情,接下来讲下server ,其实都差不多,因为只是各自干的事情内容又一点不一样而已。

    kvserver

    构造函数

    explicit KVServer(int app_id) : SimpleApp() {
        using namespace std::placeholders;
        obj_ = new Customer(app_id, app_id, std::bind(&KVServer<Val>::Process, this, _1));
      }
    

    Server 主要是处理参数更新和数据查询

    1. 参数更新:根据梯度更新相应的神经网络参数
    2. 数据查询:worker需要拉取参数去执行前向传播

    完成上述需求主要依靠两个函数

    Process

    这个主要是来处理woker push 过来的数据

    template <typename Val>
    void KVServer<Val>::Process(const Message& msg) {
      if (msg.meta.simple_app) {
        SimpleApp::Process(msg); return;
      }
      KVMeta meta;
      meta.cmd       = msg.meta.head;
      meta.push      = msg.meta.push;
      meta.pull      = msg.meta.pull;
      meta.sender    = msg.meta.sender;
      meta.timestamp = msg.meta.timestamp;
      meta.customer_id = msg.meta.customer_id;
      //KVPairs 保存的就是传递的数据
      KVPairs<Val> data;
      int n = msg.data.size();
      if (n) {
        CHECK_GE(n, 2);
        data.keys = msg.data[0];
        data.vals = msg.data[1];
        if (n > 2) {
          CHECK_EQ(n, 3);
          data.lens = msg.data[2];
          CHECK_EQ(data.lens.size(), data.keys.size());
        }
      }
      CHECK(request_handle_);
      //这个request_handle_是用户自定义的处理逻辑函数,主要是梯度更新参数的规则等
      request_handle_(meta, data, this);
    }
    

    这里给出test里面的一个实例

    void StartServer() {
      if (!IsServer()) return;
      auto server = new KVServer<float>(0);
      //这一步就是在设置 request_handle_
      server->set_request_handle(KVServerDefaultHandle<float>());
      RegisterExitCallback([server](){ delete server; });
    }
    

    Response

    故名思义就是将数据回复给worker,好像没啥要讲的。。。

    template <typename Val>
    void KVServer<Val>::Response(const KVMeta& req, const KVPairs<Val>& res) {
      //res里存储的就是worker需要数据,这里只是在包装以 Message 封装一下,最后在通过send回复给worker
      Message msg;
      msg.meta.app_id = obj_->app_id();
      msg.meta.customer_id = req.customer_id;
      msg.meta.request     = false;
      msg.meta.push        = req.push;
      msg.meta.pull        = req.pull;
      msg.meta.head        = req.cmd;
      msg.meta.timestamp   = req.timestamp;
      msg.meta.recver      = req.sender;
      if (res.keys.size()) {
        msg.AddData(res.keys);
        msg.AddData(res.vals);
        if (res.lens.size()) {
          msg.AddData(res.lens);
        }
      }
      Postoffice::Get()->van()->Send(msg);
    }
    


沪ICP备19023445号-2号
友情链接