初识AES

AES(1)

零、 前言

​ 朝花夕拾杯中酒。从CTF开始攻击CBC模式下的加密漏洞,再到第四学期的《密码学基础》课程,时至今日,《密码工程》课程的需要以及自己对于对称密码的兴趣驱使,下定决心写下这篇笔记记录自己对于AES的认识以及学习。这篇文章工期拖得比较长,一方面是因为自己的拖延症,二是针对于密码学的工程应用以及编码硬实力的不足,自己翻阅了很多师傅的相关资料,也翻阅了部分教材和相关资料,会在结尾进行指明。由于作者本人的实力与阅历不足,本文在内容、专业知识、排版以及语法表达上或多或少存在错误,请读者仔细辨别并积极指出交流。很庆幸完成了这篇笔记。😊😊

一、AES背景介绍

\(1997\)\(9\)月,美国国家标准与技术研究院\(NIST\)公开征集新的高级加密标准\((Advanced Encryption Standard, AES)\),取代\(DES\)算法.\(1998\)\(6\)月,提交过程截至,共收到\(21\)个算法.经过简单筛选,\(1998\)\(8\)月,公开来自\(12\)个国家的\(15\)个候选算法进入第一轮.\(1999\)\(8\)月,\(5\)个候选算法进入第二轮,分别是\(MARS\)\(RC6\)\(Rijndael\)\(Serpent\)\(Twofish\)算法.\(2000\)\(10\)月,由比利时密码学家\(Daemen\)\(Rijmen\)共同设计的\(Rijndael\)算法获胜,经过讨论和规范后,于\(2001\)\(12\)月作为\(NIPS 197\)发布,成为高级加密标准\(AES\)

\(AES\)算法采用\(SPN\)结构,分组长度为\(128-bit(16-Byte)\).将输入按照字节划分,从\(0\)\(15\)标号,并按照下图所示的顺序排列

0 4 8 12
1 5 9 13
2 6 10 14
3 7 11 15

\(AES\)算法共有三个版本,对应不同的密钥长度和安全强度:当密钥长度为\(128-bit\)时,迭代轮数为\(10\)轮;当密钥长度为\(192-bit\)时,迭代轮数为\(12\)轮;当密钥长度为\(256-bit\)时,迭代轮数为\(14\)轮,分别标记为\(AES-128/192/256\).

​ 本文将介绍\(AES\)的内部结构、\(AES\)加密、解密算法的各种编程语言实现、\(AES\)的各种工作模式、\(AES\)的优化策略以及优化结果展示。

二、内部结构

\(AES\)为分组对称密码,针对于其内部结构而言,每轮加密/解密由(逆)字节代换、(逆)行移位、(逆)列混合、轮密钥加四部分组成。所谓分组,即将需要加密的明文\(message\)\(128-bit\)分成组别,每一组都经过\(10/12/14\)轮上述的四部分操作进行加密。假设\(message = b'abcdefghijklmnopqrstuvwxyz'\),那么\(block0=b'abcdefghijklmnop',bloc1='qrstuvwxyz'\),而\(block1\)由于内容并不够\(128-bit\),所以我们需要进行填充,即block1=b'qrstuvwxyz\x00\x00\x00\x00\x00\x00'。注意,\(AES\)使用的四个组件都是可逆的,针对于解密而言,就是高度对称下的逆向组件的使用,接下来我们将从加密和解密中所用到的四个组件进行详细介绍。

​ 为了方便\(AES\)的展示和实现,我们通过状态矩阵来展示一个分组的加密或解密的详细过程,例如\(block0\)会根据先列后行的顺序被分配到状态矩阵中进行后续的组件操作。

a e i m
b f j n
c g k o
d h l p

​ 针对于明文信息,我们需要按照字节对应为相应的ASCII码进行后的加密/解密操作。

0x61 0x65 0x69 0x73
0x62 0x66 0x70 0x74
0x63 0x67 0x71 0x75
0x64 0x68 0x72 0x76

1. 字节代换

字节代换

​ 所谓字节代换,其实就是一个查表操作,而这里的表,我们称之为\(S盒\)
\(AES\)\(S盒\)

行/列 0 1 2 3 4 5 6 7 8 9 A B C D E F
0 0x63 0x7c 0x77 0x7b 0xf2 0x6b 0x6f 0xc5 0x30 0x01 0x67 0x2b 0xfe 0xd7 0xab 0x76
1 0xca 0x82 0xc9 0x7d 0xfa 0x59 0x47 0xf0 0xad 0xd4 0xa2 0xaf 0x9c 0xa4 0x72 0xc0
2 0xb7 0xfd 0x93 0x26 0x36 0x3f 0xf7 0xcc 0x34 0xa5 0xe5 0xf1 0x71 0xd8 0x31 0x15
3 0x04 0xc7 0x23 0xc3 0x18 0x96 0x05 0x9a 0x07 0x12 0x80 0xe2 0xeb 0x27 0xb2 0x75
4 0x09 0x83 0x2c 0x1a 0x1b 0x6e 0x5a 0xa0 0x52 0x3b 0xd6 0xb3 0x29 0xe3 0x2f 0x84
5 0x53 0xd1 0x00 0xed 0x20 0xfc 0xb1 0x5b 0x6a 0xcb 0xbe 0x39 0x4a 0x4c 0x58 0xcf
6 0xd0 0xef 0xaa 0xfb 0x43 0x4d 0x33 0x85 0x45 0xf9 0x02 0x7f 0x50 0x3c 0x9f 0xa8
7 0x51 0xa3 0x40 0x8f 0x92 0x9d 0x38 0xf5 0xbc 0xb6, 0xda 0x21 0x10 0xff 0xf3 0xd2
8 0xcd 0x0c 0x13 0xec 0x5f 0x97 0x44 0x17 0xc4 0xa7 0x7e 0x3d 0x64 0x5d 0x19 0x73
9 0x60 0x81 0x4f 0xdc 0x22 0x2a 0x90 0x88 0x46 0xee 0xb8 0x14 0xde 0x5e 0x0b 0xdb
A 0xe0 0x32 0x3a 0x0a 0x49 0x06 0x24 0x5c 0xc2 0xd3 0xac 0x62 0x91 0x95 0xe4 0x78
B 0xe7 0xc8 0x37 0x6d 0x8d 0xd5 0x4e 0xa9 0x6c 0x56 0xf4 0xea 0x65 0x7a 0xae 0x08
C 0xba 0x78 0x25 0x2e 0x1c 0xa6 0xb4 0xc6 0xe8 0xdd 0x74 0x1f 0x4b 0xbd 0x8b 0x8a
D 0x70 0x3e 0xb5 0x66 0x48 0a03 0xf6 0x0e 0x61 0x35 0x57 0xb9 0x86 0xc1 0x1d 0x9e
E 0xe1 0xf8 0x98 0x11 0x69 0xd9 0x8e 0x94 0x9b 0x1e 0x87 0xe9 0xce 0x55 0x28 0xdf
F 0x8c 0xa1 0x89 0x0d 0xbf 0xe6 0x42 0x68 0x41 0x99 0x2d 0x0f 0xb0 0x54 0xbb 0x16

​ 状态矩阵中的元素按照下面的方式映射为一个新的字节:把该字节的高\(4\)位作为行值。低\(4\)位作为列值,去除\(S盒\)或者\(逆S盒\)中对应的行的元素作为输出。例如,加密时,输出的字节\(S1=b'0x12'\),则查\(S盒\)的第\(0x01行\)\(0x02列\),得到值\(0xc9\),然后替换\(S1\)原有的\(0x12\)\(0xc9\)

逆字节代换

​ 操作步骤与实现过程与\(S盒\)一直,是字节代换的逆操作,一般称之为\(逆S盒\)
\(AES\)\(逆S盒\)

行/列 0 1 2 3 4 5 6 7 8 9 A B C D E F
0 0x52 0x09 0x6a 0xd5 0x30 0x36 0xa5 0x38 0xbf 0x40 0xa3 0x9e 0x81 0xf3 0xd7 0xfb
1 0x7c 0xe3 0x39 0x82 0x9b 0x2f 0xff 0x87 0x34 0x8e 0x43 0x44 0xc4 0xde 0xe9 0xcb
2 0x54 0x7b 0x94 0x32 0xa6 0xc2 0x23 0x3d 0xee 0x4c 0x95 0x0b 0x42 0xfa 0xc3 0x4e
3 0x08 0x2e 0xa1 0x66 0x28 0xd9 0x24 0xb2 0x76 0x5b 0xa2 0x49 0x6d 0x8b 0xd1 0x25
4 0x72 0xf8 0xf6 0x64 0x86 0x68 0x98 0x16 0xd4 0xa4 0x5c 0xcc 0x5d 0x65 0xb6 0x92
5 0x6c 0x70 0x48 0x50 0xfd 0xed 0xb9 0xda 0x5e 0x15 0x46 0x57 0xa7 0x8d 0x9d 0x84
6 0x90 0xd8 0xab 0x00 0x8c 0xbc 0xd3 0x0a 0xf7 0xe4 0x58 0x05 0xb8 0xb3 0x45 0x06
7 0xd0 0x2c 0x1e 0x8f 0xca 0x3f 0x0f 0x02 0xc1 0xaf 0xbd 0x03 0x01 0x13 0x8a 0x6b
8 0x3a 0x91 0x11 0x41 0x4f 0x67 0xdc 0xea 0x97 0xf2 0xcf 0xce 0xf0 0xb4 0xe6 0x73
9 0x96 0xac 0x74 0x22 0xe7 0xad 0x35 0x85 0xe2 0xf9 0x37 0xe8 0x1c 0x75 0xdf 0x6e
A 0x47 0xf1 0x1a 0x71 0x1d 0x29 0xc5 0x89 0x6f 0xb7 0x62 0x0e 0xaa 0x18 0xbe 0x1b
B 0xfc 0x56 0x3e 0x4b 0xc6 0xd2 0x79 0x20 0x9a 0xdb 0xc0 0xfe 0x78 0xcd 0x5a 0xf4
C 0x1f 0xdd 0xa8 0x33 0x88 0x07 0xc7 0x31 0xb1 0x12 0x10 0x59 0x27 0x80 0xec 0x5f
D 0x60 0x51 0x7f 0xa9 0x19 0xb5 0x4a 0x0d 0x2d 0xe5 0x7a 0x9f 0x93 0xc9 0x9c 0xef
E 0xa0 0xe0 0x3b 0x4d 0xae 0x2a 0xf5 0xb0 0xc8 0xeb 0xbb 0x3c 0x83 0x53 0x99 0x61
F 0x17 0x2b 0x04 0x7e 0xba 0x77 0xd6 0x26 0xe1 0x69 0x14 0x63 0x55 0x21 0x0c 0x7d

