summaryrefslogtreecommitdiff
path: root/drivers/vhost
diff options
context:
space:
mode:
Diffstat (limited to 'drivers/vhost')
-rw-r--r--drivers/vhost/net.c19
-rw-r--r--drivers/vhost/vhost.c165
-rw-r--r--drivers/vhost/vhost.h28
3 files changed, 160 insertions, 52 deletions
diff --git a/drivers/vhost/net.c b/drivers/vhost/net.c
index 29e850a..4b4da5b 100644
--- a/drivers/vhost/net.c
+++ b/drivers/vhost/net.c
@@ -127,7 +127,10 @@ static void handle_tx(struct vhost_net *net)
size_t len, total_len = 0;
int err, wmem;
size_t hdr_size;
- struct socket *sock = rcu_dereference(vq->private_data);
+ struct socket *sock;
+
+ sock = rcu_dereference_check(vq->private_data,
+ lockdep_is_held(&vq->mutex));
if (!sock)
return;
@@ -243,7 +246,7 @@ static int get_rx_bufs(struct vhost_virtqueue *vq,
int r, nlogs = 0;
while (datalen > 0) {
- if (unlikely(headcount >= VHOST_NET_MAX_SG)) {
+ if (unlikely(seg >= UIO_MAXIOV)) {
r = -ENOBUFS;
goto err;
}
@@ -582,7 +585,10 @@ static void vhost_net_disable_vq(struct vhost_net *n,
static void vhost_net_enable_vq(struct vhost_net *n,
struct vhost_virtqueue *vq)
{
- struct socket *sock = vq->private_data;
+ struct socket *sock;
+
+ sock = rcu_dereference_protected(vq->private_data,
+ lockdep_is_held(&vq->mutex));
if (!sock)
return;
if (vq == n->vqs + VHOST_NET_VQ_TX) {
@@ -598,7 +604,8 @@ static struct socket *vhost_net_stop_vq(struct vhost_net *n,
struct socket *sock;
mutex_lock(&vq->mutex);
- sock = vq->private_data;
+ sock = rcu_dereference_protected(vq->private_data,
+ lockdep_is_held(&vq->mutex));
vhost_net_disable_vq(n, vq);
rcu_assign_pointer(vq->private_data, NULL);
mutex_unlock(&vq->mutex);
@@ -736,7 +743,8 @@ static long vhost_net_set_backend(struct vhost_net *n, unsigned index, int fd)
}
/* start polling new socket */
- oldsock = vq->private_data;
+ oldsock = rcu_dereference_protected(vq->private_data,
+ lockdep_is_held(&vq->mutex));
if (sock != oldsock) {
vhost_net_disable_vq(n, vq);
rcu_assign_pointer(vq->private_data, sock);
@@ -869,6 +877,7 @@ static const struct file_operations vhost_net_fops = {
.compat_ioctl = vhost_net_compat_ioctl,
#endif
.open = vhost_net_open,
+ .llseek = noop_llseek,
};
static struct miscdevice vhost_net_misc = {
diff --git a/drivers/vhost/vhost.c b/drivers/vhost/vhost.c
index e05557d..94701ff 100644
--- a/drivers/vhost/vhost.c
+++ b/drivers/vhost/vhost.c
@@ -60,22 +60,25 @@ static int vhost_poll_wakeup(wait_queue_t *wait, unsigned mode, int sync,
return 0;
}
+static void vhost_work_init(struct vhost_work *work, vhost_work_fn_t fn)
+{
+ INIT_LIST_HEAD(&work->node);
+ work->fn = fn;
+ init_waitqueue_head(&work->done);
+ work->flushing = 0;
+ work->queue_seq = work->done_seq = 0;
+}
+
/* Init poll structure */
void vhost_poll_init(struct vhost_poll *poll, vhost_work_fn_t fn,
unsigned long mask, struct vhost_dev *dev)
{
- struct vhost_work *work = &poll->work;
-
init_waitqueue_func_entry(&poll->wait, vhost_poll_wakeup);
init_poll_funcptr(&poll->table, vhost_poll_func);
poll->mask = mask;
poll->dev = dev;
- INIT_LIST_HEAD(&work->node);
- work->fn = fn;
- init_waitqueue_head(&work->done);
- work->flushing = 0;
- work->queue_seq = work->done_seq = 0;
+ vhost_work_init(&poll->work, fn);
}
/* Start polling a file. We add ourselves to file's wait queue. The caller must
@@ -95,35 +98,38 @@ void vhost_poll_stop(struct vhost_poll *poll)
remove_wait_queue(poll->wqh, &poll->wait);
}
-/* Flush any work that has been scheduled. When calling this, don't hold any
- * locks that are also used by the callback. */
-void vhost_poll_flush(struct vhost_poll *poll)
+static void vhost_work_flush(struct vhost_dev *dev, struct vhost_work *work)
{
- struct vhost_work *work = &poll->work;
unsigned seq;
int left;
int flushing;
- spin_lock_irq(&poll->dev->work_lock);
+ spin_lock_irq(&dev->work_lock);
seq = work->queue_seq;
work->flushing++;
- spin_unlock_irq(&poll->dev->work_lock);
+ spin_unlock_irq(&dev->work_lock);
wait_event(work->done, ({
- spin_lock_irq(&poll->dev->work_lock);
+ spin_lock_irq(&dev->work_lock);
left = seq - work->done_seq <= 0;
- spin_unlock_irq(&poll->dev->work_lock);
+ spin_unlock_irq(&dev->work_lock);
left;
}));
- spin_lock_irq(&poll->dev->work_lock);
+ spin_lock_irq(&dev->work_lock);
flushing = --work->flushing;
- spin_unlock_irq(&poll->dev->work_lock);
+ spin_unlock_irq(&dev->work_lock);
BUG_ON(flushing < 0);
}
-void vhost_poll_queue(struct vhost_poll *poll)
+/* Flush any work that has been scheduled. When calling this, don't hold any
+ * locks that are also used by the callback. */
+void vhost_poll_flush(struct vhost_poll *poll)
+{
+ vhost_work_flush(poll->dev, &poll->work);
+}
+
+static inline void vhost_work_queue(struct vhost_dev *dev,
+ struct vhost_work *work)
{
- struct vhost_dev *dev = poll->dev;
- struct vhost_work *work = &poll->work;
unsigned long flags;
spin_lock_irqsave(&dev->work_lock, flags);
@@ -135,6 +141,11 @@ void vhost_poll_queue(struct vhost_poll *poll)
spin_unlock_irqrestore(&dev->work_lock, flags);
}
+void vhost_poll_queue(struct vhost_poll *poll)
+{
+ vhost_work_queue(poll->dev, &poll->work);
+}
+
static void vhost_vq_reset(struct vhost_dev *dev,
struct vhost_virtqueue *vq)
{
@@ -201,6 +212,45 @@ static int vhost_worker(void *data)
}
}
+/* Helper to allocate iovec buffers for all vqs. */
+static long vhost_dev_alloc_iovecs(struct vhost_dev *dev)
+{
+ int i;
+ for (i = 0; i < dev->nvqs; ++i) {
+ dev->vqs[i].indirect = kmalloc(sizeof *dev->vqs[i].indirect *
+ UIO_MAXIOV, GFP_KERNEL);
+ dev->vqs[i].log = kmalloc(sizeof *dev->vqs[i].log * UIO_MAXIOV,
+ GFP_KERNEL);
+ dev->vqs[i].heads = kmalloc(sizeof *dev->vqs[i].heads *
+ UIO_MAXIOV, GFP_KERNEL);
+
+ if (!dev->vqs[i].indirect || !dev->vqs[i].log ||
+ !dev->vqs[i].heads)
+ goto err_nomem;
+ }
+ return 0;
+err_nomem:
+ for (; i >= 0; --i) {
+ kfree(dev->vqs[i].indirect);
+ kfree(dev->vqs[i].log);
+ kfree(dev->vqs[i].heads);
+ }
+ return -ENOMEM;
+}
+
+static void vhost_dev_free_iovecs(struct vhost_dev *dev)
+{
+ int i;
+ for (i = 0; i < dev->nvqs; ++i) {
+ kfree(dev->vqs[i].indirect);
+ dev->vqs[i].indirect = NULL;
+ kfree(dev->vqs[i].log);
+ dev->vqs[i].log = NULL;
+ kfree(dev->vqs[i].heads);
+ dev->vqs[i].heads = NULL;
+ }
+}
+
long vhost_dev_init(struct vhost_dev *dev,
struct vhost_virtqueue *vqs, int nvqs)
{
@@ -218,6 +268,9 @@ long vhost_dev_init(struct vhost_dev *dev,
dev->worker = NULL;
for (i = 0; i < dev->nvqs; ++i) {
+ dev->vqs[i].log = NULL;
+ dev->vqs[i].indirect = NULL;
+ dev->vqs[i].heads = NULL;
dev->vqs[i].dev = dev;
mutex_init(&dev->vqs[i].mutex);
vhost_vq_reset(dev, dev->vqs + i);
@@ -236,6 +289,29 @@ long vhost_dev_check_owner(struct vhost_dev *dev)
return dev->mm == current->mm ? 0 : -EPERM;
}
+struct vhost_attach_cgroups_struct {
+ struct vhost_work work;
+ struct task_struct *owner;
+ int ret;
+};
+
+static void vhost_attach_cgroups_work(struct vhost_work *work)
+{
+ struct vhost_attach_cgroups_struct *s;
+ s = container_of(work, struct vhost_attach_cgroups_struct, work);
+ s->ret = cgroup_attach_task_all(s->owner, current);
+}
+
+static int vhost_attach_cgroups(struct vhost_dev *dev)
+{
+ struct vhost_attach_cgroups_struct attach;
+ attach.owner = current;
+ vhost_work_init(&attach.work, vhost_attach_cgroups_work);
+ vhost_work_queue(dev, &attach.work);
+ vhost_work_flush(dev, &attach.work);
+ return attach.ret;
+}
+
/* Caller should have device mutex */
static long vhost_dev_set_owner(struct vhost_dev *dev)
{
@@ -255,14 +331,20 @@ static long vhost_dev_set_owner(struct vhost_dev *dev)
}
dev->worker = worker;
- err = cgroup_attach_task_current_cg(worker);
+ wake_up_process(worker); /* avoid contributing to loadavg */
+
+ err = vhost_attach_cgroups(dev);
+ if (err)
+ goto err_cgroup;
+
+ err = vhost_dev_alloc_iovecs(dev);
if (err)
goto err_cgroup;
- wake_up_process(worker); /* avoid contributing to loadavg */
return 0;
err_cgroup:
kthread_stop(worker);
+ dev->worker = NULL;
err_worker:
if (dev->mm)
mmput(dev->mm);
@@ -284,7 +366,7 @@ long vhost_dev_reset_owner(struct vhost_dev *dev)
vhost_dev_cleanup(dev);
memory->nregions = 0;
- dev->memory = memory;
+ RCU_INIT_POINTER(dev->memory, memory);
return 0;
}
@@ -309,6 +391,7 @@ void vhost_dev_cleanup(struct vhost_dev *dev)
fput(dev->vqs[i].call);
vhost_vq_reset(dev, dev->vqs + i);
}
+ vhost_dev_free_iovecs(dev);
if (dev->log_ctx)
eventfd_ctx_put(dev->log_ctx);
dev->log_ctx = NULL;
@@ -316,14 +399,18 @@ void vhost_dev_cleanup(struct vhost_dev *dev)
fput(dev->log_file);
dev->log_file = NULL;
/* No one will access memory at this point */
- kfree(dev->memory);
- dev->memory = NULL;
+ kfree(rcu_dereference_protected(dev->memory,
+ lockdep_is_held(&dev->mutex)));
+ RCU_INIT_POINTER(dev->memory, NULL);
if (dev->mm)
mmput(dev->mm);
dev->mm = NULL;
WARN_ON(!list_empty(&dev->work_list));
- kthread_stop(dev->worker);
+ if (dev->worker) {
+ kthread_stop(dev->worker);
+ dev->worker = NULL;
+ }
}
static int log_access_ok(void __user *log_base, u64 addr, unsigned long sz)
@@ -332,7 +419,7 @@ static int log_access_ok(void __user *log_base, u64 addr, unsigned long sz)
/* Make sure 64 bit math will not overflow. */
if (a > ULONG_MAX - (unsigned long)log_base ||
a + (unsigned long)log_base > ULONG_MAX)
- return -EFAULT;
+ return 0;
return access_ok(VERIFY_WRITE, log_base + a,
(sz + VHOST_PAGE_SIZE * 8 - 1) / VHOST_PAGE_SIZE / 8);
@@ -401,14 +488,22 @@ static int vq_access_ok(unsigned int num,
/* Caller should have device mutex but not vq mutex */
int vhost_log_access_ok(struct vhost_dev *dev)
{
- return memory_access_ok(dev, dev->memory, 1);
+ struct vhost_memory *mp;
+
+ mp = rcu_dereference_protected(dev->memory,
+ lockdep_is_held(&dev->mutex));
+ return memory_access_ok(dev, mp, 1);
}
/* Verify access for write logging. */
/* Caller should have vq mutex and device mutex */
static int vq_log_access_ok(struct vhost_virtqueue *vq, void __user *log_base)
{
- return vq_memory_access_ok(log_base, vq->dev->memory,
+ struct vhost_memory *mp;
+
+ mp = rcu_dereference_protected(vq->dev->memory,
+ lockdep_is_held(&vq->mutex));
+ return vq_memory_access_ok(log_base, mp,
vhost_has_feature(vq->dev, VHOST_F_LOG_ALL)) &&
(!vq->log_used || log_access_ok(log_base, vq->log_addr,
sizeof *vq->used +
@@ -448,7 +543,8 @@ static long vhost_set_memory(struct vhost_dev *d, struct vhost_memory __user *m)
kfree(newmem);
return -EFAULT;
}
- oldmem = d->memory;
+ oldmem = rcu_dereference_protected(d->memory,
+ lockdep_is_held(&d->mutex));
rcu_assign_pointer(d->memory, newmem);
synchronize_rcu();
kfree(oldmem);
@@ -819,11 +915,12 @@ int vhost_log_write(struct vhost_virtqueue *vq, struct vhost_log *log,
if (r < 0)
return r;
len -= l;
- if (!len)
+ if (!len) {
+ if (vq->log_ctx)
+ eventfd_signal(vq->log_ctx, 1);
return 0;
+ }
}
- if (vq->log_ctx)
- eventfd_signal(vq->log_ctx, 1);
/* Length written exceeds what we have stored. This is a bug. */
BUG();
return 0;
@@ -907,7 +1004,7 @@ static int get_indirect(struct vhost_dev *dev, struct vhost_virtqueue *vq,
}
ret = translate_desc(dev, indirect->addr, indirect->len, vq->indirect,
- ARRAY_SIZE(vq->indirect));
+ UIO_MAXIOV);
if (unlikely(ret < 0)) {
vq_err(vq, "Translation failure %d in indirect.\n", ret);
return ret;
diff --git a/drivers/vhost/vhost.h b/drivers/vhost/vhost.h
index afd7729..073d06a 100644
--- a/drivers/vhost/vhost.h
+++ b/drivers/vhost/vhost.h
@@ -15,11 +15,6 @@
struct vhost_device;
-enum {
- /* Enough place for all fragments, head, and virtio net header. */
- VHOST_NET_MAX_SG = MAX_SKB_FRAGS + 2,
-};
-
struct vhost_work;
typedef void (*vhost_work_fn_t)(struct vhost_work *work);
@@ -93,12 +88,15 @@ struct vhost_virtqueue {
bool log_used;
u64 log_addr;
- struct iovec indirect[VHOST_NET_MAX_SG];
- struct iovec iov[VHOST_NET_MAX_SG];
- struct iovec hdr[VHOST_NET_MAX_SG];
+ struct iovec iov[UIO_MAXIOV];
+ /* hdr is used to store the virtio header.
+ * Since each iovec has >= 1 byte length, we never need more than
+ * header length entries to store the header. */
+ struct iovec hdr[sizeof(struct virtio_net_hdr_mrg_rxbuf)];
+ struct iovec *indirect;
size_t vhost_hlen;
size_t sock_hlen;
- struct vring_used_elem heads[VHOST_NET_MAX_SG];
+ struct vring_used_elem *heads;
/* We use a kind of RCU to access private pointer.
* All readers access it from worker, which makes it possible to
* flush the vhost_work instead of synchronize_rcu. Therefore readers do
@@ -106,17 +104,17 @@ struct vhost_virtqueue {
* vhost_work execution acts instead of rcu_read_lock() and the end of
* vhost_work execution acts instead of rcu_read_lock().
* Writers use virtqueue mutex. */
- void *private_data;
+ void __rcu *private_data;
/* Log write descriptors */
void __user *log_base;
- struct vhost_log log[VHOST_NET_MAX_SG];
+ struct vhost_log *log;
};
struct vhost_dev {
/* Readers use RCU to access memory table pointer
* log base pointer and features.
* Writers use mutex below.*/
- struct vhost_memory *memory;
+ struct vhost_memory __rcu *memory;
struct mm_struct *mm;
struct mutex mutex;
unsigned acked_features;
@@ -173,7 +171,11 @@ enum {
static inline int vhost_has_feature(struct vhost_dev *dev, int bit)
{
- unsigned acked_features = rcu_dereference(dev->acked_features);
+ unsigned acked_features;
+
+ acked_features =
+ rcu_dereference_index_check(dev->acked_features,
+ lockdep_is_held(&dev->mutex));
return acked_features & (1 << bit);
}