1/*
2 * Cryptographic API.
3 *
4 * s390 implementation of the AES Cipher Algorithm.
5 *
6 * s390 Version:
7 *   Copyright IBM Corp. 2005, 2007
8 *   Author(s): Jan Glauber (jang@de.ibm.com)
9 *		Sebastian Siewior (sebastian@breakpoint.cc> SW-Fallback
10 *
11 * Derived from "crypto/aes_generic.c"
12 *
13 * This program is free software; you can redistribute it and/or modify it
14 * under the terms of the GNU General Public License as published by the Free
15 * Software Foundation; either version 2 of the License, or (at your option)
16 * any later version.
17 *
18 */
19
20#define KMSG_COMPONENT "aes_s390"
21#define pr_fmt(fmt) KMSG_COMPONENT ": " fmt
22
23#include <crypto/aes.h>
24#include <crypto/algapi.h>
25#include <linux/err.h>
26#include <linux/module.h>
27#include <linux/init.h>
28#include <linux/spinlock.h>
29#include "crypt_s390.h"
30
31#define AES_KEYLEN_128		1
32#define AES_KEYLEN_192		2
33#define AES_KEYLEN_256		4
34
35static u8 *ctrblk;
36static DEFINE_SPINLOCK(ctrblk_lock);
37static char keylen_flag;
38
39struct s390_aes_ctx {
40	u8 key[AES_MAX_KEY_SIZE];
41	long enc;
42	long dec;
43	int key_len;
44	union {
45		struct crypto_blkcipher *blk;
46		struct crypto_cipher *cip;
47	} fallback;
48};
49
50struct pcc_param {
51	u8 key[32];
52	u8 tweak[16];
53	u8 block[16];
54	u8 bit[16];
55	u8 xts[16];
56};
57
58struct s390_xts_ctx {
59	u8 key[32];
60	u8 pcc_key[32];
61	long enc;
62	long dec;
63	int key_len;
64	struct crypto_blkcipher *fallback;
65};
66
67/*
68 * Check if the key_len is supported by the HW.
69 * Returns 0 if it is, a positive number if it is not and software fallback is
70 * required or a negative number in case the key size is not valid
71 */
72static int need_fallback(unsigned int key_len)
73{
74	switch (key_len) {
75	case 16:
76		if (!(keylen_flag & AES_KEYLEN_128))
77			return 1;
78		break;
79	case 24:
80		if (!(keylen_flag & AES_KEYLEN_192))
81			return 1;
82		break;
83	case 32:
84		if (!(keylen_flag & AES_KEYLEN_256))
85			return 1;
86		break;
87	default:
88		return -1;
89		break;
90	}
91	return 0;
92}
93
94static int setkey_fallback_cip(struct crypto_tfm *tfm, const u8 *in_key,
95		unsigned int key_len)
96{
97	struct s390_aes_ctx *sctx = crypto_tfm_ctx(tfm);
98	int ret;
99
100	sctx->fallback.cip->base.crt_flags &= ~CRYPTO_TFM_REQ_MASK;
101	sctx->fallback.cip->base.crt_flags |= (tfm->crt_flags &
102			CRYPTO_TFM_REQ_MASK);
103
104	ret = crypto_cipher_setkey(sctx->fallback.cip, in_key, key_len);
105	if (ret) {
106		tfm->crt_flags &= ~CRYPTO_TFM_RES_MASK;
107		tfm->crt_flags |= (sctx->fallback.cip->base.crt_flags &
108				CRYPTO_TFM_RES_MASK);
109	}
110	return ret;
111}
112
113static int aes_set_key(struct crypto_tfm *tfm, const u8 *in_key,
114		       unsigned int key_len)
115{
116	struct s390_aes_ctx *sctx = crypto_tfm_ctx(tfm);
117	u32 *flags = &tfm->crt_flags;
118	int ret;
119
120	ret = need_fallback(key_len);
121	if (ret < 0) {
122		*flags |= CRYPTO_TFM_RES_BAD_KEY_LEN;
123		return -EINVAL;
124	}
125
126	sctx->key_len = key_len;
127	if (!ret) {
128		memcpy(sctx->key, in_key, key_len);
129		return 0;
130	}
131
132	return setkey_fallback_cip(tfm, in_key, key_len);
133}
134
135static void aes_encrypt(struct crypto_tfm *tfm, u8 *out, const u8 *in)
136{
137	struct s390_aes_ctx *sctx = crypto_tfm_ctx(tfm);
138
139	if (unlikely(need_fallback(sctx->key_len))) {
140		crypto_cipher_encrypt_one(sctx->fallback.cip, out, in);
141		return;
142	}
143
144	switch (sctx->key_len) {
145	case 16:
146		crypt_s390_km(KM_AES_128_ENCRYPT, &sctx->key, out, in,
147			      AES_BLOCK_SIZE);
148		break;
149	case 24:
150		crypt_s390_km(KM_AES_192_ENCRYPT, &sctx->key, out, in,
151			      AES_BLOCK_SIZE);
152		break;
153	case 32:
154		crypt_s390_km(KM_AES_256_ENCRYPT, &sctx->key, out, in,
155			      AES_BLOCK_SIZE);
156		break;
157	}
158}
159
160static void aes_decrypt(struct crypto_tfm *tfm, u8 *out, const u8 *in)
161{
162	struct s390_aes_ctx *sctx = crypto_tfm_ctx(tfm);
163
164	if (unlikely(need_fallback(sctx->key_len))) {
165		crypto_cipher_decrypt_one(sctx->fallback.cip, out, in);
166		return;
167	}
168
169	switch (sctx->key_len) {
170	case 16:
171		crypt_s390_km(KM_AES_128_DECRYPT, &sctx->key, out, in,
172			      AES_BLOCK_SIZE);
173		break;
174	case 24:
175		crypt_s390_km(KM_AES_192_DECRYPT, &sctx->key, out, in,
176			      AES_BLOCK_SIZE);
177		break;
178	case 32:
179		crypt_s390_km(KM_AES_256_DECRYPT, &sctx->key, out, in,
180			      AES_BLOCK_SIZE);
181		break;
182	}
183}
184
185static int fallback_init_cip(struct crypto_tfm *tfm)
186{
187	const char *name = tfm->__crt_alg->cra_name;
188	struct s390_aes_ctx *sctx = crypto_tfm_ctx(tfm);
189
190	sctx->fallback.cip = crypto_alloc_cipher(name, 0,
191			CRYPTO_ALG_ASYNC | CRYPTO_ALG_NEED_FALLBACK);
192
193	if (IS_ERR(sctx->fallback.cip)) {
194		pr_err("Allocating AES fallback algorithm %s failed\n",
195		       name);
196		return PTR_ERR(sctx->fallback.cip);
197	}
198
199	return 0;
200}
201
202static void fallback_exit_cip(struct crypto_tfm *tfm)
203{
204	struct s390_aes_ctx *sctx = crypto_tfm_ctx(tfm);
205
206	crypto_free_cipher(sctx->fallback.cip);
207	sctx->fallback.cip = NULL;
208}
209
210static struct crypto_alg aes_alg = {
211	.cra_name		=	"aes",
212	.cra_driver_name	=	"aes-s390",
213	.cra_priority		=	CRYPT_S390_PRIORITY,
214	.cra_flags		=	CRYPTO_ALG_TYPE_CIPHER |
215					CRYPTO_ALG_NEED_FALLBACK,
216	.cra_blocksize		=	AES_BLOCK_SIZE,
217	.cra_ctxsize		=	sizeof(struct s390_aes_ctx),
218	.cra_module		=	THIS_MODULE,
219	.cra_init               =       fallback_init_cip,
220	.cra_exit               =       fallback_exit_cip,
221	.cra_u			=	{
222		.cipher = {
223			.cia_min_keysize	=	AES_MIN_KEY_SIZE,
224			.cia_max_keysize	=	AES_MAX_KEY_SIZE,
225			.cia_setkey		=	aes_set_key,
226			.cia_encrypt		=	aes_encrypt,
227			.cia_decrypt		=	aes_decrypt,
228		}
229	}
230};
231
232static int setkey_fallback_blk(struct crypto_tfm *tfm, const u8 *key,
233		unsigned int len)
234{
235	struct s390_aes_ctx *sctx = crypto_tfm_ctx(tfm);
236	unsigned int ret;
237
238	sctx->fallback.blk->base.crt_flags &= ~CRYPTO_TFM_REQ_MASK;
239	sctx->fallback.blk->base.crt_flags |= (tfm->crt_flags &
240			CRYPTO_TFM_REQ_MASK);
241
242	ret = crypto_blkcipher_setkey(sctx->fallback.blk, key, len);
243	if (ret) {
244		tfm->crt_flags &= ~CRYPTO_TFM_RES_MASK;
245		tfm->crt_flags |= (sctx->fallback.blk->base.crt_flags &
246				CRYPTO_TFM_RES_MASK);
247	}
248	return ret;
249}
250
251static int fallback_blk_dec(struct blkcipher_desc *desc,
252		struct scatterlist *dst, struct scatterlist *src,
253		unsigned int nbytes)
254{
255	unsigned int ret;
256	struct crypto_blkcipher *tfm;
257	struct s390_aes_ctx *sctx = crypto_blkcipher_ctx(desc->tfm);
258
259	tfm = desc->tfm;
260	desc->tfm = sctx->fallback.blk;
261
262	ret = crypto_blkcipher_decrypt_iv(desc, dst, src, nbytes);
263
264	desc->tfm = tfm;
265	return ret;
266}
267
268static int fallback_blk_enc(struct blkcipher_desc *desc,
269		struct scatterlist *dst, struct scatterlist *src,
270		unsigned int nbytes)
271{
272	unsigned int ret;
273	struct crypto_blkcipher *tfm;
274	struct s390_aes_ctx *sctx = crypto_blkcipher_ctx(desc->tfm);
275
276	tfm = desc->tfm;
277	desc->tfm = sctx->fallback.blk;
278
279	ret = crypto_blkcipher_encrypt_iv(desc, dst, src, nbytes);
280
281	desc->tfm = tfm;
282	return ret;
283}
284
285static int ecb_aes_set_key(struct crypto_tfm *tfm, const u8 *in_key,
286			   unsigned int key_len)
287{
288	struct s390_aes_ctx *sctx = crypto_tfm_ctx(tfm);
289	int ret;
290
291	ret = need_fallback(key_len);
292	if (ret > 0) {
293		sctx->key_len = key_len;
294		return setkey_fallback_blk(tfm, in_key, key_len);
295	}
296
297	switch (key_len) {
298	case 16:
299		sctx->enc = KM_AES_128_ENCRYPT;
300		sctx->dec = KM_AES_128_DECRYPT;
301		break;
302	case 24:
303		sctx->enc = KM_AES_192_ENCRYPT;
304		sctx->dec = KM_AES_192_DECRYPT;
305		break;
306	case 32:
307		sctx->enc = KM_AES_256_ENCRYPT;
308		sctx->dec = KM_AES_256_DECRYPT;
309		break;
310	}
311
312	return aes_set_key(tfm, in_key, key_len);
313}
314
315static int ecb_aes_crypt(struct blkcipher_desc *desc, long func, void *param,
316			 struct blkcipher_walk *walk)
317{
318	int ret = blkcipher_walk_virt(desc, walk);
319	unsigned int nbytes;
320
321	while ((nbytes = walk->nbytes)) {
322		/* only use complete blocks */
323		unsigned int n = nbytes & ~(AES_BLOCK_SIZE - 1);
324		u8 *out = walk->dst.virt.addr;
325		u8 *in = walk->src.virt.addr;
326
327		ret = crypt_s390_km(func, param, out, in, n);
328		if (ret < 0 || ret != n)
329			return -EIO;
330
331		nbytes &= AES_BLOCK_SIZE - 1;
332		ret = blkcipher_walk_done(desc, walk, nbytes);
333	}
334
335	return ret;
336}
337
338static int ecb_aes_encrypt(struct blkcipher_desc *desc,
339			   struct scatterlist *dst, struct scatterlist *src,
340			   unsigned int nbytes)
341{
342	struct s390_aes_ctx *sctx = crypto_blkcipher_ctx(desc->tfm);
343	struct blkcipher_walk walk;
344
345	if (unlikely(need_fallback(sctx->key_len)))
346		return fallback_blk_enc(desc, dst, src, nbytes);
347
348	blkcipher_walk_init(&walk, dst, src, nbytes);
349	return ecb_aes_crypt(desc, sctx->enc, sctx->key, &walk);
350}
351
352static int ecb_aes_decrypt(struct blkcipher_desc *desc,
353			   struct scatterlist *dst, struct scatterlist *src,
354			   unsigned int nbytes)
355{
356	struct s390_aes_ctx *sctx = crypto_blkcipher_ctx(desc->tfm);
357	struct blkcipher_walk walk;
358
359	if (unlikely(need_fallback(sctx->key_len)))
360		return fallback_blk_dec(desc, dst, src, nbytes);
361
362	blkcipher_walk_init(&walk, dst, src, nbytes);
363	return ecb_aes_crypt(desc, sctx->dec, sctx->key, &walk);
364}
365
366static int fallback_init_blk(struct crypto_tfm *tfm)
367{
368	const char *name = tfm->__crt_alg->cra_name;
369	struct s390_aes_ctx *sctx = crypto_tfm_ctx(tfm);
370
371	sctx->fallback.blk = crypto_alloc_blkcipher(name, 0,
372			CRYPTO_ALG_ASYNC | CRYPTO_ALG_NEED_FALLBACK);
373
374	if (IS_ERR(sctx->fallback.blk)) {
375		pr_err("Allocating AES fallback algorithm %s failed\n",
376		       name);
377		return PTR_ERR(sctx->fallback.blk);
378	}
379
380	return 0;
381}
382
383static void fallback_exit_blk(struct crypto_tfm *tfm)
384{
385	struct s390_aes_ctx *sctx = crypto_tfm_ctx(tfm);
386
387	crypto_free_blkcipher(sctx->fallback.blk);
388	sctx->fallback.blk = NULL;
389}
390
391static struct crypto_alg ecb_aes_alg = {
392	.cra_name		=	"ecb(aes)",
393	.cra_driver_name	=	"ecb-aes-s390",
394	.cra_priority		=	CRYPT_S390_COMPOSITE_PRIORITY,
395	.cra_flags		=	CRYPTO_ALG_TYPE_BLKCIPHER |
396					CRYPTO_ALG_NEED_FALLBACK,
397	.cra_blocksize		=	AES_BLOCK_SIZE,
398	.cra_ctxsize		=	sizeof(struct s390_aes_ctx),
399	.cra_type		=	&crypto_blkcipher_type,
400	.cra_module		=	THIS_MODULE,
401	.cra_init		=	fallback_init_blk,
402	.cra_exit		=	fallback_exit_blk,
403	.cra_u			=	{
404		.blkcipher = {
405			.min_keysize		=	AES_MIN_KEY_SIZE,
406			.max_keysize		=	AES_MAX_KEY_SIZE,
407			.setkey			=	ecb_aes_set_key,
408			.encrypt		=	ecb_aes_encrypt,
409			.decrypt		=	ecb_aes_decrypt,
410		}
411	}
412};
413
414static int cbc_aes_set_key(struct crypto_tfm *tfm, const u8 *in_key,
415			   unsigned int key_len)
416{
417	struct s390_aes_ctx *sctx = crypto_tfm_ctx(tfm);
418	int ret;
419
420	ret = need_fallback(key_len);
421	if (ret > 0) {
422		sctx->key_len = key_len;
423		return setkey_fallback_blk(tfm, in_key, key_len);
424	}
425
426	switch (key_len) {
427	case 16:
428		sctx->enc = KMC_AES_128_ENCRYPT;
429		sctx->dec = KMC_AES_128_DECRYPT;
430		break;
431	case 24:
432		sctx->enc = KMC_AES_192_ENCRYPT;
433		sctx->dec = KMC_AES_192_DECRYPT;
434		break;
435	case 32:
436		sctx->enc = KMC_AES_256_ENCRYPT;
437		sctx->dec = KMC_AES_256_DECRYPT;
438		break;
439	}
440
441	return aes_set_key(tfm, in_key, key_len);
442}
443
444static int cbc_aes_crypt(struct blkcipher_desc *desc, long func,
445			 struct blkcipher_walk *walk)
446{
447	struct s390_aes_ctx *sctx = crypto_blkcipher_ctx(desc->tfm);
448	int ret = blkcipher_walk_virt(desc, walk);
449	unsigned int nbytes = walk->nbytes;
450	struct {
451		u8 iv[AES_BLOCK_SIZE];
452		u8 key[AES_MAX_KEY_SIZE];
453	} param;
454
455	if (!nbytes)
456		goto out;
457
458	memcpy(param.iv, walk->iv, AES_BLOCK_SIZE);
459	memcpy(param.key, sctx->key, sctx->key_len);
460	do {
461		/* only use complete blocks */
462		unsigned int n = nbytes & ~(AES_BLOCK_SIZE - 1);
463		u8 *out = walk->dst.virt.addr;
464		u8 *in = walk->src.virt.addr;
465
466		ret = crypt_s390_kmc(func, &param, out, in, n);
467		if (ret < 0 || ret != n)
468			return -EIO;
469
470		nbytes &= AES_BLOCK_SIZE - 1;
471		ret = blkcipher_walk_done(desc, walk, nbytes);
472	} while ((nbytes = walk->nbytes));
473	memcpy(walk->iv, param.iv, AES_BLOCK_SIZE);
474
475out:
476	return ret;
477}
478
479static int cbc_aes_encrypt(struct blkcipher_desc *desc,
480			   struct scatterlist *dst, struct scatterlist *src,
481			   unsigned int nbytes)
482{
483	struct s390_aes_ctx *sctx = crypto_blkcipher_ctx(desc->tfm);
484	struct blkcipher_walk walk;
485
486	if (unlikely(need_fallback(sctx->key_len)))
487		return fallback_blk_enc(desc, dst, src, nbytes);
488
489	blkcipher_walk_init(&walk, dst, src, nbytes);
490	return cbc_aes_crypt(desc, sctx->enc, &walk);
491}
492
493static int cbc_aes_decrypt(struct blkcipher_desc *desc,
494			   struct scatterlist *dst, struct scatterlist *src,
495			   unsigned int nbytes)
496{
497	struct s390_aes_ctx *sctx = crypto_blkcipher_ctx(desc->tfm);
498	struct blkcipher_walk walk;
499
500	if (unlikely(need_fallback(sctx->key_len)))
501		return fallback_blk_dec(desc, dst, src, nbytes);
502
503	blkcipher_walk_init(&walk, dst, src, nbytes);
504	return cbc_aes_crypt(desc, sctx->dec, &walk);
505}
506
507static struct crypto_alg cbc_aes_alg = {
508	.cra_name		=	"cbc(aes)",
509	.cra_driver_name	=	"cbc-aes-s390",
510	.cra_priority		=	CRYPT_S390_COMPOSITE_PRIORITY,
511	.cra_flags		=	CRYPTO_ALG_TYPE_BLKCIPHER |
512					CRYPTO_ALG_NEED_FALLBACK,
513	.cra_blocksize		=	AES_BLOCK_SIZE,
514	.cra_ctxsize		=	sizeof(struct s390_aes_ctx),
515	.cra_type		=	&crypto_blkcipher_type,
516	.cra_module		=	THIS_MODULE,
517	.cra_init		=	fallback_init_blk,
518	.cra_exit		=	fallback_exit_blk,
519	.cra_u			=	{
520		.blkcipher = {
521			.min_keysize		=	AES_MIN_KEY_SIZE,
522			.max_keysize		=	AES_MAX_KEY_SIZE,
523			.ivsize			=	AES_BLOCK_SIZE,
524			.setkey			=	cbc_aes_set_key,
525			.encrypt		=	cbc_aes_encrypt,
526			.decrypt		=	cbc_aes_decrypt,
527		}
528	}
529};
530
531static int xts_fallback_setkey(struct crypto_tfm *tfm, const u8 *key,
532				   unsigned int len)
533{
534	struct s390_xts_ctx *xts_ctx = crypto_tfm_ctx(tfm);
535	unsigned int ret;
536
537	xts_ctx->fallback->base.crt_flags &= ~CRYPTO_TFM_REQ_MASK;
538	xts_ctx->fallback->base.crt_flags |= (tfm->crt_flags &
539			CRYPTO_TFM_REQ_MASK);
540
541	ret = crypto_blkcipher_setkey(xts_ctx->fallback, key, len);
542	if (ret) {
543		tfm->crt_flags &= ~CRYPTO_TFM_RES_MASK;
544		tfm->crt_flags |= (xts_ctx->fallback->base.crt_flags &
545				CRYPTO_TFM_RES_MASK);
546	}
547	return ret;
548}
549
550static int xts_fallback_decrypt(struct blkcipher_desc *desc,
551		struct scatterlist *dst, struct scatterlist *src,
552		unsigned int nbytes)
553{
554	struct s390_xts_ctx *xts_ctx = crypto_blkcipher_ctx(desc->tfm);
555	struct crypto_blkcipher *tfm;
556	unsigned int ret;
557
558	tfm = desc->tfm;
559	desc->tfm = xts_ctx->fallback;
560
561	ret = crypto_blkcipher_decrypt_iv(desc, dst, src, nbytes);
562
563	desc->tfm = tfm;
564	return ret;
565}
566
567static int xts_fallback_encrypt(struct blkcipher_desc *desc,
568		struct scatterlist *dst, struct scatterlist *src,
569		unsigned int nbytes)
570{
571	struct s390_xts_ctx *xts_ctx = crypto_blkcipher_ctx(desc->tfm);
572	struct crypto_blkcipher *tfm;
573	unsigned int ret;
574
575	tfm = desc->tfm;
576	desc->tfm = xts_ctx->fallback;
577
578	ret = crypto_blkcipher_encrypt_iv(desc, dst, src, nbytes);
579
580	desc->tfm = tfm;
581	return ret;
582}
583
584static int xts_aes_set_key(struct crypto_tfm *tfm, const u8 *in_key,
585			   unsigned int key_len)
586{
587	struct s390_xts_ctx *xts_ctx = crypto_tfm_ctx(tfm);
588	u32 *flags = &tfm->crt_flags;
589
590	switch (key_len) {
591	case 32:
592		xts_ctx->enc = KM_XTS_128_ENCRYPT;
593		xts_ctx->dec = KM_XTS_128_DECRYPT;
594		memcpy(xts_ctx->key + 16, in_key, 16);
595		memcpy(xts_ctx->pcc_key + 16, in_key + 16, 16);
596		break;
597	case 48:
598		xts_ctx->enc = 0;
599		xts_ctx->dec = 0;
600		xts_fallback_setkey(tfm, in_key, key_len);
601		break;
602	case 64:
603		xts_ctx->enc = KM_XTS_256_ENCRYPT;
604		xts_ctx->dec = KM_XTS_256_DECRYPT;
605		memcpy(xts_ctx->key, in_key, 32);
606		memcpy(xts_ctx->pcc_key, in_key + 32, 32);
607		break;
608	default:
609		*flags |= CRYPTO_TFM_RES_BAD_KEY_LEN;
610		return -EINVAL;
611	}
612	xts_ctx->key_len = key_len;
613	return 0;
614}
615
616static int xts_aes_crypt(struct blkcipher_desc *desc, long func,
617			 struct s390_xts_ctx *xts_ctx,
618			 struct blkcipher_walk *walk)
619{
620	unsigned int offset = (xts_ctx->key_len >> 1) & 0x10;
621	int ret = blkcipher_walk_virt(desc, walk);
622	unsigned int nbytes = walk->nbytes;
623	unsigned int n;
624	u8 *in, *out;
625	struct pcc_param pcc_param;
626	struct {
627		u8 key[32];
628		u8 init[16];
629	} xts_param;
630
631	if (!nbytes)
632		goto out;
633
634	memset(pcc_param.block, 0, sizeof(pcc_param.block));
635	memset(pcc_param.bit, 0, sizeof(pcc_param.bit));
636	memset(pcc_param.xts, 0, sizeof(pcc_param.xts));
637	memcpy(pcc_param.tweak, walk->iv, sizeof(pcc_param.tweak));
638	memcpy(pcc_param.key, xts_ctx->pcc_key, 32);
639	ret = crypt_s390_pcc(func, &pcc_param.key[offset]);
640	if (ret < 0)
641		return -EIO;
642
643	memcpy(xts_param.key, xts_ctx->key, 32);
644	memcpy(xts_param.init, pcc_param.xts, 16);
645	do {
646		/* only use complete blocks */
647		n = nbytes & ~(AES_BLOCK_SIZE - 1);
648		out = walk->dst.virt.addr;
649		in = walk->src.virt.addr;
650
651		ret = crypt_s390_km(func, &xts_param.key[offset], out, in, n);
652		if (ret < 0 || ret != n)
653			return -EIO;
654
655		nbytes &= AES_BLOCK_SIZE - 1;
656		ret = blkcipher_walk_done(desc, walk, nbytes);
657	} while ((nbytes = walk->nbytes));
658out:
659	return ret;
660}
661
662static int xts_aes_encrypt(struct blkcipher_desc *desc,
663			   struct scatterlist *dst, struct scatterlist *src,
664			   unsigned int nbytes)
665{
666	struct s390_xts_ctx *xts_ctx = crypto_blkcipher_ctx(desc->tfm);
667	struct blkcipher_walk walk;
668
669	if (unlikely(xts_ctx->key_len == 48))
670		return xts_fallback_encrypt(desc, dst, src, nbytes);
671
672	blkcipher_walk_init(&walk, dst, src, nbytes);
673	return xts_aes_crypt(desc, xts_ctx->enc, xts_ctx, &walk);
674}
675
676static int xts_aes_decrypt(struct blkcipher_desc *desc,
677			   struct scatterlist *dst, struct scatterlist *src,
678			   unsigned int nbytes)
679{
680	struct s390_xts_ctx *xts_ctx = crypto_blkcipher_ctx(desc->tfm);
681	struct blkcipher_walk walk;
682
683	if (unlikely(xts_ctx->key_len == 48))
684		return xts_fallback_decrypt(desc, dst, src, nbytes);
685
686	blkcipher_walk_init(&walk, dst, src, nbytes);
687	return xts_aes_crypt(desc, xts_ctx->dec, xts_ctx, &walk);
688}
689
690static int xts_fallback_init(struct crypto_tfm *tfm)
691{
692	const char *name = tfm->__crt_alg->cra_name;
693	struct s390_xts_ctx *xts_ctx = crypto_tfm_ctx(tfm);
694
695	xts_ctx->fallback = crypto_alloc_blkcipher(name, 0,
696			CRYPTO_ALG_ASYNC | CRYPTO_ALG_NEED_FALLBACK);
697
698	if (IS_ERR(xts_ctx->fallback)) {
699		pr_err("Allocating XTS fallback algorithm %s failed\n",
700		       name);
701		return PTR_ERR(xts_ctx->fallback);
702	}
703	return 0;
704}
705
706static void xts_fallback_exit(struct crypto_tfm *tfm)
707{
708	struct s390_xts_ctx *xts_ctx = crypto_tfm_ctx(tfm);
709
710	crypto_free_blkcipher(xts_ctx->fallback);
711	xts_ctx->fallback = NULL;
712}
713
714static struct crypto_alg xts_aes_alg = {
715	.cra_name		=	"xts(aes)",
716	.cra_driver_name	=	"xts-aes-s390",
717	.cra_priority		=	CRYPT_S390_COMPOSITE_PRIORITY,
718	.cra_flags		=	CRYPTO_ALG_TYPE_BLKCIPHER |
719					CRYPTO_ALG_NEED_FALLBACK,
720	.cra_blocksize		=	AES_BLOCK_SIZE,
721	.cra_ctxsize		=	sizeof(struct s390_xts_ctx),
722	.cra_type		=	&crypto_blkcipher_type,
723	.cra_module		=	THIS_MODULE,
724	.cra_init		=	xts_fallback_init,
725	.cra_exit		=	xts_fallback_exit,
726	.cra_u			=	{
727		.blkcipher = {
728			.min_keysize		=	2 * AES_MIN_KEY_SIZE,
729			.max_keysize		=	2 * AES_MAX_KEY_SIZE,
730			.ivsize			=	AES_BLOCK_SIZE,
731			.setkey			=	xts_aes_set_key,
732			.encrypt		=	xts_aes_encrypt,
733			.decrypt		=	xts_aes_decrypt,
734		}
735	}
736};
737
738static int xts_aes_alg_reg;
739
740static int ctr_aes_set_key(struct crypto_tfm *tfm, const u8 *in_key,
741			   unsigned int key_len)
742{
743	struct s390_aes_ctx *sctx = crypto_tfm_ctx(tfm);
744
745	switch (key_len) {
746	case 16:
747		sctx->enc = KMCTR_AES_128_ENCRYPT;
748		sctx->dec = KMCTR_AES_128_DECRYPT;
749		break;
750	case 24:
751		sctx->enc = KMCTR_AES_192_ENCRYPT;
752		sctx->dec = KMCTR_AES_192_DECRYPT;
753		break;
754	case 32:
755		sctx->enc = KMCTR_AES_256_ENCRYPT;
756		sctx->dec = KMCTR_AES_256_DECRYPT;
757		break;
758	}
759
760	return aes_set_key(tfm, in_key, key_len);
761}
762
763static unsigned int __ctrblk_init(u8 *ctrptr, unsigned int nbytes)
764{
765	unsigned int i, n;
766
767	/* only use complete blocks, max. PAGE_SIZE */
768	n = (nbytes > PAGE_SIZE) ? PAGE_SIZE : nbytes & ~(AES_BLOCK_SIZE - 1);
769	for (i = AES_BLOCK_SIZE; i < n; i += AES_BLOCK_SIZE) {
770		memcpy(ctrptr + i, ctrptr + i - AES_BLOCK_SIZE,
771		       AES_BLOCK_SIZE);
772		crypto_inc(ctrptr + i, AES_BLOCK_SIZE);
773	}
774	return n;
775}
776
777static int ctr_aes_crypt(struct blkcipher_desc *desc, long func,
778			 struct s390_aes_ctx *sctx, struct blkcipher_walk *walk)
779{
780	int ret = blkcipher_walk_virt_block(desc, walk, AES_BLOCK_SIZE);
781	unsigned int n, nbytes;
782	u8 buf[AES_BLOCK_SIZE], ctrbuf[AES_BLOCK_SIZE];
783	u8 *out, *in, *ctrptr = ctrbuf;
784
785	if (!walk->nbytes)
786		return ret;
787
788	if (spin_trylock(&ctrblk_lock))
789		ctrptr = ctrblk;
790
791	memcpy(ctrptr, walk->iv, AES_BLOCK_SIZE);
792	while ((nbytes = walk->nbytes) >= AES_BLOCK_SIZE) {
793		out = walk->dst.virt.addr;
794		in = walk->src.virt.addr;
795		while (nbytes >= AES_BLOCK_SIZE) {
796			if (ctrptr == ctrblk)
797				n = __ctrblk_init(ctrptr, nbytes);
798			else
799				n = AES_BLOCK_SIZE;
800			ret = crypt_s390_kmctr(func, sctx->key, out, in,
801					       n, ctrptr);
802			if (ret < 0 || ret != n) {
803				if (ctrptr == ctrblk)
804					spin_unlock(&ctrblk_lock);
805				return -EIO;
806			}
807			if (n > AES_BLOCK_SIZE)
808				memcpy(ctrptr, ctrptr + n - AES_BLOCK_SIZE,
809				       AES_BLOCK_SIZE);
810			crypto_inc(ctrptr, AES_BLOCK_SIZE);
811			out += n;
812			in += n;
813			nbytes -= n;
814		}
815		ret = blkcipher_walk_done(desc, walk, nbytes);
816	}
817	if (ctrptr == ctrblk) {
818		if (nbytes)
819			memcpy(ctrbuf, ctrptr, AES_BLOCK_SIZE);
820		else
821			memcpy(walk->iv, ctrptr, AES_BLOCK_SIZE);
822		spin_unlock(&ctrblk_lock);
823	} else {
824		if (!nbytes)
825			memcpy(walk->iv, ctrptr, AES_BLOCK_SIZE);
826	}
827	/*
828	 * final block may be < AES_BLOCK_SIZE, copy only nbytes
829	 */
830	if (nbytes) {
831		out = walk->dst.virt.addr;
832		in = walk->src.virt.addr;
833		ret = crypt_s390_kmctr(func, sctx->key, buf, in,
834				       AES_BLOCK_SIZE, ctrbuf);
835		if (ret < 0 || ret != AES_BLOCK_SIZE)
836			return -EIO;
837		memcpy(out, buf, nbytes);
838		crypto_inc(ctrbuf, AES_BLOCK_SIZE);
839		ret = blkcipher_walk_done(desc, walk, 0);
840		memcpy(walk->iv, ctrbuf, AES_BLOCK_SIZE);
841	}
842
843	return ret;
844}
845
846static int ctr_aes_encrypt(struct blkcipher_desc *desc,
847			   struct scatterlist *dst, struct scatterlist *src,
848			   unsigned int nbytes)
849{
850	struct s390_aes_ctx *sctx = crypto_blkcipher_ctx(desc->tfm);
851	struct blkcipher_walk walk;
852
853	blkcipher_walk_init(&walk, dst, src, nbytes);
854	return ctr_aes_crypt(desc, sctx->enc, sctx, &walk);
855}
856
857static int ctr_aes_decrypt(struct blkcipher_desc *desc,
858			   struct scatterlist *dst, struct scatterlist *src,
859			   unsigned int nbytes)
860{
861	struct s390_aes_ctx *sctx = crypto_blkcipher_ctx(desc->tfm);
862	struct blkcipher_walk walk;
863
864	blkcipher_walk_init(&walk, dst, src, nbytes);
865	return ctr_aes_crypt(desc, sctx->dec, sctx, &walk);
866}
867
868static struct crypto_alg ctr_aes_alg = {
869	.cra_name		=	"ctr(aes)",
870	.cra_driver_name	=	"ctr-aes-s390",
871	.cra_priority		=	CRYPT_S390_COMPOSITE_PRIORITY,
872	.cra_flags		=	CRYPTO_ALG_TYPE_BLKCIPHER,
873	.cra_blocksize		=	1,
874	.cra_ctxsize		=	sizeof(struct s390_aes_ctx),
875	.cra_type		=	&crypto_blkcipher_type,
876	.cra_module		=	THIS_MODULE,
877	.cra_u			=	{
878		.blkcipher = {
879			.min_keysize		=	AES_MIN_KEY_SIZE,
880			.max_keysize		=	AES_MAX_KEY_SIZE,
881			.ivsize			=	AES_BLOCK_SIZE,
882			.setkey			=	ctr_aes_set_key,
883			.encrypt		=	ctr_aes_encrypt,
884			.decrypt		=	ctr_aes_decrypt,
885		}
886	}
887};
888
889static int ctr_aes_alg_reg;
890
891static int __init aes_s390_init(void)
892{
893	int ret;
894
895	if (crypt_s390_func_available(KM_AES_128_ENCRYPT, CRYPT_S390_MSA))
896		keylen_flag |= AES_KEYLEN_128;
897	if (crypt_s390_func_available(KM_AES_192_ENCRYPT, CRYPT_S390_MSA))
898		keylen_flag |= AES_KEYLEN_192;
899	if (crypt_s390_func_available(KM_AES_256_ENCRYPT, CRYPT_S390_MSA))
900		keylen_flag |= AES_KEYLEN_256;
901
902	if (!keylen_flag)
903		return -EOPNOTSUPP;
904
905	/* z9 109 and z9 BC/EC only support 128 bit key length */
906	if (keylen_flag == AES_KEYLEN_128)
907		pr_info("AES hardware acceleration is only available for"
908			" 128-bit keys\n");
909
910	ret = crypto_register_alg(&aes_alg);
911	if (ret)
912		goto aes_err;
913
914	ret = crypto_register_alg(&ecb_aes_alg);
915	if (ret)
916		goto ecb_aes_err;
917
918	ret = crypto_register_alg(&cbc_aes_alg);
919	if (ret)
920		goto cbc_aes_err;
921
922	if (crypt_s390_func_available(KM_XTS_128_ENCRYPT,
923			CRYPT_S390_MSA | CRYPT_S390_MSA4) &&
924	    crypt_s390_func_available(KM_XTS_256_ENCRYPT,
925			CRYPT_S390_MSA | CRYPT_S390_MSA4)) {
926		ret = crypto_register_alg(&xts_aes_alg);
927		if (ret)
928			goto xts_aes_err;
929		xts_aes_alg_reg = 1;
930	}
931
932	if (crypt_s390_func_available(KMCTR_AES_128_ENCRYPT,
933				CRYPT_S390_MSA | CRYPT_S390_MSA4) &&
934	    crypt_s390_func_available(KMCTR_AES_192_ENCRYPT,
935				CRYPT_S390_MSA | CRYPT_S390_MSA4) &&
936	    crypt_s390_func_available(KMCTR_AES_256_ENCRYPT,
937				CRYPT_S390_MSA | CRYPT_S390_MSA4)) {
938		ctrblk = (u8 *) __get_free_page(GFP_KERNEL);
939		if (!ctrblk) {
940			ret = -ENOMEM;
941			goto ctr_aes_err;
942		}
943		ret = crypto_register_alg(&ctr_aes_alg);
944		if (ret) {
945			free_page((unsigned long) ctrblk);
946			goto ctr_aes_err;
947		}
948		ctr_aes_alg_reg = 1;
949	}
950
951out:
952	return ret;
953
954ctr_aes_err:
955	crypto_unregister_alg(&xts_aes_alg);
956xts_aes_err:
957	crypto_unregister_alg(&cbc_aes_alg);
958cbc_aes_err:
959	crypto_unregister_alg(&ecb_aes_alg);
960ecb_aes_err:
961	crypto_unregister_alg(&aes_alg);
962aes_err:
963	goto out;
964}
965
966static void __exit aes_s390_fini(void)
967{
968	if (ctr_aes_alg_reg) {
969		crypto_unregister_alg(&ctr_aes_alg);
970		free_page((unsigned long) ctrblk);
971	}
972	if (xts_aes_alg_reg)
973		crypto_unregister_alg(&xts_aes_alg);
974	crypto_unregister_alg(&cbc_aes_alg);
975	crypto_unregister_alg(&ecb_aes_alg);
976	crypto_unregister_alg(&aes_alg);
977}
978
979module_init(aes_s390_init);
980module_exit(aes_s390_fini);
981
982MODULE_ALIAS_CRYPTO("aes-all");
983
984MODULE_DESCRIPTION("Rijndael (AES) Cipher Algorithm");
985MODULE_LICENSE("GPL");
986