2. 行移位

行移位

​ 行移位是一个简单的做循环移位操作,当密钥长度为\(128-bit\)时,状态矩阵的第一行循环左移\(0-Byte\),第二行循环左移\(1-Byte\),第三行循环左移\(2-Byte\),第四行循环左移\(3-Byte\)
设行移位前的状态矩阵为:

S0 S4 S8 SC
S1 S5 S9 SD
S2 S6 SA SE
S3 S7 SB SF

行移位后的状态矩阵为:

S0 S4 S8 SC
S5 S9 SD S1
SA SE S2 S6
SF S3 S7 SB
逆行移位

​ 行移位的逆变换就是将状态矩阵中的每一行执行相反方向的循环移位即可,例如\(AES-128\)中的状态矩阵的第一行循环右移\(0-Byte\),第二行循环右移\(1-Byte\),第三行循环右移\(2-Byte\),第四行循环右移\(3-Byte\)

3. 列混合

列变换 -- 矩阵乘法

​ 列混合变换是通过在\(GF(2^8)\)上的矩阵相乘来实现的,经过行移位后的状态矩阵与固定的矩阵相乘,得到混淆后的状态矩阵。其中,加法运算\(a+b=a \oplus b\),乘法运算\(a \cdot b\)在模不可约多项式\(x ^ 8 + x ^ 4 + x ^ 3 + x + 1\).

\[\begin{bmatrix} s'_{0,0} & s'_{0,1} & s'_{0,2} & s'_{0,3} \\ s'_{1,0} & s'_{1,1} & s'_{1,2} & s'_{1,3} \\ s'_{2,0} & s'_{2,1} & s'_{2,2} & s'_{2,3} \\ s'_{3,0} & s'_{3,1} & s'_{3,2} & s'_{3,3} \end{bmatrix} = \begin{bmatrix} 02 & 03 & 01 & 01 \\ 01 & 02 & 03 & 01 \\ 01 & 01 & 02 & 03 \\ 03 & 01 & 01 & 02 \end{bmatrix} \begin{bmatrix} s_{0,0} & s_{0,1} & s_{0,2} & s_{0,3} \\ s_{1,0} & s_{1,1} & s_{1,2} & s_{1,3} \\ s_{2,0} & s_{2,1} & s_{2,2} & s_{2,3} \\ s_{3,0} & s_{3,1} & s_{3,2} & s_{3,3} \end{bmatrix} \]

列变换 -- GF(0xff)乘法

​ 这里介绍一下在\(GF(2^8)\)中的乘法运算以及加法运算,其中加法运算就是按位异或的运算,而乘法相对来说有些复杂。首先,在\(GF(2 ^ 8)\)中每个元素都可以表示为\(8次二进制多项式\),形式为:\(a(x) = a_7x^7+a_6x^6+a_5x^5+a_4x^4+a_3x^3+a_2x^2+a_1x+1\),其中\(a_i\in \{0,1\}\)(系数遵循$GF(2)运算规则:\(0+0=0,0+1=1,1+1=0,0\cdot0=0,0\cdot1=0,1\cdot1=1\).

​ 对于两个 \({GF}(2^8)\) 元素 \(A(x) = \sum_{i=0}^7 a_i x^i\)\(B(x) = \sum_{i=0}^7 b_i x^i\),先执行普通多项式乘法(系数运算在 \({GF}(2)\) 下进行)。乘法结果是次数不超过 15 的多项式:\(C(x) = A(x) \cdot B(x) = \sum_{k=0}^{15} c_k x^k\).其中 \(c_k = \left( \sum_{\substack{i+j=k \\ 0 \leq i,j \leq 7}} a_i b_j \right) \mod 2\)(对每个次数 \(k\),合并所有 \(i+j=k\) 的项,系数模 \(2\))。

​ 由于最终结果需要属于\(GF(2^8)\),需将\(C(x)\)\(m(x)=x^8+x^4+x^3+x+1\)取模,得到次数<8的余式。利用模运算性质:因为\(m(x)\equiv 0 (mod\ m(x))\),所以\(x^8\equiv x^4 + x^ 3 + x + 1 (mod \ m(x))\ \ \ GF(2)\)\(-1=1\),故\(x ^ 8 = -(x^4+x^3+x+1) \equiv x ^ 4 + x ^ 3 + x + 1)\)。对于更高次幂(如\(x^k,k≥8\)),可以通过反复替换\(x^8\)\(x^4+x^3+x+1\)来降低次数。例如:

  • \(x ^ 9 \equiv x \cdot x ^ 8 \equiv x \cdot (x ^ 4 + x ^ 3 + x + 1) = x ^ 5 + x ^ 4 + x ^ 2 + x\)
  • \(x ^ {10} \equiv x \cdot x ^ 9 \equiv x \cdot (x ^ 5 + x ^ 4 + x ^ 2 + x) = x ^ 6 + x ^ 5 + x ^ 3 + x ^ 2\)
  • 以此类推,直到所有次数\(≥8\)的项都被替换为次数\(<8\)的项。

​ 因此,列变换的结果矩阵中的每一列元素与原矩阵的元素之间的关系如下:

\[\begin{align*} s'_{0,j} &= (2 \cdot s_{0,j}) \oplus (3 \cdot s_{1,j}) \oplus s_{2,j} \oplus s_{3,j} \\ s'_{1,j} &= s_{0,j} \oplus (2 \cdot s_{1,j}) \oplus (3 \cdot s_{2,j}) \oplus s_{3,j} \\ s'_{2,j} &= s_{0,j} \oplus s_{1,j} \oplus (2 \cdot s_{2,j}) \oplus (3 \cdot s_{3,j}) \\ s'_{3,j} &= (3 \cdot s_{0,j}) \oplus s_{1,j} \oplus s_{2,j} \oplus (2 \cdot s_{3,j}) \end{align*} \]

\(AES\)中我们采用位操作快速约简:若乘法结果的多项式含\(x^k(k≥8)\),则将结果与\(m(x)\)左移\(k-8\)位的多项式异或(对应\(GF(2)\)的加法)

  • 若结果含\(x^8\),则异或\(m(x)=x^8+x^4+x^3+x+1\)
  • 若含\(x^9\),则异或\(x \cdot m(x) = x ^ 9 + x ^ 5 + x ^4 + x ^ 2 + x\)
  • 重复此过程,知道消去所有次数\(≥8\)的项。
示例演示

​ 对\(a, b \in GF(2 ^ 8)\)\(a + b = a \oplus b\)\(3 \cdot a = 2 \cdot a \oplus a\).记字节\(a\)按比特表示为\(a_7a_6a_5a_4a_3a_2a_1a_0\)(\(a_7\)为最高位)对应\(GF(2 ^ 8)\)中的多项式\(a_7x ^ 7 + a_6 x ^ 6 + a_5x^5+a_4x^4+a_3x^3+a_2x^2+a_1x+a_0\)\(2\)对应\(GF(2 ^ 8)\)中的\(x\),则\(2 \cdot a\)可看作:

\[x(a_7x^7+a_6x^6+\cdots+a_1x+a_0)\ \ mod\ (x^8+x^4+x^3+x+1)\\ =a_7x^8+a_6x^7+\cdots+a_1x^2+a_0x\ \ mod\ (x^8+x^4+x^3+x+1)\\ = \]

\[\begin{cases} a_6x^7+a_5x^6+\cdots+a_1x^2+a_0, & a_7 = 0 \\ a_6x^7+\cdots+(a_3+1)x^4+(a_2+1)x^3+a_1x^2+(a_0+1)x+1, & a_7 = 1 \end{cases} \]

\[再将多项式表示成二进制串的形式,2 \cdot a = \begin{cases} a_6a_5\cdots a_1a_00, & a_7 = 0 \\ a_6a_5\cdots a_1a_00 \oplus 00011011, & a_7 = 1 \end{cases} = \]

\[\begin{cases} a_7a_6a_5\cdots a_1a_0<<1, & a_7 = 0 \\ (a_7a_6a_5\cdots a_1a_0<<1)\oplus 00011011, & a_7 = 1 \end{cases} \rightarrow 2 \cdot a = (a_7a_6a_5\cdots a_1a_0 << 1)\oplus a_7 \oplus (00011011) \]

