初识AES
AES(1)
零、 前言
朝花夕拾杯中酒。从CTF开始攻击CBC模式下的加密漏洞,再到第四学期的《密码学基础》课程,时至今日,《密码工程》课程的需要以及自己对于对称密码的兴趣驱使,下定决心写下这篇笔记记录自己对于AES的认识以及学习。这篇文章工期拖得比较长,一方面是因为自己的拖延症,二是针对于密码学的工程应用以及编码硬实力的不足,自己翻阅了很多师傅的相关资料,也翻阅了部分教材和相关资料,会在结尾进行指明。由于作者本人的实力与阅历不足,本文在内容、专业知识、排版以及语法表达上或多或少存在错误,请读者仔细辨别并积极指出交流。很庆幸完成了这篇笔记。😊😊
一、AES背景介绍
\(1997\)年\(9\)月,美国国家标准与技术研究院\(NIST\)公开征集新的高级加密标准\((Advanced Encryption Standard, AES)\),取代\(DES\)算法.\(1998\)年\(6\)月,提交过程截至,共收到\(21\)个算法.经过简单筛选,\(1998\)年\(8\)月,公开来自\(12\)个国家的\(15\)个候选算法进入第一轮.\(1999\)年\(8\)月,\(5\)个候选算法进入第二轮,分别是\(MARS\)、\(RC6\)、\(Rijndael\)、\(Serpent\)和\(Twofish\)算法.\(2000\)年\(10\)月,由比利时密码学家\(Daemen\)和\(Rijmen\)共同设计的\(Rijndael\)算法获胜,经过讨论和规范后,于\(2001\)年\(12\)月作为\(NIPS 197\)发布,成为高级加密标准\(AES\)。
\(AES\)算法采用\(SPN\)结构,分组长度为\(128-bit(16-Byte)\).将输入按照字节划分,从\(0\)到\(15\)标号,并按照下图所示的顺序排列
| 0 | 4 | 8 | 12 |
|---|---|---|---|
| 1 | 5 | 9 | 13 |
| 2 | 6 | 10 | 14 |
| 3 | 7 | 11 | 15 |
\(AES\)算法共有三个版本,对应不同的密钥长度和安全强度:当密钥长度为\(128-bit\)时,迭代轮数为\(10\)轮;当密钥长度为\(192-bit\)时,迭代轮数为\(12\)轮;当密钥长度为\(256-bit\)时,迭代轮数为\(14\)轮,分别标记为\(AES-128/192/256\).
本文将介绍\(AES\)的内部结构、\(AES\)加密、解密算法的各种编程语言实现、\(AES\)的各种工作模式、\(AES\)的优化策略以及优化结果展示。
二、内部结构
\(AES\)为分组对称密码,针对于其内部结构而言,每轮加密/解密由(逆)字节代换、(逆)行移位、(逆)列混合、轮密钥加四部分组成。所谓分组,即将需要加密的明文\(message\)以\(128-bit\)分成组别,每一组都经过\(10/12/14\)轮上述的四部分操作进行加密。假设\(message = b'abcdefghijklmnopqrstuvwxyz'\),那么\(block0=b'abcdefghijklmnop',bloc1='qrstuvwxyz'\),而\(block1\)由于内容并不够\(128-bit\),所以我们需要进行填充,即block1=b'qrstuvwxyz\x00\x00\x00\x00\x00\x00'。注意,\(AES\)使用的四个组件都是可逆的,针对于解密而言,就是高度对称下的逆向组件的使用,接下来我们将从加密和解密中所用到的四个组件进行详细介绍。
为了方便\(AES\)的展示和实现,我们通过状态矩阵来展示一个分组的加密或解密的详细过程,例如\(block0\)会根据先列后行的顺序被分配到状态矩阵中进行后续的组件操作。
| a | e | i | m |
|---|---|---|---|
| b | f | j | n |
| c | g | k | o |
| d | h | l | p |
针对于明文信息,我们需要按照字节对应为相应的ASCII码进行后的加密/解密操作。
| 0x61 | 0x65 | 0x69 | 0x73 |
|---|---|---|---|
| 0x62 | 0x66 | 0x70 | 0x74 |
| 0x63 | 0x67 | 0x71 | 0x75 |
| 0x64 | 0x68 | 0x72 | 0x76 |
1. 字节代换
字节代换
所谓字节代换,其实就是一个查表操作,而这里的表,我们称之为\(S盒\)。
\(AES\)的\(S盒\)
| 行/列 | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | A | B | C | D | E | F |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 0x63 | 0x7c | 0x77 | 0x7b | 0xf2 | 0x6b | 0x6f | 0xc5 | 0x30 | 0x01 | 0x67 | 0x2b | 0xfe | 0xd7 | 0xab | 0x76 |
| 1 | 0xca | 0x82 | 0xc9 | 0x7d | 0xfa | 0x59 | 0x47 | 0xf0 | 0xad | 0xd4 | 0xa2 | 0xaf | 0x9c | 0xa4 | 0x72 | 0xc0 |
| 2 | 0xb7 | 0xfd | 0x93 | 0x26 | 0x36 | 0x3f | 0xf7 | 0xcc | 0x34 | 0xa5 | 0xe5 | 0xf1 | 0x71 | 0xd8 | 0x31 | 0x15 |
| 3 | 0x04 | 0xc7 | 0x23 | 0xc3 | 0x18 | 0x96 | 0x05 | 0x9a | 0x07 | 0x12 | 0x80 | 0xe2 | 0xeb | 0x27 | 0xb2 | 0x75 |
| 4 | 0x09 | 0x83 | 0x2c | 0x1a | 0x1b | 0x6e | 0x5a | 0xa0 | 0x52 | 0x3b | 0xd6 | 0xb3 | 0x29 | 0xe3 | 0x2f | 0x84 |
| 5 | 0x53 | 0xd1 | 0x00 | 0xed | 0x20 | 0xfc | 0xb1 | 0x5b | 0x6a | 0xcb | 0xbe | 0x39 | 0x4a | 0x4c | 0x58 | 0xcf |
| 6 | 0xd0 | 0xef | 0xaa | 0xfb | 0x43 | 0x4d | 0x33 | 0x85 | 0x45 | 0xf9 | 0x02 | 0x7f | 0x50 | 0x3c | 0x9f | 0xa8 |
| 7 | 0x51 | 0xa3 | 0x40 | 0x8f | 0x92 | 0x9d | 0x38 | 0xf5 | 0xbc | 0xb6, | 0xda | 0x21 | 0x10 | 0xff | 0xf3 | 0xd2 |
| 8 | 0xcd | 0x0c | 0x13 | 0xec | 0x5f | 0x97 | 0x44 | 0x17 | 0xc4 | 0xa7 | 0x7e | 0x3d | 0x64 | 0x5d | 0x19 | 0x73 |
| 9 | 0x60 | 0x81 | 0x4f | 0xdc | 0x22 | 0x2a | 0x90 | 0x88 | 0x46 | 0xee | 0xb8 | 0x14 | 0xde | 0x5e | 0x0b | 0xdb |
| A | 0xe0 | 0x32 | 0x3a | 0x0a | 0x49 | 0x06 | 0x24 | 0x5c | 0xc2 | 0xd3 | 0xac | 0x62 | 0x91 | 0x95 | 0xe4 | 0x78 |
| B | 0xe7 | 0xc8 | 0x37 | 0x6d | 0x8d | 0xd5 | 0x4e | 0xa9 | 0x6c | 0x56 | 0xf4 | 0xea | 0x65 | 0x7a | 0xae | 0x08 |
| C | 0xba | 0x78 | 0x25 | 0x2e | 0x1c | 0xa6 | 0xb4 | 0xc6 | 0xe8 | 0xdd | 0x74 | 0x1f | 0x4b | 0xbd | 0x8b | 0x8a |
| D | 0x70 | 0x3e | 0xb5 | 0x66 | 0x48 | 0a03 | 0xf6 | 0x0e | 0x61 | 0x35 | 0x57 | 0xb9 | 0x86 | 0xc1 | 0x1d | 0x9e |
| E | 0xe1 | 0xf8 | 0x98 | 0x11 | 0x69 | 0xd9 | 0x8e | 0x94 | 0x9b | 0x1e | 0x87 | 0xe9 | 0xce | 0x55 | 0x28 | 0xdf |
| F | 0x8c | 0xa1 | 0x89 | 0x0d | 0xbf | 0xe6 | 0x42 | 0x68 | 0x41 | 0x99 | 0x2d | 0x0f | 0xb0 | 0x54 | 0xbb | 0x16 |
状态矩阵中的元素按照下面的方式映射为一个新的字节:把该字节的高\(4\)位作为行值。低\(4\)位作为列值,去除\(S盒\)或者\(逆S盒\)中对应的行的元素作为输出。例如,加密时,输出的字节\(S1=b'0x12'\),则查\(S盒\)的第\(0x01行\)和\(0x02列\),得到值\(0xc9\),然后替换\(S1\)原有的\(0x12\)为\(0xc9\)。
逆字节代换
操作步骤与实现过程与\(S盒\)一直,是字节代换的逆操作,一般称之为\(逆S盒\)。
\(AES\)的\(逆S盒\)
| 行/列 | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | A | B | C | D | E | F |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 0x52 | 0x09 | 0x6a | 0xd5 | 0x30 | 0x36 | 0xa5 | 0x38 | 0xbf | 0x40 | 0xa3 | 0x9e | 0x81 | 0xf3 | 0xd7 | 0xfb |
| 1 | 0x7c | 0xe3 | 0x39 | 0x82 | 0x9b | 0x2f | 0xff | 0x87 | 0x34 | 0x8e | 0x43 | 0x44 | 0xc4 | 0xde | 0xe9 | 0xcb |
| 2 | 0x54 | 0x7b | 0x94 | 0x32 | 0xa6 | 0xc2 | 0x23 | 0x3d | 0xee | 0x4c | 0x95 | 0x0b | 0x42 | 0xfa | 0xc3 | 0x4e |
| 3 | 0x08 | 0x2e | 0xa1 | 0x66 | 0x28 | 0xd9 | 0x24 | 0xb2 | 0x76 | 0x5b | 0xa2 | 0x49 | 0x6d | 0x8b | 0xd1 | 0x25 |
| 4 | 0x72 | 0xf8 | 0xf6 | 0x64 | 0x86 | 0x68 | 0x98 | 0x16 | 0xd4 | 0xa4 | 0x5c | 0xcc | 0x5d | 0x65 | 0xb6 | 0x92 |
| 5 | 0x6c | 0x70 | 0x48 | 0x50 | 0xfd | 0xed | 0xb9 | 0xda | 0x5e | 0x15 | 0x46 | 0x57 | 0xa7 | 0x8d | 0x9d | 0x84 |
| 6 | 0x90 | 0xd8 | 0xab | 0x00 | 0x8c | 0xbc | 0xd3 | 0x0a | 0xf7 | 0xe4 | 0x58 | 0x05 | 0xb8 | 0xb3 | 0x45 | 0x06 |
| 7 | 0xd0 | 0x2c | 0x1e | 0x8f | 0xca | 0x3f | 0x0f | 0x02 | 0xc1 | 0xaf | 0xbd | 0x03 | 0x01 | 0x13 | 0x8a | 0x6b |
| 8 | 0x3a | 0x91 | 0x11 | 0x41 | 0x4f | 0x67 | 0xdc | 0xea | 0x97 | 0xf2 | 0xcf | 0xce | 0xf0 | 0xb4 | 0xe6 | 0x73 |
| 9 | 0x96 | 0xac | 0x74 | 0x22 | 0xe7 | 0xad | 0x35 | 0x85 | 0xe2 | 0xf9 | 0x37 | 0xe8 | 0x1c | 0x75 | 0xdf | 0x6e |
| A | 0x47 | 0xf1 | 0x1a | 0x71 | 0x1d | 0x29 | 0xc5 | 0x89 | 0x6f | 0xb7 | 0x62 | 0x0e | 0xaa | 0x18 | 0xbe | 0x1b |
| B | 0xfc | 0x56 | 0x3e | 0x4b | 0xc6 | 0xd2 | 0x79 | 0x20 | 0x9a | 0xdb | 0xc0 | 0xfe | 0x78 | 0xcd | 0x5a | 0xf4 |
| C | 0x1f | 0xdd | 0xa8 | 0x33 | 0x88 | 0x07 | 0xc7 | 0x31 | 0xb1 | 0x12 | 0x10 | 0x59 | 0x27 | 0x80 | 0xec | 0x5f |
| D | 0x60 | 0x51 | 0x7f | 0xa9 | 0x19 | 0xb5 | 0x4a | 0x0d | 0x2d | 0xe5 | 0x7a | 0x9f | 0x93 | 0xc9 | 0x9c | 0xef |
| E | 0xa0 | 0xe0 | 0x3b | 0x4d | 0xae | 0x2a | 0xf5 | 0xb0 | 0xc8 | 0xeb | 0xbb | 0x3c | 0x83 | 0x53 | 0x99 | 0x61 |
| F | 0x17 | 0x2b | 0x04 | 0x7e | 0xba | 0x77 | 0xd6 | 0x26 | 0xe1 | 0x69 | 0x14 | 0x63 | 0x55 | 0x21 | 0x0c | 0x7d |
2. 行移位
行移位
行移位是一个简单的做循环移位操作,当密钥长度为\(128-bit\)时,状态矩阵的第一行循环左移\(0-Byte\),第二行循环左移\(1-Byte\),第三行循环左移\(2-Byte\),第四行循环左移\(3-Byte\)。
设行移位前的状态矩阵为:
| S0 | S4 | S8 | SC |
|---|---|---|---|
| S1 | S5 | S9 | SD |
| S2 | S6 | SA | SE |
| S3 | S7 | SB | SF |
行移位后的状态矩阵为:
| S0 | S4 | S8 | SC |
|---|---|---|---|
| S5 | S9 | SD | S1 |
| SA | SE | S2 | S6 |
| SF | S3 | S7 | SB |
逆行移位
行移位的逆变换就是将状态矩阵中的每一行执行相反方向的循环移位即可,例如\(AES-128\)中的状态矩阵的第一行循环右移\(0-Byte\),第二行循环右移\(1-Byte\),第三行循环右移\(2-Byte\),第四行循环右移\(3-Byte\)。
3. 列混合
列变换 -- 矩阵乘法
列混合变换是通过在\(GF(2^8)\)上的矩阵相乘来实现的,经过行移位后的状态矩阵与固定的矩阵相乘,得到混淆后的状态矩阵。其中,加法运算\(a+b=a \oplus b\),乘法运算\(a \cdot b\)在模不可约多项式\(x ^ 8 + x ^ 4 + x ^ 3 + x + 1\).
列变换 -- GF(0xff)乘法
这里介绍一下在\(GF(2^8)\)中的乘法运算以及加法运算,其中加法运算就是按位异或的运算,而乘法相对来说有些复杂。首先,在\(GF(2 ^ 8)\)中每个元素都可以表示为\(8次二进制多项式\),形式为:\(a(x) = a_7x^7+a_6x^6+a_5x^5+a_4x^4+a_3x^3+a_2x^2+a_1x+1\),其中\(a_i\in \{0,1\}\)(系数遵循$GF(2)运算规则:\(0+0=0,0+1=1,1+1=0,0\cdot0=0,0\cdot1=0,1\cdot1=1\).
对于两个 \({GF}(2^8)\) 元素 \(A(x) = \sum_{i=0}^7 a_i x^i\) 和 \(B(x) = \sum_{i=0}^7 b_i x^i\),先执行普通多项式乘法(系数运算在 \({GF}(2)\) 下进行)。乘法结果是次数不超过 15 的多项式:\(C(x) = A(x) \cdot B(x) = \sum_{k=0}^{15} c_k x^k\).其中 \(c_k = \left( \sum_{\substack{i+j=k \\ 0 \leq i,j \leq 7}} a_i b_j \right) \mod 2\)(对每个次数 \(k\),合并所有 \(i+j=k\) 的项,系数模 \(2\))。
由于最终结果需要属于\(GF(2^8)\),需将\(C(x)\)对\(m(x)=x^8+x^4+x^3+x+1\)取模,得到次数<8的余式。利用模运算性质:因为\(m(x)\equiv 0 (mod\ m(x))\),所以\(x^8\equiv x^4 + x^ 3 + x + 1 (mod \ m(x))\ \ \ GF(2)\)中\(-1=1\),故\(x ^ 8 = -(x^4+x^3+x+1) \equiv x ^ 4 + x ^ 3 + x + 1)\)。对于更高次幂(如\(x^k,k≥8\)),可以通过反复替换\(x^8\)为\(x^4+x^3+x+1\)来降低次数。例如:
- \(x ^ 9 \equiv x \cdot x ^ 8 \equiv x \cdot (x ^ 4 + x ^ 3 + x + 1) = x ^ 5 + x ^ 4 + x ^ 2 + x\)
- \(x ^ {10} \equiv x \cdot x ^ 9 \equiv x \cdot (x ^ 5 + x ^ 4 + x ^ 2 + x) = x ^ 6 + x ^ 5 + x ^ 3 + x ^ 2\)
- 以此类推,直到所有次数\(≥8\)的项都被替换为次数\(<8\)的项。
因此,列变换的结果矩阵中的每一列元素与原矩阵的元素之间的关系如下:
\(AES\)中我们采用位操作快速约简:若乘法结果的多项式含\(x^k(k≥8)\),则将结果与\(m(x)\)左移\(k-8\)位的多项式异或(对应\(GF(2)\)的加法)
- 若结果含\(x^8\),则异或\(m(x)=x^8+x^4+x^3+x+1\)
- 若含\(x^9\),则异或\(x \cdot m(x) = x ^ 9 + x ^ 5 + x ^4 + x ^ 2 + x\)
- 重复此过程,知道消去所有次数\(≥8\)的项。
示例演示
对\(a, b \in GF(2 ^ 8)\),\(a + b = a \oplus b\), \(3 \cdot a = 2 \cdot a \oplus a\).记字节\(a\)按比特表示为\(a_7a_6a_5a_4a_3a_2a_1a_0\)(\(a_7\)为最高位)对应\(GF(2 ^ 8)\)中的多项式\(a_7x ^ 7 + a_6 x ^ 6 + a_5x^5+a_4x^4+a_3x^3+a_2x^2+a_1x+a_0\),\(2\)对应\(GF(2 ^ 8)\)中的\(x\),则\(2 \cdot a\)可看作:
由此,我们可以看出在加密过程中,列混合矩阵中出现最大的\(0x03\)我们至多需要两次异或运算与一次左移运算即可完成,也印证了我们前面提到的加密过程中的计算量<解密时的计算量,加密算法较解密算法而言更加常用。
逆列变换
注:加密算法中的列混合矩阵中参数最大为\(0x03\),具体实现时至多采用三次运算(两次异或和一次移位)即可完成\(3 \cdot a\),但是在解密时可能采用的计算就比较复杂(\(0x0E\)),这是因为在工程实践中,加密比解密更加常用,例如,\(CTF\)和\(OFB\)工作模式中,只用到加密算法;分组密码作为部件去构造杂凑函数或消息认证码时,大多数情况下只用到加密算法……
4. 轮密钥加
轮密钥加是\(128-bit\)的状态直接和\(128-bit\)的轮密钥进行逐比特异或运算,如下图所示:
5. 密钥扩展
由于\(AES-128/192/1256\)在密钥扩展中的流程是一致的,但是由于迭代次数不一致,所以扩展出的密钥长度并不一致。但其中的思想是一致的。我们在这里介绍的是\(AES-128\),我们输入的\(key(16-Byte)\),这里我们规定其中的密钥矩阵每一列的\(4-Byte\)为一个字,起初的四个字会被当成主密钥进行扩展,最终扩展出来的为长度为\(44-字\)的密钥,其中每\(4-字\)参与每轮的轮密钥加操作,其中前四个字会在迭代前与明文进行异或。
扩展规则:
设主密钥为\(w[0],w[1],w[2],w[3]\),对\(W\)数组扩充\(40\)个新列,构造总共\(44\)列的扩展密钥数组。
- \(i\)不是\(4\)的倍数
\(w[i]=w[i-4] \oplus w[i-1]\) - \(i\)是\(4\)的倍数
\(w[i]=w[i-4] \oplus T(w[i-1])\)
T函数
T函数由三部分组成:字循环、字节代换、轮常量异或
- 字循环:将字中的4个字节循环左移1个字节
[b₀, b₁, b₂, b₃] → [b₁, b₂, b₃, b₀] - 字节代换:使用\(S盒\)进行字节代换
- 轮常量异或:与轮常量\(Rcon[j]\)进行异或
轮常量表:
| 轮数 j | Rcon[j] | 轮数 j | Rcon[j] |
|---|---|---|---|
| 1 | 01 00 00 00 | 6 | 20 00 00 00 |
| 2 | 02 00 00 00 | 7 | 40 00 00 00 |
| 3 | 04 00 00 00 | 8 | 80 00 00 00 |
| 4 | 08 00 00 00 | 9 | 1B 00 00 00 |
| 5 | 10 00 00 00 | 10 | 36 00 00 00 |
示例演示
初始密钥:3C A1 0B 21 57 F0 19 16 90 2E 13 80 AC C1 07 BD
初始W数组:
- W[0] =
3C A1 0B 21 - W[1] =
57 F0 19 16 - W[2] =
90 2E 13 80 - W[3] =
AC C1 07 BD
计算第一轮子密钥(W[4] ~ W[7])
W[4] 计算 (\(i = 4\),是\(4\)的倍数)
T(W[3])计算:
1. 字循环:AC C1 07 BD → C1 07 BD AC
2. 字节代换:C1 07 BD AC → 78 C5 7A 91
3. 轮常量异或:78 C5 7A 91 ⊕ 01 00 00 00 = 79 C5 7A 91
W[4] = W[0] ⊕ T(W[3]) = 3C A1 0B 21 ⊕ 79 C5 7A 91 = 45 64 71 B0
W[5] 计算 (\(i = 5\),不是\(4\)的倍数)
W[5] = W[1] ⊕ W[4] = 57 F0 19 16 ⊕ 45 64 71 B0 = 12 94 68 A6
W[6] 计算 (\(i = 5\),不是\(4\)的倍数)
W[6] = W[2] ⊕ W[5] = 90 2E 13 80 ⊕ 12 94 68 A6 = 82 BA 7B 26
W[7] 计算 (\(i = 5\),不是\(4\)的倍数)
W[7] = W[3] ⊕ W[6] = AC C1 07 BD ⊕ 82 BA 7B 26 = 2E 7B 7C 9B
第一轮密钥结果:45 64 71 B0 12 94 68 A6 82 BA 7B 26 2E 7B 7C 9B
三、加密 and 解密流程
注:其中解密的流程并不固定,其中针对于每一轮中的逆组件使用,逆行移位操作与逆\(S盒\)操作可以调换先后顺序。
四、AES加密 and 解密代码实现
声明:为了适应大多数人的需求,减少重复知识的冗余。我们将对\(AES-128\)进行编码。其实针对于目前主流的编程语言而言,都已经有对应的库进行了支持。但是这避免不了我们会有使用\(AES\)源码的需求。便于博客管理以及避免文章篇幅很长,在这里只写下参照上述流程基本实现的主流语言编程下的加密与解密代码,针对于课程需要,还会有很多软件实现策略和优化策略,我们放在后面的文章中详细介绍,例如调用AESNI指令、T表法优化等软件实现的优化策略。
C/C++
/*
* File Name: AES_test.cpp
* Author: chen_xing
*/
#include <iostream>
#include <vector>
#include <iomanip>
#include <sstream>
#include <cstdint>
#include <algorithm>
#include <stdexcept>
using namespace std;
class AES
{
private:
static constexpr uint8_t Rcon[14] = {
0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x1B, 0x36, 0x6C, 0xD8, 0xAB, 0x4D};
vector<uint8_t> key;
vector<vector<uint8_t>> W;
vector<uint8_t> S, S_inv;
int rounds, nk;
// GF(2^8)乘法
uint8_t _gmul(uint8_t a, uint8_t b)
{
uint8_t p = 0;
for (int i = 0; i < 8; ++i)
{
if (b & 1)
p ^= a;
bool hi = a & 0x80;
a <<= 1;
if (hi)
a ^= 0x1B;
b >>= 1;
}
return p;
}
// 向量异或
vector<uint8_t> _xor_vec(const vector<uint8_t> &a, const vector<uint8_t> &b)
{
vector<uint8_t> res(a.size());
for (size_t i = 0; i < a.size(); ++i)
res[i] = a[i] ^ b[i];
return res;
}
// 字节转矩阵
vector<vector<uint8_t>> _bytes_to_matrix(const vector<uint8_t> &text)
{
vector<vector<uint8_t>> m(4, vector<uint8_t>(4));
for (int i = 0; i < 16; i++)
m[i % 4][i / 4] = text[i];
return m;
}
// 矩阵转字节
vector<uint8_t> _matrix_to_bytes(const vector<vector<uint8_t>> &m)
{
vector<uint8_t> res(16);
for (int i = 0; i < 16; i++)
res[i] = m[i % 4][i / 4];
return res;
}
// 字循环
vector<uint8_t> _rot_word(vector<uint8_t> w)
{
rotate(w.begin(), w.begin() + 1, w.end());
return w;
}
// 字替换
vector<uint8_t> _sub_word(vector<uint8_t> w)
{
for (auto &x : w)
x = S[x];
return w;
}
// 生成S盒
void _generate_sbox()
{
S.assign(256, 0);
S[0] = 0x63;
uint8_t p = 1, q = 1;
for (int i = 0; i < 255; ++i)
{
p = p ^ (p << 1) ^ ((p & 0x80) ? 0x1B : 0);
q ^= q << 1;
q ^= q << 2;
q ^= q << 4;
q ^= (q & 0x80) ? 0x09 : 0x00;
uint8_t xformed = q ^ ((q << 1) | (q >> 7)) ^
((q << 2) | (q >> 6)) ^
((q << 3) | (q >> 5)) ^
((q << 4) | (q >> 4)) ^ 0x63;
S[p] = xformed;
}
}
// 生成逆S盒
void _generate_sinvbox()
{
S_inv.assign(256, 0);
for (int i = 0; i < 256; ++i)
S_inv[S[i]] = i;
}
// 密钥扩展
void _key_expansion()
{
for (int i = nk; i < (rounds + 1) * 4; ++i)
{
vector<uint8_t> temp = W[i - 1];
if (i % nk == 0)
{
temp = _sub_word(_rot_word(temp));
temp[0] ^= Rcon[i / nk - 1];
}
else if (nk == 8 && i % nk == 4)
{
temp = _sub_word(temp);
}
W.push_back(_xor_vec(W[i - nk], temp));
}
}
// 轮密钥加
vector<vector<uint8_t>> _round_key_add(vector<vector<uint8_t>> &block, int round)
{
for (int c = 0; c < 4; c++)
for (int r = 0; r < 4; r++)
block[r][c] ^= W[round * 4 + c][r];
return block;
}
// 字节替换
vector<vector<uint8_t>> _sub_bytes(vector<vector<uint8_t>> &block)
{
for (auto &row : block)
for (auto &b : row)
b = S[b];
return block;
}
// 行移位
vector<vector<uint8_t>> _shift_rows(vector<vector<uint8_t>> &block)
{
for (int r = 1; r < 4; r++)
rotate(block[r].begin(), block[r].begin() + r, block[r].end());
return block;
}
// 列混合
vector<vector<uint8_t>> _mix_columns(vector<vector<uint8_t>> &s)
{
for (int c = 0; c < 4; c++)
{
uint8_t a0 = s[0][c], a1 = s[1][c], a2 = s[2][c], a3 = s[3][c];
s[0][c] = _gmul(a0, 2) ^ _gmul(a1, 3) ^ a2 ^ a3;
s[1][c] = a0 ^ _gmul(a1, 2) ^ _gmul(a2, 3) ^ a3;
s[2][c] = a0 ^ a1 ^ _gmul(a2, 2) ^ _gmul(a3, 3);
s[3][c] = _gmul(a0, 3) ^ a1 ^ a2 ^ _gmul(a3, 2);
}
return s;
}
// 逆字节替换
vector<vector<uint8_t>> _inv_sub_bytes(vector<vector<uint8_t>> &block)
{
for (auto &row : block)
for (auto &b : row)
b = S_inv[b];
return block;
}
// 逆行移位
vector<vector<uint8_t>> _inv_shift_rows(vector<vector<uint8_t>> &block)
{
for (int r = 1; r < 4; r++)
rotate(block[r].begin(), block[r].begin() + (4 - r), block[r].end());
return block;
}
// 逆列混合
vector<vector<uint8_t>> _inv_mix_columns(vector<vector<uint8_t>> &s)
{
for (int c = 0; c < 4; c++)
{
uint8_t a0 = s[0][c], a1 = s[1][c], a2 = s[2][c], a3 = s[3][c];
s[0][c] = _gmul(a0, 14) ^ _gmul(a1, 11) ^ _gmul(a2, 13) ^ _gmul(a3, 9);
s[1][c] = _gmul(a0, 9) ^ _gmul(a1, 14) ^ _gmul(a2, 11) ^ _gmul(a3, 13);
s[2][c] = _gmul(a0, 13) ^ _gmul(a1, 9) ^ _gmul(a2, 14) ^ _gmul(a3, 11);
s[3][c] = _gmul(a0, 11) ^ _gmul(a1, 13) ^ _gmul(a2, 9) ^ _gmul(a3, 14);
}
return s;
}
public:
AES(const vector<uint8_t> &key_)
{
key = key_;
size_t len = key.size();
if (len == 16)
{
rounds = 10;
nk = 4;
}
else if (len == 24)
{
rounds = 12;
nk = 6;
}
else if (len == 32)
{
rounds = 14;
nk = 8;
}
else
throw invalid_argument("Invalid AES key length");
_generate_sbox();
_generate_sinvbox();
// 初始化密钥扩展
for (size_t i = 0; i < len; i += 4)
W.push_back(vector<uint8_t>(key.begin() + i, key.begin() + i + 4));
_key_expansion();
}
// PKCS7 填充
vector<uint8_t> pad(const vector<uint8_t> &data)
{
size_t pad_len = 16 - (data.size() % 16);
vector<uint8_t> res = data;
res.insert(res.end(), pad_len, static_cast<uint8_t>(pad_len));
return res;
}
// 去除 PKCS7 填充
vector<uint8_t> unpad(const vector<uint8_t> &data)
{
if (data.empty() || data.size() % 16 != 0)
throw runtime_error("Invalid data length for unpad");
uint8_t pad_len = data.back();
if (pad_len == 0 || pad_len > 16)
throw runtime_error("Invalid padding");
for (size_t i = data.size() - pad_len; i < data.size(); ++i)
if (data[i] != pad_len)
throw runtime_error("Bad padding byte");
return vector<uint8_t>(data.begin(), data.end() - pad_len);
}
vector<uint8_t> encrypt(const vector<uint8_t> &plaintext)
{
vector<uint8_t> data = pad(plaintext);
vector<uint8_t> out;
for (size_t i = 0; i < data.size(); i += 16)
{
auto block = vector<uint8_t>(data.begin() + i, data.begin() + i + 16);
auto state = _bytes_to_matrix(block);
// 初始轮密钥加
state = _round_key_add(state, 0);
// 主轮
for (int r = 1; r < rounds; ++r)
{
state = _sub_bytes(state);
state = _shift_rows(state);
state = _mix_columns(state);
state = _round_key_add(state, r);
}
// 最终轮
state = _sub_bytes(state);
state = _shift_rows(state);
state = _round_key_add(state, rounds);
auto encrypted_block = _matrix_to_bytes(state);
out.insert(out.end(), encrypted_block.begin(), encrypted_block.end());
}
return out;
}
vector<uint8_t> decrypt(const vector<uint8_t> &ciphertext)
{
if (ciphertext.size() % 16 != 0)
throw invalid_argument("Ciphertext length must be multiple of 16");
vector<uint8_t> out;
for (size_t i = 0; i < ciphertext.size(); i += 16)
{
auto block = vector<uint8_t>(ciphertext.begin() + i, ciphertext.begin() + i + 16);
auto state = _bytes_to_matrix(block);
// 初始轮
state = _round_key_add(state, rounds);
state = _inv_shift_rows(state);
state = _inv_sub_bytes(state);
// 主轮
for (int r = rounds - 1; r > 0; --r)
{
state = _round_key_add(state, r);
state = _inv_mix_columns(state);
state = _inv_shift_rows(state);
state = _inv_sub_bytes(state);
}
// 最终轮
state = _round_key_add(state, 0);
auto decrypted_block = _matrix_to_bytes(state);
out.insert(out.end(), decrypted_block.begin(), decrypted_block.end());
}
return unpad(out);
}
};
int main()
{
string msg = "He who conquers others is strong; he who conquers himself is mighty.";
vector<uint8_t> data(msg.begin(), msg.end());
// key = chenxing_AES_c++
vector<uint8_t> key = {'c', 'h', 'e', 'n', 'x', 'i', 'n', 'g', '_', 'A', 'E', 'S', '_', 'c', '+', '+'};
AES aes(key);
// 加密
auto enc = aes.encrypt(data);
cout << "Encrypted (hex): ";
for (auto c : enc)
cout << hex << setw(2) << setfill('0') << (int)c;
cout << dec << endl;
// 解密
auto dec = aes.decrypt(enc);
cout << "Decrypted: " << string(dec.begin(), dec.end()) << endl;
// 验证
cout << "Match: " << (msg == string(dec.begin(), dec.end())) << endl;
return 0;
}
Python
class AES():
Rcon = [0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x1B, 0x36, 0x6C, 0xD8, 0xAB, 0x4D]
def __init__(self, key: bytes):
key_length = len(key)
if key_length == 16: # 128位
self.rounds = 10
self.nk = 4 # 密钥字数
elif key_length == 24: # 192位
self.rounds = 12
self.nk = 6
elif key_length == 32: # 256位
self.rounds = 14
self.nk = 8
else:
raise ValueError(f"Invalid AES key length: {key_length} bytes")
self.key = key
self.S = self._generate_sbox()
self.S_inv = self._generate_sinvbox()
self.W = [list(key[i:i + 4]) for i in range(0, len(key), 4)]
self._key_expansion()
def _xor(self, a, b=None):
if b is None:
assert len(a) > 0
return a[0] ^ self._xor(a[1:]) if len(a) > 1 else a[0]
assert len(a) == len(b)
return [x ^ y for x, y in zip(a, b)]
def _gmul(self, a, b):
p = 0
while b:
if b & 1:
p ^= a
a = a << 1
if a >> 8:
a ^= 0x11B # AES不可约多项式
b >>= 1
return p
def _multiply(self, a, b):
return self._xor([self._gmul(x, y) for x, y in zip(a, b)])
def _matrix_multiply(self, const, state):
result = [[0] * 4 for _ in range(4)]
for i in range(4):
for j in range(4):
for k in range(4):
result[i][j] ^= self._gmul(const[i][k], state[k][j])
return result
def _permutation(self, lis, table):
return [table[i] for i in lis]
def _block_permutation(self, block, table):
return [[table[block[i][j]] for j in range(4)] for i in range(4)]
def _left_shift(self, block, n):
return block[n:] + block[:n]
def _rot_word(self, word):
return word[1:] + word[:1]
def _sub_word(self, word):
return [self.S[b] for b in word]
def _bytes_to_matrix(self, text: bytes) -> list[list[int]]:
return [[text[j * 4 + i] for j in range(4)] for i in range(4)]
def _matrix_to_bytes(self, block: list[list[int]]) -> bytes:
return bytes([block[j][i] for i in range(4) for j in range(4)])
def _generate_sbox(self):
S = [0x63] + [0] * 255
r = lambda x, s: (x << s | x >> 8 - s) % 256
p, q = 1, 1
for _ in range(255):
p = (p ^ (p * 2) ^ [27, 0][p < 128]) % 256
q ^= q * 2
q ^= q * 4
q ^= q * 16
q &= 255
q ^= [9, 0][q < 128]
S[p] = q ^ r(q, 1) ^ r(q, 2) ^ r(q, 3) ^ r(q, 4) ^ 99
return S
def _generate_sinvbox(self):
S_inv = [0] * 256
for i in range(256):
S_inv[self.S[i]] = i
return S_inv
def _key_expansion(self):
for i in range(self.nk, (self.rounds + 1) * 4):
temp = self.W[i - 1].copy()
if i % self.nk == 0:
temp = self._sub_word(self._rot_word(temp))
temp[0] ^= self.Rcon[i // self.nk - 1]
elif self.nk == 8 and i % self.nk == 4:
temp = self._sub_word(temp)
self.W.append(self._xor(self.W[i - self.nk], temp))
def _row_shift(self, block: list[list[int]]) -> list[list[int]]:
return [self._left_shift(block[i], i) for i in range(4)]
def _inv_row_shift(self, block: list[list[int]]) -> list[list[int]]:
return [self._left_shift(block[i], 4 - i) for i in range(4)]
def _column_mix(self, block: list[list[int]]) -> list[list[int]]:
return self._matrix_multiply([self._left_shift([2, 3, 1, 1], 4 - i) for i in range(4)], block)
def _inv_column_mix(self, block: list[list[int]]) -> list[list[int]]:
return self._matrix_multiply([self._left_shift([14, 11, 13, 0x9], 4 - i) for i in range(4)], block)
def _round_key_add(self, block: list[list[int]], key: list[list[int]]) -> list[list[int]]:
return [self._xor(block[i], [key[j][i] for j in range(4)]) for i in range(4)]
def _encrypt(self, block: list[list[int]]) -> list[list[int]]:
block = self._round_key_add(block, [self.W[i] for i in range(4)])
for i in range(1, self.rounds):
block = self._block_permutation(block, self.S)
block = self._row_shift(block)
block = self._column_mix(block)
block = self._round_key_add(block, [self.W[i * 4 + j] for j in range(4)])
block = self._block_permutation(block, self.S)
block = self._row_shift(block)
block = self._round_key_add(block, [self.W[i] for i in range(self.rounds * 4, (self.rounds + 1) * 4)])
return block
def _decrypt(self, block: list[list[int]]) -> list[list[int]]:
block = self._round_key_add(block, [self.W[i] for i in range(self.rounds * 4, (self.rounds + 1) * 4)])
for i in range(self.rounds - 1, 0, -1):
block = self._inv_row_shift(block)
block = self._block_permutation(block, self.S_inv)
block = self._round_key_add(block, [self.W[i * 4 + j] for j in range(4)])
block = self._inv_column_mix(block)
block = self._inv_row_shift(block)
block = self._block_permutation(block, self.S_inv)
block = self._round_key_add(block, [self.W[i] for i in range(4)])
return block
def encrypt(self, plaintext: bytes) -> bytes:
assert len(plaintext) % 16 == 0, ValueError(f"Incorrect AES plaintext length ({len(plaintext)} bytes)")
ciphertext = b''
for i in range(0, len(plaintext), 16):
block = plaintext[i:i + 16]
block = self._bytes_to_matrix(block)
block = self._encrypt(block)
ciphertext += self._matrix_to_bytes(block)
return ciphertext
def decrypt(self, ciphertext: bytes) -> bytes:
assert len(ciphertext) % 16 == 0, ValueError(f"Incorrect AES ciphertext length ({len(ciphertext)} bytes)")
plaintext = b''
for i in range(0, len(ciphertext), 16):
block = ciphertext[i:i + 16]
block = self._bytes_to_matrix(block)
block = self._decrypt(block)
plaintext += self._matrix_to_bytes(block)
return plaintext
if __name__ == "__main__":
key = b'chen_xing_AES_py'
aes = AES(key)
# 胜人者有力,自强者胜。
test_plaintext = b'He who conquers others is strong; he who conquers himself is mighty.'
# 填充到16字节
if len(test_plaintext) % 16 != 0:
test_plaintext = test_plaintext.ljust(16 * ((len(test_plaintext) // 16) + 1), b'\x00')
encrypted = aes.encrypt(test_plaintext)
decrypted = aes.decrypt(encrypted)
print(f"Decrypted: {decrypted}")
print(f"Test - Match: {test_plaintext == decrypted}")
Java
import java.util.Arrays;
public class AES {
private static final int[] RCON = {
0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x1B, 0x36, 0x6C, 0xD8, 0xAB, 0x4D
};
private final int Nk; // 密钥字数
private final int Nr; // 轮数
private final int Nb = 4; // 状态矩阵列数
private final byte[] key; // 原始密钥
private final int[][] w; // 扩展密钥
private final byte[] S = new byte[256]; // S盒
private final byte[] S_INV = new byte[256]; // 逆S盒
public AES(byte[] key) {
// 根据密钥长度确定参数
int len = key.length;
if (len == 16) {
Nk = 4;
Nr = 10;
} else if (len == 24) {
Nk = 6;
Nr = 12;
} else if (len == 32) {
Nk = 8;
Nr = 14;
} else {
throw new IllegalArgumentException("Invalid AES key length: " + len);
}
this.key = Arrays.copyOf(key, key.length);
generateSBox();
generateSInvBox();
// 密钥扩展:总字数 = Nb * (Nr + 1)
int totalWords = Nb * (Nr + 1);
w = new int[totalWords][4];
keyExpansion();
}
// ---------- GF(2^8) 运算辅助函数 ----------
/**
* 在GF(2^8)中乘以x(即2)
*/
private static int xtime(int a) {
a <<= 1;
if ((a & 0x100) != 0) {
a ^= 0x11b; // AES不可约多项式
}
return a & 0xff;
}
/**
* 在GF(2^8)中乘法运算
*/
private static int mul(int a, int b) {
int res = 0;
while (b != 0) {
if ((b & 1) != 0) {
res ^= a;
}
a = xtime(a);
b >>>= 1;
}
return res & 0xff;
}
// ---------- S盒生成 ----------
/**
* 生成AES S盒
*/
private void generateSBox() {
// 使用指数运算 a^254 在GF(2^8)中计算乘法逆元
S[0] = 0x63;
for (int i = 1; i < 256; i++) {
int inv = gfInv(i);
int x = inv;
// 仿射变换
int y = x ^ ((x << 1) | (x >>> 7))
^ ((x << 2) | (x >>> 6))
^ ((x << 3) | (x >>> 5))
^ ((x << 4) | (x >>> 4))
^ 0x63;
S[i] = (byte) (y & 0xff);
}
}
/**
* 生成逆S盒
*/
private void generateSInvBox() {
for (int i = 0; i < 256; i++) {
S_INV[S[i] & 0xFF] = (byte) i;
}
}
/**
* 在GF(2^8)中计算乘法逆元
*/
private int gfInv(int a) {
if (a == 0) return 0;
int result = 1;
int base = a;
int exp = 254; // a^254 是逆元
while (exp > 0) {
if ((exp & 1) != 0) {
result = mul(result, base);
}
base = mul(base, base);
exp >>= 1;
}
return result & 0xff;
}
// ---------- 密钥扩展 ----------
/**
* AES密钥扩展算法
*/
private void keyExpansion() {
// 用原始密钥填充前Nk个字
for (int i = 0; i < Nk; i++) {
for (int j = 0; j < 4; j++) {
w[i][j] = key[4 * i + j] & 0xFF;
}
}
// 扩展剩余的字
for (int i = Nk; i < w.length; i++) {
int[] temp = Arrays.copyOf(w[i - 1], 4);
if (i % Nk == 0) {
temp = subWord(rotWord(temp));
temp[0] ^= RCON[i / Nk - 1];
} else if (Nk > 6 && (i % Nk) == 4) {
temp = subWord(temp);
}
for (int j = 0; j < 4; j++) {
w[i][j] = w[i - Nk][j] ^ temp[j];
}
}
}
/**
* 字循环:将字向左循环移位
*/
private int[] rotWord(int[] word) {
return new int[]{word[1], word[2], word[3], word[0]};
}
/**
* 字替换:使用S盒替换字中的每个字节
*/
private int[] subWord(int[] word) {
int[] out = new int[4];
for (int i = 0; i < 4; i++) {
out[i] = S[word[i]] & 0xFF;
}
return out;
}
// ---------- 状态矩阵辅助函数 ----------
/**
* 将字节数组转换为状态矩阵
*/
private int[][] bytesToState(byte[] in, int offset) {
int[][] state = new int[4][4];
for (int i = 0; i < 16; i++) {
state[i % 4][i / 4] = in[offset + i] & 0xFF;
}
return state;
}
/**
* 将状态矩阵转换为字节数组
*/
private void stateToBytes(int[][] state, byte[] out, int offset) {
for (int i = 0; i < 16; i++) {
out[offset + i] = (byte) (state[i % 4][i / 4]);
}
}
// ---------- 核心变换操作 ----------
/**
* 字节替换
*/
private void subBytes(int[][] state) {
for (int row = 0; row < 4; row++) {
for (int col = 0; col < 4; col++) {
state[row][col] = S[state[row][col]] & 0xff;
}
}
}
/**
* 逆字节替换
*/
private void invSubBytes(int[][] state) {
for (int row = 0; row < 4; row++) {
for (int col = 0; col < 4; col++) {
state[row][col] = S_INV[state[row][col]] & 0xff;
}
}
}
/**
* 行移位
*/
private void shiftRows(int[][] state) {
for (int row = 1; row < 4; row++) {
int[] temp = Arrays.copyOf(state[row], 4);
for (int col = 0; col < 4; col++) {
state[row][col] = temp[(col + row) % 4];
}
}
}
/**
* 逆行移位
*/
private void invShiftRows(int[][] state) {
for (int row = 1; row < 4; row++) {
int[] temp = Arrays.copyOf(state[row], 4);
for (int col = 0; col < 4; col++) {
state[row][(col + row) % 4] = temp[col];
}
}
}
/**
* 列混合
*/
private void mixColumns(int[][] state) {
for (int col = 0; col < 4; col++) {
int a0 = state[0][col];
int a1 = state[1][col];
int a2 = state[2][col];
int a3 = state[3][col];
state[0][col] = (mul(2, a0) ^ mul(3, a1) ^ a2 ^ a3) & 0xff;
state[1][col] = (a0 ^ mul(2, a1) ^ mul(3, a2) ^ a3) & 0xff;
state[2][col] = (a0 ^ a1 ^ mul(2, a2) ^ mul(3, a3)) & 0xff;
state[3][col] = (mul(3, a0) ^ a1 ^ a2 ^ mul(2, a3)) & 0xff;
}
}
/**
* 逆列混合
*/
private void invMixColumns(int[][] state) {
for (int col = 0; col < 4; col++) {
int a0 = state[0][col];
int a1 = state[1][col];
int a2 = state[2][col];
int a3 = state[3][col];
state[0][col] = (mul(0x0e, a0) ^ mul(0x0b, a1) ^ mul(0x0d, a2) ^ mul(0x09, a3)) & 0xff;
state[1][col] = (mul(0x09, a0) ^ mul(0x0e, a1) ^ mul(0x0b, a2) ^ mul(0x0d, a3)) & 0xff;
state[2][col] = (mul(0x0d, a0) ^ mul(0x09, a1) ^ mul(0x0e, a2) ^ mul(0x0b, a3)) & 0xff;
state[3][col] = (mul(0x0b, a0) ^ mul(0x0d, a1) ^ mul(0x09, a2) ^ mul(0x0e, a3)) & 0xff;
}
}
/**
* 轮密钥加
*/
private void addRoundKey(int[][] state, int round) {
for (int col = 0; col < 4; col++) {
for (int row = 0; row < 4; row++) {
state[row][col] ^= w[round * 4 + col][row];
}
}
}
// ---------- 单块加密/解密 ----------
/**
* 加密单个16字节块
*/
public byte[] encryptBlock(byte[] inputBlock) {
if (inputBlock.length != 16) {
throw new IllegalArgumentException("Block size must be 16 bytes");
}
int[][] state = bytesToState(inputBlock, 0);
// 初始轮密钥加
addRoundKey(state, 0);
// 主轮(前Nr-1轮)
for (int round = 1; round < Nr; round++) {
subBytes(state);
shiftRows(state);
mixColumns(state);
addRoundKey(state, round);
}
// 最终轮(无列混合)
subBytes(state);
shiftRows(state);
addRoundKey(state, Nr);
byte[] output = new byte[16];
stateToBytes(state, output, 0);
return output;
}
/**
* 解密单个16字节块
*/
public byte[] decryptBlock(byte[] inputBlock) {
if (inputBlock.length != 16) {
throw new IllegalArgumentException("Block size must be 16 bytes");
}
int[][] state = bytesToState(inputBlock, 0);
// 初始轮
addRoundKey(state, Nr);
invShiftRows(state);
invSubBytes(state);
// 主轮(前Nr-1轮)
for (int round = Nr - 1; round > 0; round--) {
addRoundKey(state, round);
invMixColumns(state);
invShiftRows(state);
invSubBytes(state);
}
// 最终轮
addRoundKey(state, 0);
byte[] output = new byte[16];
stateToBytes(state, output, 0);
return output;
}
// ---------- PKCS#7 填充 ----------
/**
* PKCS#7填充
*/
private byte[] pkcs7Pad(byte[] data) {
int padLength = 16 - (data.length % 16);
if (padLength == 0) {
padLength = 16;
}
byte[] padded = Arrays.copyOf(data, data.length + padLength);
Arrays.fill(padded, data.length, padded.length, (byte) padLength);
return padded;
}
/**
* 去除PKCS#7填充
*/
private byte[] pkcs7Unpad(byte[] data) {
if (data.length == 0 || data.length % 16 != 0) {
throw new IllegalArgumentException("Invalid padded data length");
}
int paddingLength = data[data.length - 1] & 0xFF;
if (paddingLength < 1 || paddingLength > 16) {
throw new IllegalArgumentException("Invalid padding length");
}
// 验证填充字节
for (int i = data.length - paddingLength; i < data.length; i++) {
if ((data[i] & 0xFF) != paddingLength) {
throw new IllegalArgumentException("Invalid padding bytes");
}
}
return Arrays.copyOf(data, data.length - paddingLength);
}
// ---------- 多块加密/解密 (ECB模式) ----------
/**
* 加密任意长度数据
*/
public byte[] encrypt(byte[] plaintext) {
byte[] paddedData = pkcs7Pad(plaintext);
byte[] ciphertext = new byte[paddedData.length];
for (int i = 0; i < paddedData.length; i += 16) {
byte[] block = Arrays.copyOfRange(paddedData, i, i + 16);
byte[] encryptedBlock = encryptBlock(block);
System.arraycopy(encryptedBlock, 0, ciphertext, i, 16);
}
return ciphertext;
}
/**
* 解密任意长度数据
*/
public byte[] decrypt(byte[] ciphertext) {
if (ciphertext.length % 16 != 0) {
throw new IllegalArgumentException("Ciphertext length must be multiple of 16 bytes");
}
byte[] decryptedData = new byte[ciphertext.length];
for (int i = 0; i < ciphertext.length; i += 16) {
byte[] block = Arrays.copyOfRange(ciphertext, i, i + 16);
byte[] decryptedBlock = decryptBlock(block);
System.arraycopy(decryptedBlock, 0, decryptedData, i, 16);
}
return pkcs7Unpad(decryptedData);
}
// ---------- 辅助函数 ----------
/**
* 将字节数组转换为十六进制字符串
*/
public static String bytesToHex(byte[] bytes) {
StringBuilder hexString = new StringBuilder();
for (byte b : bytes) {
hexString.append(String.format("%02x", b & 0xFF));
}
return hexString.toString();
}
// ---------- 测试函数 ----------
public static void main(String[] args) throws Exception {
byte[] key = "chenxingAES_JAVA".getBytes("UTF-8");
AES aes = new AES(key);
String plainText = "He who conquers others is strong; he who conquers himself is mighty.";
byte[] plaintextBytes = plainText.getBytes("UTF-8");
// 加密
byte[] ciphertext = aes.encrypt(plaintextBytes);
System.out.println("Ciphertext (hex): " + bytesToHex(ciphertext));
// 解密
byte[] decryptedBytes = aes.decrypt(ciphertext);
String decryptedText = new String(decryptedBytes, "UTF-8");
System.out.println("Decrypted text: " + decryptedText);
// 验证
boolean isMatch = Arrays.equals(plaintextBytes, decryptedBytes);
System.out.println("Match: " + isMatch);
if (!isMatch) {
throw new RuntimeException("Encryption/decryption mismatch!");
}
}
}
Rust
use std::convert::TryInto;
struct AES {
rounds: usize,
nk: usize,
w: Vec<[u8; 4]>,
sbox: [u8; 256],
inv_sbox: [u8; 256],
}
impl AES {
const RCON: [u8; 14] = [
0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x1B, 0x36, 0x6C, 0xD8, 0xAB, 0x4D
];
pub fn new(key: &[u8]) -> Self {
let (nk, rounds) = match key.len() {
16 => (4, 10),
24 => (6, 12),
32 => (8, 14),
_ => panic!("Invalid AES key length: {} bytes", key.len()),
};
let mut aes = AES {
rounds,
nk,
w: Vec::new(),
sbox: [0u8; 256],
inv_sbox: [0u8; 256],
};
aes.generate_sbox();
aes.generate_inv_sbox();
aes.key_expansion(key);
aes
}
/// 在GF(2^8)中乘法运算
fn gmul(mut a: u8, mut b: u8) -> u8 {
let mut product: u8 = 0;
while b != 0 {
if b & 1 != 0 {
product ^= a;
}
let high_bit = a & 0x80;
a <<= 1;
if high_bit != 0 {
a ^= 0x1b;
}
b >>= 1;
}
product
}
/// 在GF(2^8)中指数运算
fn gf_pow(mut base: u8, mut exponent: u32) -> u8 {
let mut result: u8 = 1;
while exponent != 0 {
if exponent & 1 != 0 {
result = AES::gmul(result, base);
}
base = AES::gmul(base, base);
exponent >>= 1;
}
result
}
/// 生成AES S盒
fn generate_sbox(&mut self) {
self.sbox[0] = 0x63;
for i in 1..256 {
let inv = if i == 0 { 0 } else { AES::gf_pow(i as u8, 254) };
let x = inv;
// 仿射变换
let y = x ^ x.rotate_left(1)
^ x.rotate_left(2)
^ x.rotate_left(3)
^ x.rotate_left(4)
^ 0x63;
self.sbox[i] = y;
}
}
/// 生成逆S盒
fn generate_inv_sbox(&mut self) {
for i in 0..256 {
self.inv_sbox[self.sbox[i] as usize] = i as u8;
}
}
/// 字循环:将字向左循环移位
fn rot_word(word: [u8; 4]) -> [u8; 4] {
[word[1], word[2], word[3], word[0]]
}
/// 字替换:使用S盒替换字中的每个字节
fn sub_word(&self, word: [u8; 4]) -> [u8; 4] {
[
self.sbox[word[0] as usize],
self.sbox[word[1] as usize],
self.sbox[word[2] as usize],
self.sbox[word[3] as usize],
]
}
/// AES密钥扩展算法
fn key_expansion(&mut self, key: &[u8]) {
let nb = 4; // 状态矩阵列数
let total_words = nb * (self.rounds + 1);
self.w = vec![[0u8; 4]; total_words];
// 用原始密钥填充前Nk个字
for i in 0..self.nk {
self.w[i].copy_from_slice(&key[4 * i..4 * i + 4]);
}
// 扩展剩余的字
for i in self.nk..total_words {
let mut temp = self.w[i - 1];
if i % self.nk == 0 {
temp = Self::rot_word(temp);
temp = self.sub_word(temp);
temp[0] ^= Self::RCON[i / self.nk - 1];
} else if self.nk > 6 && i % self.nk == 4 {
temp = self.sub_word(temp);
}
for j in 0..4 {
self.w[i][j] = self.w[i - self.nk][j] ^ temp[j];
}
}
}
/// 轮密钥加
fn add_round_key(state: &mut [[u8; 4]; 4], expanded_key: &[[u8; 4]], round: usize) {
for col in 0..4 {
for row in 0..4 {
state[row][col] ^= expanded_key[round * 4 + col][row];
}
}
}
/// 字节替换
fn sub_bytes(&self, state: &mut [[u8; 4]; 4]) {
for row in 0..4 {
for col in 0..4 {
state[row][col] = self.sbox[state[row][col] as usize];
}
}
}
/// 逆字节替换
fn inv_sub_bytes(&self, state: &mut [[u8; 4]; 4]) {
for row in 0..4 {
for col in 0..4 {
state[row][col] = self.inv_sbox[state[row][col] as usize];
}
}
}
/// 行移位
fn shift_rows(&self, state: &mut [[u8; 4]; 4]) {
for row in 1..4 {
state[row].rotate_left(row);
}
}
/// 逆行移位
fn inv_shift_rows(&self, state: &mut [[u8; 4]; 4]) {
for row in 1..4 {
state[row].rotate_right(row);
}
}
/// 单列混合
fn mix_single_column(&self, column: [u8; 4]) -> [u8; 4] {
[
Self::gmul(2, column[0]) ^ Self::gmul(3, column[1]) ^ column[2] ^ column[3],
column[0] ^ Self::gmul(2, column[1]) ^ Self::gmul(3, column[2]) ^ column[3],
column[0] ^ column[1] ^ Self::gmul(2, column[2]) ^ Self::gmul(3, column[3]),
Self::gmul(3, column[0]) ^ column[1] ^ column[2] ^ Self::gmul(2, column[3]),
]
}
/// 列混合
fn mix_columns(&self, state: &mut [[u8; 4]; 4]) {
for col in 0..4 {
let column = [state[0][col], state[1][col], state[2][col], state[3][col]];
let mixed_column = self.mix_single_column(column);
for row in 0..4 {
state[row][col] = mixed_column[row];
}
}
}
/// 逆列混合
fn inv_mix_columns(&self, state: &mut [[u8; 4]; 4]) {
for col in 0..4 {
let column = [state[0][col], state[1][col], state[2][col], state[3][col]];
state[0][col] = Self::gmul(0x0e, column[0])
^ Self::gmul(0x0b, column[1])
^ Self::gmul(0x0d, column[2])
^ Self::gmul(0x09, column[3]);
state[1][col] = Self::gmul(0x09, column[0])
^ Self::gmul(0x0e, column[1])
^ Self::gmul(0x0b, column[2])
^ Self::gmul(0x0d, column[3]);
state[2][col] = Self::gmul(0x0d, column[0])
^ Self::gmul(0x09, column[1])
^ Self::gmul(0x0e, column[2])
^ Self::gmul(0x0b, column[3]);
state[3][col] = Self::gmul(0x0b, column[0])
^ Self::gmul(0x0d, column[1])
^ Self::gmul(0x09, column[2])
^ Self::gmul(0x0e, column[3]);
}
}
/// 将16字节块转换为4x4状态矩阵
fn bytes_to_state(block: &[u8; 16]) -> [[u8; 4]; 4] {
let mut state = [[0u8; 4]; 4];
for i in 0..16 {
state[i % 4][i / 4] = block[i];
}
state
}
/// 将4x4状态矩阵转换为16字节块
fn state_to_bytes(state: &[[u8; 4]; 4]) -> [u8; 16] {
let mut block = [0u8; 16];
for i in 0..16 {
block[i] = state[i % 4][i / 4];
}
block
}
/// 加密单个16字节块
pub fn encrypt_block(&self, block: &[u8; 16]) -> [u8; 16] {
let mut state = Self::bytes_to_state(block);
// 初始轮密钥加
Self::add_round_key(&mut state, &self.w, 0);
// 主轮(前Nr-1轮)
for round in 1..self.rounds {
self.sub_bytes(&mut state);
self.shift_rows(&mut state);
self.mix_columns(&mut state);
Self::add_round_key(&mut state, &self.w, round);
}
// 最终轮(无列混合)
self.sub_bytes(&mut state);
self.shift_rows(&mut state);
Self::add_round_key(&mut state, &self.w, self.rounds);
Self::state_to_bytes(&state)
}
/// 解密单个16字节块
pub fn decrypt_block(&self, block: &[u8; 16]) -> [u8; 16] {
let mut state = Self::bytes_to_state(block);
// 初始轮
Self::add_round_key(&mut state, &self.w, self.rounds);
self.inv_shift_rows(&mut state);
self.inv_sub_bytes(&mut state);
// 主轮(前Nr-1轮)
for round in (1..self.rounds).rev() {
Self::add_round_key(&mut state, &self.w, round);
self.inv_mix_columns(&mut state);
self.inv_shift_rows(&mut state);
self.inv_sub_bytes(&mut state);
}
// 最终轮
Self::add_round_key(&mut state, &self.w, 0);
Self::state_to_bytes(&state)
}
// ---------- PKCS7 填充辅助函数 ----------
/// PKCS7填充
fn pkcs7_pad(data: &[u8]) -> Vec<u8> {
let padding_length = 16 - (data.len() % 16);
let mut padded_data = Vec::with_capacity(data.len() + padding_length);
padded_data.extend_from_slice(data);
padded_data.extend(std::iter::repeat(padding_length as u8).take(padding_length));
padded_data
}
/// 去除PKCS7填充
fn pkcs7_unpad(data: &[u8]) -> Vec<u8> {
if data.is_empty() || data.len() % 16 != 0 {
panic!("Invalid padded data: length must be non-zero and multiple of 16");
}
let padding_length = data[data.len() - 1] as usize;
if padding_length == 0 || padding_length > 16 {
panic!("Invalid padding length: {}", padding_length);
}
// 验证所有填充字节
for &byte in &data[data.len() - padding_length..] {
if byte as usize != padding_length {
panic!("Invalid padding bytes");
}
}
data[..data.len() - padding_length].to_vec()
}
pub fn encrypt(&self, plaintext: &[u8]) -> Vec<u8> {
let padded_data = Self::pkcs7_pad(plaintext);
let mut ciphertext = Vec::with_capacity(padded_data.len());
for chunk in padded_data.chunks(16) {
let block: [u8; 16] = chunk.try_into().unwrap();
ciphertext.extend_from_slice(&self.encrypt_block(&block));
}
ciphertext
}
pub fn decrypt(&self, ciphertext: &[u8]) -> Vec<u8> {
if ciphertext.len() % 16 != 0 {
panic!("Ciphertext length must be multiple of 16 bytes");
}
let mut decrypted_data = Vec::with_capacity(ciphertext.len());
for chunk in ciphertext.chunks(16) {
let block: [u8; 16] = chunk.try_into().unwrap();
decrypted_data.extend_from_slice(&self.decrypt_block(&block));
}
Self::pkcs7_unpad(&decrypted_data)
}
}
/// 将字节数组转换为十六进制字符串
fn bytes_to_hex(bytes: &[u8]) -> String {
bytes.iter()
.map(|byte| format!("{:02x}", byte))
.collect::<Vec<String>>()
.join("")
}
fn main() {
let key = b"chenxingAES_Rust";
let aes = AES::new(key);
let plaintext = b"He who conquers others is strong; he who conquers himself is mighty.";
// 加密
let ciphertext = aes.encrypt(plaintext);
println!("Ciphertext (hex): {}", bytes_to_hex(&ciphertext));
// 解密
let decrypted = aes.decrypt(&ciphertext);
println!("Decrypted: {}", String::from_utf8_lossy(&decrypted));
// 验证
assert_eq!(plaintext, decrypted.as_slice());
println!("Successfully!!!");
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_aes_round_trip() {
let key = b"chenxingAES_Rust";
let aes = AES::new(key);
let test_data = b"Test AES encryption and decryption";
let encrypted = aes.encrypt(test_data);
let decrypted = aes.decrypt(&encrypted);
assert_eq!(test_data, decrypted.as_slice());
}
#[test]
fn test_single_block() {
let key = b"chenxingAES_Rust";
let aes = AES::new(key);
let block = [0x00, 0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77,
0x88, 0x99, 0xaa, 0xbb, 0xcc, 0xdd, 0xee, 0xff];
let encrypted = aes.encrypt_block(&block);
let decrypted = aes.decrypt_block(&encrypted);
assert_eq!(block, decrypted);
}
}
参考资料
很庆幸,在各位前辈的基础上,我们得以很便捷地得到相关知识的干货,也很推荐读者读一下其他师傅的相关博客,完善程度以及过程分析超级高。

AES内部结构 + 基础代码实现
浙公网安备 33010602011771号