1 // SPDX-License-Identifier: GPL-2.0
2 /*
3  * Amazon Nitro Secure Module driver.
4  *
5  * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
6  *
7  * The Nitro Secure Module implements commands via CBOR over virtio.
8  * This driver exposes a raw message ioctls on /dev/nsm that user
9  * space can use to issue these commands.
10  */
11 
12 #include <linux/file.h>
13 #include <linux/fs.h>
14 #include <linux/interrupt.h>
15 #include <linux/hw_random.h>
16 #include <linux/miscdevice.h>
17 #include <linux/module.h>
18 #include <linux/mutex.h>
19 #include <linux/slab.h>
20 #include <linux/string.h>
21 #include <linux/uaccess.h>
22 #include <linux/uio.h>
23 #include <linux/virtio_config.h>
24 #include <linux/virtio_ids.h>
25 #include <linux/virtio.h>
26 #include <linux/wait.h>
27 #include <uapi/linux/nsm.h>
28 
29 /* Timeout for NSM virtqueue respose in milliseconds. */
30 #define NSM_DEFAULT_TIMEOUT_MSECS (120000) /* 2 minutes */
31 
32 /* Maximum length input data */
33 struct nsm_data_req {
34 	u32 len;
35 	u8  data[NSM_REQUEST_MAX_SIZE];
36 };
37 
38 /* Maximum length output data */
39 struct nsm_data_resp {
40 	u32 len;
41 	u8  data[NSM_RESPONSE_MAX_SIZE];
42 };
43 
44 /* Full NSM request/response message */
45 struct nsm_msg {
46 	struct nsm_data_req req;
47 	struct nsm_data_resp resp;
48 };
49 
50 struct nsm {
51 	struct virtio_device *vdev;
52 	struct virtqueue     *vq;
53 	struct mutex          lock;
54 	struct completion     cmd_done;
55 	struct miscdevice     misc;
56 	struct hwrng          hwrng;
57 	struct work_struct    misc_init;
58 	struct nsm_msg        msg;
59 };
60 
61 /* NSM device ID */
62 static const struct virtio_device_id id_table[] = {
63 	{ VIRTIO_ID_NITRO_SEC_MOD, VIRTIO_DEV_ANY_ID },
64 	{ 0 },
65 };
66 
file_to_nsm(struct file * file)67 static struct nsm *file_to_nsm(struct file *file)
68 {
69 	return container_of(file->private_data, struct nsm, misc);
70 }
71 
hwrng_to_nsm(struct hwrng * rng)72 static struct nsm *hwrng_to_nsm(struct hwrng *rng)
73 {
74 	return container_of(rng, struct nsm, hwrng);
75 }
76 
77 #define CBOR_TYPE_MASK  0xE0
78 #define CBOR_TYPE_MAP 0xA0
79 #define CBOR_TYPE_TEXT 0x60
80 #define CBOR_TYPE_ARRAY 0x40
81 #define CBOR_HEADER_SIZE_SHORT 1
82 
83 #define CBOR_SHORT_SIZE_MAX_VALUE 23
84 #define CBOR_LONG_SIZE_U8  24
85 #define CBOR_LONG_SIZE_U16 25
86 #define CBOR_LONG_SIZE_U32 26
87 #define CBOR_LONG_SIZE_U64 27
88 
cbor_object_is_array(const u8 * cbor_object,size_t cbor_object_size)89 static bool cbor_object_is_array(const u8 *cbor_object, size_t cbor_object_size)
90 {
91 	if (cbor_object_size == 0 || cbor_object == NULL)
92 		return false;
93 
94 	return (cbor_object[0] & CBOR_TYPE_MASK) == CBOR_TYPE_ARRAY;
95 }
96 
cbor_object_get_array(u8 * cbor_object,size_t cbor_object_size,u8 ** cbor_array)97 static int cbor_object_get_array(u8 *cbor_object, size_t cbor_object_size, u8 **cbor_array)
98 {
99 	u8 cbor_short_size;
100 	void *array_len_p;
101 	u64 array_len;
102 	u64 array_offset;
103 
104 	if (!cbor_object_is_array(cbor_object, cbor_object_size))
105 		return -EFAULT;
106 
107 	cbor_short_size = (cbor_object[0] & 0x1F);
108 
109 	/* Decoding byte array length */
110 	array_offset = CBOR_HEADER_SIZE_SHORT;
111 	if (cbor_short_size >= CBOR_LONG_SIZE_U8)
112 		array_offset += BIT(cbor_short_size - CBOR_LONG_SIZE_U8);
113 
114 	if (cbor_object_size < array_offset)
115 		return -EFAULT;
116 
117 	array_len_p = &cbor_object[1];
118 
119 	switch (cbor_short_size) {
120 	case CBOR_SHORT_SIZE_MAX_VALUE: /* short encoding */
121 		array_len = cbor_short_size;
122 		break;
123 	case CBOR_LONG_SIZE_U8:
124 		array_len = *(u8 *)array_len_p;
125 		break;
126 	case CBOR_LONG_SIZE_U16:
127 		array_len = be16_to_cpup((__be16 *)array_len_p);
128 		break;
129 	case CBOR_LONG_SIZE_U32:
130 		array_len = be32_to_cpup((__be32 *)array_len_p);
131 		break;
132 	case CBOR_LONG_SIZE_U64:
133 		array_len = be64_to_cpup((__be64 *)array_len_p);
134 		break;
135 	}
136 
137 	if (cbor_object_size < array_offset)
138 		return -EFAULT;
139 
140 	if (cbor_object_size - array_offset < array_len)
141 		return -EFAULT;
142 
143 	if (array_len > INT_MAX)
144 		return -EFAULT;
145 
146 	*cbor_array = cbor_object + array_offset;
147 	return array_len;
148 }
149 
150 /* Copy the request of a raw message to kernel space */
fill_req_raw(struct nsm * nsm,struct nsm_data_req * req,struct nsm_raw * raw)151 static int fill_req_raw(struct nsm *nsm, struct nsm_data_req *req,
152 			struct nsm_raw *raw)
153 {
154 	/* Verify the user input size. */
155 	if (raw->request.len > sizeof(req->data))
156 		return -EMSGSIZE;
157 
158 	/* Copy the request payload */
159 	if (copy_from_user(req->data, u64_to_user_ptr(raw->request.addr),
160 			   raw->request.len))
161 		return -EFAULT;
162 
163 	req->len = raw->request.len;
164 
165 	return 0;
166 }
167 
168 /* Copy the response of a raw message back to user-space */
parse_resp_raw(struct nsm * nsm,struct nsm_data_resp * resp,struct nsm_raw * raw)169 static int parse_resp_raw(struct nsm *nsm, struct nsm_data_resp *resp,
170 			  struct nsm_raw *raw)
171 {
172 	/* Truncate any message that does not fit. */
173 	raw->response.len = min_t(u64, raw->response.len, resp->len);
174 
175 	/* Copy the response content to user space */
176 	if (copy_to_user(u64_to_user_ptr(raw->response.addr),
177 			 resp->data, raw->response.len))
178 		return -EFAULT;
179 
180 	return 0;
181 }
182 
183 /* Virtqueue interrupt handler */
nsm_vq_callback(struct virtqueue * vq)184 static void nsm_vq_callback(struct virtqueue *vq)
185 {
186 	struct nsm *nsm = vq->vdev->priv;
187 
188 	complete(&nsm->cmd_done);
189 }
190 
191 /* Forward a message to the NSM device and wait for the response from it */
nsm_sendrecv_msg_locked(struct nsm * nsm)192 static int nsm_sendrecv_msg_locked(struct nsm *nsm)
193 {
194 	struct device *dev = &nsm->vdev->dev;
195 	struct scatterlist sg_in, sg_out;
196 	struct nsm_msg *msg = &nsm->msg;
197 	struct virtqueue *vq = nsm->vq;
198 	unsigned int len;
199 	void *queue_buf;
200 	bool kicked;
201 	int rc;
202 
203 	/* Initialize scatter-gather lists with request and response buffers. */
204 	sg_init_one(&sg_out, msg->req.data, msg->req.len);
205 	sg_init_one(&sg_in, msg->resp.data, sizeof(msg->resp.data));
206 
207 	init_completion(&nsm->cmd_done);
208 	/* Add the request buffer (read by the device). */
209 	rc = virtqueue_add_outbuf(vq, &sg_out, 1, msg->req.data, GFP_KERNEL);
210 	if (rc)
211 		return rc;
212 
213 	/* Add the response buffer (written by the device). */
214 	rc = virtqueue_add_inbuf(vq, &sg_in, 1, msg->resp.data, GFP_KERNEL);
215 	if (rc)
216 		goto cleanup;
217 
218 	kicked = virtqueue_kick(vq);
219 	if (!kicked) {
220 		/* Cannot kick the virtqueue. */
221 		rc = -EIO;
222 		goto cleanup;
223 	}
224 
225 	/* If the kick succeeded, wait for the device's response. */
226 	if (!wait_for_completion_io_timeout(&nsm->cmd_done,
227 		msecs_to_jiffies(NSM_DEFAULT_TIMEOUT_MSECS))) {
228 		rc = -ETIMEDOUT;
229 		goto cleanup;
230 	}
231 
232 	queue_buf = virtqueue_get_buf(vq, &len);
233 	if (!queue_buf || (queue_buf != msg->req.data)) {
234 		dev_err(dev, "wrong request buffer.");
235 		rc = -ENODATA;
236 		goto cleanup;
237 	}
238 
239 	queue_buf = virtqueue_get_buf(vq, &len);
240 	if (!queue_buf || (queue_buf != msg->resp.data)) {
241 		dev_err(dev, "wrong response buffer.");
242 		rc = -ENODATA;
243 		goto cleanup;
244 	}
245 
246 	msg->resp.len = len;
247 
248 	rc = 0;
249 
250 cleanup:
251 	if (rc) {
252 		/* Clean the virtqueue. */
253 		while (virtqueue_get_buf(vq, &len) != NULL)
254 			;
255 	}
256 
257 	return rc;
258 }
259 
fill_req_get_random(struct nsm * nsm,struct nsm_data_req * req)260 static int fill_req_get_random(struct nsm *nsm, struct nsm_data_req *req)
261 {
262 	/*
263 	 * 69                          # text(9)
264 	 *     47657452616E646F6D      # "GetRandom"
265 	 */
266 	const u8 request[] = { CBOR_TYPE_TEXT + strlen("GetRandom"),
267 			       'G', 'e', 't', 'R', 'a', 'n', 'd', 'o', 'm' };
268 
269 	memcpy(req->data, request, sizeof(request));
270 	req->len = sizeof(request);
271 
272 	return 0;
273 }
274 
parse_resp_get_random(struct nsm * nsm,struct nsm_data_resp * resp,void * out,size_t max)275 static int parse_resp_get_random(struct nsm *nsm, struct nsm_data_resp *resp,
276 				 void *out, size_t max)
277 {
278 	/*
279 	 * A1                          # map(1)
280 	 *     69                      # text(9) - Name of field
281 	 *         47657452616E646F6D  # "GetRandom"
282 	 * A1                          # map(1) - The field itself
283 	 *     66                      # text(6)
284 	 *         72616E646F6D        # "random"
285 	 *	# The rest of the response is random data
286 	 */
287 	const u8 response[] = { CBOR_TYPE_MAP + 1,
288 				CBOR_TYPE_TEXT + strlen("GetRandom"),
289 				'G', 'e', 't', 'R', 'a', 'n', 'd', 'o', 'm',
290 				CBOR_TYPE_MAP + 1,
291 				CBOR_TYPE_TEXT + strlen("random"),
292 				'r', 'a', 'n', 'd', 'o', 'm' };
293 	struct device *dev = &nsm->vdev->dev;
294 	u8 *rand_data = NULL;
295 	u8 *resp_ptr = resp->data;
296 	u64 resp_len = resp->len;
297 	int rc;
298 
299 	if ((resp->len < sizeof(response) + 1) ||
300 	    (memcmp(resp_ptr, response, sizeof(response)) != 0)) {
301 		dev_err(dev, "Invalid response for GetRandom");
302 		return -EFAULT;
303 	}
304 
305 	resp_ptr += sizeof(response);
306 	resp_len -= sizeof(response);
307 
308 	rc = cbor_object_get_array(resp_ptr, resp_len, &rand_data);
309 	if (rc < 0) {
310 		dev_err(dev, "GetRandom: Invalid CBOR encoding\n");
311 		return rc;
312 	}
313 
314 	rc = min_t(size_t, rc, max);
315 	memcpy(out, rand_data, rc);
316 
317 	return rc;
318 }
319 
320 /*
321  * HwRNG implementation
322  */
nsm_rng_read(struct hwrng * rng,void * data,size_t max,bool wait)323 static int nsm_rng_read(struct hwrng *rng, void *data, size_t max, bool wait)
324 {
325 	struct nsm *nsm = hwrng_to_nsm(rng);
326 	struct device *dev = &nsm->vdev->dev;
327 	int rc = 0;
328 
329 	/* NSM always needs to wait for a response */
330 	if (!wait)
331 		return 0;
332 
333 	mutex_lock(&nsm->lock);
334 
335 	rc = fill_req_get_random(nsm, &nsm->msg.req);
336 	if (rc != 0)
337 		goto out;
338 
339 	rc = nsm_sendrecv_msg_locked(nsm);
340 	if (rc != 0)
341 		goto out;
342 
343 	rc = parse_resp_get_random(nsm, &nsm->msg.resp, data, max);
344 	if (rc < 0)
345 		goto out;
346 
347 	dev_dbg(dev, "RNG: returning rand bytes = %d", rc);
348 out:
349 	mutex_unlock(&nsm->lock);
350 	return rc;
351 }
352 
nsm_dev_ioctl(struct file * file,unsigned int cmd,unsigned long arg)353 static long nsm_dev_ioctl(struct file *file, unsigned int cmd,
354 	unsigned long arg)
355 {
356 	void __user *argp = u64_to_user_ptr((u64)arg);
357 	struct nsm *nsm = file_to_nsm(file);
358 	struct nsm_raw raw;
359 	int r = 0;
360 
361 	if (cmd != NSM_IOCTL_RAW)
362 		return -EINVAL;
363 
364 	if (_IOC_SIZE(cmd) != sizeof(raw))
365 		return -EINVAL;
366 
367 	/* Copy user argument struct to kernel argument struct */
368 	r = -EFAULT;
369 	if (copy_from_user(&raw, argp, _IOC_SIZE(cmd)))
370 		goto out;
371 
372 	mutex_lock(&nsm->lock);
373 
374 	/* Convert kernel argument struct to device request */
375 	r = fill_req_raw(nsm, &nsm->msg.req, &raw);
376 	if (r)
377 		goto out;
378 
379 	/* Send message to NSM and read reply */
380 	r = nsm_sendrecv_msg_locked(nsm);
381 	if (r)
382 		goto out;
383 
384 	/* Parse device response into kernel argument struct */
385 	r = parse_resp_raw(nsm, &nsm->msg.resp, &raw);
386 	if (r)
387 		goto out;
388 
389 	/* Copy kernel argument struct back to user argument struct */
390 	r = -EFAULT;
391 	if (copy_to_user(argp, &raw, sizeof(raw)))
392 		goto out;
393 
394 	r = 0;
395 
396 out:
397 	mutex_unlock(&nsm->lock);
398 	return r;
399 }
400 
nsm_device_init_vq(struct virtio_device * vdev)401 static int nsm_device_init_vq(struct virtio_device *vdev)
402 {
403 	struct virtqueue *vq = virtio_find_single_vq(vdev,
404 		nsm_vq_callback, "nsm.vq.0");
405 	struct nsm *nsm = vdev->priv;
406 
407 	if (IS_ERR(vq))
408 		return PTR_ERR(vq);
409 
410 	nsm->vq = vq;
411 
412 	return 0;
413 }
414 
415 static const struct file_operations nsm_dev_fops = {
416 	.unlocked_ioctl = nsm_dev_ioctl,
417 	.compat_ioctl = compat_ptr_ioctl,
418 };
419 
420 /* Handler for probing the NSM device */
nsm_device_probe(struct virtio_device * vdev)421 static int nsm_device_probe(struct virtio_device *vdev)
422 {
423 	struct device *dev = &vdev->dev;
424 	struct nsm *nsm;
425 	int rc;
426 
427 	nsm = devm_kzalloc(&vdev->dev, sizeof(*nsm), GFP_KERNEL);
428 	if (!nsm)
429 		return -ENOMEM;
430 
431 	vdev->priv = nsm;
432 	nsm->vdev = vdev;
433 
434 	rc = nsm_device_init_vq(vdev);
435 	if (rc) {
436 		dev_err(dev, "queue failed to initialize: %d.\n", rc);
437 		goto err_init_vq;
438 	}
439 
440 	mutex_init(&nsm->lock);
441 
442 	/* Register as hwrng provider */
443 	nsm->hwrng = (struct hwrng) {
444 		.read = nsm_rng_read,
445 		.name = "nsm-hwrng",
446 		.quality = 1000,
447 	};
448 
449 	rc = hwrng_register(&nsm->hwrng);
450 	if (rc) {
451 		dev_err(dev, "RNG initialization error: %d.\n", rc);
452 		goto err_hwrng;
453 	}
454 
455 	/* Register /dev/nsm device node */
456 	nsm->misc = (struct miscdevice) {
457 		.minor	= MISC_DYNAMIC_MINOR,
458 		.name	= "nsm",
459 		.fops	= &nsm_dev_fops,
460 		.mode	= 0666,
461 	};
462 
463 	rc = misc_register(&nsm->misc);
464 	if (rc) {
465 		dev_err(dev, "misc device registration error: %d.\n", rc);
466 		goto err_misc;
467 	}
468 
469 	return 0;
470 
471 err_misc:
472 	hwrng_unregister(&nsm->hwrng);
473 err_hwrng:
474 	vdev->config->del_vqs(vdev);
475 err_init_vq:
476 	return rc;
477 }
478 
479 /* Handler for removing the NSM device */
nsm_device_remove(struct virtio_device * vdev)480 static void nsm_device_remove(struct virtio_device *vdev)
481 {
482 	struct nsm *nsm = vdev->priv;
483 
484 	hwrng_unregister(&nsm->hwrng);
485 
486 	vdev->config->del_vqs(vdev);
487 	misc_deregister(&nsm->misc);
488 }
489 
490 /* NSM device configuration structure */
491 static struct virtio_driver virtio_nsm_driver = {
492 	.feature_table             = 0,
493 	.feature_table_size        = 0,
494 	.feature_table_legacy      = 0,
495 	.feature_table_size_legacy = 0,
496 	.driver.name               = KBUILD_MODNAME,
497 	.id_table                  = id_table,
498 	.probe                     = nsm_device_probe,
499 	.remove                    = nsm_device_remove,
500 };
501 
502 module_virtio_driver(virtio_nsm_driver);
503 MODULE_DEVICE_TABLE(virtio, id_table);
504 MODULE_DESCRIPTION("Virtio NSM driver");
505 MODULE_LICENSE("GPL");
506