​ 由此,我们可以看出在加密过程中,列混合矩阵中出现最大的\(0x03\)我们至多需要两次异或运算与一次左移运算即可完成,也印证了我们前面提到的加密过程中的计算量<解密时的计算量,加密算法较解密算法而言更加常用。

逆列变换

\[\begin{bmatrix} s'_{0,0} & s'_{0,1} & s'_{0,2} & s'_{0,3} \\ s'_{1,0} & s'_{1,1} & s'_{1,2} & s'_{1,3} \\ s'_{2,0} & s'_{2,1} & s'_{2,2} & s'_{2,3} \\ s'_{3,0} & s'_{3,1} & s'_{3,2} & s'_{3,3} \end{bmatrix} = \begin{bmatrix} 0\text{E} & 0\text{B} & 0\text{D} & 0\text{9} \\ 0\text{9} & 0\text{E} & 0\text{B} & 0\text{D} \\ 0\text{D} & 0\text{9} & 0\text{E} & 0\text{B} \\ 0\text{B} & 0\text{D} & 0\text{9} & 0\text{E} \end{bmatrix} \begin{bmatrix} s_{0,0} & s_{0,1} & s_{0,2} & s_{0,3} \\ s_{1,0} & s_{1,1} & s_{1,2} & s_{1,3} \\ s_{2,0} & s_{2,1} & s_{2,2} & s_{2,3} \\ s_{3,0} & s_{3,1} & s_{3,2} & s_{3,3} \end{bmatrix} \]

注:加密算法中的列混合矩阵中参数最大为\(0x03\),具体实现时至多采用三次运算(两次异或和一次移位)即可完成\(3 \cdot a\),但是在解密时可能采用的计算就比较复杂(\(0x0E\)),这是因为在工程实践中,加密比解密更加常用,例如,\(CTF\)\(OFB\)工作模式中,只用到加密算法;分组密码作为部件去构造杂凑函数或消息认证码时,大多数情况下只用到加密算法……

4. 轮密钥加

​ 轮密钥加是\(128-bit\)的状态直接和\(128-bit\)的轮密钥进行逐比特异或运算,如下图所示:

flowchart TD %% 简化的矩阵表示 subgraph InputState[输入状态矩阵] A[4×4字节矩阵<br>X₀₀ X₀₁ X₀₂ X₀₃<br>X₁₀ X₁₁ X₁₂ X₁₃<br>X₂₀ X₂₁ X₂₂ X₂₃<br>X₃₀ X₃₁ X₃₂ X₃₃] end subgraph InputKey[轮密钥矩阵] K[4×4字节矩阵<br>K₀₀ K₀₁ K₀₂ K₀₃<br>K₁₀ K₁₁ K₁₂ K₁₃<br>K₂₀ K₂₁ K₂₂ K₂₃<br>K₃₀ K₃₁ K₃₂ K₃₃] end Op[轮密钥加<br>AddRoundKey<br>逐字节异或操作] subgraph OutputState[输出状态矩阵] R[4×4字节矩阵<br>R₀₀ R₀₁ R₀₂ R₀₃<br>R₁₀ R₁₁ R₁₂ R₁₃<br>R₂₀ R₂₁ R₂₂ R₂₃<br>R₃₀ R₃₁ R₃₂ R₃₃] end InputState --> Op InputKey --> Op Op --> OutputState

5. 密钥扩展

​ 由于\(AES-128/192/1256\)在密钥扩展中的流程是一致的,但是由于迭代次数不一致,所以扩展出的密钥长度并不一致。但其中的思想是一致的。我们在这里介绍的是\(AES-128\),我们输入的\(key(16-Byte)\),这里我们规定其中的密钥矩阵每一列的\(4-Byte\)为一个字,起初的四个字会被当成主密钥进行扩展,最终扩展出来的为长度为\(44-字\)的密钥,其中每\(4-字\)参与每轮的轮密钥加操作,其中前四个字会在迭代前与明文进行异或。

扩展规则:

设主密钥为\(w[0],w[1],w[2],w[3]\),对\(W\)数组扩充\(40\)个新列,构造总共\(44\)列的扩展密钥数组。

  1. \(i\)不是\(4\)的倍数
    \(w[i]=w[i-4] \oplus w[i-1]\)
  2. \(i\)\(4\)的倍数
    \(w[i]=w[i-4] \oplus T(w[i-1])\)
T函数

​ T函数由三部分组成:字循环、字节代换、轮常量异或

  1. 字循环:将字中的4个字节循环左移1个字节
    [b₀, b₁, b₂, b₃] → [b₁, b₂, b₃, b₀]
  2. 字节代换:使用\(S盒\)进行字节代换
  3. 轮常量异或:与轮常量\(Rcon[j]\)进行异或
    轮常量表:
轮数 j Rcon[j] 轮数 j Rcon[j]
1 01 00 00 00 6 20 00 00 00
2 02 00 00 00 7 40 00 00 00
3 04 00 00 00 8 80 00 00 00
4 08 00 00 00 9 1B 00 00 00
5 10 00 00 00 10 36 00 00 00
示例演示

初始密钥:3C A1 0B 21 57 F0 19 16 90 2E 13 80 AC C1 07 BD

初始W数组:

  • W[0] = 3C A1 0B 21
  • W[1] = 57 F0 19 16
  • W[2] = 90 2E 13 80
  • W[3] = AC C1 07 BD

计算第一轮子密钥(W[4] ~ W[7])
W[4] 计算 (\(i = 4\),是\(4\)的倍数)

T(W[3])计算:
1. 字循环:AC C1 07 BD → C1 07 BD AC
2. 字节代换:C1 07 BD AC → 78 C5 7A 91
3. 轮常量异或:78 C5 7A 91 ⊕ 01 00 00 00 = 79 C5 7A 91

W[4] = W[0] ⊕ T(W[3]) = 3C A1 0B 21 ⊕ 79 C5 7A 91 = 45 64 71 B0

W[5] 计算 (\(i = 5\),不是\(4\)的倍数)

W[5] = W[1] ⊕ W[4] = 57 F0 19 16 ⊕ 45 64 71 B0 = 12 94 68 A6

W[6] 计算 (\(i = 5\),不是\(4\)的倍数)

W[6] = W[2] ⊕ W[5] = 90 2E 13 80 ⊕ 12 94 68 A6 = 82 BA 7B 26

W[7] 计算 (\(i = 5\),不是\(4\)的倍数)

W[7] = W[3] ⊕ W[6] = AC C1 07 BD ⊕ 82 BA 7B 26 = 2E 7B 7C 9B

第一轮密钥结果:45 64 71 B0 12 94 68 A6 82 BA 7B 26 2E 7B 7C 9B

三、加密 and 解密流程

flowchart TD subgraph KeyGeneration[密钥扩展 Key Expansion] direction LR MasterKey[128位主密钥<br/>Master Key] --> KeyExpansion[密钥扩展算法<br/>Key Expansion Algorithm] KeyExpansion --> K0[轮密钥 K0] KeyExpansion --> K1[轮密钥 K1] KeyExpansion --> Kdots[...] KeyExpansion --> K9[轮密钥 K9] KeyExpansion --> K10[轮密钥 K10] end P[128位明文<br/>Plaintext] --> AddRoundKey0[轮密钥加<br/>AddRoundKey] AddRoundKey0 --> Round1[第1轮 Round 1] subgraph Round1[第1轮 Round 1] R1_SubBytes[S盒字节替换<br/>SubBytes] --> R1_ShiftRows[行移位<br/>ShiftRows] --> R1_MixColumns[列混合<br/>MixColumns] --> R1_AddRoundKey[轮密钥加<br/>AddRoundKey] end R1_AddRoundKey --> Rounds2to9[第2至9轮<br/>Rounds 2-9] subgraph Rounds2to9[第2至9轮 Rounds 2-9] R2_9_SubBytes[S盒字节替换<br/>SubBytes] --> R2_9_ShiftRows[行移位<br/>ShiftRows] --> R2_9_MixColumns[列混合<br/>MixColumns] --> R2_9_AddRoundKey[轮密钥加<br/>AddRoundKey] end R2_9_AddRoundKey --> FinalRound[最终轮 Final Round] subgraph FinalRound[最终轮 Final Round] FR_SubBytes[S盒字节替换<br/>SubBytes] --> FR_ShiftRows[行移位<br/>ShiftRows] --> FR_AddRoundKey[轮密钥加<br/>AddRoundKey] end FR_AddRoundKey --> C[128位密文<br/>Ciphertext] C --> AddRoundKey10_Dec[轮密钥加<br/>AddRoundKey] AddRoundKey10_Dec --> InvShiftRows10[逆行移位<br/>InvShiftRows] --> InvSubBytes10[逆S盒字节替换<br/>InvSubBytes] --> AddRoundKey9_Dec[轮密钥加<br/>AddRoundKey] AddRoundKey9_Dec --> InvMixColumns9[逆列混合<br/>InvMixColumns] --> InvShiftRows9[逆行移位<br/>InvShiftRows] --> InvSubBytes9[逆S盒字节替换<br/>InvSubBytes] --> AddRoundKey8_Dec[轮密钥加<br/>AddRoundKey] AddRoundKey8_Dec --> Rounds8to2_Dec[第8至2轮解密<br/>Rounds 8-2 Decryption] subgraph Rounds8to2_Dec[第8至2轮解密 Rounds 8-2 Decryption] R8_2_InvMixColumns[逆列混合<br/>InvMixColumns] --> R8_2_InvShiftRows[逆行移位<br/>InvShiftRows] --> R8_2_InvSubBytes[逆S盒字节替换<br/>InvSubBytes] --> R8_2_AddRoundKey[轮密钥加<br/>AddRoundKey] end R8_2_AddRoundKey --> FinalDecRound[第1轮解密 Round 1 Decryption] subgraph FinalDecRound[第1轮解密 Round 1 Decryption] FDR_InvMixColumns[逆列混合<br/>InvMixColumns] --> FDR_InvShiftRows[逆行移位<br/>InvShiftRows] --> FDR_InvSubBytes[逆S盒字节替换<br/>InvSubBytes] --> AddRoundKey0_Dec[轮密钥加<br/>AddRoundKey] end AddRoundKey0_Dec --> P_Dec[128位明文<br/>Plaintext] K0 --> AddRoundKey0 K1 --> R1_AddRoundKey K9 --> R2_9_AddRoundKey K10 --> FR_AddRoundKey K10 --> AddRoundKey10_Dec K9 --> AddRoundKey9_Dec K1 --> R8_2_AddRoundKey K0 --> AddRoundKey0_Dec

