1  // SPDX-License-Identifier: GPL-2.0-only
2  /* Copyright (C) 2009 Red Hat, Inc.
3   * Copyright (C) 2006 Rusty Russell IBM Corporation
4   *
5   * Author: Michael S. Tsirkin <mst@redhat.com>
6   *
7   * Inspiration, some code, and most witty comments come from
8   * Documentation/virtual/lguest/lguest.c, by Rusty Russell
9   *
10   * Generic code for virtio server in host kernel.
11   */
12  
13  #include <linux/eventfd.h>
14  #include <linux/vhost.h>
15  #include <linux/uio.h>
16  #include <linux/mm.h>
17  #include <linux/miscdevice.h>
18  #include <linux/mutex.h>
19  #include <linux/poll.h>
20  #include <linux/file.h>
21  #include <linux/highmem.h>
22  #include <linux/slab.h>
23  #include <linux/vmalloc.h>
24  #include <linux/kthread.h>
25  #include <linux/module.h>
26  #include <linux/sort.h>
27  #include <linux/sched/mm.h>
28  #include <linux/sched/signal.h>
29  #include <linux/sched/vhost_task.h>
30  #include <linux/interval_tree_generic.h>
31  #include <linux/nospec.h>
32  #include <linux/kcov.h>
33  
34  #include "vhost.h"
35  
36  static ushort max_mem_regions = 64;
37  module_param(max_mem_regions, ushort, 0444);
38  MODULE_PARM_DESC(max_mem_regions,
39  	"Maximum number of memory regions in memory map. (default: 64)");
40  static int max_iotlb_entries = 2048;
41  module_param(max_iotlb_entries, int, 0444);
42  MODULE_PARM_DESC(max_iotlb_entries,
43  	"Maximum number of iotlb entries. (default: 2048)");
44  
45  enum {
46  	VHOST_MEMORY_F_LOG = 0x1,
47  };
48  
49  #define vhost_used_event(vq) ((__virtio16 __user *)&vq->avail->ring[vq->num])
50  #define vhost_avail_event(vq) ((__virtio16 __user *)&vq->used->ring[vq->num])
51  
52  #ifdef CONFIG_VHOST_CROSS_ENDIAN_LEGACY
vhost_disable_cross_endian(struct vhost_virtqueue * vq)53  static void vhost_disable_cross_endian(struct vhost_virtqueue *vq)
54  {
55  	vq->user_be = !virtio_legacy_is_little_endian();
56  }
57  
vhost_enable_cross_endian_big(struct vhost_virtqueue * vq)58  static void vhost_enable_cross_endian_big(struct vhost_virtqueue *vq)
59  {
60  	vq->user_be = true;
61  }
62  
vhost_enable_cross_endian_little(struct vhost_virtqueue * vq)63  static void vhost_enable_cross_endian_little(struct vhost_virtqueue *vq)
64  {
65  	vq->user_be = false;
66  }
67  
vhost_set_vring_endian(struct vhost_virtqueue * vq,int __user * argp)68  static long vhost_set_vring_endian(struct vhost_virtqueue *vq, int __user *argp)
69  {
70  	struct vhost_vring_state s;
71  
72  	if (vq->private_data)
73  		return -EBUSY;
74  
75  	if (copy_from_user(&s, argp, sizeof(s)))
76  		return -EFAULT;
77  
78  	if (s.num != VHOST_VRING_LITTLE_ENDIAN &&
79  	    s.num != VHOST_VRING_BIG_ENDIAN)
80  		return -EINVAL;
81  
82  	if (s.num == VHOST_VRING_BIG_ENDIAN)
83  		vhost_enable_cross_endian_big(vq);
84  	else
85  		vhost_enable_cross_endian_little(vq);
86  
87  	return 0;
88  }
89  
vhost_get_vring_endian(struct vhost_virtqueue * vq,u32 idx,int __user * argp)90  static long vhost_get_vring_endian(struct vhost_virtqueue *vq, u32 idx,
91  				   int __user *argp)
92  {
93  	struct vhost_vring_state s = {
94  		.index = idx,
95  		.num = vq->user_be
96  	};
97  
98  	if (copy_to_user(argp, &s, sizeof(s)))
99  		return -EFAULT;
100  
101  	return 0;
102  }
103  
vhost_init_is_le(struct vhost_virtqueue * vq)104  static void vhost_init_is_le(struct vhost_virtqueue *vq)
105  {
106  	/* Note for legacy virtio: user_be is initialized at reset time
107  	 * according to the host endianness. If userspace does not set an
108  	 * explicit endianness, the default behavior is native endian, as
109  	 * expected by legacy virtio.
110  	 */
111  	vq->is_le = vhost_has_feature(vq, VIRTIO_F_VERSION_1) || !vq->user_be;
112  }
113  #else
vhost_disable_cross_endian(struct vhost_virtqueue * vq)114  static void vhost_disable_cross_endian(struct vhost_virtqueue *vq)
115  {
116  }
117  
vhost_set_vring_endian(struct vhost_virtqueue * vq,int __user * argp)118  static long vhost_set_vring_endian(struct vhost_virtqueue *vq, int __user *argp)
119  {
120  	return -ENOIOCTLCMD;
121  }
122  
vhost_get_vring_endian(struct vhost_virtqueue * vq,u32 idx,int __user * argp)123  static long vhost_get_vring_endian(struct vhost_virtqueue *vq, u32 idx,
124  				   int __user *argp)
125  {
126  	return -ENOIOCTLCMD;
127  }
128  
vhost_init_is_le(struct vhost_virtqueue * vq)129  static void vhost_init_is_le(struct vhost_virtqueue *vq)
130  {
131  	vq->is_le = vhost_has_feature(vq, VIRTIO_F_VERSION_1)
132  		|| virtio_legacy_is_little_endian();
133  }
134  #endif /* CONFIG_VHOST_CROSS_ENDIAN_LEGACY */
135  
vhost_reset_is_le(struct vhost_virtqueue * vq)136  static void vhost_reset_is_le(struct vhost_virtqueue *vq)
137  {
138  	vhost_init_is_le(vq);
139  }
140  
141  struct vhost_flush_struct {
142  	struct vhost_work work;
143  	struct completion wait_event;
144  };
145  
vhost_flush_work(struct vhost_work * work)146  static void vhost_flush_work(struct vhost_work *work)
147  {
148  	struct vhost_flush_struct *s;
149  
150  	s = container_of(work, struct vhost_flush_struct, work);
151  	complete(&s->wait_event);
152  }
153  
vhost_poll_func(struct file * file,wait_queue_head_t * wqh,poll_table * pt)154  static void vhost_poll_func(struct file *file, wait_queue_head_t *wqh,
155  			    poll_table *pt)
156  {
157  	struct vhost_poll *poll;
158  
159  	poll = container_of(pt, struct vhost_poll, table);
160  	poll->wqh = wqh;
161  	add_wait_queue(wqh, &poll->wait);
162  }
163  
vhost_poll_wakeup(wait_queue_entry_t * wait,unsigned mode,int sync,void * key)164  static int vhost_poll_wakeup(wait_queue_entry_t *wait, unsigned mode, int sync,
165  			     void *key)
166  {
167  	struct vhost_poll *poll = container_of(wait, struct vhost_poll, wait);
168  	struct vhost_work *work = &poll->work;
169  
170  	if (!(key_to_poll(key) & poll->mask))
171  		return 0;
172  
173  	if (!poll->dev->use_worker)
174  		work->fn(work);
175  	else
176  		vhost_poll_queue(poll);
177  
178  	return 0;
179  }
180  
vhost_work_init(struct vhost_work * work,vhost_work_fn_t fn)181  void vhost_work_init(struct vhost_work *work, vhost_work_fn_t fn)
182  {
183  	clear_bit(VHOST_WORK_QUEUED, &work->flags);
184  	work->fn = fn;
185  }
186  EXPORT_SYMBOL_GPL(vhost_work_init);
187  
188  /* Init poll structure */
vhost_poll_init(struct vhost_poll * poll,vhost_work_fn_t fn,__poll_t mask,struct vhost_dev * dev,struct vhost_virtqueue * vq)189  void vhost_poll_init(struct vhost_poll *poll, vhost_work_fn_t fn,
190  		     __poll_t mask, struct vhost_dev *dev,
191  		     struct vhost_virtqueue *vq)
192  {
193  	init_waitqueue_func_entry(&poll->wait, vhost_poll_wakeup);
194  	init_poll_funcptr(&poll->table, vhost_poll_func);
195  	poll->mask = mask;
196  	poll->dev = dev;
197  	poll->wqh = NULL;
198  	poll->vq = vq;
199  
200  	vhost_work_init(&poll->work, fn);
201  }
202  EXPORT_SYMBOL_GPL(vhost_poll_init);
203  
204  /* Start polling a file. We add ourselves to file's wait queue. The caller must
205   * keep a reference to a file until after vhost_poll_stop is called. */
vhost_poll_start(struct vhost_poll * poll,struct file * file)206  int vhost_poll_start(struct vhost_poll *poll, struct file *file)
207  {
208  	__poll_t mask;
209  
210  	if (poll->wqh)
211  		return 0;
212  
213  	mask = vfs_poll(file, &poll->table);
214  	if (mask)
215  		vhost_poll_wakeup(&poll->wait, 0, 0, poll_to_key(mask));
216  	if (mask & EPOLLERR) {
217  		vhost_poll_stop(poll);
218  		return -EINVAL;
219  	}
220  
221  	return 0;
222  }
223  EXPORT_SYMBOL_GPL(vhost_poll_start);
224  
225  /* Stop polling a file. After this function returns, it becomes safe to drop the
226   * file reference. You must also flush afterwards. */
vhost_poll_stop(struct vhost_poll * poll)227  void vhost_poll_stop(struct vhost_poll *poll)
228  {
229  	if (poll->wqh) {
230  		remove_wait_queue(poll->wqh, &poll->wait);
231  		poll->wqh = NULL;
232  	}
233  }
234  EXPORT_SYMBOL_GPL(vhost_poll_stop);
235  
vhost_worker_queue(struct vhost_worker * worker,struct vhost_work * work)236  static void vhost_worker_queue(struct vhost_worker *worker,
237  			       struct vhost_work *work)
238  {
239  	if (!test_and_set_bit(VHOST_WORK_QUEUED, &work->flags)) {
240  		/* We can only add the work to the list after we're
241  		 * sure it was not in the list.
242  		 * test_and_set_bit() implies a memory barrier.
243  		 */
244  		llist_add(&work->node, &worker->work_list);
245  		vhost_task_wake(worker->vtsk);
246  	}
247  }
248  
vhost_vq_work_queue(struct vhost_virtqueue * vq,struct vhost_work * work)249  bool vhost_vq_work_queue(struct vhost_virtqueue *vq, struct vhost_work *work)
250  {
251  	struct vhost_worker *worker;
252  	bool queued = false;
253  
254  	rcu_read_lock();
255  	worker = rcu_dereference(vq->worker);
256  	if (worker) {
257  		queued = true;
258  		vhost_worker_queue(worker, work);
259  	}
260  	rcu_read_unlock();
261  
262  	return queued;
263  }
264  EXPORT_SYMBOL_GPL(vhost_vq_work_queue);
265  
266  /**
267   * __vhost_worker_flush - flush a worker
268   * @worker: worker to flush
269   *
270   * The worker's flush_mutex must be held.
271   */
__vhost_worker_flush(struct vhost_worker * worker)272  static void __vhost_worker_flush(struct vhost_worker *worker)
273  {
274  	struct vhost_flush_struct flush;
275  
276  	if (!worker->attachment_cnt || worker->killed)
277  		return;
278  
279  	init_completion(&flush.wait_event);
280  	vhost_work_init(&flush.work, vhost_flush_work);
281  
282  	vhost_worker_queue(worker, &flush.work);
283  	/*
284  	 * Drop mutex in case our worker is killed and it needs to take the
285  	 * mutex to force cleanup.
286  	 */
287  	mutex_unlock(&worker->mutex);
288  	wait_for_completion(&flush.wait_event);
289  	mutex_lock(&worker->mutex);
290  }
291  
vhost_worker_flush(struct vhost_worker * worker)292  static void vhost_worker_flush(struct vhost_worker *worker)
293  {
294  	mutex_lock(&worker->mutex);
295  	__vhost_worker_flush(worker);
296  	mutex_unlock(&worker->mutex);
297  }
298  
vhost_dev_flush(struct vhost_dev * dev)299  void vhost_dev_flush(struct vhost_dev *dev)
300  {
301  	struct vhost_worker *worker;
302  	unsigned long i;
303  
304  	xa_for_each(&dev->worker_xa, i, worker)
305  		vhost_worker_flush(worker);
306  }
307  EXPORT_SYMBOL_GPL(vhost_dev_flush);
308  
309  /* A lockless hint for busy polling code to exit the loop */
vhost_vq_has_work(struct vhost_virtqueue * vq)310  bool vhost_vq_has_work(struct vhost_virtqueue *vq)
311  {
312  	struct vhost_worker *worker;
313  	bool has_work = false;
314  
315  	rcu_read_lock();
316  	worker = rcu_dereference(vq->worker);
317  	if (worker && !llist_empty(&worker->work_list))
318  		has_work = true;
319  	rcu_read_unlock();
320  
321  	return has_work;
322  }
323  EXPORT_SYMBOL_GPL(vhost_vq_has_work);
324  
vhost_poll_queue(struct vhost_poll * poll)325  void vhost_poll_queue(struct vhost_poll *poll)
326  {
327  	vhost_vq_work_queue(poll->vq, &poll->work);
328  }
329  EXPORT_SYMBOL_GPL(vhost_poll_queue);
330  
__vhost_vq_meta_reset(struct vhost_virtqueue * vq)331  static void __vhost_vq_meta_reset(struct vhost_virtqueue *vq)
332  {
333  	int j;
334  
335  	for (j = 0; j < VHOST_NUM_ADDRS; j++)
336  		vq->meta_iotlb[j] = NULL;
337  }
338  
vhost_vq_meta_reset(struct vhost_dev * d)339  static void vhost_vq_meta_reset(struct vhost_dev *d)
340  {
341  	int i;
342  
343  	for (i = 0; i < d->nvqs; ++i)
344  		__vhost_vq_meta_reset(d->vqs[i]);
345  }
346  
vhost_vring_call_reset(struct vhost_vring_call * call_ctx)347  static void vhost_vring_call_reset(struct vhost_vring_call *call_ctx)
348  {
349  	call_ctx->ctx = NULL;
350  	memset(&call_ctx->producer, 0x0, sizeof(struct irq_bypass_producer));
351  }
352  
vhost_vq_is_setup(struct vhost_virtqueue * vq)353  bool vhost_vq_is_setup(struct vhost_virtqueue *vq)
354  {
355  	return vq->avail && vq->desc && vq->used && vhost_vq_access_ok(vq);
356  }
357  EXPORT_SYMBOL_GPL(vhost_vq_is_setup);
358  
vhost_vq_reset(struct vhost_dev * dev,struct vhost_virtqueue * vq)359  static void vhost_vq_reset(struct vhost_dev *dev,
360  			   struct vhost_virtqueue *vq)
361  {
362  	vq->num = 1;
363  	vq->desc = NULL;
364  	vq->avail = NULL;
365  	vq->used = NULL;
366  	vq->last_avail_idx = 0;
367  	vq->avail_idx = 0;
368  	vq->last_used_idx = 0;
369  	vq->signalled_used = 0;
370  	vq->signalled_used_valid = false;
371  	vq->used_flags = 0;
372  	vq->log_used = false;
373  	vq->log_addr = -1ull;
374  	vq->private_data = NULL;
375  	vq->acked_features = 0;
376  	vq->acked_backend_features = 0;
377  	vq->log_base = NULL;
378  	vq->error_ctx = NULL;
379  	vq->kick = NULL;
380  	vq->log_ctx = NULL;
381  	vhost_disable_cross_endian(vq);
382  	vhost_reset_is_le(vq);
383  	vq->busyloop_timeout = 0;
384  	vq->umem = NULL;
385  	vq->iotlb = NULL;
386  	rcu_assign_pointer(vq->worker, NULL);
387  	vhost_vring_call_reset(&vq->call_ctx);
388  	__vhost_vq_meta_reset(vq);
389  }
390  
vhost_run_work_list(void * data)391  static bool vhost_run_work_list(void *data)
392  {
393  	struct vhost_worker *worker = data;
394  	struct vhost_work *work, *work_next;
395  	struct llist_node *node;
396  
397  	node = llist_del_all(&worker->work_list);
398  	if (node) {
399  		__set_current_state(TASK_RUNNING);
400  
401  		node = llist_reverse_order(node);
402  		/* make sure flag is seen after deletion */
403  		smp_wmb();
404  		llist_for_each_entry_safe(work, work_next, node, node) {
405  			clear_bit(VHOST_WORK_QUEUED, &work->flags);
406  			kcov_remote_start_common(worker->kcov_handle);
407  			work->fn(work);
408  			kcov_remote_stop();
409  			cond_resched();
410  		}
411  	}
412  
413  	return !!node;
414  }
415  
vhost_worker_killed(void * data)416  static void vhost_worker_killed(void *data)
417  {
418  	struct vhost_worker *worker = data;
419  	struct vhost_dev *dev = worker->dev;
420  	struct vhost_virtqueue *vq;
421  	int i, attach_cnt = 0;
422  
423  	mutex_lock(&worker->mutex);
424  	worker->killed = true;
425  
426  	for (i = 0; i < dev->nvqs; i++) {
427  		vq = dev->vqs[i];
428  
429  		mutex_lock(&vq->mutex);
430  		if (worker ==
431  		    rcu_dereference_check(vq->worker,
432  					  lockdep_is_held(&vq->mutex))) {
433  			rcu_assign_pointer(vq->worker, NULL);
434  			attach_cnt++;
435  		}
436  		mutex_unlock(&vq->mutex);
437  	}
438  
439  	worker->attachment_cnt -= attach_cnt;
440  	if (attach_cnt)
441  		synchronize_rcu();
442  	/*
443  	 * Finish vhost_worker_flush calls and any other works that snuck in
444  	 * before the synchronize_rcu.
445  	 */
446  	vhost_run_work_list(worker);
447  	mutex_unlock(&worker->mutex);
448  }
449  
vhost_vq_free_iovecs(struct vhost_virtqueue * vq)450  static void vhost_vq_free_iovecs(struct vhost_virtqueue *vq)
451  {
452  	kfree(vq->indirect);
453  	vq->indirect = NULL;
454  	kfree(vq->log);
455  	vq->log = NULL;
456  	kfree(vq->heads);
457  	vq->heads = NULL;
458  }
459  
460  /* Helper to allocate iovec buffers for all vqs. */
vhost_dev_alloc_iovecs(struct vhost_dev * dev)461  static long vhost_dev_alloc_iovecs(struct vhost_dev *dev)
462  {
463  	struct vhost_virtqueue *vq;
464  	int i;
465  
466  	for (i = 0; i < dev->nvqs; ++i) {
467  		vq = dev->vqs[i];
468  		vq->indirect = kmalloc_array(UIO_MAXIOV,
469  					     sizeof(*vq->indirect),
470  					     GFP_KERNEL);
471  		vq->log = kmalloc_array(dev->iov_limit, sizeof(*vq->log),
472  					GFP_KERNEL);
473  		vq->heads = kmalloc_array(dev->iov_limit, sizeof(*vq->heads),
474  					  GFP_KERNEL);
475  		if (!vq->indirect || !vq->log || !vq->heads)
476  			goto err_nomem;
477  	}
478  	return 0;
479  
480  err_nomem:
481  	for (; i >= 0; --i)
482  		vhost_vq_free_iovecs(dev->vqs[i]);
483  	return -ENOMEM;
484  }
485  
vhost_dev_free_iovecs(struct vhost_dev * dev)486  static void vhost_dev_free_iovecs(struct vhost_dev *dev)
487  {
488  	int i;
489  
490  	for (i = 0; i < dev->nvqs; ++i)
491  		vhost_vq_free_iovecs(dev->vqs[i]);
492  }
493  
vhost_exceeds_weight(struct vhost_virtqueue * vq,int pkts,int total_len)494  bool vhost_exceeds_weight(struct vhost_virtqueue *vq,
495  			  int pkts, int total_len)
496  {
497  	struct vhost_dev *dev = vq->dev;
498  
499  	if ((dev->byte_weight && total_len >= dev->byte_weight) ||
500  	    pkts >= dev->weight) {
501  		vhost_poll_queue(&vq->poll);
502  		return true;
503  	}
504  
505  	return false;
506  }
507  EXPORT_SYMBOL_GPL(vhost_exceeds_weight);
508  
vhost_get_avail_size(struct vhost_virtqueue * vq,unsigned int num)509  static size_t vhost_get_avail_size(struct vhost_virtqueue *vq,
510  				   unsigned int num)
511  {
512  	size_t event __maybe_unused =
513  	       vhost_has_feature(vq, VIRTIO_RING_F_EVENT_IDX) ? 2 : 0;
514  
515  	return size_add(struct_size(vq->avail, ring, num), event);
516  }
517  
vhost_get_used_size(struct vhost_virtqueue * vq,unsigned int num)518  static size_t vhost_get_used_size(struct vhost_virtqueue *vq,
519  				  unsigned int num)
520  {
521  	size_t event __maybe_unused =
522  	       vhost_has_feature(vq, VIRTIO_RING_F_EVENT_IDX) ? 2 : 0;
523  
524  	return size_add(struct_size(vq->used, ring, num), event);
525  }
526  
vhost_get_desc_size(struct vhost_virtqueue * vq,unsigned int num)527  static size_t vhost_get_desc_size(struct vhost_virtqueue *vq,
528  				  unsigned int num)
529  {
530  	return sizeof(*vq->desc) * num;
531  }
532  
vhost_dev_init(struct vhost_dev * dev,struct vhost_virtqueue ** vqs,int nvqs,int iov_limit,int weight,int byte_weight,bool use_worker,int (* msg_handler)(struct vhost_dev * dev,u32 asid,struct vhost_iotlb_msg * msg))533  void vhost_dev_init(struct vhost_dev *dev,
534  		    struct vhost_virtqueue **vqs, int nvqs,
535  		    int iov_limit, int weight, int byte_weight,
536  		    bool use_worker,
537  		    int (*msg_handler)(struct vhost_dev *dev, u32 asid,
538  				       struct vhost_iotlb_msg *msg))
539  {
540  	struct vhost_virtqueue *vq;
541  	int i;
542  
543  	dev->vqs = vqs;
544  	dev->nvqs = nvqs;
545  	mutex_init(&dev->mutex);
546  	dev->log_ctx = NULL;
547  	dev->umem = NULL;
548  	dev->iotlb = NULL;
549  	dev->mm = NULL;
550  	dev->iov_limit = iov_limit;
551  	dev->weight = weight;
552  	dev->byte_weight = byte_weight;
553  	dev->use_worker = use_worker;
554  	dev->msg_handler = msg_handler;
555  	init_waitqueue_head(&dev->wait);
556  	INIT_LIST_HEAD(&dev->read_list);
557  	INIT_LIST_HEAD(&dev->pending_list);
558  	spin_lock_init(&dev->iotlb_lock);
559  	xa_init_flags(&dev->worker_xa, XA_FLAGS_ALLOC);
560  
561  	for (i = 0; i < dev->nvqs; ++i) {
562  		vq = dev->vqs[i];
563  		vq->log = NULL;
564  		vq->indirect = NULL;
565  		vq->heads = NULL;
566  		vq->dev = dev;
567  		mutex_init(&vq->mutex);
568  		vhost_vq_reset(dev, vq);
569  		if (vq->handle_kick)
570  			vhost_poll_init(&vq->poll, vq->handle_kick,
571  					EPOLLIN, dev, vq);
572  	}
573  }
574  EXPORT_SYMBOL_GPL(vhost_dev_init);
575  
576  /* Caller should have device mutex */
vhost_dev_check_owner(struct vhost_dev * dev)577  long vhost_dev_check_owner(struct vhost_dev *dev)
578  {
579  	/* Are you the owner? If not, I don't think you mean to do that */
580  	return dev->mm == current->mm ? 0 : -EPERM;
581  }
582  EXPORT_SYMBOL_GPL(vhost_dev_check_owner);
583  
584  /* Caller should have device mutex */
vhost_dev_has_owner(struct vhost_dev * dev)585  bool vhost_dev_has_owner(struct vhost_dev *dev)
586  {
587  	return dev->mm;
588  }
589  EXPORT_SYMBOL_GPL(vhost_dev_has_owner);
590  
vhost_attach_mm(struct vhost_dev * dev)591  static void vhost_attach_mm(struct vhost_dev *dev)
592  {
593  	/* No owner, become one */
594  	if (dev->use_worker) {
595  		dev->mm = get_task_mm(current);
596  	} else {
597  		/* vDPA device does not use worker thead, so there's
598  		 * no need to hold the address space for mm. This help
599  		 * to avoid deadlock in the case of mmap() which may
600  		 * held the refcnt of the file and depends on release
601  		 * method to remove vma.
602  		 */
603  		dev->mm = current->mm;
604  		mmgrab(dev->mm);
605  	}
606  }
607  
vhost_detach_mm(struct vhost_dev * dev)608  static void vhost_detach_mm(struct vhost_dev *dev)
609  {
610  	if (!dev->mm)
611  		return;
612  
613  	if (dev->use_worker)
614  		mmput(dev->mm);
615  	else
616  		mmdrop(dev->mm);
617  
618  	dev->mm = NULL;
619  }
620  
vhost_worker_destroy(struct vhost_dev * dev,struct vhost_worker * worker)621  static void vhost_worker_destroy(struct vhost_dev *dev,
622  				 struct vhost_worker *worker)
623  {
624  	if (!worker)
625  		return;
626  
627  	WARN_ON(!llist_empty(&worker->work_list));
628  	xa_erase(&dev->worker_xa, worker->id);
629  	vhost_task_stop(worker->vtsk);
630  	kfree(worker);
631  }
632  
vhost_workers_free(struct vhost_dev * dev)633  static void vhost_workers_free(struct vhost_dev *dev)
634  {
635  	struct vhost_worker *worker;
636  	unsigned long i;
637  
638  	if (!dev->use_worker)
639  		return;
640  
641  	for (i = 0; i < dev->nvqs; i++)
642  		rcu_assign_pointer(dev->vqs[i]->worker, NULL);
643  	/*
644  	 * Free the default worker we created and cleanup workers userspace
645  	 * created but couldn't clean up (it forgot or crashed).
646  	 */
647  	xa_for_each(&dev->worker_xa, i, worker)
648  		vhost_worker_destroy(dev, worker);
649  	xa_destroy(&dev->worker_xa);
650  }
651  
vhost_worker_create(struct vhost_dev * dev)652  static struct vhost_worker *vhost_worker_create(struct vhost_dev *dev)
653  {
654  	struct vhost_worker *worker;
655  	struct vhost_task *vtsk;
656  	char name[TASK_COMM_LEN];
657  	int ret;
658  	u32 id;
659  
660  	worker = kzalloc(sizeof(*worker), GFP_KERNEL_ACCOUNT);
661  	if (!worker)
662  		return NULL;
663  
664  	worker->dev = dev;
665  	snprintf(name, sizeof(name), "vhost-%d", current->pid);
666  
667  	vtsk = vhost_task_create(vhost_run_work_list, vhost_worker_killed,
668  				 worker, name);
669  	if (!vtsk)
670  		goto free_worker;
671  
672  	mutex_init(&worker->mutex);
673  	init_llist_head(&worker->work_list);
674  	worker->kcov_handle = kcov_common_handle();
675  	worker->vtsk = vtsk;
676  
677  	vhost_task_start(vtsk);
678  
679  	ret = xa_alloc(&dev->worker_xa, &id, worker, xa_limit_32b, GFP_KERNEL);
680  	if (ret < 0)
681  		goto stop_worker;
682  	worker->id = id;
683  
684  	return worker;
685  
686  stop_worker:
687  	vhost_task_stop(vtsk);
688  free_worker:
689  	kfree(worker);
690  	return NULL;
691  }
692  
693  /* Caller must have device mutex */
__vhost_vq_attach_worker(struct vhost_virtqueue * vq,struct vhost_worker * worker)694  static void __vhost_vq_attach_worker(struct vhost_virtqueue *vq,
695  				     struct vhost_worker *worker)
696  {
697  	struct vhost_worker *old_worker;
698  
699  	mutex_lock(&worker->mutex);
700  	if (worker->killed) {
701  		mutex_unlock(&worker->mutex);
702  		return;
703  	}
704  
705  	mutex_lock(&vq->mutex);
706  
707  	old_worker = rcu_dereference_check(vq->worker,
708  					   lockdep_is_held(&vq->mutex));
709  	rcu_assign_pointer(vq->worker, worker);
710  	worker->attachment_cnt++;
711  
712  	if (!old_worker) {
713  		mutex_unlock(&vq->mutex);
714  		mutex_unlock(&worker->mutex);
715  		return;
716  	}
717  	mutex_unlock(&vq->mutex);
718  	mutex_unlock(&worker->mutex);
719  
720  	/*
721  	 * Take the worker mutex to make sure we see the work queued from
722  	 * device wide flushes which doesn't use RCU for execution.
723  	 */
724  	mutex_lock(&old_worker->mutex);
725  	if (old_worker->killed) {
726  		mutex_unlock(&old_worker->mutex);
727  		return;
728  	}
729  
730  	/*
731  	 * We don't want to call synchronize_rcu for every vq during setup
732  	 * because it will slow down VM startup. If we haven't done
733  	 * VHOST_SET_VRING_KICK and not done the driver specific
734  	 * SET_ENDPOINT/RUNNUNG then we can skip the sync since there will
735  	 * not be any works queued for scsi and net.
736  	 */
737  	mutex_lock(&vq->mutex);
738  	if (!vhost_vq_get_backend(vq) && !vq->kick) {
739  		mutex_unlock(&vq->mutex);
740  
741  		old_worker->attachment_cnt--;
742  		mutex_unlock(&old_worker->mutex);
743  		/*
744  		 * vsock can queue anytime after VHOST_VSOCK_SET_GUEST_CID.
745  		 * Warn if it adds support for multiple workers but forgets to
746  		 * handle the early queueing case.
747  		 */
748  		WARN_ON(!old_worker->attachment_cnt &&
749  			!llist_empty(&old_worker->work_list));
750  		return;
751  	}
752  	mutex_unlock(&vq->mutex);
753  
754  	/* Make sure new vq queue/flush/poll calls see the new worker */
755  	synchronize_rcu();
756  	/* Make sure whatever was queued gets run */
757  	__vhost_worker_flush(old_worker);
758  	old_worker->attachment_cnt--;
759  	mutex_unlock(&old_worker->mutex);
760  }
761  
762   /* Caller must have device mutex */
vhost_vq_attach_worker(struct vhost_virtqueue * vq,struct vhost_vring_worker * info)763  static int vhost_vq_attach_worker(struct vhost_virtqueue *vq,
764  				  struct vhost_vring_worker *info)
765  {
766  	unsigned long index = info->worker_id;
767  	struct vhost_dev *dev = vq->dev;
768  	struct vhost_worker *worker;
769  
770  	if (!dev->use_worker)
771  		return -EINVAL;
772  
773  	worker = xa_find(&dev->worker_xa, &index, UINT_MAX, XA_PRESENT);
774  	if (!worker || worker->id != info->worker_id)
775  		return -ENODEV;
776  
777  	__vhost_vq_attach_worker(vq, worker);
778  	return 0;
779  }
780  
781  /* Caller must have device mutex */
vhost_new_worker(struct vhost_dev * dev,struct vhost_worker_state * info)782  static int vhost_new_worker(struct vhost_dev *dev,
783  			    struct vhost_worker_state *info)
784  {
785  	struct vhost_worker *worker;
786  
787  	worker = vhost_worker_create(dev);
788  	if (!worker)
789  		return -ENOMEM;
790  
791  	info->worker_id = worker->id;
792  	return 0;
793  }
794  
795  /* Caller must have device mutex */
vhost_free_worker(struct vhost_dev * dev,struct vhost_worker_state * info)796  static int vhost_free_worker(struct vhost_dev *dev,
797  			     struct vhost_worker_state *info)
798  {
799  	unsigned long index = info->worker_id;
800  	struct vhost_worker *worker;
801  
802  	worker = xa_find(&dev->worker_xa, &index, UINT_MAX, XA_PRESENT);
803  	if (!worker || worker->id != info->worker_id)
804  		return -ENODEV;
805  
806  	mutex_lock(&worker->mutex);
807  	if (worker->attachment_cnt || worker->killed) {
808  		mutex_unlock(&worker->mutex);
809  		return -EBUSY;
810  	}
811  	/*
812  	 * A flush might have raced and snuck in before attachment_cnt was set
813  	 * to zero. Make sure flushes are flushed from the queue before
814  	 * freeing.
815  	 */
816  	__vhost_worker_flush(worker);
817  	mutex_unlock(&worker->mutex);
818  
819  	vhost_worker_destroy(dev, worker);
820  	return 0;
821  }
822  
vhost_get_vq_from_user(struct vhost_dev * dev,void __user * argp,struct vhost_virtqueue ** vq,u32 * id)823  static int vhost_get_vq_from_user(struct vhost_dev *dev, void __user *argp,
824  				  struct vhost_virtqueue **vq, u32 *id)
825  {
826  	u32 __user *idxp = argp;
827  	u32 idx;
828  	long r;
829  
830  	r = get_user(idx, idxp);
831  	if (r < 0)
832  		return r;
833  
834  	if (idx >= dev->nvqs)
835  		return -ENOBUFS;
836  
837  	idx = array_index_nospec(idx, dev->nvqs);
838  
839  	*vq = dev->vqs[idx];
840  	*id = idx;
841  	return 0;
842  }
843  
844  /* Caller must have device mutex */
vhost_worker_ioctl(struct vhost_dev * dev,unsigned int ioctl,void __user * argp)845  long vhost_worker_ioctl(struct vhost_dev *dev, unsigned int ioctl,
846  			void __user *argp)
847  {
848  	struct vhost_vring_worker ring_worker;
849  	struct vhost_worker_state state;
850  	struct vhost_worker *worker;
851  	struct vhost_virtqueue *vq;
852  	long ret;
853  	u32 idx;
854  
855  	if (!dev->use_worker)
856  		return -EINVAL;
857  
858  	if (!vhost_dev_has_owner(dev))
859  		return -EINVAL;
860  
861  	ret = vhost_dev_check_owner(dev);
862  	if (ret)
863  		return ret;
864  
865  	switch (ioctl) {
866  	/* dev worker ioctls */
867  	case VHOST_NEW_WORKER:
868  		ret = vhost_new_worker(dev, &state);
869  		if (!ret && copy_to_user(argp, &state, sizeof(state)))
870  			ret = -EFAULT;
871  		return ret;
872  	case VHOST_FREE_WORKER:
873  		if (copy_from_user(&state, argp, sizeof(state)))
874  			return -EFAULT;
875  		return vhost_free_worker(dev, &state);
876  	/* vring worker ioctls */
877  	case VHOST_ATTACH_VRING_WORKER:
878  	case VHOST_GET_VRING_WORKER:
879  		break;
880  	default:
881  		return -ENOIOCTLCMD;
882  	}
883  
884  	ret = vhost_get_vq_from_user(dev, argp, &vq, &idx);
885  	if (ret)
886  		return ret;
887  
888  	switch (ioctl) {
889  	case VHOST_ATTACH_VRING_WORKER:
890  		if (copy_from_user(&ring_worker, argp, sizeof(ring_worker))) {
891  			ret = -EFAULT;
892  			break;
893  		}
894  
895  		ret = vhost_vq_attach_worker(vq, &ring_worker);
896  		break;
897  	case VHOST_GET_VRING_WORKER:
898  		worker = rcu_dereference_check(vq->worker,
899  					       lockdep_is_held(&dev->mutex));
900  		if (!worker) {
901  			ret = -EINVAL;
902  			break;
903  		}
904  
905  		ring_worker.index = idx;
906  		ring_worker.worker_id = worker->id;
907  
908  		if (copy_to_user(argp, &ring_worker, sizeof(ring_worker)))
909  			ret = -EFAULT;
910  		break;
911  	default:
912  		ret = -ENOIOCTLCMD;
913  		break;
914  	}
915  
916  	return ret;
917  }
918  EXPORT_SYMBOL_GPL(vhost_worker_ioctl);
919  
920  /* Caller should have device mutex */
vhost_dev_set_owner(struct vhost_dev * dev)921  long vhost_dev_set_owner(struct vhost_dev *dev)
922  {
923  	struct vhost_worker *worker;
924  	int err, i;
925  
926  	/* Is there an owner already? */
927  	if (vhost_dev_has_owner(dev)) {
928  		err = -EBUSY;
929  		goto err_mm;
930  	}
931  
932  	vhost_attach_mm(dev);
933  
934  	err = vhost_dev_alloc_iovecs(dev);
935  	if (err)
936  		goto err_iovecs;
937  
938  	if (dev->use_worker) {
939  		/*
940  		 * This should be done last, because vsock can queue work
941  		 * before VHOST_SET_OWNER so it simplifies the failure path
942  		 * below since we don't have to worry about vsock queueing
943  		 * while we free the worker.
944  		 */
945  		worker = vhost_worker_create(dev);
946  		if (!worker) {
947  			err = -ENOMEM;
948  			goto err_worker;
949  		}
950  
951  		for (i = 0; i < dev->nvqs; i++)
952  			__vhost_vq_attach_worker(dev->vqs[i], worker);
953  	}
954  
955  	return 0;
956  
957  err_worker:
958  	vhost_dev_free_iovecs(dev);
959  err_iovecs:
960  	vhost_detach_mm(dev);
961  err_mm:
962  	return err;
963  }
964  EXPORT_SYMBOL_GPL(vhost_dev_set_owner);
965  
iotlb_alloc(void)966  static struct vhost_iotlb *iotlb_alloc(void)
967  {
968  	return vhost_iotlb_alloc(max_iotlb_entries,
969  				 VHOST_IOTLB_FLAG_RETIRE);
970  }
971  
vhost_dev_reset_owner_prepare(void)972  struct vhost_iotlb *vhost_dev_reset_owner_prepare(void)
973  {
974  	return iotlb_alloc();
975  }
976  EXPORT_SYMBOL_GPL(vhost_dev_reset_owner_prepare);
977  
978  /* Caller should have device mutex */
vhost_dev_reset_owner(struct vhost_dev * dev,struct vhost_iotlb * umem)979  void vhost_dev_reset_owner(struct vhost_dev *dev, struct vhost_iotlb *umem)
980  {
981  	int i;
982  
983  	vhost_dev_cleanup(dev);
984  
985  	dev->umem = umem;
986  	/* We don't need VQ locks below since vhost_dev_cleanup makes sure
987  	 * VQs aren't running.
988  	 */
989  	for (i = 0; i < dev->nvqs; ++i)
990  		dev->vqs[i]->umem = umem;
991  }
992  EXPORT_SYMBOL_GPL(vhost_dev_reset_owner);
993  
vhost_dev_stop(struct vhost_dev * dev)994  void vhost_dev_stop(struct vhost_dev *dev)
995  {
996  	int i;
997  
998  	for (i = 0; i < dev->nvqs; ++i) {
999  		if (dev->vqs[i]->kick && dev->vqs[i]->handle_kick)
1000  			vhost_poll_stop(&dev->vqs[i]->poll);
1001  	}
1002  
1003  	vhost_dev_flush(dev);
1004  }
1005  EXPORT_SYMBOL_GPL(vhost_dev_stop);
1006  
vhost_clear_msg(struct vhost_dev * dev)1007  void vhost_clear_msg(struct vhost_dev *dev)
1008  {
1009  	struct vhost_msg_node *node, *n;
1010  
1011  	spin_lock(&dev->iotlb_lock);
1012  
1013  	list_for_each_entry_safe(node, n, &dev->read_list, node) {
1014  		list_del(&node->node);
1015  		kfree(node);
1016  	}
1017  
1018  	list_for_each_entry_safe(node, n, &dev->pending_list, node) {
1019  		list_del(&node->node);
1020  		kfree(node);
1021  	}
1022  
1023  	spin_unlock(&dev->iotlb_lock);
1024  }
1025  EXPORT_SYMBOL_GPL(vhost_clear_msg);
1026  
vhost_dev_cleanup(struct vhost_dev * dev)1027  void vhost_dev_cleanup(struct vhost_dev *dev)
1028  {
1029  	int i;
1030  
1031  	for (i = 0; i < dev->nvqs; ++i) {
1032  		if (dev->vqs[i]->error_ctx)
1033  			eventfd_ctx_put(dev->vqs[i]->error_ctx);
1034  		if (dev->vqs[i]->kick)
1035  			fput(dev->vqs[i]->kick);
1036  		if (dev->vqs[i]->call_ctx.ctx)
1037  			eventfd_ctx_put(dev->vqs[i]->call_ctx.ctx);
1038  		vhost_vq_reset(dev, dev->vqs[i]);
1039  	}
1040  	vhost_dev_free_iovecs(dev);
1041  	if (dev->log_ctx)
1042  		eventfd_ctx_put(dev->log_ctx);
1043  	dev->log_ctx = NULL;
1044  	/* No one will access memory at this point */
1045  	vhost_iotlb_free(dev->umem);
1046  	dev->umem = NULL;
1047  	vhost_iotlb_free(dev->iotlb);
1048  	dev->iotlb = NULL;
1049  	vhost_clear_msg(dev);
1050  	wake_up_interruptible_poll(&dev->wait, EPOLLIN | EPOLLRDNORM);
1051  	vhost_workers_free(dev);
1052  	vhost_detach_mm(dev);
1053  }
1054  EXPORT_SYMBOL_GPL(vhost_dev_cleanup);
1055  
log_access_ok(void __user * log_base,u64 addr,unsigned long sz)1056  static bool log_access_ok(void __user *log_base, u64 addr, unsigned long sz)
1057  {
1058  	u64 a = addr / VHOST_PAGE_SIZE / 8;
1059  
1060  	/* Make sure 64 bit math will not overflow. */
1061  	if (a > ULONG_MAX - (unsigned long)log_base ||
1062  	    a + (unsigned long)log_base > ULONG_MAX)
1063  		return false;
1064  
1065  	return access_ok(log_base + a,
1066  			 (sz + VHOST_PAGE_SIZE * 8 - 1) / VHOST_PAGE_SIZE / 8);
1067  }
1068  
1069  /* Make sure 64 bit math will not overflow. */
vhost_overflow(u64 uaddr,u64 size)1070  static bool vhost_overflow(u64 uaddr, u64 size)
1071  {
1072  	if (uaddr > ULONG_MAX || size > ULONG_MAX)
1073  		return true;
1074  
1075  	if (!size)
1076  		return false;
1077  
1078  	return uaddr > ULONG_MAX - size + 1;
1079  }
1080  
1081  /* Caller should have vq mutex and device mutex. */
vq_memory_access_ok(void __user * log_base,struct vhost_iotlb * umem,int log_all)1082  static bool vq_memory_access_ok(void __user *log_base, struct vhost_iotlb *umem,
1083  				int log_all)
1084  {
1085  	struct vhost_iotlb_map *map;
1086  
1087  	if (!umem)
1088  		return false;
1089  
1090  	list_for_each_entry(map, &umem->list, link) {
1091  		unsigned long a = map->addr;
1092  
1093  		if (vhost_overflow(map->addr, map->size))
1094  			return false;
1095  
1096  
1097  		if (!access_ok((void __user *)a, map->size))
1098  			return false;
1099  		else if (log_all && !log_access_ok(log_base,
1100  						   map->start,
1101  						   map->size))
1102  			return false;
1103  	}
1104  	return true;
1105  }
1106  
vhost_vq_meta_fetch(struct vhost_virtqueue * vq,u64 addr,unsigned int size,int type)1107  static inline void __user *vhost_vq_meta_fetch(struct vhost_virtqueue *vq,
1108  					       u64 addr, unsigned int size,
1109  					       int type)
1110  {
1111  	const struct vhost_iotlb_map *map = vq->meta_iotlb[type];
1112  
1113  	if (!map)
1114  		return NULL;
1115  
1116  	return (void __user *)(uintptr_t)(map->addr + addr - map->start);
1117  }
1118  
1119  /* Can we switch to this memory table? */
1120  /* Caller should have device mutex but not vq mutex */
memory_access_ok(struct vhost_dev * d,struct vhost_iotlb * umem,int log_all)1121  static bool memory_access_ok(struct vhost_dev *d, struct vhost_iotlb *umem,
1122  			     int log_all)
1123  {
1124  	int i;
1125  
1126  	for (i = 0; i < d->nvqs; ++i) {
1127  		bool ok;
1128  		bool log;
1129  
1130  		mutex_lock(&d->vqs[i]->mutex);
1131  		log = log_all || vhost_has_feature(d->vqs[i], VHOST_F_LOG_ALL);
1132  		/* If ring is inactive, will check when it's enabled. */
1133  		if (d->vqs[i]->private_data)
1134  			ok = vq_memory_access_ok(d->vqs[i]->log_base,
1135  						 umem, log);
1136  		else
1137  			ok = true;
1138  		mutex_unlock(&d->vqs[i]->mutex);
1139  		if (!ok)
1140  			return false;
1141  	}
1142  	return true;
1143  }
1144  
1145  static int translate_desc(struct vhost_virtqueue *vq, u64 addr, u32 len,
1146  			  struct iovec iov[], int iov_size, int access);
1147  
vhost_copy_to_user(struct vhost_virtqueue * vq,void __user * to,const void * from,unsigned size)1148  static int vhost_copy_to_user(struct vhost_virtqueue *vq, void __user *to,
1149  			      const void *from, unsigned size)
1150  {
1151  	int ret;
1152  
1153  	if (!vq->iotlb)
1154  		return __copy_to_user(to, from, size);
1155  	else {
1156  		/* This function should be called after iotlb
1157  		 * prefetch, which means we're sure that all vq
1158  		 * could be access through iotlb. So -EAGAIN should
1159  		 * not happen in this case.
1160  		 */
1161  		struct iov_iter t;
1162  		void __user *uaddr = vhost_vq_meta_fetch(vq,
1163  				     (u64)(uintptr_t)to, size,
1164  				     VHOST_ADDR_USED);
1165  
1166  		if (uaddr)
1167  			return __copy_to_user(uaddr, from, size);
1168  
1169  		ret = translate_desc(vq, (u64)(uintptr_t)to, size, vq->iotlb_iov,
1170  				     ARRAY_SIZE(vq->iotlb_iov),
1171  				     VHOST_ACCESS_WO);
1172  		if (ret < 0)
1173  			goto out;
1174  		iov_iter_init(&t, ITER_DEST, vq->iotlb_iov, ret, size);
1175  		ret = copy_to_iter(from, size, &t);
1176  		if (ret == size)
1177  			ret = 0;
1178  	}
1179  out:
1180  	return ret;
1181  }
1182  
vhost_copy_from_user(struct vhost_virtqueue * vq,void * to,void __user * from,unsigned size)1183  static int vhost_copy_from_user(struct vhost_virtqueue *vq, void *to,
1184  				void __user *from, unsigned size)
1185  {
1186  	int ret;
1187  
1188  	if (!vq->iotlb)
1189  		return __copy_from_user(to, from, size);
1190  	else {
1191  		/* This function should be called after iotlb
1192  		 * prefetch, which means we're sure that vq
1193  		 * could be access through iotlb. So -EAGAIN should
1194  		 * not happen in this case.
1195  		 */
1196  		void __user *uaddr = vhost_vq_meta_fetch(vq,
1197  				     (u64)(uintptr_t)from, size,
1198  				     VHOST_ADDR_DESC);
1199  		struct iov_iter f;
1200  
1201  		if (uaddr)
1202  			return __copy_from_user(to, uaddr, size);
1203  
1204  		ret = translate_desc(vq, (u64)(uintptr_t)from, size, vq->iotlb_iov,
1205  				     ARRAY_SIZE(vq->iotlb_iov),
1206  				     VHOST_ACCESS_RO);
1207  		if (ret < 0) {
1208  			vq_err(vq, "IOTLB translation failure: uaddr "
1209  			       "%p size 0x%llx\n", from,
1210  			       (unsigned long long) size);
1211  			goto out;
1212  		}
1213  		iov_iter_init(&f, ITER_SOURCE, vq->iotlb_iov, ret, size);
1214  		ret = copy_from_iter(to, size, &f);
1215  		if (ret == size)
1216  			ret = 0;
1217  	}
1218  
1219  out:
1220  	return ret;
1221  }
1222  
__vhost_get_user_slow(struct vhost_virtqueue * vq,void __user * addr,unsigned int size,int type)1223  static void __user *__vhost_get_user_slow(struct vhost_virtqueue *vq,
1224  					  void __user *addr, unsigned int size,
1225  					  int type)
1226  {
1227  	int ret;
1228  
1229  	ret = translate_desc(vq, (u64)(uintptr_t)addr, size, vq->iotlb_iov,
1230  			     ARRAY_SIZE(vq->iotlb_iov),
1231  			     VHOST_ACCESS_RO);
1232  	if (ret < 0) {
1233  		vq_err(vq, "IOTLB translation failure: uaddr "
1234  			"%p size 0x%llx\n", addr,
1235  			(unsigned long long) size);
1236  		return NULL;
1237  	}
1238  
1239  	if (ret != 1 || vq->iotlb_iov[0].iov_len != size) {
1240  		vq_err(vq, "Non atomic userspace memory access: uaddr "
1241  			"%p size 0x%llx\n", addr,
1242  			(unsigned long long) size);
1243  		return NULL;
1244  	}
1245  
1246  	return vq->iotlb_iov[0].iov_base;
1247  }
1248  
1249  /* This function should be called after iotlb
1250   * prefetch, which means we're sure that vq
1251   * could be access through iotlb. So -EAGAIN should
1252   * not happen in this case.
1253   */
__vhost_get_user(struct vhost_virtqueue * vq,void __user * addr,unsigned int size,int type)1254  static inline void __user *__vhost_get_user(struct vhost_virtqueue *vq,
1255  					    void __user *addr, unsigned int size,
1256  					    int type)
1257  {
1258  	void __user *uaddr = vhost_vq_meta_fetch(vq,
1259  			     (u64)(uintptr_t)addr, size, type);
1260  	if (uaddr)
1261  		return uaddr;
1262  
1263  	return __vhost_get_user_slow(vq, addr, size, type);
1264  }
1265  
1266  #define vhost_put_user(vq, x, ptr)		\
1267  ({ \
1268  	int ret; \
1269  	if (!vq->iotlb) { \
1270  		ret = __put_user(x, ptr); \
1271  	} else { \
1272  		__typeof__(ptr) to = \
1273  			(__typeof__(ptr)) __vhost_get_user(vq, ptr,	\
1274  					  sizeof(*ptr), VHOST_ADDR_USED); \
1275  		if (to != NULL) \
1276  			ret = __put_user(x, to); \
1277  		else \
1278  			ret = -EFAULT;	\
1279  	} \
1280  	ret; \
1281  })
1282  
vhost_put_avail_event(struct vhost_virtqueue * vq)1283  static inline int vhost_put_avail_event(struct vhost_virtqueue *vq)
1284  {
1285  	return vhost_put_user(vq, cpu_to_vhost16(vq, vq->avail_idx),
1286  			      vhost_avail_event(vq));
1287  }
1288  
vhost_put_used(struct vhost_virtqueue * vq,struct vring_used_elem * head,int idx,int count)1289  static inline int vhost_put_used(struct vhost_virtqueue *vq,
1290  				 struct vring_used_elem *head, int idx,
1291  				 int count)
1292  {
1293  	return vhost_copy_to_user(vq, vq->used->ring + idx, head,
1294  				  count * sizeof(*head));
1295  }
1296  
vhost_put_used_flags(struct vhost_virtqueue * vq)1297  static inline int vhost_put_used_flags(struct vhost_virtqueue *vq)
1298  
1299  {
1300  	return vhost_put_user(vq, cpu_to_vhost16(vq, vq->used_flags),
1301  			      &vq->used->flags);
1302  }
1303  
vhost_put_used_idx(struct vhost_virtqueue * vq)1304  static inline int vhost_put_used_idx(struct vhost_virtqueue *vq)
1305  
1306  {
1307  	return vhost_put_user(vq, cpu_to_vhost16(vq, vq->last_used_idx),
1308  			      &vq->used->idx);
1309  }
1310  
1311  #define vhost_get_user(vq, x, ptr, type)		\
1312  ({ \
1313  	int ret; \
1314  	if (!vq->iotlb) { \
1315  		ret = __get_user(x, ptr); \
1316  	} else { \
1317  		__typeof__(ptr) from = \
1318  			(__typeof__(ptr)) __vhost_get_user(vq, ptr, \
1319  							   sizeof(*ptr), \
1320  							   type); \
1321  		if (from != NULL) \
1322  			ret = __get_user(x, from); \
1323  		else \
1324  			ret = -EFAULT; \
1325  	} \
1326  	ret; \
1327  })
1328  
1329  #define vhost_get_avail(vq, x, ptr) \
1330  	vhost_get_user(vq, x, ptr, VHOST_ADDR_AVAIL)
1331  
1332  #define vhost_get_used(vq, x, ptr) \
1333  	vhost_get_user(vq, x, ptr, VHOST_ADDR_USED)
1334  
vhost_dev_lock_vqs(struct vhost_dev * d)1335  static void vhost_dev_lock_vqs(struct vhost_dev *d)
1336  {
1337  	int i = 0;
1338  	for (i = 0; i < d->nvqs; ++i)
1339  		mutex_lock_nested(&d->vqs[i]->mutex, i);
1340  }
1341  
vhost_dev_unlock_vqs(struct vhost_dev * d)1342  static void vhost_dev_unlock_vqs(struct vhost_dev *d)
1343  {
1344  	int i = 0;
1345  	for (i = 0; i < d->nvqs; ++i)
1346  		mutex_unlock(&d->vqs[i]->mutex);
1347  }
1348  
vhost_get_avail_idx(struct vhost_virtqueue * vq)1349  static inline int vhost_get_avail_idx(struct vhost_virtqueue *vq)
1350  {
1351  	__virtio16 idx;
1352  	int r;
1353  
1354  	r = vhost_get_avail(vq, idx, &vq->avail->idx);
1355  	if (unlikely(r < 0)) {
1356  		vq_err(vq, "Failed to access available index at %p (%d)\n",
1357  		       &vq->avail->idx, r);
1358  		return r;
1359  	}
1360  
1361  	/* Check it isn't doing very strange thing with available indexes */
1362  	vq->avail_idx = vhost16_to_cpu(vq, idx);
1363  	if (unlikely((u16)(vq->avail_idx - vq->last_avail_idx) > vq->num)) {
1364  		vq_err(vq, "Invalid available index change from %u to %u",
1365  		       vq->last_avail_idx, vq->avail_idx);
1366  		return -EINVAL;
1367  	}
1368  
1369  	/* We're done if there is nothing new */
1370  	if (vq->avail_idx == vq->last_avail_idx)
1371  		return 0;
1372  
1373  	/*
1374  	 * We updated vq->avail_idx so we need a memory barrier between
1375  	 * the index read above and the caller reading avail ring entries.
1376  	 */
1377  	smp_rmb();
1378  	return 1;
1379  }
1380  
vhost_get_avail_head(struct vhost_virtqueue * vq,__virtio16 * head,int idx)1381  static inline int vhost_get_avail_head(struct vhost_virtqueue *vq,
1382  				       __virtio16 *head, int idx)
1383  {
1384  	return vhost_get_avail(vq, *head,
1385  			       &vq->avail->ring[idx & (vq->num - 1)]);
1386  }
1387  
vhost_get_avail_flags(struct vhost_virtqueue * vq,__virtio16 * flags)1388  static inline int vhost_get_avail_flags(struct vhost_virtqueue *vq,
1389  					__virtio16 *flags)
1390  {
1391  	return vhost_get_avail(vq, *flags, &vq->avail->flags);
1392  }
1393  
vhost_get_used_event(struct vhost_virtqueue * vq,__virtio16 * event)1394  static inline int vhost_get_used_event(struct vhost_virtqueue *vq,
1395  				       __virtio16 *event)
1396  {
1397  	return vhost_get_avail(vq, *event, vhost_used_event(vq));
1398  }
1399  
vhost_get_used_idx(struct vhost_virtqueue * vq,__virtio16 * idx)1400  static inline int vhost_get_used_idx(struct vhost_virtqueue *vq,
1401  				     __virtio16 *idx)
1402  {
1403  	return vhost_get_used(vq, *idx, &vq->used->idx);
1404  }
1405  
vhost_get_desc(struct vhost_virtqueue * vq,struct vring_desc * desc,int idx)1406  static inline int vhost_get_desc(struct vhost_virtqueue *vq,
1407  				 struct vring_desc *desc, int idx)
1408  {
1409  	return vhost_copy_from_user(vq, desc, vq->desc + idx, sizeof(*desc));
1410  }
1411  
vhost_iotlb_notify_vq(struct vhost_dev * d,struct vhost_iotlb_msg * msg)1412  static void vhost_iotlb_notify_vq(struct vhost_dev *d,
1413  				  struct vhost_iotlb_msg *msg)
1414  {
1415  	struct vhost_msg_node *node, *n;
1416  
1417  	spin_lock(&d->iotlb_lock);
1418  
1419  	list_for_each_entry_safe(node, n, &d->pending_list, node) {
1420  		struct vhost_iotlb_msg *vq_msg = &node->msg.iotlb;
1421  		if (msg->iova <= vq_msg->iova &&
1422  		    msg->iova + msg->size - 1 >= vq_msg->iova &&
1423  		    vq_msg->type == VHOST_IOTLB_MISS) {
1424  			vhost_poll_queue(&node->vq->poll);
1425  			list_del(&node->node);
1426  			kfree(node);
1427  		}
1428  	}
1429  
1430  	spin_unlock(&d->iotlb_lock);
1431  }
1432  
umem_access_ok(u64 uaddr,u64 size,int access)1433  static bool umem_access_ok(u64 uaddr, u64 size, int access)
1434  {
1435  	unsigned long a = uaddr;
1436  
1437  	/* Make sure 64 bit math will not overflow. */
1438  	if (vhost_overflow(uaddr, size))
1439  		return false;
1440  
1441  	if ((access & VHOST_ACCESS_RO) &&
1442  	    !access_ok((void __user *)a, size))
1443  		return false;
1444  	if ((access & VHOST_ACCESS_WO) &&
1445  	    !access_ok((void __user *)a, size))
1446  		return false;
1447  	return true;
1448  }
1449  
vhost_process_iotlb_msg(struct vhost_dev * dev,u32 asid,struct vhost_iotlb_msg * msg)1450  static int vhost_process_iotlb_msg(struct vhost_dev *dev, u32 asid,
1451  				   struct vhost_iotlb_msg *msg)
1452  {
1453  	int ret = 0;
1454  
1455  	if (asid != 0)
1456  		return -EINVAL;
1457  
1458  	mutex_lock(&dev->mutex);
1459  	vhost_dev_lock_vqs(dev);
1460  	switch (msg->type) {
1461  	case VHOST_IOTLB_UPDATE:
1462  		if (!dev->iotlb) {
1463  			ret = -EFAULT;
1464  			break;
1465  		}
1466  		if (!umem_access_ok(msg->uaddr, msg->size, msg->perm)) {
1467  			ret = -EFAULT;
1468  			break;
1469  		}
1470  		vhost_vq_meta_reset(dev);
1471  		if (vhost_iotlb_add_range(dev->iotlb, msg->iova,
1472  					  msg->iova + msg->size - 1,
1473  					  msg->uaddr, msg->perm)) {
1474  			ret = -ENOMEM;
1475  			break;
1476  		}
1477  		vhost_iotlb_notify_vq(dev, msg);
1478  		break;
1479  	case VHOST_IOTLB_INVALIDATE:
1480  		if (!dev->iotlb) {
1481  			ret = -EFAULT;
1482  			break;
1483  		}
1484  		vhost_vq_meta_reset(dev);
1485  		vhost_iotlb_del_range(dev->iotlb, msg->iova,
1486  				      msg->iova + msg->size - 1);
1487  		break;
1488  	default:
1489  		ret = -EINVAL;
1490  		break;
1491  	}
1492  
1493  	vhost_dev_unlock_vqs(dev);
1494  	mutex_unlock(&dev->mutex);
1495  
1496  	return ret;
1497  }
vhost_chr_write_iter(struct vhost_dev * dev,struct iov_iter * from)1498  ssize_t vhost_chr_write_iter(struct vhost_dev *dev,
1499  			     struct iov_iter *from)
1500  {
1501  	struct vhost_iotlb_msg msg;
1502  	size_t offset;
1503  	int type, ret;
1504  	u32 asid = 0;
1505  
1506  	ret = copy_from_iter(&type, sizeof(type), from);
1507  	if (ret != sizeof(type)) {
1508  		ret = -EINVAL;
1509  		goto done;
1510  	}
1511  
1512  	switch (type) {
1513  	case VHOST_IOTLB_MSG:
1514  		/* There maybe a hole after type for V1 message type,
1515  		 * so skip it here.
1516  		 */
1517  		offset = offsetof(struct vhost_msg, iotlb) - sizeof(int);
1518  		break;
1519  	case VHOST_IOTLB_MSG_V2:
1520  		if (vhost_backend_has_feature(dev->vqs[0],
1521  					      VHOST_BACKEND_F_IOTLB_ASID)) {
1522  			ret = copy_from_iter(&asid, sizeof(asid), from);
1523  			if (ret != sizeof(asid)) {
1524  				ret = -EINVAL;
1525  				goto done;
1526  			}
1527  			offset = 0;
1528  		} else
1529  			offset = sizeof(__u32);
1530  		break;
1531  	default:
1532  		ret = -EINVAL;
1533  		goto done;
1534  	}
1535  
1536  	iov_iter_advance(from, offset);
1537  	ret = copy_from_iter(&msg, sizeof(msg), from);
1538  	if (ret != sizeof(msg)) {
1539  		ret = -EINVAL;
1540  		goto done;
1541  	}
1542  
1543  	if (msg.type == VHOST_IOTLB_UPDATE && msg.size == 0) {
1544  		ret = -EINVAL;
1545  		goto done;
1546  	}
1547  
1548  	if (dev->msg_handler)
1549  		ret = dev->msg_handler(dev, asid, &msg);
1550  	else
1551  		ret = vhost_process_iotlb_msg(dev, asid, &msg);
1552  	if (ret) {
1553  		ret = -EFAULT;
1554  		goto done;
1555  	}
1556  
1557  	ret = (type == VHOST_IOTLB_MSG) ? sizeof(struct vhost_msg) :
1558  	      sizeof(struct vhost_msg_v2);
1559  done:
1560  	return ret;
1561  }
1562  EXPORT_SYMBOL(vhost_chr_write_iter);
1563  
vhost_chr_poll(struct file * file,struct vhost_dev * dev,poll_table * wait)1564  __poll_t vhost_chr_poll(struct file *file, struct vhost_dev *dev,
1565  			    poll_table *wait)
1566  {
1567  	__poll_t mask = 0;
1568  
1569  	poll_wait(file, &dev->wait, wait);
1570  
1571  	if (!list_empty(&dev->read_list))
1572  		mask |= EPOLLIN | EPOLLRDNORM;
1573  
1574  	return mask;
1575  }
1576  EXPORT_SYMBOL(vhost_chr_poll);
1577  
vhost_chr_read_iter(struct vhost_dev * dev,struct iov_iter * to,int noblock)1578  ssize_t vhost_chr_read_iter(struct vhost_dev *dev, struct iov_iter *to,
1579  			    int noblock)
1580  {
1581  	DEFINE_WAIT(wait);
1582  	struct vhost_msg_node *node;
1583  	ssize_t ret = 0;
1584  	unsigned size = sizeof(struct vhost_msg);
1585  
1586  	if (iov_iter_count(to) < size)
1587  		return 0;
1588  
1589  	while (1) {
1590  		if (!noblock)
1591  			prepare_to_wait(&dev->wait, &wait,
1592  					TASK_INTERRUPTIBLE);
1593  
1594  		node = vhost_dequeue_msg(dev, &dev->read_list);
1595  		if (node)
1596  			break;
1597  		if (noblock) {
1598  			ret = -EAGAIN;
1599  			break;
1600  		}
1601  		if (signal_pending(current)) {
1602  			ret = -ERESTARTSYS;
1603  			break;
1604  		}
1605  		if (!dev->iotlb) {
1606  			ret = -EBADFD;
1607  			break;
1608  		}
1609  
1610  		schedule();
1611  	}
1612  
1613  	if (!noblock)
1614  		finish_wait(&dev->wait, &wait);
1615  
1616  	if (node) {
1617  		struct vhost_iotlb_msg *msg;
1618  		void *start = &node->msg;
1619  
1620  		switch (node->msg.type) {
1621  		case VHOST_IOTLB_MSG:
1622  			size = sizeof(node->msg);
1623  			msg = &node->msg.iotlb;
1624  			break;
1625  		case VHOST_IOTLB_MSG_V2:
1626  			size = sizeof(node->msg_v2);
1627  			msg = &node->msg_v2.iotlb;
1628  			break;
1629  		default:
1630  			BUG();
1631  			break;
1632  		}
1633  
1634  		ret = copy_to_iter(start, size, to);
1635  		if (ret != size || msg->type != VHOST_IOTLB_MISS) {
1636  			kfree(node);
1637  			return ret;
1638  		}
1639  		vhost_enqueue_msg(dev, &dev->pending_list, node);
1640  	}
1641  
1642  	return ret;
1643  }
1644  EXPORT_SYMBOL_GPL(vhost_chr_read_iter);
1645  
vhost_iotlb_miss(struct vhost_virtqueue * vq,u64 iova,int access)1646  static int vhost_iotlb_miss(struct vhost_virtqueue *vq, u64 iova, int access)
1647  {
1648  	struct vhost_dev *dev = vq->dev;
1649  	struct vhost_msg_node *node;
1650  	struct vhost_iotlb_msg *msg;
1651  	bool v2 = vhost_backend_has_feature(vq, VHOST_BACKEND_F_IOTLB_MSG_V2);
1652  
1653  	node = vhost_new_msg(vq, v2 ? VHOST_IOTLB_MSG_V2 : VHOST_IOTLB_MSG);
1654  	if (!node)
1655  		return -ENOMEM;
1656  
1657  	if (v2) {
1658  		node->msg_v2.type = VHOST_IOTLB_MSG_V2;
1659  		msg = &node->msg_v2.iotlb;
1660  	} else {
1661  		msg = &node->msg.iotlb;
1662  	}
1663  
1664  	msg->type = VHOST_IOTLB_MISS;
1665  	msg->iova = iova;
1666  	msg->perm = access;
1667  
1668  	vhost_enqueue_msg(dev, &dev->read_list, node);
1669  
1670  	return 0;
1671  }
1672  
vq_access_ok(struct vhost_virtqueue * vq,unsigned int num,vring_desc_t __user * desc,vring_avail_t __user * avail,vring_used_t __user * used)1673  static bool vq_access_ok(struct vhost_virtqueue *vq, unsigned int num,
1674  			 vring_desc_t __user *desc,
1675  			 vring_avail_t __user *avail,
1676  			 vring_used_t __user *used)
1677  
1678  {
1679  	/* If an IOTLB device is present, the vring addresses are
1680  	 * GIOVAs. Access validation occurs at prefetch time. */
1681  	if (vq->iotlb)
1682  		return true;
1683  
1684  	return access_ok(desc, vhost_get_desc_size(vq, num)) &&
1685  	       access_ok(avail, vhost_get_avail_size(vq, num)) &&
1686  	       access_ok(used, vhost_get_used_size(vq, num));
1687  }
1688  
vhost_vq_meta_update(struct vhost_virtqueue * vq,const struct vhost_iotlb_map * map,int type)1689  static void vhost_vq_meta_update(struct vhost_virtqueue *vq,
1690  				 const struct vhost_iotlb_map *map,
1691  				 int type)
1692  {
1693  	int access = (type == VHOST_ADDR_USED) ?
1694  		     VHOST_ACCESS_WO : VHOST_ACCESS_RO;
1695  
1696  	if (likely(map->perm & access))
1697  		vq->meta_iotlb[type] = map;
1698  }
1699  
iotlb_access_ok(struct vhost_virtqueue * vq,int access,u64 addr,u64 len,int type)1700  static bool iotlb_access_ok(struct vhost_virtqueue *vq,
1701  			    int access, u64 addr, u64 len, int type)
1702  {
1703  	const struct vhost_iotlb_map *map;
1704  	struct vhost_iotlb *umem = vq->iotlb;
1705  	u64 s = 0, size, orig_addr = addr, last = addr + len - 1;
1706  
1707  	if (vhost_vq_meta_fetch(vq, addr, len, type))
1708  		return true;
1709  
1710  	while (len > s) {
1711  		map = vhost_iotlb_itree_first(umem, addr, last);
1712  		if (map == NULL || map->start > addr) {
1713  			vhost_iotlb_miss(vq, addr, access);
1714  			return false;
1715  		} else if (!(map->perm & access)) {
1716  			/* Report the possible access violation by
1717  			 * request another translation from userspace.
1718  			 */
1719  			return false;
1720  		}
1721  
1722  		size = map->size - addr + map->start;
1723  
1724  		if (orig_addr == addr && size >= len)
1725  			vhost_vq_meta_update(vq, map, type);
1726  
1727  		s += size;
1728  		addr += size;
1729  	}
1730  
1731  	return true;
1732  }
1733  
vq_meta_prefetch(struct vhost_virtqueue * vq)1734  int vq_meta_prefetch(struct vhost_virtqueue *vq)
1735  {
1736  	unsigned int num = vq->num;
1737  
1738  	if (!vq->iotlb)
1739  		return 1;
1740  
1741  	return iotlb_access_ok(vq, VHOST_MAP_RO, (u64)(uintptr_t)vq->desc,
1742  			       vhost_get_desc_size(vq, num), VHOST_ADDR_DESC) &&
1743  	       iotlb_access_ok(vq, VHOST_MAP_RO, (u64)(uintptr_t)vq->avail,
1744  			       vhost_get_avail_size(vq, num),
1745  			       VHOST_ADDR_AVAIL) &&
1746  	       iotlb_access_ok(vq, VHOST_MAP_WO, (u64)(uintptr_t)vq->used,
1747  			       vhost_get_used_size(vq, num), VHOST_ADDR_USED);
1748  }
1749  EXPORT_SYMBOL_GPL(vq_meta_prefetch);
1750  
1751  /* Can we log writes? */
1752  /* Caller should have device mutex but not vq mutex */
vhost_log_access_ok(struct vhost_dev * dev)1753  bool vhost_log_access_ok(struct vhost_dev *dev)
1754  {
1755  	return memory_access_ok(dev, dev->umem, 1);
1756  }
1757  EXPORT_SYMBOL_GPL(vhost_log_access_ok);
1758  
vq_log_used_access_ok(struct vhost_virtqueue * vq,void __user * log_base,bool log_used,u64 log_addr)1759  static bool vq_log_used_access_ok(struct vhost_virtqueue *vq,
1760  				  void __user *log_base,
1761  				  bool log_used,
1762  				  u64 log_addr)
1763  {
1764  	/* If an IOTLB device is present, log_addr is a GIOVA that
1765  	 * will never be logged by log_used(). */
1766  	if (vq->iotlb)
1767  		return true;
1768  
1769  	return !log_used || log_access_ok(log_base, log_addr,
1770  					  vhost_get_used_size(vq, vq->num));
1771  }
1772  
1773  /* Verify access for write logging. */
1774  /* Caller should have vq mutex and device mutex */
vq_log_access_ok(struct vhost_virtqueue * vq,void __user * log_base)1775  static bool vq_log_access_ok(struct vhost_virtqueue *vq,
1776  			     void __user *log_base)
1777  {
1778  	return vq_memory_access_ok(log_base, vq->umem,
1779  				   vhost_has_feature(vq, VHOST_F_LOG_ALL)) &&
1780  		vq_log_used_access_ok(vq, log_base, vq->log_used, vq->log_addr);
1781  }
1782  
1783  /* Can we start vq? */
1784  /* Caller should have vq mutex and device mutex */
vhost_vq_access_ok(struct vhost_virtqueue * vq)1785  bool vhost_vq_access_ok(struct vhost_virtqueue *vq)
1786  {
1787  	if (!vq_log_access_ok(vq, vq->log_base))
1788  		return false;
1789  
1790  	return vq_access_ok(vq, vq->num, vq->desc, vq->avail, vq->used);
1791  }
1792  EXPORT_SYMBOL_GPL(vhost_vq_access_ok);
1793  
vhost_set_memory(struct vhost_dev * d,struct vhost_memory __user * m)1794  static long vhost_set_memory(struct vhost_dev *d, struct vhost_memory __user *m)
1795  {
1796  	struct vhost_memory mem, *newmem;
1797  	struct vhost_memory_region *region;
1798  	struct vhost_iotlb *newumem, *oldumem;
1799  	unsigned long size = offsetof(struct vhost_memory, regions);
1800  	int i;
1801  
1802  	if (copy_from_user(&mem, m, size))
1803  		return -EFAULT;
1804  	if (mem.padding)
1805  		return -EOPNOTSUPP;
1806  	if (mem.nregions > max_mem_regions)
1807  		return -E2BIG;
1808  	newmem = kvzalloc(struct_size(newmem, regions, mem.nregions),
1809  			GFP_KERNEL);
1810  	if (!newmem)
1811  		return -ENOMEM;
1812  
1813  	memcpy(newmem, &mem, size);
1814  	if (copy_from_user(newmem->regions, m->regions,
1815  			   flex_array_size(newmem, regions, mem.nregions))) {
1816  		kvfree(newmem);
1817  		return -EFAULT;
1818  	}
1819  
1820  	newumem = iotlb_alloc();
1821  	if (!newumem) {
1822  		kvfree(newmem);
1823  		return -ENOMEM;
1824  	}
1825  
1826  	for (region = newmem->regions;
1827  	     region < newmem->regions + mem.nregions;
1828  	     region++) {
1829  		if (vhost_iotlb_add_range(newumem,
1830  					  region->guest_phys_addr,
1831  					  region->guest_phys_addr +
1832  					  region->memory_size - 1,
1833  					  region->userspace_addr,
1834  					  VHOST_MAP_RW))
1835  			goto err;
1836  	}
1837  
1838  	if (!memory_access_ok(d, newumem, 0))
1839  		goto err;
1840  
1841  	oldumem = d->umem;
1842  	d->umem = newumem;
1843  
1844  	/* All memory accesses are done under some VQ mutex. */
1845  	for (i = 0; i < d->nvqs; ++i) {
1846  		mutex_lock(&d->vqs[i]->mutex);
1847  		d->vqs[i]->umem = newumem;
1848  		mutex_unlock(&d->vqs[i]->mutex);
1849  	}
1850  
1851  	kvfree(newmem);
1852  	vhost_iotlb_free(oldumem);
1853  	return 0;
1854  
1855  err:
1856  	vhost_iotlb_free(newumem);
1857  	kvfree(newmem);
1858  	return -EFAULT;
1859  }
1860  
vhost_vring_set_num(struct vhost_dev * d,struct vhost_virtqueue * vq,void __user * argp)1861  static long vhost_vring_set_num(struct vhost_dev *d,
1862  				struct vhost_virtqueue *vq,
1863  				void __user *argp)
1864  {
1865  	struct vhost_vring_state s;
1866  
1867  	/* Resizing ring with an active backend?
1868  	 * You don't want to do that. */
1869  	if (vq->private_data)
1870  		return -EBUSY;
1871  
1872  	if (copy_from_user(&s, argp, sizeof s))
1873  		return -EFAULT;
1874  
1875  	if (!s.num || s.num > 0xffff || (s.num & (s.num - 1)))
1876  		return -EINVAL;
1877  	vq->num = s.num;
1878  
1879  	return 0;
1880  }
1881  
vhost_vring_set_addr(struct vhost_dev * d,struct vhost_virtqueue * vq,void __user * argp)1882  static long vhost_vring_set_addr(struct vhost_dev *d,
1883  				 struct vhost_virtqueue *vq,
1884  				 void __user *argp)
1885  {
1886  	struct vhost_vring_addr a;
1887  
1888  	if (copy_from_user(&a, argp, sizeof a))
1889  		return -EFAULT;
1890  	if (a.flags & ~(0x1 << VHOST_VRING_F_LOG))
1891  		return -EOPNOTSUPP;
1892  
1893  	/* For 32bit, verify that the top 32bits of the user
1894  	   data are set to zero. */
1895  	if ((u64)(unsigned long)a.desc_user_addr != a.desc_user_addr ||
1896  	    (u64)(unsigned long)a.used_user_addr != a.used_user_addr ||
1897  	    (u64)(unsigned long)a.avail_user_addr != a.avail_user_addr)
1898  		return -EFAULT;
1899  
1900  	/* Make sure it's safe to cast pointers to vring types. */
1901  	BUILD_BUG_ON(__alignof__ *vq->avail > VRING_AVAIL_ALIGN_SIZE);
1902  	BUILD_BUG_ON(__alignof__ *vq->used > VRING_USED_ALIGN_SIZE);
1903  	if ((a.avail_user_addr & (VRING_AVAIL_ALIGN_SIZE - 1)) ||
1904  	    (a.used_user_addr & (VRING_USED_ALIGN_SIZE - 1)) ||
1905  	    (a.log_guest_addr & (VRING_USED_ALIGN_SIZE - 1)))
1906  		return -EINVAL;
1907  
1908  	/* We only verify access here if backend is configured.
1909  	 * If it is not, we don't as size might not have been setup.
1910  	 * We will verify when backend is configured. */
1911  	if (vq->private_data) {
1912  		if (!vq_access_ok(vq, vq->num,
1913  			(void __user *)(unsigned long)a.desc_user_addr,
1914  			(void __user *)(unsigned long)a.avail_user_addr,
1915  			(void __user *)(unsigned long)a.used_user_addr))
1916  			return -EINVAL;
1917  
1918  		/* Also validate log access for used ring if enabled. */
1919  		if (!vq_log_used_access_ok(vq, vq->log_base,
1920  				a.flags & (0x1 << VHOST_VRING_F_LOG),
1921  				a.log_guest_addr))
1922  			return -EINVAL;
1923  	}
1924  
1925  	vq->log_used = !!(a.flags & (0x1 << VHOST_VRING_F_LOG));
1926  	vq->desc = (void __user *)(unsigned long)a.desc_user_addr;
1927  	vq->avail = (void __user *)(unsigned long)a.avail_user_addr;
1928  	vq->log_addr = a.log_guest_addr;
1929  	vq->used = (void __user *)(unsigned long)a.used_user_addr;
1930  
1931  	return 0;
1932  }
1933  
vhost_vring_set_num_addr(struct vhost_dev * d,struct vhost_virtqueue * vq,unsigned int ioctl,void __user * argp)1934  static long vhost_vring_set_num_addr(struct vhost_dev *d,
1935  				     struct vhost_virtqueue *vq,
1936  				     unsigned int ioctl,
1937  				     void __user *argp)
1938  {
1939  	long r;
1940  
1941  	mutex_lock(&vq->mutex);
1942  
1943  	switch (ioctl) {
1944  	case VHOST_SET_VRING_NUM:
1945  		r = vhost_vring_set_num(d, vq, argp);
1946  		break;
1947  	case VHOST_SET_VRING_ADDR:
1948  		r = vhost_vring_set_addr(d, vq, argp);
1949  		break;
1950  	default:
1951  		BUG();
1952  	}
1953  
1954  	mutex_unlock(&vq->mutex);
1955  
1956  	return r;
1957  }
vhost_vring_ioctl(struct vhost_dev * d,unsigned int ioctl,void __user * argp)1958  long vhost_vring_ioctl(struct vhost_dev *d, unsigned int ioctl, void __user *argp)
1959  {
1960  	struct file *eventfp, *filep = NULL;
1961  	bool pollstart = false, pollstop = false;
1962  	struct eventfd_ctx *ctx = NULL;
1963  	struct vhost_virtqueue *vq;
1964  	struct vhost_vring_state s;
1965  	struct vhost_vring_file f;
1966  	u32 idx;
1967  	long r;
1968  
1969  	r = vhost_get_vq_from_user(d, argp, &vq, &idx);
1970  	if (r < 0)
1971  		return r;
1972  
1973  	if (ioctl == VHOST_SET_VRING_NUM ||
1974  	    ioctl == VHOST_SET_VRING_ADDR) {
1975  		return vhost_vring_set_num_addr(d, vq, ioctl, argp);
1976  	}
1977  
1978  	mutex_lock(&vq->mutex);
1979  
1980  	switch (ioctl) {
1981  	case VHOST_SET_VRING_BASE:
1982  		/* Moving base with an active backend?
1983  		 * You don't want to do that. */
1984  		if (vq->private_data) {
1985  			r = -EBUSY;
1986  			break;
1987  		}
1988  		if (copy_from_user(&s, argp, sizeof s)) {
1989  			r = -EFAULT;
1990  			break;
1991  		}
1992  		if (vhost_has_feature(vq, VIRTIO_F_RING_PACKED)) {
1993  			vq->last_avail_idx = s.num & 0xffff;
1994  			vq->last_used_idx = (s.num >> 16) & 0xffff;
1995  		} else {
1996  			if (s.num > 0xffff) {
1997  				r = -EINVAL;
1998  				break;
1999  			}
2000  			vq->last_avail_idx = s.num;
2001  		}
2002  		/* Forget the cached index value. */
2003  		vq->avail_idx = vq->last_avail_idx;
2004  		break;
2005  	case VHOST_GET_VRING_BASE:
2006  		s.index = idx;
2007  		if (vhost_has_feature(vq, VIRTIO_F_RING_PACKED))
2008  			s.num = (u32)vq->last_avail_idx | ((u32)vq->last_used_idx << 16);
2009  		else
2010  			s.num = vq->last_avail_idx;
2011  		if (copy_to_user(argp, &s, sizeof s))
2012  			r = -EFAULT;
2013  		break;
2014  	case VHOST_SET_VRING_KICK:
2015  		if (copy_from_user(&f, argp, sizeof f)) {
2016  			r = -EFAULT;
2017  			break;
2018  		}
2019  		eventfp = f.fd == VHOST_FILE_UNBIND ? NULL : eventfd_fget(f.fd);
2020  		if (IS_ERR(eventfp)) {
2021  			r = PTR_ERR(eventfp);
2022  			break;
2023  		}
2024  		if (eventfp != vq->kick) {
2025  			pollstop = (filep = vq->kick) != NULL;
2026  			pollstart = (vq->kick = eventfp) != NULL;
2027  		} else
2028  			filep = eventfp;
2029  		break;
2030  	case VHOST_SET_VRING_CALL:
2031  		if (copy_from_user(&f, argp, sizeof f)) {
2032  			r = -EFAULT;
2033  			break;
2034  		}
2035  		ctx = f.fd == VHOST_FILE_UNBIND ? NULL : eventfd_ctx_fdget(f.fd);
2036  		if (IS_ERR(ctx)) {
2037  			r = PTR_ERR(ctx);
2038  			break;
2039  		}
2040  
2041  		swap(ctx, vq->call_ctx.ctx);
2042  		break;
2043  	case VHOST_SET_VRING_ERR:
2044  		if (copy_from_user(&f, argp, sizeof f)) {
2045  			r = -EFAULT;
2046  			break;
2047  		}
2048  		ctx = f.fd == VHOST_FILE_UNBIND ? NULL : eventfd_ctx_fdget(f.fd);
2049  		if (IS_ERR(ctx)) {
2050  			r = PTR_ERR(ctx);
2051  			break;
2052  		}
2053  		swap(ctx, vq->error_ctx);
2054  		break;
2055  	case VHOST_SET_VRING_ENDIAN:
2056  		r = vhost_set_vring_endian(vq, argp);
2057  		break;
2058  	case VHOST_GET_VRING_ENDIAN:
2059  		r = vhost_get_vring_endian(vq, idx, argp);
2060  		break;
2061  	case VHOST_SET_VRING_BUSYLOOP_TIMEOUT:
2062  		if (copy_from_user(&s, argp, sizeof(s))) {
2063  			r = -EFAULT;
2064  			break;
2065  		}
2066  		vq->busyloop_timeout = s.num;
2067  		break;
2068  	case VHOST_GET_VRING_BUSYLOOP_TIMEOUT:
2069  		s.index = idx;
2070  		s.num = vq->busyloop_timeout;
2071  		if (copy_to_user(argp, &s, sizeof(s)))
2072  			r = -EFAULT;
2073  		break;
2074  	default:
2075  		r = -ENOIOCTLCMD;
2076  	}
2077  
2078  	if (pollstop && vq->handle_kick)
2079  		vhost_poll_stop(&vq->poll);
2080  
2081  	if (!IS_ERR_OR_NULL(ctx))
2082  		eventfd_ctx_put(ctx);
2083  	if (filep)
2084  		fput(filep);
2085  
2086  	if (pollstart && vq->handle_kick)
2087  		r = vhost_poll_start(&vq->poll, vq->kick);
2088  
2089  	mutex_unlock(&vq->mutex);
2090  
2091  	if (pollstop && vq->handle_kick)
2092  		vhost_dev_flush(vq->poll.dev);
2093  	return r;
2094  }
2095  EXPORT_SYMBOL_GPL(vhost_vring_ioctl);
2096  
vhost_init_device_iotlb(struct vhost_dev * d)2097  int vhost_init_device_iotlb(struct vhost_dev *d)
2098  {
2099  	struct vhost_iotlb *niotlb, *oiotlb;
2100  	int i;
2101  
2102  	niotlb = iotlb_alloc();
2103  	if (!niotlb)
2104  		return -ENOMEM;
2105  
2106  	oiotlb = d->iotlb;
2107  	d->iotlb = niotlb;
2108  
2109  	for (i = 0; i < d->nvqs; ++i) {
2110  		struct vhost_virtqueue *vq = d->vqs[i];
2111  
2112  		mutex_lock(&vq->mutex);
2113  		vq->iotlb = niotlb;
2114  		__vhost_vq_meta_reset(vq);
2115  		mutex_unlock(&vq->mutex);
2116  	}
2117  
2118  	vhost_iotlb_free(oiotlb);
2119  
2120  	return 0;
2121  }
2122  EXPORT_SYMBOL_GPL(vhost_init_device_iotlb);
2123  
2124  /* Caller must have device mutex */
vhost_dev_ioctl(struct vhost_dev * d,unsigned int ioctl,void __user * argp)2125  long vhost_dev_ioctl(struct vhost_dev *d, unsigned int ioctl, void __user *argp)
2126  {
2127  	struct eventfd_ctx *ctx;
2128  	u64 p;
2129  	long r;
2130  	int i, fd;
2131  
2132  	/* If you are not the owner, you can become one */
2133  	if (ioctl == VHOST_SET_OWNER) {
2134  		r = vhost_dev_set_owner(d);
2135  		goto done;
2136  	}
2137  
2138  	/* You must be the owner to do anything else */
2139  	r = vhost_dev_check_owner(d);
2140  	if (r)
2141  		goto done;
2142  
2143  	switch (ioctl) {
2144  	case VHOST_SET_MEM_TABLE:
2145  		r = vhost_set_memory(d, argp);
2146  		break;
2147  	case VHOST_SET_LOG_BASE:
2148  		if (copy_from_user(&p, argp, sizeof p)) {
2149  			r = -EFAULT;
2150  			break;
2151  		}
2152  		if ((u64)(unsigned long)p != p) {
2153  			r = -EFAULT;
2154  			break;
2155  		}
2156  		for (i = 0; i < d->nvqs; ++i) {
2157  			struct vhost_virtqueue *vq;
2158  			void __user *base = (void __user *)(unsigned long)p;
2159  			vq = d->vqs[i];
2160  			mutex_lock(&vq->mutex);
2161  			/* If ring is inactive, will check when it's enabled. */
2162  			if (vq->private_data && !vq_log_access_ok(vq, base))
2163  				r = -EFAULT;
2164  			else
2165  				vq->log_base = base;
2166  			mutex_unlock(&vq->mutex);
2167  		}
2168  		break;
2169  	case VHOST_SET_LOG_FD:
2170  		r = get_user(fd, (int __user *)argp);
2171  		if (r < 0)
2172  			break;
2173  		ctx = fd == VHOST_FILE_UNBIND ? NULL : eventfd_ctx_fdget(fd);
2174  		if (IS_ERR(ctx)) {
2175  			r = PTR_ERR(ctx);
2176  			break;
2177  		}
2178  		swap(ctx, d->log_ctx);
2179  		for (i = 0; i < d->nvqs; ++i) {
2180  			mutex_lock(&d->vqs[i]->mutex);
2181  			d->vqs[i]->log_ctx = d->log_ctx;
2182  			mutex_unlock(&d->vqs[i]->mutex);
2183  		}
2184  		if (ctx)
2185  			eventfd_ctx_put(ctx);
2186  		break;
2187  	default:
2188  		r = -ENOIOCTLCMD;
2189  		break;
2190  	}
2191  done:
2192  	return r;
2193  }
2194  EXPORT_SYMBOL_GPL(vhost_dev_ioctl);
2195  
2196  /* TODO: This is really inefficient.  We need something like get_user()
2197   * (instruction directly accesses the data, with an exception table entry
2198   * returning -EFAULT). See Documentation/arch/x86/exception-tables.rst.
2199   */
set_bit_to_user(int nr,void __user * addr)2200  static int set_bit_to_user(int nr, void __user *addr)
2201  {
2202  	unsigned long log = (unsigned long)addr;
2203  	struct page *page;
2204  	void *base;
2205  	int bit = nr + (log % PAGE_SIZE) * 8;
2206  	int r;
2207  
2208  	r = pin_user_pages_fast(log, 1, FOLL_WRITE, &page);
2209  	if (r < 0)
2210  		return r;
2211  	BUG_ON(r != 1);
2212  	base = kmap_atomic(page);
2213  	set_bit(bit, base);
2214  	kunmap_atomic(base);
2215  	unpin_user_pages_dirty_lock(&page, 1, true);
2216  	return 0;
2217  }
2218  
log_write(void __user * log_base,u64 write_address,u64 write_length)2219  static int log_write(void __user *log_base,
2220  		     u64 write_address, u64 write_length)
2221  {
2222  	u64 write_page = write_address / VHOST_PAGE_SIZE;
2223  	int r;
2224  
2225  	if (!write_length)
2226  		return 0;
2227  	write_length += write_address % VHOST_PAGE_SIZE;
2228  	for (;;) {
2229  		u64 base = (u64)(unsigned long)log_base;
2230  		u64 log = base + write_page / 8;
2231  		int bit = write_page % 8;
2232  		if ((u64)(unsigned long)log != log)
2233  			return -EFAULT;
2234  		r = set_bit_to_user(bit, (void __user *)(unsigned long)log);
2235  		if (r < 0)
2236  			return r;
2237  		if (write_length <= VHOST_PAGE_SIZE)
2238  			break;
2239  		write_length -= VHOST_PAGE_SIZE;
2240  		write_page += 1;
2241  	}
2242  	return r;
2243  }
2244  
log_write_hva(struct vhost_virtqueue * vq,u64 hva,u64 len)2245  static int log_write_hva(struct vhost_virtqueue *vq, u64 hva, u64 len)
2246  {
2247  	struct vhost_iotlb *umem = vq->umem;
2248  	struct vhost_iotlb_map *u;
2249  	u64 start, end, l, min;
2250  	int r;
2251  	bool hit = false;
2252  
2253  	while (len) {
2254  		min = len;
2255  		/* More than one GPAs can be mapped into a single HVA. So
2256  		 * iterate all possible umems here to be safe.
2257  		 */
2258  		list_for_each_entry(u, &umem->list, link) {
2259  			if (u->addr > hva - 1 + len ||
2260  			    u->addr - 1 + u->size < hva)
2261  				continue;
2262  			start = max(u->addr, hva);
2263  			end = min(u->addr - 1 + u->size, hva - 1 + len);
2264  			l = end - start + 1;
2265  			r = log_write(vq->log_base,
2266  				      u->start + start - u->addr,
2267  				      l);
2268  			if (r < 0)
2269  				return r;
2270  			hit = true;
2271  			min = min(l, min);
2272  		}
2273  
2274  		if (!hit)
2275  			return -EFAULT;
2276  
2277  		len -= min;
2278  		hva += min;
2279  	}
2280  
2281  	return 0;
2282  }
2283  
log_used(struct vhost_virtqueue * vq,u64 used_offset,u64 len)2284  static int log_used(struct vhost_virtqueue *vq, u64 used_offset, u64 len)
2285  {
2286  	struct iovec *iov = vq->log_iov;
2287  	int i, ret;
2288  
2289  	if (!vq->iotlb)
2290  		return log_write(vq->log_base, vq->log_addr + used_offset, len);
2291  
2292  	ret = translate_desc(vq, (uintptr_t)vq->used + used_offset,
2293  			     len, iov, 64, VHOST_ACCESS_WO);
2294  	if (ret < 0)
2295  		return ret;
2296  
2297  	for (i = 0; i < ret; i++) {
2298  		ret = log_write_hva(vq,	(uintptr_t)iov[i].iov_base,
2299  				    iov[i].iov_len);
2300  		if (ret)
2301  			return ret;
2302  	}
2303  
2304  	return 0;
2305  }
2306  
vhost_log_write(struct vhost_virtqueue * vq,struct vhost_log * log,unsigned int log_num,u64 len,struct iovec * iov,int count)2307  int vhost_log_write(struct vhost_virtqueue *vq, struct vhost_log *log,
2308  		    unsigned int log_num, u64 len, struct iovec *iov, int count)
2309  {
2310  	int i, r;
2311  
2312  	/* Make sure data written is seen before log. */
2313  	smp_wmb();
2314  
2315  	if (vq->iotlb) {
2316  		for (i = 0; i < count; i++) {
2317  			r = log_write_hva(vq, (uintptr_t)iov[i].iov_base,
2318  					  iov[i].iov_len);
2319  			if (r < 0)
2320  				return r;
2321  		}
2322  		return 0;
2323  	}
2324  
2325  	for (i = 0; i < log_num; ++i) {
2326  		u64 l = min(log[i].len, len);
2327  		r = log_write(vq->log_base, log[i].addr, l);
2328  		if (r < 0)
2329  			return r;
2330  		len -= l;
2331  		if (!len) {
2332  			if (vq->log_ctx)
2333  				eventfd_signal(vq->log_ctx);
2334  			return 0;
2335  		}
2336  	}
2337  	/* Length written exceeds what we have stored. This is a bug. */
2338  	BUG();
2339  	return 0;
2340  }
2341  EXPORT_SYMBOL_GPL(vhost_log_write);
2342  
vhost_update_used_flags(struct vhost_virtqueue * vq)2343  static int vhost_update_used_flags(struct vhost_virtqueue *vq)
2344  {
2345  	void __user *used;
2346  	if (vhost_put_used_flags(vq))
2347  		return -EFAULT;
2348  	if (unlikely(vq->log_used)) {
2349  		/* Make sure the flag is seen before log. */
2350  		smp_wmb();
2351  		/* Log used flag write. */
2352  		used = &vq->used->flags;
2353  		log_used(vq, (used - (void __user *)vq->used),
2354  			 sizeof vq->used->flags);
2355  		if (vq->log_ctx)
2356  			eventfd_signal(vq->log_ctx);
2357  	}
2358  	return 0;
2359  }
2360  
vhost_update_avail_event(struct vhost_virtqueue * vq)2361  static int vhost_update_avail_event(struct vhost_virtqueue *vq)
2362  {
2363  	if (vhost_put_avail_event(vq))
2364  		return -EFAULT;
2365  	if (unlikely(vq->log_used)) {
2366  		void __user *used;
2367  		/* Make sure the event is seen before log. */
2368  		smp_wmb();
2369  		/* Log avail event write */
2370  		used = vhost_avail_event(vq);
2371  		log_used(vq, (used - (void __user *)vq->used),
2372  			 sizeof *vhost_avail_event(vq));
2373  		if (vq->log_ctx)
2374  			eventfd_signal(vq->log_ctx);
2375  	}
2376  	return 0;
2377  }
2378  
vhost_vq_init_access(struct vhost_virtqueue * vq)2379  int vhost_vq_init_access(struct vhost_virtqueue *vq)
2380  {
2381  	__virtio16 last_used_idx;
2382  	int r;
2383  	bool is_le = vq->is_le;
2384  
2385  	if (!vq->private_data)
2386  		return 0;
2387  
2388  	vhost_init_is_le(vq);
2389  
2390  	r = vhost_update_used_flags(vq);
2391  	if (r)
2392  		goto err;
2393  	vq->signalled_used_valid = false;
2394  	if (!vq->iotlb &&
2395  	    !access_ok(&vq->used->idx, sizeof vq->used->idx)) {
2396  		r = -EFAULT;
2397  		goto err;
2398  	}
2399  	r = vhost_get_used_idx(vq, &last_used_idx);
2400  	if (r) {
2401  		vq_err(vq, "Can't access used idx at %p\n",
2402  		       &vq->used->idx);
2403  		goto err;
2404  	}
2405  	vq->last_used_idx = vhost16_to_cpu(vq, last_used_idx);
2406  	return 0;
2407  
2408  err:
2409  	vq->is_le = is_le;
2410  	return r;
2411  }
2412  EXPORT_SYMBOL_GPL(vhost_vq_init_access);
2413  
translate_desc(struct vhost_virtqueue * vq,u64 addr,u32 len,struct iovec iov[],int iov_size,int access)2414  static int translate_desc(struct vhost_virtqueue *vq, u64 addr, u32 len,
2415  			  struct iovec iov[], int iov_size, int access)
2416  {
2417  	const struct vhost_iotlb_map *map;
2418  	struct vhost_dev *dev = vq->dev;
2419  	struct vhost_iotlb *umem = dev->iotlb ? dev->iotlb : dev->umem;
2420  	struct iovec *_iov;
2421  	u64 s = 0, last = addr + len - 1;
2422  	int ret = 0;
2423  
2424  	while ((u64)len > s) {
2425  		u64 size;
2426  		if (unlikely(ret >= iov_size)) {
2427  			ret = -ENOBUFS;
2428  			break;
2429  		}
2430  
2431  		map = vhost_iotlb_itree_first(umem, addr, last);
2432  		if (map == NULL || map->start > addr) {
2433  			if (umem != dev->iotlb) {
2434  				ret = -EFAULT;
2435  				break;
2436  			}
2437  			ret = -EAGAIN;
2438  			break;
2439  		} else if (!(map->perm & access)) {
2440  			ret = -EPERM;
2441  			break;
2442  		}
2443  
2444  		_iov = iov + ret;
2445  		size = map->size - addr + map->start;
2446  		_iov->iov_len = min((u64)len - s, size);
2447  		_iov->iov_base = (void __user *)(unsigned long)
2448  				 (map->addr + addr - map->start);
2449  		s += size;
2450  		addr += size;
2451  		++ret;
2452  	}
2453  
2454  	if (ret == -EAGAIN)
2455  		vhost_iotlb_miss(vq, addr, access);
2456  	return ret;
2457  }
2458  
2459  /* Each buffer in the virtqueues is actually a chain of descriptors.  This
2460   * function returns the next descriptor in the chain,
2461   * or -1U if we're at the end. */
next_desc(struct vhost_virtqueue * vq,struct vring_desc * desc)2462  static unsigned next_desc(struct vhost_virtqueue *vq, struct vring_desc *desc)
2463  {
2464  	unsigned int next;
2465  
2466  	/* If this descriptor says it doesn't chain, we're done. */
2467  	if (!(desc->flags & cpu_to_vhost16(vq, VRING_DESC_F_NEXT)))
2468  		return -1U;
2469  
2470  	/* Check they're not leading us off end of descriptors. */
2471  	next = vhost16_to_cpu(vq, READ_ONCE(desc->next));
2472  	return next;
2473  }
2474  
get_indirect(struct vhost_virtqueue * vq,struct iovec iov[],unsigned int iov_size,unsigned int * out_num,unsigned int * in_num,struct vhost_log * log,unsigned int * log_num,struct vring_desc * indirect)2475  static int get_indirect(struct vhost_virtqueue *vq,
2476  			struct iovec iov[], unsigned int iov_size,
2477  			unsigned int *out_num, unsigned int *in_num,
2478  			struct vhost_log *log, unsigned int *log_num,
2479  			struct vring_desc *indirect)
2480  {
2481  	struct vring_desc desc;
2482  	unsigned int i = 0, count, found = 0;
2483  	u32 len = vhost32_to_cpu(vq, indirect->len);
2484  	struct iov_iter from;
2485  	int ret, access;
2486  
2487  	/* Sanity check */
2488  	if (unlikely(len % sizeof desc)) {
2489  		vq_err(vq, "Invalid length in indirect descriptor: "
2490  		       "len 0x%llx not multiple of 0x%zx\n",
2491  		       (unsigned long long)len,
2492  		       sizeof desc);
2493  		return -EINVAL;
2494  	}
2495  
2496  	ret = translate_desc(vq, vhost64_to_cpu(vq, indirect->addr), len, vq->indirect,
2497  			     UIO_MAXIOV, VHOST_ACCESS_RO);
2498  	if (unlikely(ret < 0)) {
2499  		if (ret != -EAGAIN)
2500  			vq_err(vq, "Translation failure %d in indirect.\n", ret);
2501  		return ret;
2502  	}
2503  	iov_iter_init(&from, ITER_SOURCE, vq->indirect, ret, len);
2504  	count = len / sizeof desc;
2505  	/* Buffers are chained via a 16 bit next field, so
2506  	 * we can have at most 2^16 of these. */
2507  	if (unlikely(count > USHRT_MAX + 1)) {
2508  		vq_err(vq, "Indirect buffer length too big: %d\n",
2509  		       indirect->len);
2510  		return -E2BIG;
2511  	}
2512  
2513  	do {
2514  		unsigned iov_count = *in_num + *out_num;
2515  		if (unlikely(++found > count)) {
2516  			vq_err(vq, "Loop detected: last one at %u "
2517  			       "indirect size %u\n",
2518  			       i, count);
2519  			return -EINVAL;
2520  		}
2521  		if (unlikely(!copy_from_iter_full(&desc, sizeof(desc), &from))) {
2522  			vq_err(vq, "Failed indirect descriptor: idx %d, %zx\n",
2523  			       i, (size_t)vhost64_to_cpu(vq, indirect->addr) + i * sizeof desc);
2524  			return -EINVAL;
2525  		}
2526  		if (unlikely(desc.flags & cpu_to_vhost16(vq, VRING_DESC_F_INDIRECT))) {
2527  			vq_err(vq, "Nested indirect descriptor: idx %d, %zx\n",
2528  			       i, (size_t)vhost64_to_cpu(vq, indirect->addr) + i * sizeof desc);
2529  			return -EINVAL;
2530  		}
2531  
2532  		if (desc.flags & cpu_to_vhost16(vq, VRING_DESC_F_WRITE))
2533  			access = VHOST_ACCESS_WO;
2534  		else
2535  			access = VHOST_ACCESS_RO;
2536  
2537  		ret = translate_desc(vq, vhost64_to_cpu(vq, desc.addr),
2538  				     vhost32_to_cpu(vq, desc.len), iov + iov_count,
2539  				     iov_size - iov_count, access);
2540  		if (unlikely(ret < 0)) {
2541  			if (ret != -EAGAIN)
2542  				vq_err(vq, "Translation failure %d indirect idx %d\n",
2543  					ret, i);
2544  			return ret;
2545  		}
2546  		/* If this is an input descriptor, increment that count. */
2547  		if (access == VHOST_ACCESS_WO) {
2548  			*in_num += ret;
2549  			if (unlikely(log && ret)) {
2550  				log[*log_num].addr = vhost64_to_cpu(vq, desc.addr);
2551  				log[*log_num].len = vhost32_to_cpu(vq, desc.len);
2552  				++*log_num;
2553  			}
2554  		} else {
2555  			/* If it's an output descriptor, they're all supposed
2556  			 * to come before any input descriptors. */
2557  			if (unlikely(*in_num)) {
2558  				vq_err(vq, "Indirect descriptor "
2559  				       "has out after in: idx %d\n", i);
2560  				return -EINVAL;
2561  			}
2562  			*out_num += ret;
2563  		}
2564  	} while ((i = next_desc(vq, &desc)) != -1);
2565  	return 0;
2566  }
2567  
2568  /* This looks in the virtqueue and for the first available buffer, and converts
2569   * it to an iovec for convenient access.  Since descriptors consist of some
2570   * number of output then some number of input descriptors, it's actually two
2571   * iovecs, but we pack them into one and note how many of each there were.
2572   *
2573   * This function returns the descriptor number found, or vq->num (which is
2574   * never a valid descriptor number) if none was found.  A negative code is
2575   * returned on error. */
vhost_get_vq_desc(struct vhost_virtqueue * vq,struct iovec iov[],unsigned int iov_size,unsigned int * out_num,unsigned int * in_num,struct vhost_log * log,unsigned int * log_num)2576  int vhost_get_vq_desc(struct vhost_virtqueue *vq,
2577  		      struct iovec iov[], unsigned int iov_size,
2578  		      unsigned int *out_num, unsigned int *in_num,
2579  		      struct vhost_log *log, unsigned int *log_num)
2580  {
2581  	struct vring_desc desc;
2582  	unsigned int i, head, found = 0;
2583  	u16 last_avail_idx = vq->last_avail_idx;
2584  	__virtio16 ring_head;
2585  	int ret, access;
2586  
2587  	if (vq->avail_idx == vq->last_avail_idx) {
2588  		ret = vhost_get_avail_idx(vq);
2589  		if (unlikely(ret < 0))
2590  			return ret;
2591  
2592  		if (!ret)
2593  			return vq->num;
2594  	}
2595  
2596  	/* Grab the next descriptor number they're advertising, and increment
2597  	 * the index we've seen. */
2598  	if (unlikely(vhost_get_avail_head(vq, &ring_head, last_avail_idx))) {
2599  		vq_err(vq, "Failed to read head: idx %d address %p\n",
2600  		       last_avail_idx,
2601  		       &vq->avail->ring[last_avail_idx % vq->num]);
2602  		return -EFAULT;
2603  	}
2604  
2605  	head = vhost16_to_cpu(vq, ring_head);
2606  
2607  	/* If their number is silly, that's an error. */
2608  	if (unlikely(head >= vq->num)) {
2609  		vq_err(vq, "Guest says index %u > %u is available",
2610  		       head, vq->num);
2611  		return -EINVAL;
2612  	}
2613  
2614  	/* When we start there are none of either input nor output. */
2615  	*out_num = *in_num = 0;
2616  	if (unlikely(log))
2617  		*log_num = 0;
2618  
2619  	i = head;
2620  	do {
2621  		unsigned iov_count = *in_num + *out_num;
2622  		if (unlikely(i >= vq->num)) {
2623  			vq_err(vq, "Desc index is %u > %u, head = %u",
2624  			       i, vq->num, head);
2625  			return -EINVAL;
2626  		}
2627  		if (unlikely(++found > vq->num)) {
2628  			vq_err(vq, "Loop detected: last one at %u "
2629  			       "vq size %u head %u\n",
2630  			       i, vq->num, head);
2631  			return -EINVAL;
2632  		}
2633  		ret = vhost_get_desc(vq, &desc, i);
2634  		if (unlikely(ret)) {
2635  			vq_err(vq, "Failed to get descriptor: idx %d addr %p\n",
2636  			       i, vq->desc + i);
2637  			return -EFAULT;
2638  		}
2639  		if (desc.flags & cpu_to_vhost16(vq, VRING_DESC_F_INDIRECT)) {
2640  			ret = get_indirect(vq, iov, iov_size,
2641  					   out_num, in_num,
2642  					   log, log_num, &desc);
2643  			if (unlikely(ret < 0)) {
2644  				if (ret != -EAGAIN)
2645  					vq_err(vq, "Failure detected "
2646  						"in indirect descriptor at idx %d\n", i);
2647  				return ret;
2648  			}
2649  			continue;
2650  		}
2651  
2652  		if (desc.flags & cpu_to_vhost16(vq, VRING_DESC_F_WRITE))
2653  			access = VHOST_ACCESS_WO;
2654  		else
2655  			access = VHOST_ACCESS_RO;
2656  		ret = translate_desc(vq, vhost64_to_cpu(vq, desc.addr),
2657  				     vhost32_to_cpu(vq, desc.len), iov + iov_count,
2658  				     iov_size - iov_count, access);
2659  		if (unlikely(ret < 0)) {
2660  			if (ret != -EAGAIN)
2661  				vq_err(vq, "Translation failure %d descriptor idx %d\n",
2662  					ret, i);
2663  			return ret;
2664  		}
2665  		if (access == VHOST_ACCESS_WO) {
2666  			/* If this is an input descriptor,
2667  			 * increment that count. */
2668  			*in_num += ret;
2669  			if (unlikely(log && ret)) {
2670  				log[*log_num].addr = vhost64_to_cpu(vq, desc.addr);
2671  				log[*log_num].len = vhost32_to_cpu(vq, desc.len);
2672  				++*log_num;
2673  			}
2674  		} else {
2675  			/* If it's an output descriptor, they're all supposed
2676  			 * to come before any input descriptors. */
2677  			if (unlikely(*in_num)) {
2678  				vq_err(vq, "Descriptor has out after in: "
2679  				       "idx %d\n", i);
2680  				return -EINVAL;
2681  			}
2682  			*out_num += ret;
2683  		}
2684  	} while ((i = next_desc(vq, &desc)) != -1);
2685  
2686  	/* On success, increment avail index. */
2687  	vq->last_avail_idx++;
2688  
2689  	/* Assume notifications from guest are disabled at this point,
2690  	 * if they aren't we would need to update avail_event index. */
2691  	BUG_ON(!(vq->used_flags & VRING_USED_F_NO_NOTIFY));
2692  	return head;
2693  }
2694  EXPORT_SYMBOL_GPL(vhost_get_vq_desc);
2695  
2696  /* Reverse the effect of vhost_get_vq_desc. Useful for error handling. */
vhost_discard_vq_desc(struct vhost_virtqueue * vq,int n)2697  void vhost_discard_vq_desc(struct vhost_virtqueue *vq, int n)
2698  {
2699  	vq->last_avail_idx -= n;
2700  }
2701  EXPORT_SYMBOL_GPL(vhost_discard_vq_desc);
2702  
2703  /* After we've used one of their buffers, we tell them about it.  We'll then
2704   * want to notify the guest, using eventfd. */
vhost_add_used(struct vhost_virtqueue * vq,unsigned int head,int len)2705  int vhost_add_used(struct vhost_virtqueue *vq, unsigned int head, int len)
2706  {
2707  	struct vring_used_elem heads = {
2708  		cpu_to_vhost32(vq, head),
2709  		cpu_to_vhost32(vq, len)
2710  	};
2711  
2712  	return vhost_add_used_n(vq, &heads, 1);
2713  }
2714  EXPORT_SYMBOL_GPL(vhost_add_used);
2715  
__vhost_add_used_n(struct vhost_virtqueue * vq,struct vring_used_elem * heads,unsigned count)2716  static int __vhost_add_used_n(struct vhost_virtqueue *vq,
2717  			    struct vring_used_elem *heads,
2718  			    unsigned count)
2719  {
2720  	vring_used_elem_t __user *used;
2721  	u16 old, new;
2722  	int start;
2723  
2724  	start = vq->last_used_idx & (vq->num - 1);
2725  	used = vq->used->ring + start;
2726  	if (vhost_put_used(vq, heads, start, count)) {
2727  		vq_err(vq, "Failed to write used");
2728  		return -EFAULT;
2729  	}
2730  	if (unlikely(vq->log_used)) {
2731  		/* Make sure data is seen before log. */
2732  		smp_wmb();
2733  		/* Log used ring entry write. */
2734  		log_used(vq, ((void __user *)used - (void __user *)vq->used),
2735  			 count * sizeof *used);
2736  	}
2737  	old = vq->last_used_idx;
2738  	new = (vq->last_used_idx += count);
2739  	/* If the driver never bothers to signal in a very long while,
2740  	 * used index might wrap around. If that happens, invalidate
2741  	 * signalled_used index we stored. TODO: make sure driver
2742  	 * signals at least once in 2^16 and remove this. */
2743  	if (unlikely((u16)(new - vq->signalled_used) < (u16)(new - old)))
2744  		vq->signalled_used_valid = false;
2745  	return 0;
2746  }
2747  
2748  /* After we've used one of their buffers, we tell them about it.  We'll then
2749   * want to notify the guest, using eventfd. */
vhost_add_used_n(struct vhost_virtqueue * vq,struct vring_used_elem * heads,unsigned count)2750  int vhost_add_used_n(struct vhost_virtqueue *vq, struct vring_used_elem *heads,
2751  		     unsigned count)
2752  {
2753  	int start, n, r;
2754  
2755  	start = vq->last_used_idx & (vq->num - 1);
2756  	n = vq->num - start;
2757  	if (n < count) {
2758  		r = __vhost_add_used_n(vq, heads, n);
2759  		if (r < 0)
2760  			return r;
2761  		heads += n;
2762  		count -= n;
2763  	}
2764  	r = __vhost_add_used_n(vq, heads, count);
2765  
2766  	/* Make sure buffer is written before we update index. */
2767  	smp_wmb();
2768  	if (vhost_put_used_idx(vq)) {
2769  		vq_err(vq, "Failed to increment used idx");
2770  		return -EFAULT;
2771  	}
2772  	if (unlikely(vq->log_used)) {
2773  		/* Make sure used idx is seen before log. */
2774  		smp_wmb();
2775  		/* Log used index update. */
2776  		log_used(vq, offsetof(struct vring_used, idx),
2777  			 sizeof vq->used->idx);
2778  		if (vq->log_ctx)
2779  			eventfd_signal(vq->log_ctx);
2780  	}
2781  	return r;
2782  }
2783  EXPORT_SYMBOL_GPL(vhost_add_used_n);
2784  
vhost_notify(struct vhost_dev * dev,struct vhost_virtqueue * vq)2785  static bool vhost_notify(struct vhost_dev *dev, struct vhost_virtqueue *vq)
2786  {
2787  	__u16 old, new;
2788  	__virtio16 event;
2789  	bool v;
2790  	/* Flush out used index updates. This is paired
2791  	 * with the barrier that the Guest executes when enabling
2792  	 * interrupts. */
2793  	smp_mb();
2794  
2795  	if (vhost_has_feature(vq, VIRTIO_F_NOTIFY_ON_EMPTY) &&
2796  	    unlikely(vq->avail_idx == vq->last_avail_idx))
2797  		return true;
2798  
2799  	if (!vhost_has_feature(vq, VIRTIO_RING_F_EVENT_IDX)) {
2800  		__virtio16 flags;
2801  		if (vhost_get_avail_flags(vq, &flags)) {
2802  			vq_err(vq, "Failed to get flags");
2803  			return true;
2804  		}
2805  		return !(flags & cpu_to_vhost16(vq, VRING_AVAIL_F_NO_INTERRUPT));
2806  	}
2807  	old = vq->signalled_used;
2808  	v = vq->signalled_used_valid;
2809  	new = vq->signalled_used = vq->last_used_idx;
2810  	vq->signalled_used_valid = true;
2811  
2812  	if (unlikely(!v))
2813  		return true;
2814  
2815  	if (vhost_get_used_event(vq, &event)) {
2816  		vq_err(vq, "Failed to get used event idx");
2817  		return true;
2818  	}
2819  	return vring_need_event(vhost16_to_cpu(vq, event), new, old);
2820  }
2821  
2822  /* This actually signals the guest, using eventfd. */
vhost_signal(struct vhost_dev * dev,struct vhost_virtqueue * vq)2823  void vhost_signal(struct vhost_dev *dev, struct vhost_virtqueue *vq)
2824  {
2825  	/* Signal the Guest tell them we used something up. */
2826  	if (vq->call_ctx.ctx && vhost_notify(dev, vq))
2827  		eventfd_signal(vq->call_ctx.ctx);
2828  }
2829  EXPORT_SYMBOL_GPL(vhost_signal);
2830  
2831  /* And here's the combo meal deal.  Supersize me! */
vhost_add_used_and_signal(struct vhost_dev * dev,struct vhost_virtqueue * vq,unsigned int head,int len)2832  void vhost_add_used_and_signal(struct vhost_dev *dev,
2833  			       struct vhost_virtqueue *vq,
2834  			       unsigned int head, int len)
2835  {
2836  	vhost_add_used(vq, head, len);
2837  	vhost_signal(dev, vq);
2838  }
2839  EXPORT_SYMBOL_GPL(vhost_add_used_and_signal);
2840  
2841  /* multi-buffer version of vhost_add_used_and_signal */
vhost_add_used_and_signal_n(struct vhost_dev * dev,struct vhost_virtqueue * vq,struct vring_used_elem * heads,unsigned count)2842  void vhost_add_used_and_signal_n(struct vhost_dev *dev,
2843  				 struct vhost_virtqueue *vq,
2844  				 struct vring_used_elem *heads, unsigned count)
2845  {
2846  	vhost_add_used_n(vq, heads, count);
2847  	vhost_signal(dev, vq);
2848  }
2849  EXPORT_SYMBOL_GPL(vhost_add_used_and_signal_n);
2850  
2851  /* return true if we're sure that avaiable ring is empty */
vhost_vq_avail_empty(struct vhost_dev * dev,struct vhost_virtqueue * vq)2852  bool vhost_vq_avail_empty(struct vhost_dev *dev, struct vhost_virtqueue *vq)
2853  {
2854  	int r;
2855  
2856  	if (vq->avail_idx != vq->last_avail_idx)
2857  		return false;
2858  
2859  	r = vhost_get_avail_idx(vq);
2860  
2861  	/* Note: we treat error as non-empty here */
2862  	return r == 0;
2863  }
2864  EXPORT_SYMBOL_GPL(vhost_vq_avail_empty);
2865  
2866  /* OK, now we need to know about added descriptors. */
vhost_enable_notify(struct vhost_dev * dev,struct vhost_virtqueue * vq)2867  bool vhost_enable_notify(struct vhost_dev *dev, struct vhost_virtqueue *vq)
2868  {
2869  	int r;
2870  
2871  	if (!(vq->used_flags & VRING_USED_F_NO_NOTIFY))
2872  		return false;
2873  	vq->used_flags &= ~VRING_USED_F_NO_NOTIFY;
2874  	if (!vhost_has_feature(vq, VIRTIO_RING_F_EVENT_IDX)) {
2875  		r = vhost_update_used_flags(vq);
2876  		if (r) {
2877  			vq_err(vq, "Failed to enable notification at %p: %d\n",
2878  			       &vq->used->flags, r);
2879  			return false;
2880  		}
2881  	} else {
2882  		r = vhost_update_avail_event(vq);
2883  		if (r) {
2884  			vq_err(vq, "Failed to update avail event index at %p: %d\n",
2885  			       vhost_avail_event(vq), r);
2886  			return false;
2887  		}
2888  	}
2889  	/* They could have slipped one in as we were doing that: make
2890  	 * sure it's written, then check again. */
2891  	smp_mb();
2892  
2893  	r = vhost_get_avail_idx(vq);
2894  	/* Note: we treat error as empty here */
2895  	if (unlikely(r < 0))
2896  		return false;
2897  
2898  	return r;
2899  }
2900  EXPORT_SYMBOL_GPL(vhost_enable_notify);
2901  
2902  /* We don't need to be notified again. */
vhost_disable_notify(struct vhost_dev * dev,struct vhost_virtqueue * vq)2903  void vhost_disable_notify(struct vhost_dev *dev, struct vhost_virtqueue *vq)
2904  {
2905  	int r;
2906  
2907  	if (vq->used_flags & VRING_USED_F_NO_NOTIFY)
2908  		return;
2909  	vq->used_flags |= VRING_USED_F_NO_NOTIFY;
2910  	if (!vhost_has_feature(vq, VIRTIO_RING_F_EVENT_IDX)) {
2911  		r = vhost_update_used_flags(vq);
2912  		if (r)
2913  			vq_err(vq, "Failed to disable notification at %p: %d\n",
2914  			       &vq->used->flags, r);
2915  	}
2916  }
2917  EXPORT_SYMBOL_GPL(vhost_disable_notify);
2918  
2919  /* Create a new message. */
vhost_new_msg(struct vhost_virtqueue * vq,int type)2920  struct vhost_msg_node *vhost_new_msg(struct vhost_virtqueue *vq, int type)
2921  {
2922  	/* Make sure all padding within the structure is initialized. */
2923  	struct vhost_msg_node *node = kzalloc(sizeof(*node), GFP_KERNEL);
2924  	if (!node)
2925  		return NULL;
2926  
2927  	node->vq = vq;
2928  	node->msg.type = type;
2929  	return node;
2930  }
2931  EXPORT_SYMBOL_GPL(vhost_new_msg);
2932  
vhost_enqueue_msg(struct vhost_dev * dev,struct list_head * head,struct vhost_msg_node * node)2933  void vhost_enqueue_msg(struct vhost_dev *dev, struct list_head *head,
2934  		       struct vhost_msg_node *node)
2935  {
2936  	spin_lock(&dev->iotlb_lock);
2937  	list_add_tail(&node->node, head);
2938  	spin_unlock(&dev->iotlb_lock);
2939  
2940  	wake_up_interruptible_poll(&dev->wait, EPOLLIN | EPOLLRDNORM);
2941  }
2942  EXPORT_SYMBOL_GPL(vhost_enqueue_msg);
2943  
vhost_dequeue_msg(struct vhost_dev * dev,struct list_head * head)2944  struct vhost_msg_node *vhost_dequeue_msg(struct vhost_dev *dev,
2945  					 struct list_head *head)
2946  {
2947  	struct vhost_msg_node *node = NULL;
2948  
2949  	spin_lock(&dev->iotlb_lock);
2950  	if (!list_empty(head)) {
2951  		node = list_first_entry(head, struct vhost_msg_node,
2952  					node);
2953  		list_del(&node->node);
2954  	}
2955  	spin_unlock(&dev->iotlb_lock);
2956  
2957  	return node;
2958  }
2959  EXPORT_SYMBOL_GPL(vhost_dequeue_msg);
2960  
vhost_set_backend_features(struct vhost_dev * dev,u64 features)2961  void vhost_set_backend_features(struct vhost_dev *dev, u64 features)
2962  {
2963  	struct vhost_virtqueue *vq;
2964  	int i;
2965  
2966  	mutex_lock(&dev->mutex);
2967  	for (i = 0; i < dev->nvqs; ++i) {
2968  		vq = dev->vqs[i];
2969  		mutex_lock(&vq->mutex);
2970  		vq->acked_backend_features = features;
2971  		mutex_unlock(&vq->mutex);
2972  	}
2973  	mutex_unlock(&dev->mutex);
2974  }
2975  EXPORT_SYMBOL_GPL(vhost_set_backend_features);
2976  
vhost_init(void)2977  static int __init vhost_init(void)
2978  {
2979  	return 0;
2980  }
2981  
vhost_exit(void)2982  static void __exit vhost_exit(void)
2983  {
2984  }
2985  
2986  module_init(vhost_init);
2987  module_exit(vhost_exit);
2988  
2989  MODULE_VERSION("0.0.1");
2990  MODULE_LICENSE("GPL v2");
2991  MODULE_AUTHOR("Michael S. Tsirkin");
2992  MODULE_DESCRIPTION("Host kernel accelerator for virtio");
2993