1  // SPDX-License-Identifier: GPL-2.0-or-later
2  /*
3   * Symmetric key cipher operations.
4   *
5   * Generic encrypt/decrypt wrapper for ciphers, handles operations across
6   * multiple page boundaries by using temporary blocks.  In user context,
7   * the kernel is given a chance to schedule us once per page.
8   *
9   * Copyright (c) 2015 Herbert Xu <herbert@gondor.apana.org.au>
10   */
11  
12  #include <crypto/internal/aead.h>
13  #include <crypto/internal/cipher.h>
14  #include <crypto/internal/skcipher.h>
15  #include <crypto/scatterwalk.h>
16  #include <linux/bug.h>
17  #include <linux/cryptouser.h>
18  #include <linux/err.h>
19  #include <linux/kernel.h>
20  #include <linux/list.h>
21  #include <linux/mm.h>
22  #include <linux/module.h>
23  #include <linux/seq_file.h>
24  #include <linux/slab.h>
25  #include <linux/string.h>
26  #include <net/netlink.h>
27  #include "skcipher.h"
28  
29  #define CRYPTO_ALG_TYPE_SKCIPHER_MASK	0x0000000e
30  
31  enum {
32  	SKCIPHER_WALK_PHYS = 1 << 0,
33  	SKCIPHER_WALK_SLOW = 1 << 1,
34  	SKCIPHER_WALK_COPY = 1 << 2,
35  	SKCIPHER_WALK_DIFF = 1 << 3,
36  	SKCIPHER_WALK_SLEEP = 1 << 4,
37  };
38  
39  struct skcipher_walk_buffer {
40  	struct list_head entry;
41  	struct scatter_walk dst;
42  	unsigned int len;
43  	u8 *data;
44  	u8 buffer[];
45  };
46  
47  static const struct crypto_type crypto_skcipher_type;
48  
49  static int skcipher_walk_next(struct skcipher_walk *walk);
50  
skcipher_map_src(struct skcipher_walk * walk)51  static inline void skcipher_map_src(struct skcipher_walk *walk)
52  {
53  	walk->src.virt.addr = scatterwalk_map(&walk->in);
54  }
55  
skcipher_map_dst(struct skcipher_walk * walk)56  static inline void skcipher_map_dst(struct skcipher_walk *walk)
57  {
58  	walk->dst.virt.addr = scatterwalk_map(&walk->out);
59  }
60  
skcipher_unmap_src(struct skcipher_walk * walk)61  static inline void skcipher_unmap_src(struct skcipher_walk *walk)
62  {
63  	scatterwalk_unmap(walk->src.virt.addr);
64  }
65  
skcipher_unmap_dst(struct skcipher_walk * walk)66  static inline void skcipher_unmap_dst(struct skcipher_walk *walk)
67  {
68  	scatterwalk_unmap(walk->dst.virt.addr);
69  }
70  
skcipher_walk_gfp(struct skcipher_walk * walk)71  static inline gfp_t skcipher_walk_gfp(struct skcipher_walk *walk)
72  {
73  	return walk->flags & SKCIPHER_WALK_SLEEP ? GFP_KERNEL : GFP_ATOMIC;
74  }
75  
76  /* Get a spot of the specified length that does not straddle a page.
77   * The caller needs to ensure that there is enough space for this operation.
78   */
skcipher_get_spot(u8 * start,unsigned int len)79  static inline u8 *skcipher_get_spot(u8 *start, unsigned int len)
80  {
81  	u8 *end_page = (u8 *)(((unsigned long)(start + len - 1)) & PAGE_MASK);
82  
83  	return max(start, end_page);
84  }
85  
__crypto_skcipher_alg(struct crypto_alg * alg)86  static inline struct skcipher_alg *__crypto_skcipher_alg(
87  	struct crypto_alg *alg)
88  {
89  	return container_of(alg, struct skcipher_alg, base);
90  }
91  
skcipher_done_slow(struct skcipher_walk * walk,unsigned int bsize)92  static int skcipher_done_slow(struct skcipher_walk *walk, unsigned int bsize)
93  {
94  	u8 *addr;
95  
96  	addr = (u8 *)ALIGN((unsigned long)walk->buffer, walk->alignmask + 1);
97  	addr = skcipher_get_spot(addr, bsize);
98  	scatterwalk_copychunks(addr, &walk->out, bsize,
99  			       (walk->flags & SKCIPHER_WALK_PHYS) ? 2 : 1);
100  	return 0;
101  }
102  
skcipher_walk_done(struct skcipher_walk * walk,int err)103  int skcipher_walk_done(struct skcipher_walk *walk, int err)
104  {
105  	unsigned int n = walk->nbytes;
106  	unsigned int nbytes = 0;
107  
108  	if (!n)
109  		goto finish;
110  
111  	if (likely(err >= 0)) {
112  		n -= err;
113  		nbytes = walk->total - n;
114  	}
115  
116  	if (likely(!(walk->flags & (SKCIPHER_WALK_PHYS |
117  				    SKCIPHER_WALK_SLOW |
118  				    SKCIPHER_WALK_COPY |
119  				    SKCIPHER_WALK_DIFF)))) {
120  unmap_src:
121  		skcipher_unmap_src(walk);
122  	} else if (walk->flags & SKCIPHER_WALK_DIFF) {
123  		skcipher_unmap_dst(walk);
124  		goto unmap_src;
125  	} else if (walk->flags & SKCIPHER_WALK_COPY) {
126  		skcipher_map_dst(walk);
127  		memcpy(walk->dst.virt.addr, walk->page, n);
128  		skcipher_unmap_dst(walk);
129  	} else if (unlikely(walk->flags & SKCIPHER_WALK_SLOW)) {
130  		if (err > 0) {
131  			/*
132  			 * Didn't process all bytes.  Either the algorithm is
133  			 * broken, or this was the last step and it turned out
134  			 * the message wasn't evenly divisible into blocks but
135  			 * the algorithm requires it.
136  			 */
137  			err = -EINVAL;
138  			nbytes = 0;
139  		} else
140  			n = skcipher_done_slow(walk, n);
141  	}
142  
143  	if (err > 0)
144  		err = 0;
145  
146  	walk->total = nbytes;
147  	walk->nbytes = 0;
148  
149  	scatterwalk_advance(&walk->in, n);
150  	scatterwalk_advance(&walk->out, n);
151  	scatterwalk_done(&walk->in, 0, nbytes);
152  	scatterwalk_done(&walk->out, 1, nbytes);
153  
154  	if (nbytes) {
155  		crypto_yield(walk->flags & SKCIPHER_WALK_SLEEP ?
156  			     CRYPTO_TFM_REQ_MAY_SLEEP : 0);
157  		return skcipher_walk_next(walk);
158  	}
159  
160  finish:
161  	/* Short-circuit for the common/fast path. */
162  	if (!((unsigned long)walk->buffer | (unsigned long)walk->page))
163  		goto out;
164  
165  	if (walk->flags & SKCIPHER_WALK_PHYS)
166  		goto out;
167  
168  	if (walk->iv != walk->oiv)
169  		memcpy(walk->oiv, walk->iv, walk->ivsize);
170  	if (walk->buffer != walk->page)
171  		kfree(walk->buffer);
172  	if (walk->page)
173  		free_page((unsigned long)walk->page);
174  
175  out:
176  	return err;
177  }
178  EXPORT_SYMBOL_GPL(skcipher_walk_done);
179  
skcipher_walk_complete(struct skcipher_walk * walk,int err)180  void skcipher_walk_complete(struct skcipher_walk *walk, int err)
181  {
182  	struct skcipher_walk_buffer *p, *tmp;
183  
184  	list_for_each_entry_safe(p, tmp, &walk->buffers, entry) {
185  		u8 *data;
186  
187  		if (err)
188  			goto done;
189  
190  		data = p->data;
191  		if (!data) {
192  			data = PTR_ALIGN(&p->buffer[0], walk->alignmask + 1);
193  			data = skcipher_get_spot(data, walk->stride);
194  		}
195  
196  		scatterwalk_copychunks(data, &p->dst, p->len, 1);
197  
198  		if (offset_in_page(p->data) + p->len + walk->stride >
199  		    PAGE_SIZE)
200  			free_page((unsigned long)p->data);
201  
202  done:
203  		list_del(&p->entry);
204  		kfree(p);
205  	}
206  
207  	if (!err && walk->iv != walk->oiv)
208  		memcpy(walk->oiv, walk->iv, walk->ivsize);
209  	if (walk->buffer != walk->page)
210  		kfree(walk->buffer);
211  	if (walk->page)
212  		free_page((unsigned long)walk->page);
213  }
214  EXPORT_SYMBOL_GPL(skcipher_walk_complete);
215  
skcipher_queue_write(struct skcipher_walk * walk,struct skcipher_walk_buffer * p)216  static void skcipher_queue_write(struct skcipher_walk *walk,
217  				 struct skcipher_walk_buffer *p)
218  {
219  	p->dst = walk->out;
220  	list_add_tail(&p->entry, &walk->buffers);
221  }
222  
skcipher_next_slow(struct skcipher_walk * walk,unsigned int bsize)223  static int skcipher_next_slow(struct skcipher_walk *walk, unsigned int bsize)
224  {
225  	bool phys = walk->flags & SKCIPHER_WALK_PHYS;
226  	unsigned alignmask = walk->alignmask;
227  	struct skcipher_walk_buffer *p;
228  	unsigned a;
229  	unsigned n;
230  	u8 *buffer;
231  	void *v;
232  
233  	if (!phys) {
234  		if (!walk->buffer)
235  			walk->buffer = walk->page;
236  		buffer = walk->buffer;
237  		if (buffer)
238  			goto ok;
239  	}
240  
241  	/* Start with the minimum alignment of kmalloc. */
242  	a = crypto_tfm_ctx_alignment() - 1;
243  	n = bsize;
244  
245  	if (phys) {
246  		/* Calculate the minimum alignment of p->buffer. */
247  		a &= (sizeof(*p) ^ (sizeof(*p) - 1)) >> 1;
248  		n += sizeof(*p);
249  	}
250  
251  	/* Minimum size to align p->buffer by alignmask. */
252  	n += alignmask & ~a;
253  
254  	/* Minimum size to ensure p->buffer does not straddle a page. */
255  	n += (bsize - 1) & ~(alignmask | a);
256  
257  	v = kzalloc(n, skcipher_walk_gfp(walk));
258  	if (!v)
259  		return skcipher_walk_done(walk, -ENOMEM);
260  
261  	if (phys) {
262  		p = v;
263  		p->len = bsize;
264  		skcipher_queue_write(walk, p);
265  		buffer = p->buffer;
266  	} else {
267  		walk->buffer = v;
268  		buffer = v;
269  	}
270  
271  ok:
272  	walk->dst.virt.addr = PTR_ALIGN(buffer, alignmask + 1);
273  	walk->dst.virt.addr = skcipher_get_spot(walk->dst.virt.addr, bsize);
274  	walk->src.virt.addr = walk->dst.virt.addr;
275  
276  	scatterwalk_copychunks(walk->src.virt.addr, &walk->in, bsize, 0);
277  
278  	walk->nbytes = bsize;
279  	walk->flags |= SKCIPHER_WALK_SLOW;
280  
281  	return 0;
282  }
283  
skcipher_next_copy(struct skcipher_walk * walk)284  static int skcipher_next_copy(struct skcipher_walk *walk)
285  {
286  	struct skcipher_walk_buffer *p;
287  	u8 *tmp = walk->page;
288  
289  	skcipher_map_src(walk);
290  	memcpy(tmp, walk->src.virt.addr, walk->nbytes);
291  	skcipher_unmap_src(walk);
292  
293  	walk->src.virt.addr = tmp;
294  	walk->dst.virt.addr = tmp;
295  
296  	if (!(walk->flags & SKCIPHER_WALK_PHYS))
297  		return 0;
298  
299  	p = kmalloc(sizeof(*p), skcipher_walk_gfp(walk));
300  	if (!p)
301  		return -ENOMEM;
302  
303  	p->data = walk->page;
304  	p->len = walk->nbytes;
305  	skcipher_queue_write(walk, p);
306  
307  	if (offset_in_page(walk->page) + walk->nbytes + walk->stride >
308  	    PAGE_SIZE)
309  		walk->page = NULL;
310  	else
311  		walk->page += walk->nbytes;
312  
313  	return 0;
314  }
315  
skcipher_next_fast(struct skcipher_walk * walk)316  static int skcipher_next_fast(struct skcipher_walk *walk)
317  {
318  	unsigned long diff;
319  
320  	walk->src.phys.page = scatterwalk_page(&walk->in);
321  	walk->src.phys.offset = offset_in_page(walk->in.offset);
322  	walk->dst.phys.page = scatterwalk_page(&walk->out);
323  	walk->dst.phys.offset = offset_in_page(walk->out.offset);
324  
325  	if (walk->flags & SKCIPHER_WALK_PHYS)
326  		return 0;
327  
328  	diff = walk->src.phys.offset - walk->dst.phys.offset;
329  	diff |= walk->src.virt.page - walk->dst.virt.page;
330  
331  	skcipher_map_src(walk);
332  	walk->dst.virt.addr = walk->src.virt.addr;
333  
334  	if (diff) {
335  		walk->flags |= SKCIPHER_WALK_DIFF;
336  		skcipher_map_dst(walk);
337  	}
338  
339  	return 0;
340  }
341  
skcipher_walk_next(struct skcipher_walk * walk)342  static int skcipher_walk_next(struct skcipher_walk *walk)
343  {
344  	unsigned int bsize;
345  	unsigned int n;
346  	int err;
347  
348  	walk->flags &= ~(SKCIPHER_WALK_SLOW | SKCIPHER_WALK_COPY |
349  			 SKCIPHER_WALK_DIFF);
350  
351  	n = walk->total;
352  	bsize = min(walk->stride, max(n, walk->blocksize));
353  	n = scatterwalk_clamp(&walk->in, n);
354  	n = scatterwalk_clamp(&walk->out, n);
355  
356  	if (unlikely(n < bsize)) {
357  		if (unlikely(walk->total < walk->blocksize))
358  			return skcipher_walk_done(walk, -EINVAL);
359  
360  slow_path:
361  		err = skcipher_next_slow(walk, bsize);
362  		goto set_phys_lowmem;
363  	}
364  
365  	if (unlikely((walk->in.offset | walk->out.offset) & walk->alignmask)) {
366  		if (!walk->page) {
367  			gfp_t gfp = skcipher_walk_gfp(walk);
368  
369  			walk->page = (void *)__get_free_page(gfp);
370  			if (!walk->page)
371  				goto slow_path;
372  		}
373  
374  		walk->nbytes = min_t(unsigned, n,
375  				     PAGE_SIZE - offset_in_page(walk->page));
376  		walk->flags |= SKCIPHER_WALK_COPY;
377  		err = skcipher_next_copy(walk);
378  		goto set_phys_lowmem;
379  	}
380  
381  	walk->nbytes = n;
382  
383  	return skcipher_next_fast(walk);
384  
385  set_phys_lowmem:
386  	if (!err && (walk->flags & SKCIPHER_WALK_PHYS)) {
387  		walk->src.phys.page = virt_to_page(walk->src.virt.addr);
388  		walk->dst.phys.page = virt_to_page(walk->dst.virt.addr);
389  		walk->src.phys.offset &= PAGE_SIZE - 1;
390  		walk->dst.phys.offset &= PAGE_SIZE - 1;
391  	}
392  	return err;
393  }
394  
skcipher_copy_iv(struct skcipher_walk * walk)395  static int skcipher_copy_iv(struct skcipher_walk *walk)
396  {
397  	unsigned a = crypto_tfm_ctx_alignment() - 1;
398  	unsigned alignmask = walk->alignmask;
399  	unsigned ivsize = walk->ivsize;
400  	unsigned bs = walk->stride;
401  	unsigned aligned_bs;
402  	unsigned size;
403  	u8 *iv;
404  
405  	aligned_bs = ALIGN(bs, alignmask + 1);
406  
407  	/* Minimum size to align buffer by alignmask. */
408  	size = alignmask & ~a;
409  
410  	if (walk->flags & SKCIPHER_WALK_PHYS)
411  		size += ivsize;
412  	else {
413  		size += aligned_bs + ivsize;
414  
415  		/* Minimum size to ensure buffer does not straddle a page. */
416  		size += (bs - 1) & ~(alignmask | a);
417  	}
418  
419  	walk->buffer = kmalloc(size, skcipher_walk_gfp(walk));
420  	if (!walk->buffer)
421  		return -ENOMEM;
422  
423  	iv = PTR_ALIGN(walk->buffer, alignmask + 1);
424  	iv = skcipher_get_spot(iv, bs) + aligned_bs;
425  
426  	walk->iv = memcpy(iv, walk->iv, walk->ivsize);
427  	return 0;
428  }
429  
skcipher_walk_first(struct skcipher_walk * walk)430  static int skcipher_walk_first(struct skcipher_walk *walk)
431  {
432  	if (WARN_ON_ONCE(in_hardirq()))
433  		return -EDEADLK;
434  
435  	walk->buffer = NULL;
436  	if (unlikely(((unsigned long)walk->iv & walk->alignmask))) {
437  		int err = skcipher_copy_iv(walk);
438  		if (err)
439  			return err;
440  	}
441  
442  	walk->page = NULL;
443  
444  	return skcipher_walk_next(walk);
445  }
446  
skcipher_walk_skcipher(struct skcipher_walk * walk,struct skcipher_request * req)447  static int skcipher_walk_skcipher(struct skcipher_walk *walk,
448  				  struct skcipher_request *req)
449  {
450  	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
451  	struct skcipher_alg *alg = crypto_skcipher_alg(tfm);
452  
453  	walk->total = req->cryptlen;
454  	walk->nbytes = 0;
455  	walk->iv = req->iv;
456  	walk->oiv = req->iv;
457  
458  	if (unlikely(!walk->total))
459  		return 0;
460  
461  	scatterwalk_start(&walk->in, req->src);
462  	scatterwalk_start(&walk->out, req->dst);
463  
464  	walk->flags &= ~SKCIPHER_WALK_SLEEP;
465  	walk->flags |= req->base.flags & CRYPTO_TFM_REQ_MAY_SLEEP ?
466  		       SKCIPHER_WALK_SLEEP : 0;
467  
468  	walk->blocksize = crypto_skcipher_blocksize(tfm);
469  	walk->ivsize = crypto_skcipher_ivsize(tfm);
470  	walk->alignmask = crypto_skcipher_alignmask(tfm);
471  
472  	if (alg->co.base.cra_type != &crypto_skcipher_type)
473  		walk->stride = alg->co.chunksize;
474  	else
475  		walk->stride = alg->walksize;
476  
477  	return skcipher_walk_first(walk);
478  }
479  
skcipher_walk_virt(struct skcipher_walk * walk,struct skcipher_request * req,bool atomic)480  int skcipher_walk_virt(struct skcipher_walk *walk,
481  		       struct skcipher_request *req, bool atomic)
482  {
483  	int err;
484  
485  	might_sleep_if(req->base.flags & CRYPTO_TFM_REQ_MAY_SLEEP);
486  
487  	walk->flags &= ~SKCIPHER_WALK_PHYS;
488  
489  	err = skcipher_walk_skcipher(walk, req);
490  
491  	walk->flags &= atomic ? ~SKCIPHER_WALK_SLEEP : ~0;
492  
493  	return err;
494  }
495  EXPORT_SYMBOL_GPL(skcipher_walk_virt);
496  
skcipher_walk_async(struct skcipher_walk * walk,struct skcipher_request * req)497  int skcipher_walk_async(struct skcipher_walk *walk,
498  			struct skcipher_request *req)
499  {
500  	walk->flags |= SKCIPHER_WALK_PHYS;
501  
502  	INIT_LIST_HEAD(&walk->buffers);
503  
504  	return skcipher_walk_skcipher(walk, req);
505  }
506  EXPORT_SYMBOL_GPL(skcipher_walk_async);
507  
skcipher_walk_aead_common(struct skcipher_walk * walk,struct aead_request * req,bool atomic)508  static int skcipher_walk_aead_common(struct skcipher_walk *walk,
509  				     struct aead_request *req, bool atomic)
510  {
511  	struct crypto_aead *tfm = crypto_aead_reqtfm(req);
512  	int err;
513  
514  	walk->nbytes = 0;
515  	walk->iv = req->iv;
516  	walk->oiv = req->iv;
517  
518  	if (unlikely(!walk->total))
519  		return 0;
520  
521  	walk->flags &= ~SKCIPHER_WALK_PHYS;
522  
523  	scatterwalk_start(&walk->in, req->src);
524  	scatterwalk_start(&walk->out, req->dst);
525  
526  	scatterwalk_copychunks(NULL, &walk->in, req->assoclen, 2);
527  	scatterwalk_copychunks(NULL, &walk->out, req->assoclen, 2);
528  
529  	scatterwalk_done(&walk->in, 0, walk->total);
530  	scatterwalk_done(&walk->out, 0, walk->total);
531  
532  	if (req->base.flags & CRYPTO_TFM_REQ_MAY_SLEEP)
533  		walk->flags |= SKCIPHER_WALK_SLEEP;
534  	else
535  		walk->flags &= ~SKCIPHER_WALK_SLEEP;
536  
537  	walk->blocksize = crypto_aead_blocksize(tfm);
538  	walk->stride = crypto_aead_chunksize(tfm);
539  	walk->ivsize = crypto_aead_ivsize(tfm);
540  	walk->alignmask = crypto_aead_alignmask(tfm);
541  
542  	err = skcipher_walk_first(walk);
543  
544  	if (atomic)
545  		walk->flags &= ~SKCIPHER_WALK_SLEEP;
546  
547  	return err;
548  }
549  
skcipher_walk_aead_encrypt(struct skcipher_walk * walk,struct aead_request * req,bool atomic)550  int skcipher_walk_aead_encrypt(struct skcipher_walk *walk,
551  			       struct aead_request *req, bool atomic)
552  {
553  	walk->total = req->cryptlen;
554  
555  	return skcipher_walk_aead_common(walk, req, atomic);
556  }
557  EXPORT_SYMBOL_GPL(skcipher_walk_aead_encrypt);
558  
skcipher_walk_aead_decrypt(struct skcipher_walk * walk,struct aead_request * req,bool atomic)559  int skcipher_walk_aead_decrypt(struct skcipher_walk *walk,
560  			       struct aead_request *req, bool atomic)
561  {
562  	struct crypto_aead *tfm = crypto_aead_reqtfm(req);
563  
564  	walk->total = req->cryptlen - crypto_aead_authsize(tfm);
565  
566  	return skcipher_walk_aead_common(walk, req, atomic);
567  }
568  EXPORT_SYMBOL_GPL(skcipher_walk_aead_decrypt);
569  
skcipher_set_needkey(struct crypto_skcipher * tfm)570  static void skcipher_set_needkey(struct crypto_skcipher *tfm)
571  {
572  	if (crypto_skcipher_max_keysize(tfm) != 0)
573  		crypto_skcipher_set_flags(tfm, CRYPTO_TFM_NEED_KEY);
574  }
575  
skcipher_setkey_unaligned(struct crypto_skcipher * tfm,const u8 * key,unsigned int keylen)576  static int skcipher_setkey_unaligned(struct crypto_skcipher *tfm,
577  				     const u8 *key, unsigned int keylen)
578  {
579  	unsigned long alignmask = crypto_skcipher_alignmask(tfm);
580  	struct skcipher_alg *cipher = crypto_skcipher_alg(tfm);
581  	u8 *buffer, *alignbuffer;
582  	unsigned long absize;
583  	int ret;
584  
585  	absize = keylen + alignmask;
586  	buffer = kmalloc(absize, GFP_ATOMIC);
587  	if (!buffer)
588  		return -ENOMEM;
589  
590  	alignbuffer = (u8 *)ALIGN((unsigned long)buffer, alignmask + 1);
591  	memcpy(alignbuffer, key, keylen);
592  	ret = cipher->setkey(tfm, alignbuffer, keylen);
593  	kfree_sensitive(buffer);
594  	return ret;
595  }
596  
crypto_skcipher_setkey(struct crypto_skcipher * tfm,const u8 * key,unsigned int keylen)597  int crypto_skcipher_setkey(struct crypto_skcipher *tfm, const u8 *key,
598  			   unsigned int keylen)
599  {
600  	struct skcipher_alg *cipher = crypto_skcipher_alg(tfm);
601  	unsigned long alignmask = crypto_skcipher_alignmask(tfm);
602  	int err;
603  
604  	if (cipher->co.base.cra_type != &crypto_skcipher_type) {
605  		struct crypto_lskcipher **ctx = crypto_skcipher_ctx(tfm);
606  
607  		crypto_lskcipher_clear_flags(*ctx, CRYPTO_TFM_REQ_MASK);
608  		crypto_lskcipher_set_flags(*ctx,
609  					   crypto_skcipher_get_flags(tfm) &
610  					   CRYPTO_TFM_REQ_MASK);
611  		err = crypto_lskcipher_setkey(*ctx, key, keylen);
612  		goto out;
613  	}
614  
615  	if (keylen < cipher->min_keysize || keylen > cipher->max_keysize)
616  		return -EINVAL;
617  
618  	if ((unsigned long)key & alignmask)
619  		err = skcipher_setkey_unaligned(tfm, key, keylen);
620  	else
621  		err = cipher->setkey(tfm, key, keylen);
622  
623  out:
624  	if (unlikely(err)) {
625  		skcipher_set_needkey(tfm);
626  		return err;
627  	}
628  
629  	crypto_skcipher_clear_flags(tfm, CRYPTO_TFM_NEED_KEY);
630  	return 0;
631  }
632  EXPORT_SYMBOL_GPL(crypto_skcipher_setkey);
633  
crypto_skcipher_encrypt(struct skcipher_request * req)634  int crypto_skcipher_encrypt(struct skcipher_request *req)
635  {
636  	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
637  	struct skcipher_alg *alg = crypto_skcipher_alg(tfm);
638  
639  	if (crypto_skcipher_get_flags(tfm) & CRYPTO_TFM_NEED_KEY)
640  		return -ENOKEY;
641  	if (alg->co.base.cra_type != &crypto_skcipher_type)
642  		return crypto_lskcipher_encrypt_sg(req);
643  	return alg->encrypt(req);
644  }
645  EXPORT_SYMBOL_GPL(crypto_skcipher_encrypt);
646  
crypto_skcipher_decrypt(struct skcipher_request * req)647  int crypto_skcipher_decrypt(struct skcipher_request *req)
648  {
649  	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
650  	struct skcipher_alg *alg = crypto_skcipher_alg(tfm);
651  
652  	if (crypto_skcipher_get_flags(tfm) & CRYPTO_TFM_NEED_KEY)
653  		return -ENOKEY;
654  	if (alg->co.base.cra_type != &crypto_skcipher_type)
655  		return crypto_lskcipher_decrypt_sg(req);
656  	return alg->decrypt(req);
657  }
658  EXPORT_SYMBOL_GPL(crypto_skcipher_decrypt);
659  
crypto_lskcipher_export(struct skcipher_request * req,void * out)660  static int crypto_lskcipher_export(struct skcipher_request *req, void *out)
661  {
662  	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
663  	u8 *ivs = skcipher_request_ctx(req);
664  
665  	ivs = PTR_ALIGN(ivs, crypto_skcipher_alignmask(tfm) + 1);
666  
667  	memcpy(out, ivs + crypto_skcipher_ivsize(tfm),
668  	       crypto_skcipher_statesize(tfm));
669  
670  	return 0;
671  }
672  
crypto_lskcipher_import(struct skcipher_request * req,const void * in)673  static int crypto_lskcipher_import(struct skcipher_request *req, const void *in)
674  {
675  	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
676  	u8 *ivs = skcipher_request_ctx(req);
677  
678  	ivs = PTR_ALIGN(ivs, crypto_skcipher_alignmask(tfm) + 1);
679  
680  	memcpy(ivs + crypto_skcipher_ivsize(tfm), in,
681  	       crypto_skcipher_statesize(tfm));
682  
683  	return 0;
684  }
685  
skcipher_noexport(struct skcipher_request * req,void * out)686  static int skcipher_noexport(struct skcipher_request *req, void *out)
687  {
688  	return 0;
689  }
690  
skcipher_noimport(struct skcipher_request * req,const void * in)691  static int skcipher_noimport(struct skcipher_request *req, const void *in)
692  {
693  	return 0;
694  }
695  
crypto_skcipher_export(struct skcipher_request * req,void * out)696  int crypto_skcipher_export(struct skcipher_request *req, void *out)
697  {
698  	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
699  	struct skcipher_alg *alg = crypto_skcipher_alg(tfm);
700  
701  	if (alg->co.base.cra_type != &crypto_skcipher_type)
702  		return crypto_lskcipher_export(req, out);
703  	return alg->export(req, out);
704  }
705  EXPORT_SYMBOL_GPL(crypto_skcipher_export);
706  
crypto_skcipher_import(struct skcipher_request * req,const void * in)707  int crypto_skcipher_import(struct skcipher_request *req, const void *in)
708  {
709  	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
710  	struct skcipher_alg *alg = crypto_skcipher_alg(tfm);
711  
712  	if (alg->co.base.cra_type != &crypto_skcipher_type)
713  		return crypto_lskcipher_import(req, in);
714  	return alg->import(req, in);
715  }
716  EXPORT_SYMBOL_GPL(crypto_skcipher_import);
717  
crypto_skcipher_exit_tfm(struct crypto_tfm * tfm)718  static void crypto_skcipher_exit_tfm(struct crypto_tfm *tfm)
719  {
720  	struct crypto_skcipher *skcipher = __crypto_skcipher_cast(tfm);
721  	struct skcipher_alg *alg = crypto_skcipher_alg(skcipher);
722  
723  	alg->exit(skcipher);
724  }
725  
crypto_skcipher_init_tfm(struct crypto_tfm * tfm)726  static int crypto_skcipher_init_tfm(struct crypto_tfm *tfm)
727  {
728  	struct crypto_skcipher *skcipher = __crypto_skcipher_cast(tfm);
729  	struct skcipher_alg *alg = crypto_skcipher_alg(skcipher);
730  
731  	skcipher_set_needkey(skcipher);
732  
733  	if (tfm->__crt_alg->cra_type != &crypto_skcipher_type) {
734  		unsigned am = crypto_skcipher_alignmask(skcipher);
735  		unsigned reqsize;
736  
737  		reqsize = am & ~(crypto_tfm_ctx_alignment() - 1);
738  		reqsize += crypto_skcipher_ivsize(skcipher);
739  		reqsize += crypto_skcipher_statesize(skcipher);
740  		crypto_skcipher_set_reqsize(skcipher, reqsize);
741  
742  		return crypto_init_lskcipher_ops_sg(tfm);
743  	}
744  
745  	if (alg->exit)
746  		skcipher->base.exit = crypto_skcipher_exit_tfm;
747  
748  	if (alg->init)
749  		return alg->init(skcipher);
750  
751  	return 0;
752  }
753  
crypto_skcipher_extsize(struct crypto_alg * alg)754  static unsigned int crypto_skcipher_extsize(struct crypto_alg *alg)
755  {
756  	if (alg->cra_type != &crypto_skcipher_type)
757  		return sizeof(struct crypto_lskcipher *);
758  
759  	return crypto_alg_extsize(alg);
760  }
761  
crypto_skcipher_free_instance(struct crypto_instance * inst)762  static void crypto_skcipher_free_instance(struct crypto_instance *inst)
763  {
764  	struct skcipher_instance *skcipher =
765  		container_of(inst, struct skcipher_instance, s.base);
766  
767  	skcipher->free(skcipher);
768  }
769  
770  static void crypto_skcipher_show(struct seq_file *m, struct crypto_alg *alg)
771  	__maybe_unused;
crypto_skcipher_show(struct seq_file * m,struct crypto_alg * alg)772  static void crypto_skcipher_show(struct seq_file *m, struct crypto_alg *alg)
773  {
774  	struct skcipher_alg *skcipher = __crypto_skcipher_alg(alg);
775  
776  	seq_printf(m, "type         : skcipher\n");
777  	seq_printf(m, "async        : %s\n",
778  		   alg->cra_flags & CRYPTO_ALG_ASYNC ?  "yes" : "no");
779  	seq_printf(m, "blocksize    : %u\n", alg->cra_blocksize);
780  	seq_printf(m, "min keysize  : %u\n", skcipher->min_keysize);
781  	seq_printf(m, "max keysize  : %u\n", skcipher->max_keysize);
782  	seq_printf(m, "ivsize       : %u\n", skcipher->ivsize);
783  	seq_printf(m, "chunksize    : %u\n", skcipher->chunksize);
784  	seq_printf(m, "walksize     : %u\n", skcipher->walksize);
785  	seq_printf(m, "statesize    : %u\n", skcipher->statesize);
786  }
787  
crypto_skcipher_report(struct sk_buff * skb,struct crypto_alg * alg)788  static int __maybe_unused crypto_skcipher_report(
789  	struct sk_buff *skb, struct crypto_alg *alg)
790  {
791  	struct skcipher_alg *skcipher = __crypto_skcipher_alg(alg);
792  	struct crypto_report_blkcipher rblkcipher;
793  
794  	memset(&rblkcipher, 0, sizeof(rblkcipher));
795  
796  	strscpy(rblkcipher.type, "skcipher", sizeof(rblkcipher.type));
797  	strscpy(rblkcipher.geniv, "<none>", sizeof(rblkcipher.geniv));
798  
799  	rblkcipher.blocksize = alg->cra_blocksize;
800  	rblkcipher.min_keysize = skcipher->min_keysize;
801  	rblkcipher.max_keysize = skcipher->max_keysize;
802  	rblkcipher.ivsize = skcipher->ivsize;
803  
804  	return nla_put(skb, CRYPTOCFGA_REPORT_BLKCIPHER,
805  		       sizeof(rblkcipher), &rblkcipher);
806  }
807  
808  static const struct crypto_type crypto_skcipher_type = {
809  	.extsize = crypto_skcipher_extsize,
810  	.init_tfm = crypto_skcipher_init_tfm,
811  	.free = crypto_skcipher_free_instance,
812  #ifdef CONFIG_PROC_FS
813  	.show = crypto_skcipher_show,
814  #endif
815  #if IS_ENABLED(CONFIG_CRYPTO_USER)
816  	.report = crypto_skcipher_report,
817  #endif
818  	.maskclear = ~CRYPTO_ALG_TYPE_MASK,
819  	.maskset = CRYPTO_ALG_TYPE_SKCIPHER_MASK,
820  	.type = CRYPTO_ALG_TYPE_SKCIPHER,
821  	.tfmsize = offsetof(struct crypto_skcipher, base),
822  };
823  
crypto_grab_skcipher(struct crypto_skcipher_spawn * spawn,struct crypto_instance * inst,const char * name,u32 type,u32 mask)824  int crypto_grab_skcipher(struct crypto_skcipher_spawn *spawn,
825  			 struct crypto_instance *inst,
826  			 const char *name, u32 type, u32 mask)
827  {
828  	spawn->base.frontend = &crypto_skcipher_type;
829  	return crypto_grab_spawn(&spawn->base, inst, name, type, mask);
830  }
831  EXPORT_SYMBOL_GPL(crypto_grab_skcipher);
832  
crypto_alloc_skcipher(const char * alg_name,u32 type,u32 mask)833  struct crypto_skcipher *crypto_alloc_skcipher(const char *alg_name,
834  					      u32 type, u32 mask)
835  {
836  	return crypto_alloc_tfm(alg_name, &crypto_skcipher_type, type, mask);
837  }
838  EXPORT_SYMBOL_GPL(crypto_alloc_skcipher);
839  
crypto_alloc_sync_skcipher(const char * alg_name,u32 type,u32 mask)840  struct crypto_sync_skcipher *crypto_alloc_sync_skcipher(
841  				const char *alg_name, u32 type, u32 mask)
842  {
843  	struct crypto_skcipher *tfm;
844  
845  	/* Only sync algorithms allowed. */
846  	mask |= CRYPTO_ALG_ASYNC | CRYPTO_ALG_SKCIPHER_REQSIZE_LARGE;
847  
848  	tfm = crypto_alloc_tfm(alg_name, &crypto_skcipher_type, type, mask);
849  
850  	/*
851  	 * Make sure we do not allocate something that might get used with
852  	 * an on-stack request: check the request size.
853  	 */
854  	if (!IS_ERR(tfm) && WARN_ON(crypto_skcipher_reqsize(tfm) >
855  				    MAX_SYNC_SKCIPHER_REQSIZE)) {
856  		crypto_free_skcipher(tfm);
857  		return ERR_PTR(-EINVAL);
858  	}
859  
860  	return (struct crypto_sync_skcipher *)tfm;
861  }
862  EXPORT_SYMBOL_GPL(crypto_alloc_sync_skcipher);
863  
crypto_has_skcipher(const char * alg_name,u32 type,u32 mask)864  int crypto_has_skcipher(const char *alg_name, u32 type, u32 mask)
865  {
866  	return crypto_type_has_alg(alg_name, &crypto_skcipher_type, type, mask);
867  }
868  EXPORT_SYMBOL_GPL(crypto_has_skcipher);
869  
skcipher_prepare_alg_common(struct skcipher_alg_common * alg)870  int skcipher_prepare_alg_common(struct skcipher_alg_common *alg)
871  {
872  	struct crypto_alg *base = &alg->base;
873  
874  	if (alg->ivsize > PAGE_SIZE / 8 || alg->chunksize > PAGE_SIZE / 8 ||
875  	    alg->statesize > PAGE_SIZE / 2 ||
876  	    (alg->ivsize + alg->statesize) > PAGE_SIZE / 2)
877  		return -EINVAL;
878  
879  	if (!alg->chunksize)
880  		alg->chunksize = base->cra_blocksize;
881  
882  	base->cra_flags &= ~CRYPTO_ALG_TYPE_MASK;
883  
884  	return 0;
885  }
886  
skcipher_prepare_alg(struct skcipher_alg * alg)887  static int skcipher_prepare_alg(struct skcipher_alg *alg)
888  {
889  	struct crypto_alg *base = &alg->base;
890  	int err;
891  
892  	err = skcipher_prepare_alg_common(&alg->co);
893  	if (err)
894  		return err;
895  
896  	if (alg->walksize > PAGE_SIZE / 8)
897  		return -EINVAL;
898  
899  	if (!alg->walksize)
900  		alg->walksize = alg->chunksize;
901  
902  	if (!alg->statesize) {
903  		alg->import = skcipher_noimport;
904  		alg->export = skcipher_noexport;
905  	} else if (!(alg->import && alg->export))
906  		return -EINVAL;
907  
908  	base->cra_type = &crypto_skcipher_type;
909  	base->cra_flags |= CRYPTO_ALG_TYPE_SKCIPHER;
910  
911  	return 0;
912  }
913  
crypto_register_skcipher(struct skcipher_alg * alg)914  int crypto_register_skcipher(struct skcipher_alg *alg)
915  {
916  	struct crypto_alg *base = &alg->base;
917  	int err;
918  
919  	err = skcipher_prepare_alg(alg);
920  	if (err)
921  		return err;
922  
923  	return crypto_register_alg(base);
924  }
925  EXPORT_SYMBOL_GPL(crypto_register_skcipher);
926  
crypto_unregister_skcipher(struct skcipher_alg * alg)927  void crypto_unregister_skcipher(struct skcipher_alg *alg)
928  {
929  	crypto_unregister_alg(&alg->base);
930  }
931  EXPORT_SYMBOL_GPL(crypto_unregister_skcipher);
932  
crypto_register_skciphers(struct skcipher_alg * algs,int count)933  int crypto_register_skciphers(struct skcipher_alg *algs, int count)
934  {
935  	int i, ret;
936  
937  	for (i = 0; i < count; i++) {
938  		ret = crypto_register_skcipher(&algs[i]);
939  		if (ret)
940  			goto err;
941  	}
942  
943  	return 0;
944  
945  err:
946  	for (--i; i >= 0; --i)
947  		crypto_unregister_skcipher(&algs[i]);
948  
949  	return ret;
950  }
951  EXPORT_SYMBOL_GPL(crypto_register_skciphers);
952  
crypto_unregister_skciphers(struct skcipher_alg * algs,int count)953  void crypto_unregister_skciphers(struct skcipher_alg *algs, int count)
954  {
955  	int i;
956  
957  	for (i = count - 1; i >= 0; --i)
958  		crypto_unregister_skcipher(&algs[i]);
959  }
960  EXPORT_SYMBOL_GPL(crypto_unregister_skciphers);
961  
skcipher_register_instance(struct crypto_template * tmpl,struct skcipher_instance * inst)962  int skcipher_register_instance(struct crypto_template *tmpl,
963  			   struct skcipher_instance *inst)
964  {
965  	int err;
966  
967  	if (WARN_ON(!inst->free))
968  		return -EINVAL;
969  
970  	err = skcipher_prepare_alg(&inst->alg);
971  	if (err)
972  		return err;
973  
974  	return crypto_register_instance(tmpl, skcipher_crypto_instance(inst));
975  }
976  EXPORT_SYMBOL_GPL(skcipher_register_instance);
977  
skcipher_setkey_simple(struct crypto_skcipher * tfm,const u8 * key,unsigned int keylen)978  static int skcipher_setkey_simple(struct crypto_skcipher *tfm, const u8 *key,
979  				  unsigned int keylen)
980  {
981  	struct crypto_cipher *cipher = skcipher_cipher_simple(tfm);
982  
983  	crypto_cipher_clear_flags(cipher, CRYPTO_TFM_REQ_MASK);
984  	crypto_cipher_set_flags(cipher, crypto_skcipher_get_flags(tfm) &
985  				CRYPTO_TFM_REQ_MASK);
986  	return crypto_cipher_setkey(cipher, key, keylen);
987  }
988  
skcipher_init_tfm_simple(struct crypto_skcipher * tfm)989  static int skcipher_init_tfm_simple(struct crypto_skcipher *tfm)
990  {
991  	struct skcipher_instance *inst = skcipher_alg_instance(tfm);
992  	struct crypto_cipher_spawn *spawn = skcipher_instance_ctx(inst);
993  	struct skcipher_ctx_simple *ctx = crypto_skcipher_ctx(tfm);
994  	struct crypto_cipher *cipher;
995  
996  	cipher = crypto_spawn_cipher(spawn);
997  	if (IS_ERR(cipher))
998  		return PTR_ERR(cipher);
999  
1000  	ctx->cipher = cipher;
1001  	return 0;
1002  }
1003  
skcipher_exit_tfm_simple(struct crypto_skcipher * tfm)1004  static void skcipher_exit_tfm_simple(struct crypto_skcipher *tfm)
1005  {
1006  	struct skcipher_ctx_simple *ctx = crypto_skcipher_ctx(tfm);
1007  
1008  	crypto_free_cipher(ctx->cipher);
1009  }
1010  
skcipher_free_instance_simple(struct skcipher_instance * inst)1011  static void skcipher_free_instance_simple(struct skcipher_instance *inst)
1012  {
1013  	crypto_drop_cipher(skcipher_instance_ctx(inst));
1014  	kfree(inst);
1015  }
1016  
1017  /**
1018   * skcipher_alloc_instance_simple - allocate instance of simple block cipher mode
1019   *
1020   * Allocate an skcipher_instance for a simple block cipher mode of operation,
1021   * e.g. cbc or ecb.  The instance context will have just a single crypto_spawn,
1022   * that for the underlying cipher.  The {min,max}_keysize, ivsize, blocksize,
1023   * alignmask, and priority are set from the underlying cipher but can be
1024   * overridden if needed.  The tfm context defaults to skcipher_ctx_simple, and
1025   * default ->setkey(), ->init(), and ->exit() methods are installed.
1026   *
1027   * @tmpl: the template being instantiated
1028   * @tb: the template parameters
1029   *
1030   * Return: a pointer to the new instance, or an ERR_PTR().  The caller still
1031   *	   needs to register the instance.
1032   */
skcipher_alloc_instance_simple(struct crypto_template * tmpl,struct rtattr ** tb)1033  struct skcipher_instance *skcipher_alloc_instance_simple(
1034  	struct crypto_template *tmpl, struct rtattr **tb)
1035  {
1036  	u32 mask;
1037  	struct skcipher_instance *inst;
1038  	struct crypto_cipher_spawn *spawn;
1039  	struct crypto_alg *cipher_alg;
1040  	int err;
1041  
1042  	err = crypto_check_attr_type(tb, CRYPTO_ALG_TYPE_SKCIPHER, &mask);
1043  	if (err)
1044  		return ERR_PTR(err);
1045  
1046  	inst = kzalloc(sizeof(*inst) + sizeof(*spawn), GFP_KERNEL);
1047  	if (!inst)
1048  		return ERR_PTR(-ENOMEM);
1049  	spawn = skcipher_instance_ctx(inst);
1050  
1051  	err = crypto_grab_cipher(spawn, skcipher_crypto_instance(inst),
1052  				 crypto_attr_alg_name(tb[1]), 0, mask);
1053  	if (err)
1054  		goto err_free_inst;
1055  	cipher_alg = crypto_spawn_cipher_alg(spawn);
1056  
1057  	err = crypto_inst_setname(skcipher_crypto_instance(inst), tmpl->name,
1058  				  cipher_alg);
1059  	if (err)
1060  		goto err_free_inst;
1061  
1062  	inst->free = skcipher_free_instance_simple;
1063  
1064  	/* Default algorithm properties, can be overridden */
1065  	inst->alg.base.cra_blocksize = cipher_alg->cra_blocksize;
1066  	inst->alg.base.cra_alignmask = cipher_alg->cra_alignmask;
1067  	inst->alg.base.cra_priority = cipher_alg->cra_priority;
1068  	inst->alg.min_keysize = cipher_alg->cra_cipher.cia_min_keysize;
1069  	inst->alg.max_keysize = cipher_alg->cra_cipher.cia_max_keysize;
1070  	inst->alg.ivsize = cipher_alg->cra_blocksize;
1071  
1072  	/* Use skcipher_ctx_simple by default, can be overridden */
1073  	inst->alg.base.cra_ctxsize = sizeof(struct skcipher_ctx_simple);
1074  	inst->alg.setkey = skcipher_setkey_simple;
1075  	inst->alg.init = skcipher_init_tfm_simple;
1076  	inst->alg.exit = skcipher_exit_tfm_simple;
1077  
1078  	return inst;
1079  
1080  err_free_inst:
1081  	skcipher_free_instance_simple(inst);
1082  	return ERR_PTR(err);
1083  }
1084  EXPORT_SYMBOL_GPL(skcipher_alloc_instance_simple);
1085  
1086  MODULE_LICENSE("GPL");
1087  MODULE_DESCRIPTION("Symmetric key cipher type");
1088  MODULE_IMPORT_NS(CRYPTO_INTERNAL);
1089