注:其中解密的流程并不固定,其中针对于每一轮中的逆组件使用,逆行移位操作与逆\(S盒\)操作可以调换先后顺序。

四、AES加密 and 解密代码实现

​ 声明:为了适应大多数人的需求,减少重复知识的冗余。我们将对\(AES-128\)进行编码。其实针对于目前主流的编程语言而言,都已经有对应的库进行了支持。但是这避免不了我们会有使用\(AES\)源码的需求。便于博客管理以及避免文章篇幅很长,在这里只写下参照上述流程基本实现的主流语言编程下的加密与解密代码,针对于课程需要,还会有很多软件实现策略和优化策略,我们放在后面的文章中详细介绍,例如调用AESNI指令、T表法优化等软件实现的优化策略。

C/C++

/*
 * File Name: AES_test.cpp
 * Author: chen_xing
 */
#include <iostream>
#include <vector>
#include <iomanip>
#include <sstream>
#include <cstdint>
#include <algorithm>
#include <stdexcept>
using namespace std;

class AES
{
private:
    static constexpr uint8_t Rcon[14] = {
        0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x1B, 0x36, 0x6C, 0xD8, 0xAB, 0x4D};

    vector<uint8_t> key;
    vector<vector<uint8_t>> W;
    vector<uint8_t> S, S_inv;
    int rounds, nk;

    // GF(2^8)乘法
    uint8_t _gmul(uint8_t a, uint8_t b)
    {
        uint8_t p = 0;
        for (int i = 0; i < 8; ++i)
        {
            if (b & 1)
                p ^= a;
            bool hi = a & 0x80;
            a <<= 1;
            if (hi)
                a ^= 0x1B;
            b >>= 1;
        }
        return p;
    }

    // 向量异或
    vector<uint8_t> _xor_vec(const vector<uint8_t> &a, const vector<uint8_t> &b)
    {
        vector<uint8_t> res(a.size());
        for (size_t i = 0; i < a.size(); ++i)
            res[i] = a[i] ^ b[i];
        return res;
    }

    // 字节转矩阵
    vector<vector<uint8_t>> _bytes_to_matrix(const vector<uint8_t> &text)
    {
        vector<vector<uint8_t>> m(4, vector<uint8_t>(4));
        for (int i = 0; i < 16; i++)
            m[i % 4][i / 4] = text[i];
        return m;
    }

    // 矩阵转字节
    vector<uint8_t> _matrix_to_bytes(const vector<vector<uint8_t>> &m)
    {
        vector<uint8_t> res(16);
        for (int i = 0; i < 16; i++)
            res[i] = m[i % 4][i / 4];
        return res;
    }

    // 字循环
    vector<uint8_t> _rot_word(vector<uint8_t> w)
    {
        rotate(w.begin(), w.begin() + 1, w.end());
        return w;
    }

    // 字替换
    vector<uint8_t> _sub_word(vector<uint8_t> w)
    {
        for (auto &x : w)
            x = S[x];
        return w;
    }

    // 生成S盒
    void _generate_sbox()
    {
        S.assign(256, 0);
        S[0] = 0x63;
        uint8_t p = 1, q = 1;
        for (int i = 0; i < 255; ++i)
        {
            p = p ^ (p << 1) ^ ((p & 0x80) ? 0x1B : 0);
            q ^= q << 1;
            q ^= q << 2;
            q ^= q << 4;
            q ^= (q & 0x80) ? 0x09 : 0x00;
            uint8_t xformed = q ^ ((q << 1) | (q >> 7)) ^
                              ((q << 2) | (q >> 6)) ^
                              ((q << 3) | (q >> 5)) ^
                              ((q << 4) | (q >> 4)) ^ 0x63;
            S[p] = xformed;
        }
    }

    // 生成逆S盒
    void _generate_sinvbox()
    {
        S_inv.assign(256, 0);
        for (int i = 0; i < 256; ++i)
            S_inv[S[i]] = i;
    }

    // 密钥扩展
    void _key_expansion()
    {
        for (int i = nk; i < (rounds + 1) * 4; ++i)
        {
            vector<uint8_t> temp = W[i - 1];
            if (i % nk == 0)
            {
                temp = _sub_word(_rot_word(temp));
                temp[0] ^= Rcon[i / nk - 1];
            }
            else if (nk == 8 && i % nk == 4)
            {
                temp = _sub_word(temp);
            }
            W.push_back(_xor_vec(W[i - nk], temp));
        }
    }

    // 轮密钥加
    vector<vector<uint8_t>> _round_key_add(vector<vector<uint8_t>> &block, int round)
    {
        for (int c = 0; c < 4; c++)
            for (int r = 0; r < 4; r++)
                block[r][c] ^= W[round * 4 + c][r];
        return block;
    }

    // 字节替换
    vector<vector<uint8_t>> _sub_bytes(vector<vector<uint8_t>> &block)
    {
        for (auto &row : block)
            for (auto &b : row)
                b = S[b];
        return block;
    }

    // 行移位
    vector<vector<uint8_t>> _shift_rows(vector<vector<uint8_t>> &block)
    {
        for (int r = 1; r < 4; r++)
            rotate(block[r].begin(), block[r].begin() + r, block[r].end());
        return block;
    }

    // 列混合
    vector<vector<uint8_t>> _mix_columns(vector<vector<uint8_t>> &s)
    {
        for (int c = 0; c < 4; c++)
        {
            uint8_t a0 = s[0][c], a1 = s[1][c], a2 = s[2][c], a3 = s[3][c];
            s[0][c] = _gmul(a0, 2) ^ _gmul(a1, 3) ^ a2 ^ a3;
            s[1][c] = a0 ^ _gmul(a1, 2) ^ _gmul(a2, 3) ^ a3;
            s[2][c] = a0 ^ a1 ^ _gmul(a2, 2) ^ _gmul(a3, 3);
            s[3][c] = _gmul(a0, 3) ^ a1 ^ a2 ^ _gmul(a3, 2);
        }
        return s;
    }

    // 逆字节替换
    vector<vector<uint8_t>> _inv_sub_bytes(vector<vector<uint8_t>> &block)
    {
        for (auto &row : block)
            for (auto &b : row)
                b = S_inv[b];
        return block;
    }

    // 逆行移位
    vector<vector<uint8_t>> _inv_shift_rows(vector<vector<uint8_t>> &block)
    {
        for (int r = 1; r < 4; r++)
            rotate(block[r].begin(), block[r].begin() + (4 - r), block[r].end());
        return block;
    }

    // 逆列混合
    vector<vector<uint8_t>> _inv_mix_columns(vector<vector<uint8_t>> &s)
    {
        for (int c = 0; c < 4; c++)
        {
            uint8_t a0 = s[0][c], a1 = s[1][c], a2 = s[2][c], a3 = s[3][c];
            s[0][c] = _gmul(a0, 14) ^ _gmul(a1, 11) ^ _gmul(a2, 13) ^ _gmul(a3, 9);
            s[1][c] = _gmul(a0, 9) ^ _gmul(a1, 14) ^ _gmul(a2, 11) ^ _gmul(a3, 13);
            s[2][c] = _gmul(a0, 13) ^ _gmul(a1, 9) ^ _gmul(a2, 14) ^ _gmul(a3, 11);
            s[3][c] = _gmul(a0, 11) ^ _gmul(a1, 13) ^ _gmul(a2, 9) ^ _gmul(a3, 14);
        }
        return s;
    }

public:
    AES(const vector<uint8_t> &key_)
    {
        key = key_;
        size_t len = key.size();
        if (len == 16)
        {
            rounds = 10;
            nk = 4;
        }
        else if (len == 24)
        {
            rounds = 12;
            nk = 6;
        }
        else if (len == 32)
        {
            rounds = 14;
            nk = 8;
        }
        else
            throw invalid_argument("Invalid AES key length");

        _generate_sbox();
        _generate_sinvbox();

        // 初始化密钥扩展
        for (size_t i = 0; i < len; i += 4)
            W.push_back(vector<uint8_t>(key.begin() + i, key.begin() + i + 4));
        _key_expansion();
    }

