机器学习笔记(17.1): C2P2SL
可能更好的阅读体验:Jeefy's blog: \(C^2P^2SL\)
传统的 SL(Split Learning),例如 [[SplitNN.pdf]] 框架或者 [[SplitFed]],其核心就是将模型分层。在客户端向前几步,然后上传到服务器向前,然后返回。

常用的框架就这两种。上面是 Client,下面是 Server
- UE:user equipments
- BS: base station
- UT: uplink transmission
- DT: downlink transmission
- SL: split learning
- DNNs: deep neural networks
- SFL: split federated learning
- EPSL: efficient parallel split learning
- PP: pipeline parallelism
- AO: alternating optimization (交替优化)
- FP: forward propagation
- BP: backward propagation
- \(\Delta\) 计算建模
一个 BS,\(n\) 个 UE。

其中 \(f_i = K_U F_i\) 表示计算容量,\(c_j^F, c_j^B\) 表示划分后的模型在每一层向前或者向后需要的计算量,\(s_l\) 表示层输出的参数量。
- \(\Delta\) 通信模型(TDMA:Time Division Multiple Access)
将时间切片为长度为 \(T\) 的帧,每一帧划分为若干的槽(Slots),每一个槽的时间由每个设备的 \(\tau_i\) 确定。
这里假定每一个小段中,通信是稳定的,不会断开。
原文什么 AWGN 和通信速率,不懂。
上下传输速度由 \(r^i_u, r^i_d\) 变量衡量,\(p_i\) 表示上传能耗,\(p_B\) 表示下载能耗。\(G, h, d, f\) 是通信中的一些参数,不懂,略。
- \(\Delta\) 算法流程
- Forward
BS 广播模型,UE 将数据划分为大小为 \(k\) 的 mini-batch,进行 FP。然后可以衡量 UE 的计算时间 \(t_i^F\)。计算完成后,将层输出和对应标签上传,然后继续下一次 FP。
上传的时间可以衡量 \(t_i^U\),上传将数据聚合后,可以计算 BS 的 FP 时间 \(t_b^F\)
- Backward
BS FP 后立即 BP,可以由 \(t_b^B\) 衡量时间。计算完后,向下传输梯度,由 \(t_i^D\) 衡量时间。
然后 UE 继续 BP,由 \(t_i^B\) 衡量时间。
对于每一个 mini-batch 的训练,整个时间帧上的设备运行状态可以这样表示:
| 设备 \ 时间 | ||||||||||
|---|---|---|---|---|---|---|---|---|---|---|
| UE 1 | FP 1 | FP 2 | FP k | BP 1 | BP 2 | BP k | ||||
| UE n | FP 1 | FP 2 | FP k | BP 1 | BP 2 | BP k | ||||
| 网络 | UT 1 | UT 2 | UT k | DT 1 | DT 2 |
DT k | ||||
| BS | FP 1 | BP 1 | FP 2 | BP 2 | FP k | BP k | ||||
注意,由于数量上的关系,中间一定会有 4 帧时间是空闲的,没办法。
接下来的问题就很简单了,怎么平衡时常?
首先是对于 FP 部分,需要满足:
也就是保证 BS 可以持续运行,每次通信和 UE FP 的时间不能长于 BS 处理当前数据的时常。
对于 BP 部分,由于一定会空出一个设备帧,没办法,所以可以适当的放松一点时间:
也就是要保证在最后一次 DT k 完成并下发前,前面的任务全部完成,不阻塞。
问题建模肯定是最小化空闲时间的占比,也就是:
然后就是复杂的 AO 优化环节。简单来说:
输入:收敛容差 ϵ(convergence tolerance)
初始化:k(0), l(0), b(0), τ(0),迭代索引 m = 0
重复:
m ← m + 1,BR_min ← ∞
for l = 1 to L do
由 Lemma 1 计算 k
if BR(l,k) < BR_min then
BR_min ← BR(l,k),记录 l(m), k(m)
end if
end for
固定 (l(m),k(m)),解 MILP P3 得 b(m)
固定 (l(m),k(m),b(m)),解凸问题 P5 得 τ(m)
直到 |BR(m) - BR(m-1)| ≤ ϵ
输出:k*, l*, b*, τ*
没怎学过,先略

浙公网安备 33010602011771号