机器学习笔记(17.1): C2P2SL

可能更好的阅读体验:Jeefy's blog: \(C^2P^2SL\)

传统的 SL(Split Learning),例如 [[SplitNN.pdf]] 框架或者 [[SplitFed]],其核心就是将模型分层。在客户端向前几步,然后上传到服务器向前,然后返回。

常用的框架就这两种。上面是 Client,下面是 Server

  • UE:user equipments
  • BS: base station
  • UT: uplink transmission
  • DT: downlink transmission
  • SL: split learning
  • DNNs: deep neural networks
  • SFL: split federated learning
  • EPSL: efficient parallel split learning
  • PP: pipeline parallelism
  • AO: alternating optimization (交替优化)
  • FP: forward propagation
  • BP: backward propagation

  • \(\Delta\) 计算建模

一个 BS,\(n\) 个 UE。

其中 \(f_i = K_U F_i\) 表示计算容量,\(c_j^F, c_j^B\) 表示划分后的模型在每一层向前或者向后需要的计算量,\(s_l\) 表示层输出的参数量。

  • \(\Delta\) 通信模型(TDMA:Time Division Multiple Access)

将时间切片为长度为 \(T\) 的帧,每一帧划分为若干的槽(Slots),每一个槽的时间由每个设备的 \(\tau_i\) 确定。

这里假定每一个小段中,通信是稳定的,不会断开。
原文什么 AWGN 和通信速率,不懂。

上下传输速度由 \(r^i_u, r^i_d\) 变量衡量,\(p_i\) 表示上传能耗,\(p_B\) 表示下载能耗。\(G, h, d, f\) 是通信中的一些参数,不懂,略。


  • \(\Delta\) 算法流程
  1. Forward

BS 广播模型,UE 将数据划分为大小为 \(k\) 的 mini-batch,进行 FP。然后可以衡量 UE 的计算时间 \(t_i^F\)。计算完成后,将层输出和对应标签上传,然后继续下一次 FP。

上传的时间可以衡量 \(t_i^U\),上传将数据聚合后,可以计算 BS 的 FP 时间 \(t_b^F\)

  1. Backward

BS FP 后立即 BP,可以由 \(t_b^B\) 衡量时间。计算完后,向下传输梯度,由 \(t_i^D\) 衡量时间。

然后 UE 继续 BP,由 \(t_i^B\) 衡量时间。

对于每一个 mini-batch 的训练,整个时间帧上的设备运行状态可以这样表示:

设备 \ 时间
UE 1 FP 1 FP 2 FP k BP 1 BP 2 BP k
UE n FP 1 FP 2 FP k BP 1 BP 2 BP k
网络 UT 1 UT 2 UT k DT 1 DT 2
DT k
BS FP 1 BP 1 FP 2 BP 2 FP k BP k

注意,由于数量上的关系,中间一定会有 4 帧时间是空闲的,没办法。


接下来的问题就很简单了,怎么平衡时常?

首先是对于 FP 部分,需要满足:

\[\max \{ t_i^F, t_i^U \} \le t_b^F + t_b^F \]

也就是保证 BS 可以持续运行,每次通信和 UE FP 的时间不能长于 BS 处理当前数据的时常。

对于 BP 部分,由于一定会空出一个设备帧,没办法,所以可以适当的放松一点时间:

\[(k - 1) (\max_i t_i^U + \max_i t_i^D) \le k(t_b^F + t_b^B) \]

也就是要保证在最后一次 DT k 完成并下发前,前面的任务全部完成,不阻塞。

问题建模肯定是最小化空闲时间的占比,也就是:

\[\min BR = \frac {t_{idle}}{t_{idle} + t_{work}} = \frac {\max_i (t_i^F + t_i^U) + \max_i(t_i^D + t_i^B)}{\max_i (t_i^F + t_i^U) + \max_i(t_i^D + t_i^B) + k(t_b^F + t_b^B)} \]


然后就是复杂的 AO 优化环节。简单来说:

输入:收敛容差 ϵ(convergence tolerance)
初始化:k(0), l(0), b(0), τ(0),迭代索引 m = 0
重复:
  m ← m + 1,BR_min ← ∞
  for l = 1 to L do
    由 Lemma 1 计算 k
    if BR(l,k) < BR_min then
      BR_min ← BR(l,k),记录 l(m), k(m)
    end if
  end for
  固定 (l(m),k(m)),解 MILP P3 得 b(m)
  固定 (l(m),k(m),b(m)),解凸问题 P5 得 τ(m)
直到 |BR(m) - BR(m-1)| ≤ ϵ
输出:k*, l*, b*, τ*

没怎学过,先略

posted @ 2026-05-06 18:13  jeefy  阅读(5)  评论(0)    收藏  举报