    // PKCS7 填充
    vector<uint8_t> pad(const vector<uint8_t> &data)
    {
        size_t pad_len = 16 - (data.size() % 16);
        vector<uint8_t> res = data;
        res.insert(res.end(), pad_len, static_cast<uint8_t>(pad_len));
        return res;
    }

    // 去除 PKCS7 填充
    vector<uint8_t> unpad(const vector<uint8_t> &data)
    {
        if (data.empty() || data.size() % 16 != 0)
            throw runtime_error("Invalid data length for unpad");
        uint8_t pad_len = data.back();
        if (pad_len == 0 || pad_len > 16)
            throw runtime_error("Invalid padding");
        for (size_t i = data.size() - pad_len; i < data.size(); ++i)
            if (data[i] != pad_len)
                throw runtime_error("Bad padding byte");
        return vector<uint8_t>(data.begin(), data.end() - pad_len);
    }

    vector<uint8_t> encrypt(const vector<uint8_t> &plaintext)
    {
        vector<uint8_t> data = pad(plaintext);
        vector<uint8_t> out;

        for (size_t i = 0; i < data.size(); i += 16)
        {
            auto block = vector<uint8_t>(data.begin() + i, data.begin() + i + 16);
            auto state = _bytes_to_matrix(block);

            // 初始轮密钥加
            state = _round_key_add(state, 0);

            // 主轮
            for (int r = 1; r < rounds; ++r)
            {
                state = _sub_bytes(state);
                state = _shift_rows(state);
                state = _mix_columns(state);
                state = _round_key_add(state, r);
            }

            // 最终轮
            state = _sub_bytes(state);
            state = _shift_rows(state);
            state = _round_key_add(state, rounds);

            auto encrypted_block = _matrix_to_bytes(state);
            out.insert(out.end(), encrypted_block.begin(), encrypted_block.end());
        }
        return out;
    }

    vector<uint8_t> decrypt(const vector<uint8_t> &ciphertext)
    {
        if (ciphertext.size() % 16 != 0)
            throw invalid_argument("Ciphertext length must be multiple of 16");

        vector<uint8_t> out;

        for (size_t i = 0; i < ciphertext.size(); i += 16)
        {
            auto block = vector<uint8_t>(ciphertext.begin() + i, ciphertext.begin() + i + 16);
            auto state = _bytes_to_matrix(block);

            // 初始轮
            state = _round_key_add(state, rounds);
            state = _inv_shift_rows(state);
            state = _inv_sub_bytes(state);

            // 主轮
            for (int r = rounds - 1; r > 0; --r)
            {
                state = _round_key_add(state, r);
                state = _inv_mix_columns(state);
                state = _inv_shift_rows(state);
                state = _inv_sub_bytes(state);
            }

            // 最终轮
            state = _round_key_add(state, 0);

            auto decrypted_block = _matrix_to_bytes(state);
            out.insert(out.end(), decrypted_block.begin(), decrypted_block.end());
        }

        return unpad(out);
    }
};

int main()
{
    string msg = "He who conquers others is strong; he who conquers himself is mighty.";
    vector<uint8_t> data(msg.begin(), msg.end());
    // key = chenxing_AES_c++
    vector<uint8_t> key = {'c', 'h', 'e', 'n', 'x', 'i', 'n', 'g', '_', 'A', 'E', 'S', '_', 'c', '+', '+'};

    AES aes(key);

    // 加密
    auto enc = aes.encrypt(data);
    cout << "Encrypted (hex): ";
    for (auto c : enc)
        cout << hex << setw(2) << setfill('0') << (int)c;
    cout << dec << endl;

    // 解密
    auto dec = aes.decrypt(enc);
    cout << "Decrypted: " << string(dec.begin(), dec.end()) << endl;

    // 验证
    cout << "Match: " << (msg == string(dec.begin(), dec.end())) << endl;

    return 0;
}

Python

