1/*
2 * linux/arch/arm64/crypto/aes-glue.c - wrapper code for ARMv8 AES
3 *
4 * Copyright (C) 2013 Linaro Ltd <ard.biesheuvel@linaro.org>
5 *
6 * This program is free software; you can redistribute it and/or modify
7 * it under the terms of the GNU General Public License version 2 as
8 * published by the Free Software Foundation.
9 */
10
11#include <asm/neon.h>
12#include <asm/hwcap.h>
13#include <crypto/aes.h>
14#include <crypto/ablk_helper.h>
15#include <crypto/algapi.h>
16#include <linux/module.h>
17#include <linux/cpufeature.h>
18
19#include "aes-ce-setkey.h"
20
21#ifdef USE_V8_CRYPTO_EXTENSIONS
22#define MODE			"ce"
23#define PRIO			300
24#define aes_setkey		ce_aes_setkey
25#define aes_expandkey		ce_aes_expandkey
26#define aes_ecb_encrypt		ce_aes_ecb_encrypt
27#define aes_ecb_decrypt		ce_aes_ecb_decrypt
28#define aes_cbc_encrypt		ce_aes_cbc_encrypt
29#define aes_cbc_decrypt		ce_aes_cbc_decrypt
30#define aes_ctr_encrypt		ce_aes_ctr_encrypt
31#define aes_xts_encrypt		ce_aes_xts_encrypt
32#define aes_xts_decrypt		ce_aes_xts_decrypt
33MODULE_DESCRIPTION("AES-ECB/CBC/CTR/XTS using ARMv8 Crypto Extensions");
34#else
35#define MODE			"neon"
36#define PRIO			200
37#define aes_setkey		crypto_aes_set_key
38#define aes_expandkey		crypto_aes_expand_key
39#define aes_ecb_encrypt		neon_aes_ecb_encrypt
40#define aes_ecb_decrypt		neon_aes_ecb_decrypt
41#define aes_cbc_encrypt		neon_aes_cbc_encrypt
42#define aes_cbc_decrypt		neon_aes_cbc_decrypt
43#define aes_ctr_encrypt		neon_aes_ctr_encrypt
44#define aes_xts_encrypt		neon_aes_xts_encrypt
45#define aes_xts_decrypt		neon_aes_xts_decrypt
46MODULE_DESCRIPTION("AES-ECB/CBC/CTR/XTS using ARMv8 NEON");
47MODULE_ALIAS_CRYPTO("ecb(aes)");
48MODULE_ALIAS_CRYPTO("cbc(aes)");
49MODULE_ALIAS_CRYPTO("ctr(aes)");
50MODULE_ALIAS_CRYPTO("xts(aes)");
51#endif
52
53MODULE_AUTHOR("Ard Biesheuvel <ard.biesheuvel@linaro.org>");
54MODULE_LICENSE("GPL v2");
55
56/* defined in aes-modes.S */
57asmlinkage void aes_ecb_encrypt(u8 out[], u8 const in[], u8 const rk[],
58				int rounds, int blocks, int first);
59asmlinkage void aes_ecb_decrypt(u8 out[], u8 const in[], u8 const rk[],
60				int rounds, int blocks, int first);
61
62asmlinkage void aes_cbc_encrypt(u8 out[], u8 const in[], u8 const rk[],
63				int rounds, int blocks, u8 iv[], int first);
64asmlinkage void aes_cbc_decrypt(u8 out[], u8 const in[], u8 const rk[],
65				int rounds, int blocks, u8 iv[], int first);
66
67asmlinkage void aes_ctr_encrypt(u8 out[], u8 const in[], u8 const rk[],
68				int rounds, int blocks, u8 ctr[], int first);
69
70asmlinkage void aes_xts_encrypt(u8 out[], u8 const in[], u8 const rk1[],
71				int rounds, int blocks, u8 const rk2[], u8 iv[],
72				int first);
73asmlinkage void aes_xts_decrypt(u8 out[], u8 const in[], u8 const rk1[],
74				int rounds, int blocks, u8 const rk2[], u8 iv[],
75				int first);
76
77struct crypto_aes_xts_ctx {
78	struct crypto_aes_ctx key1;
79	struct crypto_aes_ctx __aligned(8) key2;
80};
81
82static int xts_set_key(struct crypto_tfm *tfm, const u8 *in_key,
83		       unsigned int key_len)
84{
85	struct crypto_aes_xts_ctx *ctx = crypto_tfm_ctx(tfm);
86	int ret;
87
88	ret = aes_expandkey(&ctx->key1, in_key, key_len / 2);
89	if (!ret)
90		ret = aes_expandkey(&ctx->key2, &in_key[key_len / 2],
91				    key_len / 2);
92	if (!ret)
93		return 0;
94
95	tfm->crt_flags |= CRYPTO_TFM_RES_BAD_KEY_LEN;
96	return -EINVAL;
97}
98
99static int ecb_encrypt(struct blkcipher_desc *desc, struct scatterlist *dst,
100		       struct scatterlist *src, unsigned int nbytes)
101{
102	struct crypto_aes_ctx *ctx = crypto_blkcipher_ctx(desc->tfm);
103	int err, first, rounds = 6 + ctx->key_length / 4;
104	struct blkcipher_walk walk;
105	unsigned int blocks;
106
107	desc->flags &= ~CRYPTO_TFM_REQ_MAY_SLEEP;
108	blkcipher_walk_init(&walk, dst, src, nbytes);
109	err = blkcipher_walk_virt(desc, &walk);
110
111	kernel_neon_begin();
112	for (first = 1; (blocks = (walk.nbytes / AES_BLOCK_SIZE)); first = 0) {
113		aes_ecb_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
114				(u8 *)ctx->key_enc, rounds, blocks, first);
115		err = blkcipher_walk_done(desc, &walk, walk.nbytes % AES_BLOCK_SIZE);
116	}
117	kernel_neon_end();
118	return err;
119}
120
121static int ecb_decrypt(struct blkcipher_desc *desc, struct scatterlist *dst,
122		       struct scatterlist *src, unsigned int nbytes)
123{
124	struct crypto_aes_ctx *ctx = crypto_blkcipher_ctx(desc->tfm);
125	int err, first, rounds = 6 + ctx->key_length / 4;
126	struct blkcipher_walk walk;
127	unsigned int blocks;
128
129	desc->flags &= ~CRYPTO_TFM_REQ_MAY_SLEEP;
130	blkcipher_walk_init(&walk, dst, src, nbytes);
131	err = blkcipher_walk_virt(desc, &walk);
132
133	kernel_neon_begin();
134	for (first = 1; (blocks = (walk.nbytes / AES_BLOCK_SIZE)); first = 0) {
135		aes_ecb_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
136				(u8 *)ctx->key_dec, rounds, blocks, first);
137		err = blkcipher_walk_done(desc, &walk, walk.nbytes % AES_BLOCK_SIZE);
138	}
139	kernel_neon_end();
140	return err;
141}
142
143static int cbc_encrypt(struct blkcipher_desc *desc, struct scatterlist *dst,
144		       struct scatterlist *src, unsigned int nbytes)
145{
146	struct crypto_aes_ctx *ctx = crypto_blkcipher_ctx(desc->tfm);
147	int err, first, rounds = 6 + ctx->key_length / 4;
148	struct blkcipher_walk walk;
149	unsigned int blocks;
150
151	desc->flags &= ~CRYPTO_TFM_REQ_MAY_SLEEP;
152	blkcipher_walk_init(&walk, dst, src, nbytes);
153	err = blkcipher_walk_virt(desc, &walk);
154
155	kernel_neon_begin();
156	for (first = 1; (blocks = (walk.nbytes / AES_BLOCK_SIZE)); first = 0) {
157		aes_cbc_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
158				(u8 *)ctx->key_enc, rounds, blocks, walk.iv,
159				first);
160		err = blkcipher_walk_done(desc, &walk, walk.nbytes % AES_BLOCK_SIZE);
161	}
162	kernel_neon_end();
163	return err;
164}
165
166static int cbc_decrypt(struct blkcipher_desc *desc, struct scatterlist *dst,
167		       struct scatterlist *src, unsigned int nbytes)
168{
169	struct crypto_aes_ctx *ctx = crypto_blkcipher_ctx(desc->tfm);
170	int err, first, rounds = 6 + ctx->key_length / 4;
171	struct blkcipher_walk walk;
172	unsigned int blocks;
173
174	desc->flags &= ~CRYPTO_TFM_REQ_MAY_SLEEP;
175	blkcipher_walk_init(&walk, dst, src, nbytes);
176	err = blkcipher_walk_virt(desc, &walk);
177
178	kernel_neon_begin();
179	for (first = 1; (blocks = (walk.nbytes / AES_BLOCK_SIZE)); first = 0) {
180		aes_cbc_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
181				(u8 *)ctx->key_dec, rounds, blocks, walk.iv,
182				first);
183		err = blkcipher_walk_done(desc, &walk, walk.nbytes % AES_BLOCK_SIZE);
184	}
185	kernel_neon_end();
186	return err;
187}
188
189static int ctr_encrypt(struct blkcipher_desc *desc, struct scatterlist *dst,
190		       struct scatterlist *src, unsigned int nbytes)
191{
192	struct crypto_aes_ctx *ctx = crypto_blkcipher_ctx(desc->tfm);
193	int err, first, rounds = 6 + ctx->key_length / 4;
194	struct blkcipher_walk walk;
195	int blocks;
196
197	desc->flags &= ~CRYPTO_TFM_REQ_MAY_SLEEP;
198	blkcipher_walk_init(&walk, dst, src, nbytes);
199	err = blkcipher_walk_virt_block(desc, &walk, AES_BLOCK_SIZE);
200
201	first = 1;
202	kernel_neon_begin();
203	while ((blocks = (walk.nbytes / AES_BLOCK_SIZE))) {
204		aes_ctr_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
205				(u8 *)ctx->key_enc, rounds, blocks, walk.iv,
206				first);
207		first = 0;
208		nbytes -= blocks * AES_BLOCK_SIZE;
209		if (nbytes && nbytes == walk.nbytes % AES_BLOCK_SIZE)
210			break;
211		err = blkcipher_walk_done(desc, &walk,
212					  walk.nbytes % AES_BLOCK_SIZE);
213	}
214	if (nbytes) {
215		u8 *tdst = walk.dst.virt.addr + blocks * AES_BLOCK_SIZE;
216		u8 *tsrc = walk.src.virt.addr + blocks * AES_BLOCK_SIZE;
217		u8 __aligned(8) tail[AES_BLOCK_SIZE];
218
219		/*
220		 * Minimum alignment is 8 bytes, so if nbytes is <= 8, we need
221		 * to tell aes_ctr_encrypt() to only read half a block.
222		 */
223		blocks = (nbytes <= 8) ? -1 : 1;
224
225		aes_ctr_encrypt(tail, tsrc, (u8 *)ctx->key_enc, rounds,
226				blocks, walk.iv, first);
227		memcpy(tdst, tail, nbytes);
228		err = blkcipher_walk_done(desc, &walk, 0);
229	}
230	kernel_neon_end();
231
232	return err;
233}
234
235static int xts_encrypt(struct blkcipher_desc *desc, struct scatterlist *dst,
236		       struct scatterlist *src, unsigned int nbytes)
237{
238	struct crypto_aes_xts_ctx *ctx = crypto_blkcipher_ctx(desc->tfm);
239	int err, first, rounds = 6 + ctx->key1.key_length / 4;
240	struct blkcipher_walk walk;
241	unsigned int blocks;
242
243	desc->flags &= ~CRYPTO_TFM_REQ_MAY_SLEEP;
244	blkcipher_walk_init(&walk, dst, src, nbytes);
245	err = blkcipher_walk_virt(desc, &walk);
246
247	kernel_neon_begin();
248	for (first = 1; (blocks = (walk.nbytes / AES_BLOCK_SIZE)); first = 0) {
249		aes_xts_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
250				(u8 *)ctx->key1.key_enc, rounds, blocks,
251				(u8 *)ctx->key2.key_enc, walk.iv, first);
252		err = blkcipher_walk_done(desc, &walk, walk.nbytes % AES_BLOCK_SIZE);
253	}
254	kernel_neon_end();
255
256	return err;
257}
258
259static int xts_decrypt(struct blkcipher_desc *desc, struct scatterlist *dst,
260		       struct scatterlist *src, unsigned int nbytes)
261{
262	struct crypto_aes_xts_ctx *ctx = crypto_blkcipher_ctx(desc->tfm);
263	int err, first, rounds = 6 + ctx->key1.key_length / 4;
264	struct blkcipher_walk walk;
265	unsigned int blocks;
266
267	desc->flags &= ~CRYPTO_TFM_REQ_MAY_SLEEP;
268	blkcipher_walk_init(&walk, dst, src, nbytes);
269	err = blkcipher_walk_virt(desc, &walk);
270
271	kernel_neon_begin();
272	for (first = 1; (blocks = (walk.nbytes / AES_BLOCK_SIZE)); first = 0) {
273		aes_xts_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
274				(u8 *)ctx->key1.key_dec, rounds, blocks,
275				(u8 *)ctx->key2.key_enc, walk.iv, first);
276		err = blkcipher_walk_done(desc, &walk, walk.nbytes % AES_BLOCK_SIZE);
277	}
278	kernel_neon_end();
279
280	return err;
281}
282
283static struct crypto_alg aes_algs[] = { {
284	.cra_name		= "__ecb-aes-" MODE,
285	.cra_driver_name	= "__driver-ecb-aes-" MODE,
286	.cra_priority		= 0,
287	.cra_flags		= CRYPTO_ALG_TYPE_BLKCIPHER |
288				  CRYPTO_ALG_INTERNAL,
289	.cra_blocksize		= AES_BLOCK_SIZE,
290	.cra_ctxsize		= sizeof(struct crypto_aes_ctx),
291	.cra_alignmask		= 7,
292	.cra_type		= &crypto_blkcipher_type,
293	.cra_module		= THIS_MODULE,
294	.cra_blkcipher = {
295		.min_keysize	= AES_MIN_KEY_SIZE,
296		.max_keysize	= AES_MAX_KEY_SIZE,
297		.ivsize		= AES_BLOCK_SIZE,
298		.setkey		= aes_setkey,
299		.encrypt	= ecb_encrypt,
300		.decrypt	= ecb_decrypt,
301	},
302}, {
303	.cra_name		= "__cbc-aes-" MODE,
304	.cra_driver_name	= "__driver-cbc-aes-" MODE,
305	.cra_priority		= 0,
306	.cra_flags		= CRYPTO_ALG_TYPE_BLKCIPHER |
307				  CRYPTO_ALG_INTERNAL,
308	.cra_blocksize		= AES_BLOCK_SIZE,
309	.cra_ctxsize		= sizeof(struct crypto_aes_ctx),
310	.cra_alignmask		= 7,
311	.cra_type		= &crypto_blkcipher_type,
312	.cra_module		= THIS_MODULE,
313	.cra_blkcipher = {
314		.min_keysize	= AES_MIN_KEY_SIZE,
315		.max_keysize	= AES_MAX_KEY_SIZE,
316		.ivsize		= AES_BLOCK_SIZE,
317		.setkey		= aes_setkey,
318		.encrypt	= cbc_encrypt,
319		.decrypt	= cbc_decrypt,
320	},
321}, {
322	.cra_name		= "__ctr-aes-" MODE,
323	.cra_driver_name	= "__driver-ctr-aes-" MODE,
324	.cra_priority		= 0,
325	.cra_flags		= CRYPTO_ALG_TYPE_BLKCIPHER |
326				  CRYPTO_ALG_INTERNAL,
327	.cra_blocksize		= 1,
328	.cra_ctxsize		= sizeof(struct crypto_aes_ctx),
329	.cra_alignmask		= 7,
330	.cra_type		= &crypto_blkcipher_type,
331	.cra_module		= THIS_MODULE,
332	.cra_blkcipher = {
333		.min_keysize	= AES_MIN_KEY_SIZE,
334		.max_keysize	= AES_MAX_KEY_SIZE,
335		.ivsize		= AES_BLOCK_SIZE,
336		.setkey		= aes_setkey,
337		.encrypt	= ctr_encrypt,
338		.decrypt	= ctr_encrypt,
339	},
340}, {
341	.cra_name		= "__xts-aes-" MODE,
342	.cra_driver_name	= "__driver-xts-aes-" MODE,
343	.cra_priority		= 0,
344	.cra_flags		= CRYPTO_ALG_TYPE_BLKCIPHER |
345				  CRYPTO_ALG_INTERNAL,
346	.cra_blocksize		= AES_BLOCK_SIZE,
347	.cra_ctxsize		= sizeof(struct crypto_aes_xts_ctx),
348	.cra_alignmask		= 7,
349	.cra_type		= &crypto_blkcipher_type,
350	.cra_module		= THIS_MODULE,
351	.cra_blkcipher = {
352		.min_keysize	= 2 * AES_MIN_KEY_SIZE,
353		.max_keysize	= 2 * AES_MAX_KEY_SIZE,
354		.ivsize		= AES_BLOCK_SIZE,
355		.setkey		= xts_set_key,
356		.encrypt	= xts_encrypt,
357		.decrypt	= xts_decrypt,
358	},
359}, {
360	.cra_name		= "ecb(aes)",
361	.cra_driver_name	= "ecb-aes-" MODE,
362	.cra_priority		= PRIO,
363	.cra_flags		= CRYPTO_ALG_TYPE_ABLKCIPHER|CRYPTO_ALG_ASYNC,
364	.cra_blocksize		= AES_BLOCK_SIZE,
365	.cra_ctxsize		= sizeof(struct async_helper_ctx),
366	.cra_alignmask		= 7,
367	.cra_type		= &crypto_ablkcipher_type,
368	.cra_module		= THIS_MODULE,
369	.cra_init		= ablk_init,
370	.cra_exit		= ablk_exit,
371	.cra_ablkcipher = {
372		.min_keysize	= AES_MIN_KEY_SIZE,
373		.max_keysize	= AES_MAX_KEY_SIZE,
374		.ivsize		= AES_BLOCK_SIZE,
375		.setkey		= ablk_set_key,
376		.encrypt	= ablk_encrypt,
377		.decrypt	= ablk_decrypt,
378	}
379}, {
380	.cra_name		= "cbc(aes)",
381	.cra_driver_name	= "cbc-aes-" MODE,
382	.cra_priority		= PRIO,
383	.cra_flags		= CRYPTO_ALG_TYPE_ABLKCIPHER|CRYPTO_ALG_ASYNC,
384	.cra_blocksize		= AES_BLOCK_SIZE,
385	.cra_ctxsize		= sizeof(struct async_helper_ctx),
386	.cra_alignmask		= 7,
387	.cra_type		= &crypto_ablkcipher_type,
388	.cra_module		= THIS_MODULE,
389	.cra_init		= ablk_init,
390	.cra_exit		= ablk_exit,
391	.cra_ablkcipher = {
392		.min_keysize	= AES_MIN_KEY_SIZE,
393		.max_keysize	= AES_MAX_KEY_SIZE,
394		.ivsize		= AES_BLOCK_SIZE,
395		.setkey		= ablk_set_key,
396		.encrypt	= ablk_encrypt,
397		.decrypt	= ablk_decrypt,
398	}
399}, {
400	.cra_name		= "ctr(aes)",
401	.cra_driver_name	= "ctr-aes-" MODE,
402	.cra_priority		= PRIO,
403	.cra_flags		= CRYPTO_ALG_TYPE_ABLKCIPHER|CRYPTO_ALG_ASYNC,
404	.cra_blocksize		= 1,
405	.cra_ctxsize		= sizeof(struct async_helper_ctx),
406	.cra_alignmask		= 7,
407	.cra_type		= &crypto_ablkcipher_type,
408	.cra_module		= THIS_MODULE,
409	.cra_init		= ablk_init,
410	.cra_exit		= ablk_exit,
411	.cra_ablkcipher = {
412		.min_keysize	= AES_MIN_KEY_SIZE,
413		.max_keysize	= AES_MAX_KEY_SIZE,
414		.ivsize		= AES_BLOCK_SIZE,
415		.setkey		= ablk_set_key,
416		.encrypt	= ablk_encrypt,
417		.decrypt	= ablk_decrypt,
418	}
419}, {
420	.cra_name		= "xts(aes)",
421	.cra_driver_name	= "xts-aes-" MODE,
422	.cra_priority		= PRIO,
423	.cra_flags		= CRYPTO_ALG_TYPE_ABLKCIPHER|CRYPTO_ALG_ASYNC,
424	.cra_blocksize		= AES_BLOCK_SIZE,
425	.cra_ctxsize		= sizeof(struct async_helper_ctx),
426	.cra_alignmask		= 7,
427	.cra_type		= &crypto_ablkcipher_type,
428	.cra_module		= THIS_MODULE,
429	.cra_init		= ablk_init,
430	.cra_exit		= ablk_exit,
431	.cra_ablkcipher = {
432		.min_keysize	= 2 * AES_MIN_KEY_SIZE,
433		.max_keysize	= 2 * AES_MAX_KEY_SIZE,
434		.ivsize		= AES_BLOCK_SIZE,
435		.setkey		= ablk_set_key,
436		.encrypt	= ablk_encrypt,
437		.decrypt	= ablk_decrypt,
438	}
439} };
440
441static int __init aes_init(void)
442{
443	return crypto_register_algs(aes_algs, ARRAY_SIZE(aes_algs));
444}
445
446static void __exit aes_exit(void)
447{
448	crypto_unregister_algs(aes_algs, ARRAY_SIZE(aes_algs));
449}
450
451#ifdef USE_V8_CRYPTO_EXTENSIONS
452module_cpu_feature_match(AES, aes_init);
453#else
454module_init(aes_init);
455#endif
456module_exit(aes_exit);
457