1/*
2 * linux/arch/arm/crypto/aesbs-glue.c - glue code for NEON bit sliced AES
3 *
4 * Copyright (C) 2013 Linaro Ltd <ard.biesheuvel@linaro.org>
5 *
6 * This program is free software; you can redistribute it and/or modify
7 * it under the terms of the GNU General Public License version 2 as
8 * published by the Free Software Foundation.
9 */
10
11#include <asm/neon.h>
12#include <crypto/aes.h>
13#include <crypto/ablk_helper.h>
14#include <crypto/algapi.h>
15#include <linux/module.h>
16
17#include "aes_glue.h"
18
19#define BIT_SLICED_KEY_MAXSIZE	(128 * (AES_MAXNR - 1) + 2 * AES_BLOCK_SIZE)
20
21struct BS_KEY {
22	struct AES_KEY	rk;
23	int		converted;
24	u8 __aligned(8)	bs[BIT_SLICED_KEY_MAXSIZE];
25} __aligned(8);
26
27asmlinkage void bsaes_enc_key_convert(u8 out[], struct AES_KEY const *in);
28asmlinkage void bsaes_dec_key_convert(u8 out[], struct AES_KEY const *in);
29
30asmlinkage void bsaes_cbc_encrypt(u8 const in[], u8 out[], u32 bytes,
31				  struct BS_KEY *key, u8 iv[]);
32
33asmlinkage void bsaes_ctr32_encrypt_blocks(u8 const in[], u8 out[], u32 blocks,
34					   struct BS_KEY *key, u8 const iv[]);
35
36asmlinkage void bsaes_xts_encrypt(u8 const in[], u8 out[], u32 bytes,
37				  struct BS_KEY *key, u8 tweak[]);
38
39asmlinkage void bsaes_xts_decrypt(u8 const in[], u8 out[], u32 bytes,
40				  struct BS_KEY *key, u8 tweak[]);
41
42struct aesbs_cbc_ctx {
43	struct AES_KEY	enc;
44	struct BS_KEY	dec;
45};
46
47struct aesbs_ctr_ctx {
48	struct BS_KEY	enc;
49};
50
51struct aesbs_xts_ctx {
52	struct BS_KEY	enc;
53	struct BS_KEY	dec;
54	struct AES_KEY	twkey;
55};
56
57static int aesbs_cbc_set_key(struct crypto_tfm *tfm, const u8 *in_key,
58			     unsigned int key_len)
59{
60	struct aesbs_cbc_ctx *ctx = crypto_tfm_ctx(tfm);
61	int bits = key_len * 8;
62
63	if (private_AES_set_encrypt_key(in_key, bits, &ctx->enc)) {
64		tfm->crt_flags |= CRYPTO_TFM_RES_BAD_KEY_LEN;
65		return -EINVAL;
66	}
67	ctx->dec.rk = ctx->enc;
68	private_AES_set_decrypt_key(in_key, bits, &ctx->dec.rk);
69	ctx->dec.converted = 0;
70	return 0;
71}
72
73static int aesbs_ctr_set_key(struct crypto_tfm *tfm, const u8 *in_key,
74			     unsigned int key_len)
75{
76	struct aesbs_ctr_ctx *ctx = crypto_tfm_ctx(tfm);
77	int bits = key_len * 8;
78
79	if (private_AES_set_encrypt_key(in_key, bits, &ctx->enc.rk)) {
80		tfm->crt_flags |= CRYPTO_TFM_RES_BAD_KEY_LEN;
81		return -EINVAL;
82	}
83	ctx->enc.converted = 0;
84	return 0;
85}
86
87static int aesbs_xts_set_key(struct crypto_tfm *tfm, const u8 *in_key,
88			     unsigned int key_len)
89{
90	struct aesbs_xts_ctx *ctx = crypto_tfm_ctx(tfm);
91	int bits = key_len * 4;
92
93	if (private_AES_set_encrypt_key(in_key, bits, &ctx->enc.rk)) {
94		tfm->crt_flags |= CRYPTO_TFM_RES_BAD_KEY_LEN;
95		return -EINVAL;
96	}
97	ctx->dec.rk = ctx->enc.rk;
98	private_AES_set_decrypt_key(in_key, bits, &ctx->dec.rk);
99	private_AES_set_encrypt_key(in_key + key_len / 2, bits, &ctx->twkey);
100	ctx->enc.converted = ctx->dec.converted = 0;
101	return 0;
102}
103
104static int aesbs_cbc_encrypt(struct blkcipher_desc *desc,
105			     struct scatterlist *dst,
106			     struct scatterlist *src, unsigned int nbytes)
107{
108	struct aesbs_cbc_ctx *ctx = crypto_blkcipher_ctx(desc->tfm);
109	struct blkcipher_walk walk;
110	int err;
111
112	blkcipher_walk_init(&walk, dst, src, nbytes);
113	err = blkcipher_walk_virt(desc, &walk);
114
115	while (walk.nbytes) {
116		u32 blocks = walk.nbytes / AES_BLOCK_SIZE;
117		u8 *src = walk.src.virt.addr;
118
119		if (walk.dst.virt.addr == walk.src.virt.addr) {
120			u8 *iv = walk.iv;
121
122			do {
123				crypto_xor(src, iv, AES_BLOCK_SIZE);
124				AES_encrypt(src, src, &ctx->enc);
125				iv = src;
126				src += AES_BLOCK_SIZE;
127			} while (--blocks);
128			memcpy(walk.iv, iv, AES_BLOCK_SIZE);
129		} else {
130			u8 *dst = walk.dst.virt.addr;
131
132			do {
133				crypto_xor(walk.iv, src, AES_BLOCK_SIZE);
134				AES_encrypt(walk.iv, dst, &ctx->enc);
135				memcpy(walk.iv, dst, AES_BLOCK_SIZE);
136				src += AES_BLOCK_SIZE;
137				dst += AES_BLOCK_SIZE;
138			} while (--blocks);
139		}
140		err = blkcipher_walk_done(desc, &walk, walk.nbytes % AES_BLOCK_SIZE);
141	}
142	return err;
143}
144
145static int aesbs_cbc_decrypt(struct blkcipher_desc *desc,
146			     struct scatterlist *dst,
147			     struct scatterlist *src, unsigned int nbytes)
148{
149	struct aesbs_cbc_ctx *ctx = crypto_blkcipher_ctx(desc->tfm);
150	struct blkcipher_walk walk;
151	int err;
152
153	blkcipher_walk_init(&walk, dst, src, nbytes);
154	err = blkcipher_walk_virt_block(desc, &walk, 8 * AES_BLOCK_SIZE);
155
156	while ((walk.nbytes / AES_BLOCK_SIZE) >= 8) {
157		kernel_neon_begin();
158		bsaes_cbc_encrypt(walk.src.virt.addr, walk.dst.virt.addr,
159				  walk.nbytes, &ctx->dec, walk.iv);
160		kernel_neon_end();
161		err = blkcipher_walk_done(desc, &walk, walk.nbytes % AES_BLOCK_SIZE);
162	}
163	while (walk.nbytes) {
164		u32 blocks = walk.nbytes / AES_BLOCK_SIZE;
165		u8 *dst = walk.dst.virt.addr;
166		u8 *src = walk.src.virt.addr;
167		u8 bk[2][AES_BLOCK_SIZE];
168		u8 *iv = walk.iv;
169
170		do {
171			if (walk.dst.virt.addr == walk.src.virt.addr)
172				memcpy(bk[blocks & 1], src, AES_BLOCK_SIZE);
173
174			AES_decrypt(src, dst, &ctx->dec.rk);
175			crypto_xor(dst, iv, AES_BLOCK_SIZE);
176
177			if (walk.dst.virt.addr == walk.src.virt.addr)
178				iv = bk[blocks & 1];
179			else
180				iv = src;
181
182			dst += AES_BLOCK_SIZE;
183			src += AES_BLOCK_SIZE;
184		} while (--blocks);
185		err = blkcipher_walk_done(desc, &walk, walk.nbytes % AES_BLOCK_SIZE);
186	}
187	return err;
188}
189
190static void inc_be128_ctr(__be32 ctr[], u32 addend)
191{
192	int i;
193
194	for (i = 3; i >= 0; i--, addend = 1) {
195		u32 n = be32_to_cpu(ctr[i]) + addend;
196
197		ctr[i] = cpu_to_be32(n);
198		if (n >= addend)
199			break;
200	}
201}
202
203static int aesbs_ctr_encrypt(struct blkcipher_desc *desc,
204			     struct scatterlist *dst, struct scatterlist *src,
205			     unsigned int nbytes)
206{
207	struct aesbs_ctr_ctx *ctx = crypto_blkcipher_ctx(desc->tfm);
208	struct blkcipher_walk walk;
209	u32 blocks;
210	int err;
211
212	blkcipher_walk_init(&walk, dst, src, nbytes);
213	err = blkcipher_walk_virt_block(desc, &walk, 8 * AES_BLOCK_SIZE);
214
215	while ((blocks = walk.nbytes / AES_BLOCK_SIZE)) {
216		u32 tail = walk.nbytes % AES_BLOCK_SIZE;
217		__be32 *ctr = (__be32 *)walk.iv;
218		u32 headroom = UINT_MAX - be32_to_cpu(ctr[3]);
219
220		/* avoid 32 bit counter overflow in the NEON code */
221		if (unlikely(headroom < blocks)) {
222			blocks = headroom + 1;
223			tail = walk.nbytes - blocks * AES_BLOCK_SIZE;
224		}
225		kernel_neon_begin();
226		bsaes_ctr32_encrypt_blocks(walk.src.virt.addr,
227					   walk.dst.virt.addr, blocks,
228					   &ctx->enc, walk.iv);
229		kernel_neon_end();
230		inc_be128_ctr(ctr, blocks);
231
232		nbytes -= blocks * AES_BLOCK_SIZE;
233		if (nbytes && nbytes == tail && nbytes <= AES_BLOCK_SIZE)
234			break;
235
236		err = blkcipher_walk_done(desc, &walk, tail);
237	}
238	if (walk.nbytes) {
239		u8 *tdst = walk.dst.virt.addr + blocks * AES_BLOCK_SIZE;
240		u8 *tsrc = walk.src.virt.addr + blocks * AES_BLOCK_SIZE;
241		u8 ks[AES_BLOCK_SIZE];
242
243		AES_encrypt(walk.iv, ks, &ctx->enc.rk);
244		if (tdst != tsrc)
245			memcpy(tdst, tsrc, nbytes);
246		crypto_xor(tdst, ks, nbytes);
247		err = blkcipher_walk_done(desc, &walk, 0);
248	}
249	return err;
250}
251
252static int aesbs_xts_encrypt(struct blkcipher_desc *desc,
253			     struct scatterlist *dst,
254			     struct scatterlist *src, unsigned int nbytes)
255{
256	struct aesbs_xts_ctx *ctx = crypto_blkcipher_ctx(desc->tfm);
257	struct blkcipher_walk walk;
258	int err;
259
260	blkcipher_walk_init(&walk, dst, src, nbytes);
261	err = blkcipher_walk_virt_block(desc, &walk, 8 * AES_BLOCK_SIZE);
262
263	/* generate the initial tweak */
264	AES_encrypt(walk.iv, walk.iv, &ctx->twkey);
265
266	while (walk.nbytes) {
267		kernel_neon_begin();
268		bsaes_xts_encrypt(walk.src.virt.addr, walk.dst.virt.addr,
269				  walk.nbytes, &ctx->enc, walk.iv);
270		kernel_neon_end();
271		err = blkcipher_walk_done(desc, &walk, walk.nbytes % AES_BLOCK_SIZE);
272	}
273	return err;
274}
275
276static int aesbs_xts_decrypt(struct blkcipher_desc *desc,
277			     struct scatterlist *dst,
278			     struct scatterlist *src, unsigned int nbytes)
279{
280	struct aesbs_xts_ctx *ctx = crypto_blkcipher_ctx(desc->tfm);
281	struct blkcipher_walk walk;
282	int err;
283
284	blkcipher_walk_init(&walk, dst, src, nbytes);
285	err = blkcipher_walk_virt_block(desc, &walk, 8 * AES_BLOCK_SIZE);
286
287	/* generate the initial tweak */
288	AES_encrypt(walk.iv, walk.iv, &ctx->twkey);
289
290	while (walk.nbytes) {
291		kernel_neon_begin();
292		bsaes_xts_decrypt(walk.src.virt.addr, walk.dst.virt.addr,
293				  walk.nbytes, &ctx->dec, walk.iv);
294		kernel_neon_end();
295		err = blkcipher_walk_done(desc, &walk, walk.nbytes % AES_BLOCK_SIZE);
296	}
297	return err;
298}
299
300static struct crypto_alg aesbs_algs[] = { {
301	.cra_name		= "__cbc-aes-neonbs",
302	.cra_driver_name	= "__driver-cbc-aes-neonbs",
303	.cra_priority		= 0,
304	.cra_flags		= CRYPTO_ALG_TYPE_BLKCIPHER |
305				  CRYPTO_ALG_INTERNAL,
306	.cra_blocksize		= AES_BLOCK_SIZE,
307	.cra_ctxsize		= sizeof(struct aesbs_cbc_ctx),
308	.cra_alignmask		= 7,
309	.cra_type		= &crypto_blkcipher_type,
310	.cra_module		= THIS_MODULE,
311	.cra_blkcipher = {
312		.min_keysize	= AES_MIN_KEY_SIZE,
313		.max_keysize	= AES_MAX_KEY_SIZE,
314		.ivsize		= AES_BLOCK_SIZE,
315		.setkey		= aesbs_cbc_set_key,
316		.encrypt	= aesbs_cbc_encrypt,
317		.decrypt	= aesbs_cbc_decrypt,
318	},
319}, {
320	.cra_name		= "__ctr-aes-neonbs",
321	.cra_driver_name	= "__driver-ctr-aes-neonbs",
322	.cra_priority		= 0,
323	.cra_flags		= CRYPTO_ALG_TYPE_BLKCIPHER |
324				  CRYPTO_ALG_INTERNAL,
325	.cra_blocksize		= 1,
326	.cra_ctxsize		= sizeof(struct aesbs_ctr_ctx),
327	.cra_alignmask		= 7,
328	.cra_type		= &crypto_blkcipher_type,
329	.cra_module		= THIS_MODULE,
330	.cra_blkcipher = {
331		.min_keysize	= AES_MIN_KEY_SIZE,
332		.max_keysize	= AES_MAX_KEY_SIZE,
333		.ivsize		= AES_BLOCK_SIZE,
334		.setkey		= aesbs_ctr_set_key,
335		.encrypt	= aesbs_ctr_encrypt,
336		.decrypt	= aesbs_ctr_encrypt,
337	},
338}, {
339	.cra_name		= "__xts-aes-neonbs",
340	.cra_driver_name	= "__driver-xts-aes-neonbs",
341	.cra_priority		= 0,
342	.cra_flags		= CRYPTO_ALG_TYPE_BLKCIPHER |
343				  CRYPTO_ALG_INTERNAL,
344	.cra_blocksize		= AES_BLOCK_SIZE,
345	.cra_ctxsize		= sizeof(struct aesbs_xts_ctx),
346	.cra_alignmask		= 7,
347	.cra_type		= &crypto_blkcipher_type,
348	.cra_module		= THIS_MODULE,
349	.cra_blkcipher = {
350		.min_keysize	= 2 * AES_MIN_KEY_SIZE,
351		.max_keysize	= 2 * AES_MAX_KEY_SIZE,
352		.ivsize		= AES_BLOCK_SIZE,
353		.setkey		= aesbs_xts_set_key,
354		.encrypt	= aesbs_xts_encrypt,
355		.decrypt	= aesbs_xts_decrypt,
356	},
357}, {
358	.cra_name		= "cbc(aes)",
359	.cra_driver_name	= "cbc-aes-neonbs",
360	.cra_priority		= 300,
361	.cra_flags		= CRYPTO_ALG_TYPE_ABLKCIPHER|CRYPTO_ALG_ASYNC,
362	.cra_blocksize		= AES_BLOCK_SIZE,
363	.cra_ctxsize		= sizeof(struct async_helper_ctx),
364	.cra_alignmask		= 7,
365	.cra_type		= &crypto_ablkcipher_type,
366	.cra_module		= THIS_MODULE,
367	.cra_init		= ablk_init,
368	.cra_exit		= ablk_exit,
369	.cra_ablkcipher = {
370		.min_keysize	= AES_MIN_KEY_SIZE,
371		.max_keysize	= AES_MAX_KEY_SIZE,
372		.ivsize		= AES_BLOCK_SIZE,
373		.setkey		= ablk_set_key,
374		.encrypt	= __ablk_encrypt,
375		.decrypt	= ablk_decrypt,
376	}
377}, {
378	.cra_name		= "ctr(aes)",
379	.cra_driver_name	= "ctr-aes-neonbs",
380	.cra_priority		= 300,
381	.cra_flags		= CRYPTO_ALG_TYPE_ABLKCIPHER|CRYPTO_ALG_ASYNC,
382	.cra_blocksize		= 1,
383	.cra_ctxsize		= sizeof(struct async_helper_ctx),
384	.cra_alignmask		= 7,
385	.cra_type		= &crypto_ablkcipher_type,
386	.cra_module		= THIS_MODULE,
387	.cra_init		= ablk_init,
388	.cra_exit		= ablk_exit,
389	.cra_ablkcipher = {
390		.min_keysize	= AES_MIN_KEY_SIZE,
391		.max_keysize	= AES_MAX_KEY_SIZE,
392		.ivsize		= AES_BLOCK_SIZE,
393		.setkey		= ablk_set_key,
394		.encrypt	= ablk_encrypt,
395		.decrypt	= ablk_decrypt,
396	}
397}, {
398	.cra_name		= "xts(aes)",
399	.cra_driver_name	= "xts-aes-neonbs",
400	.cra_priority		= 300,
401	.cra_flags		= CRYPTO_ALG_TYPE_ABLKCIPHER|CRYPTO_ALG_ASYNC,
402	.cra_blocksize		= AES_BLOCK_SIZE,
403	.cra_ctxsize		= sizeof(struct async_helper_ctx),
404	.cra_alignmask		= 7,
405	.cra_type		= &crypto_ablkcipher_type,
406	.cra_module		= THIS_MODULE,
407	.cra_init		= ablk_init,
408	.cra_exit		= ablk_exit,
409	.cra_ablkcipher = {
410		.min_keysize	= 2 * AES_MIN_KEY_SIZE,
411		.max_keysize	= 2 * AES_MAX_KEY_SIZE,
412		.ivsize		= AES_BLOCK_SIZE,
413		.setkey		= ablk_set_key,
414		.encrypt	= ablk_encrypt,
415		.decrypt	= ablk_decrypt,
416	}
417} };
418
419static int __init aesbs_mod_init(void)
420{
421	if (!cpu_has_neon())
422		return -ENODEV;
423
424	return crypto_register_algs(aesbs_algs, ARRAY_SIZE(aesbs_algs));
425}
426
427static void __exit aesbs_mod_exit(void)
428{
429	crypto_unregister_algs(aesbs_algs, ARRAY_SIZE(aesbs_algs));
430}
431
432module_init(aesbs_mod_init);
433module_exit(aesbs_mod_exit);
434
435MODULE_DESCRIPTION("Bit sliced AES in CBC/CTR/XTS modes using NEON");
436MODULE_AUTHOR("Ard Biesheuvel <ard.biesheuvel@linaro.org>");
437MODULE_LICENSE("GPL");
438