class AES():
    
    Rcon = [0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x1B, 0x36, 0x6C, 0xD8, 0xAB, 0x4D]

    def __init__(self, key: bytes):
        key_length = len(key)
        if key_length == 16:  # 128位
            self.rounds = 10
            self.nk = 4  # 密钥字数
        elif key_length == 24:  # 192位
            self.rounds = 12
            self.nk = 6
        elif key_length == 32:  # 256位
            self.rounds = 14
            self.nk = 8
        else:
            raise ValueError(f"Invalid AES key length: {key_length} bytes")
        
        self.key = key
        self.S = self._generate_sbox()
        self.S_inv = self._generate_sinvbox()
        self.W = [list(key[i:i + 4]) for i in range(0, len(key), 4)]
        self._key_expansion()

    def _xor(self, a, b=None):
        if b is None:
            assert len(a) > 0
            return a[0] ^ self._xor(a[1:]) if len(a) > 1 else a[0]
        assert len(a) == len(b)
        return [x ^ y for x, y in zip(a, b)]

    def _gmul(self, a, b):
        p = 0
        while b:
            if b & 1:
                p ^= a
            a = a << 1
            if a >> 8:
                a ^= 0x11B  # AES不可约多项式
            b >>= 1
        return p

    def _multiply(self, a, b):
        return self._xor([self._gmul(x, y) for x, y in zip(a, b)])

    def _matrix_multiply(self, const, state):
        result = [[0] * 4 for _ in range(4)]
        for i in range(4):
            for j in range(4):
                for k in range(4):
                    result[i][j] ^= self._gmul(const[i][k], state[k][j])
        return result

    def _permutation(self, lis, table):
        return [table[i] for i in lis]

    def _block_permutation(self, block, table):
        return [[table[block[i][j]] for j in range(4)] for i in range(4)]

    def _left_shift(self, block, n):
        return block[n:] + block[:n]

    def _rot_word(self, word):
        return word[1:] + word[:1]

    def _sub_word(self, word):
        return [self.S[b] for b in word]

    def _bytes_to_matrix(self, text: bytes) -> list[list[int]]:
        return [[text[j * 4 + i] for j in range(4)] for i in range(4)]

    def _matrix_to_bytes(self, block: list[list[int]]) -> bytes:
        return bytes([block[j][i] for i in range(4) for j in range(4)])

    def _generate_sbox(self):
        S = [0x63] + [0] * 255
        r = lambda x, s: (x << s | x >> 8 - s) % 256
        p, q = 1, 1
        for _ in range(255):
            p = (p ^ (p * 2) ^ [27, 0][p < 128]) % 256
            q ^= q * 2
            q ^= q * 4
            q ^= q * 16
            q &= 255
            q ^= [9, 0][q < 128]
            S[p] = q ^ r(q, 1) ^ r(q, 2) ^ r(q, 3) ^ r(q, 4) ^ 99
        return S

    def _generate_sinvbox(self):
        S_inv = [0] * 256
        for i in range(256):
            S_inv[self.S[i]] = i
        return S_inv

    def _key_expansion(self):
        for i in range(self.nk, (self.rounds + 1) * 4):
            temp = self.W[i - 1].copy()
            if i % self.nk == 0:
                temp = self._sub_word(self._rot_word(temp))
                temp[0] ^= self.Rcon[i // self.nk - 1]
            elif self.nk == 8 and i % self.nk == 4:
                temp = self._sub_word(temp)
            self.W.append(self._xor(self.W[i - self.nk], temp))

    def _row_shift(self, block: list[list[int]]) -> list[list[int]]:
        return [self._left_shift(block[i], i) for i in range(4)]

    def _inv_row_shift(self, block: list[list[int]]) -> list[list[int]]:
        return [self._left_shift(block[i], 4 - i) for i in range(4)]

    def _column_mix(self, block: list[list[int]]) -> list[list[int]]:
        return self._matrix_multiply([self._left_shift([2, 3, 1, 1], 4 - i) for i in range(4)], block)

    def _inv_column_mix(self, block: list[list[int]]) -> list[list[int]]:
        return self._matrix_multiply([self._left_shift([14, 11, 13, 0x9], 4 - i) for i in range(4)], block)

    def _round_key_add(self, block: list[list[int]], key: list[list[int]]) -> list[list[int]]:
        return [self._xor(block[i], [key[j][i] for j in range(4)]) for i in range(4)]

    def _encrypt(self, block: list[list[int]]) -> list[list[int]]:
        block = self._round_key_add(block, [self.W[i] for i in range(4)])
        for i in range(1, self.rounds):
            block = self._block_permutation(block, self.S)
            block = self._row_shift(block)
            block = self._column_mix(block)
            block = self._round_key_add(block, [self.W[i * 4 + j] for j in range(4)])
        block = self._block_permutation(block, self.S)
        block = self._row_shift(block)
        block = self._round_key_add(block, [self.W[i] for i in range(self.rounds * 4, (self.rounds + 1) * 4)])
        return block

    def _decrypt(self, block: list[list[int]]) -> list[list[int]]:
        block = self._round_key_add(block, [self.W[i] for i in range(self.rounds * 4, (self.rounds + 1) * 4)])
        for i in range(self.rounds - 1, 0, -1):
            block = self._inv_row_shift(block)
            block = self._block_permutation(block, self.S_inv)
            block = self._round_key_add(block, [self.W[i * 4 + j] for j in range(4)])
            block = self._inv_column_mix(block)
        block = self._inv_row_shift(block)
        block = self._block_permutation(block, self.S_inv)
        block = self._round_key_add(block, [self.W[i] for i in range(4)])
        return block

    def encrypt(self, plaintext: bytes) -> bytes:
        assert len(plaintext) % 16 == 0, ValueError(f"Incorrect AES plaintext length ({len(plaintext)} bytes)")
        ciphertext = b''
        for i in range(0, len(plaintext), 16):
            block = plaintext[i:i + 16]
            block = self._bytes_to_matrix(block)
            block = self._encrypt(block)
            ciphertext += self._matrix_to_bytes(block) 
        return ciphertext

    def decrypt(self, ciphertext: bytes) -> bytes:
        assert len(ciphertext) % 16 == 0, ValueError(f"Incorrect AES ciphertext length ({len(ciphertext)} bytes)")
        plaintext = b''
        for i in range(0, len(ciphertext), 16):
            block = ciphertext[i:i + 16]
            block = self._bytes_to_matrix(block)
            block = self._decrypt(block)
            plaintext += self._matrix_to_bytes(block)
        return plaintext


if __name__ == "__main__":

    key = b'chen_xing_AES_py'
    aes = AES(key)
    # 胜人者有力,自强者胜。
    test_plaintext = b'He who conquers others is strong; he who conquers himself is mighty.'
    # 填充到16字节
    if len(test_plaintext) % 16 != 0:
        test_plaintext = test_plaintext.ljust(16 * ((len(test_plaintext) // 16) + 1), b'\x00')
    
    encrypted = aes.encrypt(test_plaintext)
    decrypted = aes.decrypt(encrypted)
    print(f"Decrypted: {decrypted}")
    print(f"Test - Match: {test_plaintext == decrypted}")

Java

import java.util.Arrays;

public class AES {
    private static final int[] RCON = {
            0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x1B, 0x36, 0x6C, 0xD8, 0xAB, 0x4D
    };

    private final int Nk;        // 密钥字数
    private final int Nr;        // 轮数
    private final int Nb = 4;    // 状态矩阵列数
    private final byte[] key;    // 原始密钥
    private final int[][] w;     // 扩展密钥
    private final byte[] S = new byte[256];     // S盒
    private final byte[] S_INV = new byte[256]; // 逆S盒

    public AES(byte[] key) {
        // 根据密钥长度确定参数
        int len = key.length;
        if (len == 16) {
            Nk = 4;
            Nr = 10;
        } else if (len == 24) {
            Nk = 6;
            Nr = 12;
        } else if (len == 32) {
            Nk = 8;
            Nr = 14;
        } else {
            throw new IllegalArgumentException("Invalid AES key length: " + len);
        }

        this.key = Arrays.copyOf(key, key.length);
        generateSBox();
        generateSInvBox();

        // 密钥扩展:总字数 = Nb * (Nr + 1)
        int totalWords = Nb * (Nr + 1);
        w = new int[totalWords][4];
        keyExpansion();
    }

    // ---------- GF(2^8) 运算辅助函数 ----------

    /**
     * 在GF(2^8)中乘以x(即2)
     */
    private static int xtime(int a) {
        a <<= 1;
        if ((a & 0x100) != 0) {
            a ^= 0x11b; // AES不可约多项式
        }
        return a & 0xff;
    }

    /**
     * 在GF(2^8)中乘法运算
     */
    private static int mul(int a, int b) {
        int res = 0;
        while (b != 0) {
            if ((b & 1) != 0) {
                res ^= a;
            }
            a = xtime(a);
            b >>>= 1;
        }
        return res & 0xff;
    }

    // ---------- S盒生成 ----------

    /**
     * 生成AES S盒
     */
    private void generateSBox() {
        // 使用指数运算 a^254 在GF(2^8)中计算乘法逆元
        S[0] = 0x63;
        for (int i = 1; i < 256; i++) {
            int inv = gfInv(i);
            int x = inv;
            // 仿射变换
            int y = x ^ ((x << 1) | (x >>> 7))
                    ^ ((x << 2) | (x >>> 6))
                    ^ ((x << 3) | (x >>> 5))
                    ^ ((x << 4) | (x >>> 4))
                    ^ 0x63;
            S[i] = (byte) (y & 0xff);
        }
    }

    /**
     * 生成逆S盒
     */
    private void generateSInvBox() {
        for (int i = 0; i < 256; i++) {
            S_INV[S[i] & 0xFF] = (byte) i;
        }
    }

    /**
     * 在GF(2^8)中计算乘法逆元
     */
    private int gfInv(int a) {
        if (a == 0) return 0;
        int result = 1;
        int base = a;
        int exp = 254; // a^254 是逆元
        while (exp > 0) {
            if ((exp & 1) != 0) {
                result = mul(result, base);
            }
            base = mul(base, base);
            exp >>= 1;
        }
        return result & 0xff;
    }

    // ---------- 密钥扩展 ----------

    /**
     * AES密钥扩展算法
     */
    private void keyExpansion() {
        // 用原始密钥填充前Nk个字
        for (int i = 0; i < Nk; i++) {
            for (int j = 0; j < 4; j++) {
                w[i][j] = key[4 * i + j] & 0xFF;
            }
        }

        // 扩展剩余的字
        for (int i = Nk; i < w.length; i++) {
            int[] temp = Arrays.copyOf(w[i - 1], 4);

            if (i % Nk == 0) {
                temp = subWord(rotWord(temp));
                temp[0] ^= RCON[i / Nk - 1];
            } else if (Nk > 6 && (i % Nk) == 4) {
                temp = subWord(temp);
            }

            for (int j = 0; j < 4; j++) {
                w[i][j] = w[i - Nk][j] ^ temp[j];
            }
        }
    }

    /**
     * 字循环:将字向左循环移位
     */
    private int[] rotWord(int[] word) {
        return new int[]{word[1], word[2], word[3], word[0]};
    }

    /**
     * 字替换:使用S盒替换字中的每个字节
     */
    private int[] subWord(int[] word) {
        int[] out = new int[4];
        for (int i = 0; i < 4; i++) {
            out[i] = S[word[i]] & 0xFF;
        }
        return out;
    }

    // ---------- 状态矩阵辅助函数 ----------

    /**
     * 将字节数组转换为状态矩阵
     */
    private int[][] bytesToState(byte[] in, int offset) {
        int[][] state = new int[4][4];
        for (int i = 0; i < 16; i++) {
            state[i % 4][i / 4] = in[offset + i] & 0xFF;
        }
        return state;
    }

    /**
     * 将状态矩阵转换为字节数组
     */
    private void stateToBytes(int[][] state, byte[] out, int offset) {
        for (int i = 0; i < 16; i++) {
            out[offset + i] = (byte) (state[i % 4][i / 4]);
        }
    }

    // ---------- 核心变换操作 ----------

    /**
     * 字节替换
     */
    private void subBytes(int[][] state) {
        for (int row = 0; row < 4; row++) {
            for (int col = 0; col < 4; col++) {
                state[row][col] = S[state[row][col]] & 0xff;
            }
        }
    }

    /**
     * 逆字节替换
     */
    private void invSubBytes(int[][] state) {
        for (int row = 0; row < 4; row++) {
            for (int col = 0; col < 4; col++) {
                state[row][col] = S_INV[state[row][col]] & 0xff;
            }
        }
    }

    /**
     * 行移位
     */
    private void shiftRows(int[][] state) {
        for (int row = 1; row < 4; row++) {
            int[] temp = Arrays.copyOf(state[row], 4);
            for (int col = 0; col < 4; col++) {
                state[row][col] = temp[(col + row) % 4];
            }
        }
    }

    /**
     * 逆行移位
     */
    private void invShiftRows(int[][] state) {
        for (int row = 1; row < 4; row++) {
            int[] temp = Arrays.copyOf(state[row], 4);
            for (int col = 0; col < 4; col++) {
                state[row][(col + row) % 4] = temp[col];
            }
        }
    }

    /**
     * 列混合
     */
    private void mixColumns(int[][] state) {
        for (int col = 0; col < 4; col++) {
            int a0 = state[0][col];
            int a1 = state[1][col];
            int a2 = state[2][col];
            int a3 = state[3][col];

            state[0][col] = (mul(2, a0) ^ mul(3, a1) ^ a2 ^ a3) & 0xff;
            state[1][col] = (a0 ^ mul(2, a1) ^ mul(3, a2) ^ a3) & 0xff;
            state[2][col] = (a0 ^ a1 ^ mul(2, a2) ^ mul(3, a3)) & 0xff;
            state[3][col] = (mul(3, a0) ^ a1 ^ a2 ^ mul(2, a3)) & 0xff;
        }
    }

    /**
     * 逆列混合
     */
    private void invMixColumns(int[][] state) {
        for (int col = 0; col < 4; col++) {
            int a0 = state[0][col];
            int a1 = state[1][col];
            int a2 = state[2][col];
            int a3 = state[3][col];

            state[0][col] = (mul(0x0e, a0) ^ mul(0x0b, a1) ^ mul(0x0d, a2) ^ mul(0x09, a3)) & 0xff;
            state[1][col] = (mul(0x09, a0) ^ mul(0x0e, a1) ^ mul(0x0b, a2) ^ mul(0x0d, a3)) & 0xff;
            state[2][col] = (mul(0x0d, a0) ^ mul(0x09, a1) ^ mul(0x0e, a2) ^ mul(0x0b, a3)) & 0xff;
            state[3][col] = (mul(0x0b, a0) ^ mul(0x0d, a1) ^ mul(0x09, a2) ^ mul(0x0e, a3)) & 0xff;
        }
    }

    /**
     * 轮密钥加
     */
    private void addRoundKey(int[][] state, int round) {
        for (int col = 0; col < 4; col++) {
            for (int row = 0; row < 4; row++) {
                state[row][col] ^= w[round * 4 + col][row];
            }
        }
    }

    // ---------- 单块加密/解密 ----------

    /**
     * 加密单个16字节块
     */
    public byte[] encryptBlock(byte[] inputBlock) {
        if (inputBlock.length != 16) {
            throw new IllegalArgumentException("Block size must be 16 bytes");
        }

        int[][] state = bytesToState(inputBlock, 0);

        // 初始轮密钥加
        addRoundKey(state, 0);

        // 主轮(前Nr-1轮)
        for (int round = 1; round < Nr; round++) {
            subBytes(state);
            shiftRows(state);
            mixColumns(state);
            addRoundKey(state, round);
        }

        // 最终轮(无列混合)
        subBytes(state);
        shiftRows(state);
        addRoundKey(state, Nr);

        byte[] output = new byte[16];
        stateToBytes(state, output, 0);
        return output;
    }

    /**
     * 解密单个16字节块
     */
    public byte[] decryptBlock(byte[] inputBlock) {
        if (inputBlock.length != 16) {
            throw new IllegalArgumentException("Block size must be 16 bytes");
        }

        int[][] state = bytesToState(inputBlock, 0);

        // 初始轮
        addRoundKey(state, Nr);
        invShiftRows(state);
        invSubBytes(state);

        // 主轮(前Nr-1轮)
        for (int round = Nr - 1; round > 0; round--) {
            addRoundKey(state, round);
            invMixColumns(state);
            invShiftRows(state);
            invSubBytes(state);
        }

        // 最终轮
        addRoundKey(state, 0);

        byte[] output = new byte[16];
        stateToBytes(state, output, 0);
        return output;
    }

    // ---------- PKCS#7 填充 ----------

    /**
     * PKCS#7填充
     */
    private byte[] pkcs7Pad(byte[] data) {
        int padLength = 16 - (data.length % 16);
        if (padLength == 0) {
            padLength = 16;
        }
        byte[] padded = Arrays.copyOf(data, data.length + padLength);
        Arrays.fill(padded, data.length, padded.length, (byte) padLength);
        return padded;
    }

    /**
     * 去除PKCS#7填充
     */
    private byte[] pkcs7Unpad(byte[] data) {
        if (data.length == 0 || data.length % 16 != 0) {
            throw new IllegalArgumentException("Invalid padded data length");
        }

        int paddingLength = data[data.length - 1] & 0xFF;
        if (paddingLength < 1 || paddingLength > 16) {
            throw new IllegalArgumentException("Invalid padding length");
        }

        // 验证填充字节
        for (int i = data.length - paddingLength; i < data.length; i++) {
            if ((data[i] & 0xFF) != paddingLength) {
                throw new IllegalArgumentException("Invalid padding bytes");
            }
        }

        return Arrays.copyOf(data, data.length - paddingLength);
    }

    // ---------- 多块加密/解密 (ECB模式) ----------

    /**
     * 加密任意长度数据
     */
    public byte[] encrypt(byte[] plaintext) {
        byte[] paddedData = pkcs7Pad(plaintext);
        byte[] ciphertext = new byte[paddedData.length];

        for (int i = 0; i < paddedData.length; i += 16) {
            byte[] block = Arrays.copyOfRange(paddedData, i, i + 16);
            byte[] encryptedBlock = encryptBlock(block);
            System.arraycopy(encryptedBlock, 0, ciphertext, i, 16);
        }

        return ciphertext;
    }

    /**
     * 解密任意长度数据
     */
    public byte[] decrypt(byte[] ciphertext) {
        if (ciphertext.length % 16 != 0) {
            throw new IllegalArgumentException("Ciphertext length must be multiple of 16 bytes");
        }

        byte[] decryptedData = new byte[ciphertext.length];

        for (int i = 0; i < ciphertext.length; i += 16) {
            byte[] block = Arrays.copyOfRange(ciphertext, i, i + 16);
            byte[] decryptedBlock = decryptBlock(block);
            System.arraycopy(decryptedBlock, 0, decryptedData, i, 16);
        }

        return pkcs7Unpad(decryptedData);
    }

    // ---------- 辅助函数 ----------

    /**
     * 将字节数组转换为十六进制字符串
     */
    public static String bytesToHex(byte[] bytes) {
        StringBuilder hexString = new StringBuilder();
        for (byte b : bytes) {
            hexString.append(String.format("%02x", b & 0xFF));
        }
        return hexString.toString();
    }

    // ---------- 测试函数 ----------
    public static void main(String[] args) throws Exception {
        byte[] key = "chenxingAES_JAVA".getBytes("UTF-8");
        AES aes = new AES(key);

        String plainText = "He who conquers others is strong; he who conquers himself is mighty.";
        byte[] plaintextBytes = plainText.getBytes("UTF-8");

        // 加密
        byte[] ciphertext = aes.encrypt(plaintextBytes);
        System.out.println("Ciphertext (hex): " + bytesToHex(ciphertext));

        // 解密
        byte[] decryptedBytes = aes.decrypt(ciphertext);
        String decryptedText = new String(decryptedBytes, "UTF-8");
        System.out.println("Decrypted text: " + decryptedText);

        // 验证
        boolean isMatch = Arrays.equals(plaintextBytes, decryptedBytes);
        System.out.println("Match: " + isMatch);

        if (!isMatch) {
            throw new RuntimeException("Encryption/decryption mismatch!");
        }
    }
}

Rust

use std::convert::TryInto;

struct AES {
    rounds: usize,
    nk: usize,
    w: Vec<[u8; 4]>,
    sbox: [u8; 256],
    inv_sbox: [u8; 256],
}

impl AES {
    const RCON: [u8; 14] = [
        0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x1B, 0x36, 0x6C, 0xD8, 0xAB, 0x4D
    ];

    pub fn new(key: &[u8]) -> Self {
        let (nk, rounds) = match key.len() {
            16 => (4, 10),
            24 => (6, 12),
            32 => (8, 14),
            _ => panic!("Invalid AES key length: {} bytes", key.len()),
        };
        
        let mut aes = AES {
            rounds,
            nk,
            w: Vec::new(),
            sbox: [0u8; 256],
            inv_sbox: [0u8; 256],
        };
        
        aes.generate_sbox();
        aes.generate_inv_sbox();
        aes.key_expansion(key);
        aes
    }

    /// 在GF(2^8)中乘法运算
    fn gmul(mut a: u8, mut b: u8) -> u8 {
        let mut product: u8 = 0;
        while b != 0 {
            if b & 1 != 0 {
                product ^= a;
            }
            let high_bit = a & 0x80;
            a <<= 1;
            if high_bit != 0 {
                a ^= 0x1b;
            }
            b >>= 1;
        }
        product
    }

    /// 在GF(2^8)中指数运算
    fn gf_pow(mut base: u8, mut exponent: u32) -> u8 {
        let mut result: u8 = 1;
        while exponent != 0 {
            if exponent & 1 != 0 {
                result = AES::gmul(result, base);
            }
            base = AES::gmul(base, base);
            exponent >>= 1;
        }
        result
    }

    /// 生成AES S盒
    fn generate_sbox(&mut self) {
        self.sbox[0] = 0x63;
        for i in 1..256 {
            let inv = if i == 0 { 0 } else { AES::gf_pow(i as u8, 254) };
            let x = inv;
            // 仿射变换
            let y = x ^ x.rotate_left(1) 
                    ^ x.rotate_left(2) 
                    ^ x.rotate_left(3) 
                    ^ x.rotate_left(4) 
                    ^ 0x63;
            self.sbox[i] = y;
        }
    }

    /// 生成逆S盒
    fn generate_inv_sbox(&mut self) {
        for i in 0..256 {
            self.inv_sbox[self.sbox[i] as usize] = i as u8;
        }
    }

    /// 字循环:将字向左循环移位
    fn rot_word(word: [u8; 4]) -> [u8; 4] {
        [word[1], word[2], word[3], word[0]]
    }

    /// 字替换:使用S盒替换字中的每个字节
    fn sub_word(&self, word: [u8; 4]) -> [u8; 4] {
        [
            self.sbox[word[0] as usize],
            self.sbox[word[1] as usize],
            self.sbox[word[2] as usize],
            self.sbox[word[3] as usize],
        ]
    }

    /// AES密钥扩展算法
    fn key_expansion(&mut self, key: &[u8]) {
        let nb = 4; // 状态矩阵列数
        let total_words = nb * (self.rounds + 1);
        self.w = vec![[0u8; 4]; total_words];
        
        // 用原始密钥填充前Nk个字
        for i in 0..self.nk {
            self.w[i].copy_from_slice(&key[4 * i..4 * i + 4]);
        }
        
        // 扩展剩余的字
        for i in self.nk..total_words {
            let mut temp = self.w[i - 1];
            
            if i % self.nk == 0 {
                temp = Self::rot_word(temp);
                temp = self.sub_word(temp);
                temp[0] ^= Self::RCON[i / self.nk - 1];
            } else if self.nk > 6 && i % self.nk == 4 {
                temp = self.sub_word(temp);
            }
            
            for j in 0..4 {
                self.w[i][j] = self.w[i - self.nk][j] ^ temp[j];
            }
        }
    }

    /// 轮密钥加
    fn add_round_key(state: &mut [[u8; 4]; 4], expanded_key: &[[u8; 4]], round: usize) {
        for col in 0..4 {
            for row in 0..4 {
                state[row][col] ^= expanded_key[round * 4 + col][row];
            }
        }
    }

    /// 字节替换
    fn sub_bytes(&self, state: &mut [[u8; 4]; 4]) {
        for row in 0..4 {
            for col in 0..4 {
                state[row][col] = self.sbox[state[row][col] as usize];
            }
        }
    }

    /// 逆字节替换
    fn inv_sub_bytes(&self, state: &mut [[u8; 4]; 4]) {
        for row in 0..4 {
            for col in 0..4 {
                state[row][col] = self.inv_sbox[state[row][col] as usize];
            }
        }
    }

    /// 行移位
    fn shift_rows(&self, state: &mut [[u8; 4]; 4]) {
        for row in 1..4 {
            state[row].rotate_left(row);
        }
    }

    /// 逆行移位
    fn inv_shift_rows(&self, state: &mut [[u8; 4]; 4]) {
        for row in 1..4 {
            state[row].rotate_right(row);
        }
    }

    /// 单列混合
    fn mix_single_column(&self, column: [u8; 4]) -> [u8; 4] {
        [
            Self::gmul(2, column[0]) ^ Self::gmul(3, column[1]) ^ column[2] ^ column[3],
            column[0] ^ Self::gmul(2, column[1]) ^ Self::gmul(3, column[2]) ^ column[3],
            column[0] ^ column[1] ^ Self::gmul(2, column[2]) ^ Self::gmul(3, column[3]),
            Self::gmul(3, column[0]) ^ column[1] ^ column[2] ^ Self::gmul(2, column[3]),
        ]
    }

    /// 列混合
    fn mix_columns(&self, state: &mut [[u8; 4]; 4]) {
        for col in 0..4 {
            let column = [state[0][col], state[1][col], state[2][col], state[3][col]];
            let mixed_column = self.mix_single_column(column);
            for row in 0..4 {
                state[row][col] = mixed_column[row];
            }
        }
    }

    /// 逆列混合
    fn inv_mix_columns(&self, state: &mut [[u8; 4]; 4]) {
        for col in 0..4 {
            let column = [state[0][col], state[1][col], state[2][col], state[3][col]];
            state[0][col] = Self::gmul(0x0e, column[0]) 
                          ^ Self::gmul(0x0b, column[1]) 
                          ^ Self::gmul(0x0d, column[2]) 
                          ^ Self::gmul(0x09, column[3]);
            state[1][col] = Self::gmul(0x09, column[0]) 
                          ^ Self::gmul(0x0e, column[1]) 
                          ^ Self::gmul(0x0b, column[2]) 
                          ^ Self::gmul(0x0d, column[3]);
            state[2][col] = Self::gmul(0x0d, column[0]) 
                          ^ Self::gmul(0x09, column[1]) 
                          ^ Self::gmul(0x0e, column[2]) 
                          ^ Self::gmul(0x0b, column[3]);
            state[3][col] = Self::gmul(0x0b, column[0]) 
                          ^ Self::gmul(0x0d, column[1]) 
                          ^ Self::gmul(0x09, column[2]) 
                          ^ Self::gmul(0x0e, column[3]);
        }
    }

    /// 将16字节块转换为4x4状态矩阵
    fn bytes_to_state(block: &[u8; 16]) -> [[u8; 4]; 4] {
        let mut state = [[0u8; 4]; 4];
        for i in 0..16 {
            state[i % 4][i / 4] = block[i];
        }
        state
    }

    /// 将4x4状态矩阵转换为16字节块
    fn state_to_bytes(state: &[[u8; 4]; 4]) -> [u8; 16] {
        let mut block = [0u8; 16];
        for i in 0..16 {
            block[i] = state[i % 4][i / 4];
        }
        block
    }

    /// 加密单个16字节块
    pub fn encrypt_block(&self, block: &[u8; 16]) -> [u8; 16] {
        let mut state = Self::bytes_to_state(block);
        
        // 初始轮密钥加
        Self::add_round_key(&mut state, &self.w, 0);
        
        // 主轮(前Nr-1轮)
        for round in 1..self.rounds {
            self.sub_bytes(&mut state);
            self.shift_rows(&mut state);
            self.mix_columns(&mut state);
            Self::add_round_key(&mut state, &self.w, round);
        }
        
        // 最终轮(无列混合)
        self.sub_bytes(&mut state);
        self.shift_rows(&mut state);
        Self::add_round_key(&mut state, &self.w, self.rounds);
        
        Self::state_to_bytes(&state)
    }

    /// 解密单个16字节块
    pub fn decrypt_block(&self, block: &[u8; 16]) -> [u8; 16] {
        let mut state = Self::bytes_to_state(block);
        
        // 初始轮
        Self::add_round_key(&mut state, &self.w, self.rounds);
        self.inv_shift_rows(&mut state);
        self.inv_sub_bytes(&mut state);
        
        // 主轮(前Nr-1轮)
        for round in (1..self.rounds).rev() {
            Self::add_round_key(&mut state, &self.w, round);
            self.inv_mix_columns(&mut state);
            self.inv_shift_rows(&mut state);
            self.inv_sub_bytes(&mut state);
        }
        
        // 最终轮
        Self::add_round_key(&mut state, &self.w, 0);
        
        Self::state_to_bytes(&state)
    }

    // ---------- PKCS7 填充辅助函数 ----------
    
    /// PKCS7填充
    fn pkcs7_pad(data: &[u8]) -> Vec<u8> {
        let padding_length = 16 - (data.len() % 16);
        let mut padded_data = Vec::with_capacity(data.len() + padding_length);
        padded_data.extend_from_slice(data);
        padded_data.extend(std::iter::repeat(padding_length as u8).take(padding_length));
        padded_data
    }

    /// 去除PKCS7填充
    fn pkcs7_unpad(data: &[u8]) -> Vec<u8> {
        if data.is_empty() || data.len() % 16 != 0 {
            panic!("Invalid padded data: length must be non-zero and multiple of 16");
        }
        
        let padding_length = data[data.len() - 1] as usize;
        if padding_length == 0 || padding_length > 16 {
            panic!("Invalid padding length: {}", padding_length);
        }
        
        // 验证所有填充字节
        for &byte in &data[data.len() - padding_length..] {
            if byte as usize != padding_length {
                panic!("Invalid padding bytes");
            }
        }
        
        data[..data.len() - padding_length].to_vec()
    }

    pub fn encrypt(&self, plaintext: &[u8]) -> Vec<u8> {
        let padded_data = Self::pkcs7_pad(plaintext);
        let mut ciphertext = Vec::with_capacity(padded_data.len());
        
        for chunk in padded_data.chunks(16) {
            let block: [u8; 16] = chunk.try_into().unwrap();
            ciphertext.extend_from_slice(&self.encrypt_block(&block));
        }
        
        ciphertext
    }

    pub fn decrypt(&self, ciphertext: &[u8]) -> Vec<u8> {
        if ciphertext.len() % 16 != 0 {
            panic!("Ciphertext length must be multiple of 16 bytes");
        }
        
        let mut decrypted_data = Vec::with_capacity(ciphertext.len());
        
        for chunk in ciphertext.chunks(16) {
            let block: [u8; 16] = chunk.try_into().unwrap();
            decrypted_data.extend_from_slice(&self.decrypt_block(&block));
        }
        
        Self::pkcs7_unpad(&decrypted_data)
    }
}

/// 将字节数组转换为十六进制字符串
fn bytes_to_hex(bytes: &[u8]) -> String {
    bytes.iter()
        .map(|byte| format!("{:02x}", byte))
        .collect::<Vec<String>>()
        .join("")
}

fn main() {
    
    let key = b"chenxingAES_Rust";
    let aes = AES::new(key);

    let plaintext = b"He who conquers others is strong; he who conquers himself is mighty.";

    // 加密
    let ciphertext = aes.encrypt(plaintext);
    println!("Ciphertext (hex): {}", bytes_to_hex(&ciphertext));
    
    // 解密
    let decrypted = aes.decrypt(&ciphertext);
    println!("Decrypted: {}", String::from_utf8_lossy(&decrypted));
    
    // 验证
    assert_eq!(plaintext, decrypted.as_slice());
    println!("Successfully!!!");
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_aes_round_trip() {
        let key = b"chenxingAES_Rust";
        let aes = AES::new(key);
        
        let test_data = b"Test AES encryption and decryption";
        let encrypted = aes.encrypt(test_data);
        let decrypted = aes.decrypt(&encrypted);
        
        assert_eq!(test_data, decrypted.as_slice());
    }

    #[test]
    fn test_single_block() {
        let key = b"chenxingAES_Rust";
        let aes = AES::new(key);
        
        let block = [0x00, 0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77,
                     0x88, 0x99, 0xaa, 0xbb, 0xcc, 0xdd, 0xee, 0xff];
        
        let encrypted = aes.encrypt_block(&block);
        let decrypted = aes.decrypt_block(&encrypted);
        
        assert_eq!(block, decrypted);
    }
}

参考资料

​ 很庆幸,在各位前辈的基础上,我们得以很便捷地得到相关知识的干货,也很推荐读者读一下其他师傅的相关博客,完善程度以及过程分析超级高。

[1]AES加密算法原理的详细介绍与实现-CSDN博客

posted @ 2025-10-26 18:48  chen_xing  阅读(18)  评论(0)    收藏  举报