root/drivers/vhost/vsock.c

/* [<][>][^][v][top][bottom][index][help] */

DEFINITIONS

This source file includes following definitions.
  1. vhost_transport_get_local_cid
  2. vhost_vsock_get
  3. vhost_transport_do_send_pkt
  4. vhost_transport_send_pkt_work
  5. vhost_transport_send_pkt
  6. vhost_transport_cancel_pkt
  7. vhost_vsock_alloc_pkt
  8. vhost_vsock_more_replies
  9. vhost_vsock_handle_tx_kick
  10. vhost_vsock_handle_rx_kick
  11. vhost_vsock_start
  12. vhost_vsock_stop
  13. vhost_vsock_free
  14. vhost_vsock_dev_open
  15. vhost_vsock_flush
  16. vhost_vsock_reset_orphans
  17. vhost_vsock_dev_release
  18. vhost_vsock_set_cid
  19. vhost_vsock_set_features
  20. vhost_vsock_dev_ioctl
  21. vhost_vsock_dev_compat_ioctl
  22. vhost_vsock_init
  23. vhost_vsock_exit

   1 // SPDX-License-Identifier: GPL-2.0-only
   2 /*
   3  * vhost transport for vsock
   4  *
   5  * Copyright (C) 2013-2015 Red Hat, Inc.
   6  * Author: Asias He <asias@redhat.com>
   7  *         Stefan Hajnoczi <stefanha@redhat.com>
   8  */
   9 #include <linux/miscdevice.h>
  10 #include <linux/atomic.h>
  11 #include <linux/module.h>
  12 #include <linux/mutex.h>
  13 #include <linux/vmalloc.h>
  14 #include <net/sock.h>
  15 #include <linux/virtio_vsock.h>
  16 #include <linux/vhost.h>
  17 #include <linux/hashtable.h>
  18 
  19 #include <net/af_vsock.h>
  20 #include "vhost.h"
  21 
  22 #define VHOST_VSOCK_DEFAULT_HOST_CID    2
  23 /* Max number of bytes transferred before requeueing the job.
  24  * Using this limit prevents one virtqueue from starving others. */
  25 #define VHOST_VSOCK_WEIGHT 0x80000
  26 /* Max number of packets transferred before requeueing the job.
  27  * Using this limit prevents one virtqueue from starving others with
  28  * small pkts.
  29  */
  30 #define VHOST_VSOCK_PKT_WEIGHT 256
  31 
  32 enum {
  33         VHOST_VSOCK_FEATURES = VHOST_FEATURES,
  34 };
  35 
  36 /* Used to track all the vhost_vsock instances on the system. */
  37 static DEFINE_MUTEX(vhost_vsock_mutex);
  38 static DEFINE_READ_MOSTLY_HASHTABLE(vhost_vsock_hash, 8);
  39 
  40 struct vhost_vsock {
  41         struct vhost_dev dev;
  42         struct vhost_virtqueue vqs[2];
  43 
  44         /* Link to global vhost_vsock_hash, writes use vhost_vsock_mutex */
  45         struct hlist_node hash;
  46 
  47         struct vhost_work send_pkt_work;
  48         spinlock_t send_pkt_list_lock;
  49         struct list_head send_pkt_list; /* host->guest pending packets */
  50 
  51         atomic_t queued_replies;
  52 
  53         u32 guest_cid;
  54 };
  55 
  56 static u32 vhost_transport_get_local_cid(void)
  57 {
  58         return VHOST_VSOCK_DEFAULT_HOST_CID;
  59 }
  60 
  61 /* Callers that dereference the return value must hold vhost_vsock_mutex or the
  62  * RCU read lock.
  63  */
  64 static struct vhost_vsock *vhost_vsock_get(u32 guest_cid)
  65 {
  66         struct vhost_vsock *vsock;
  67 
  68         hash_for_each_possible_rcu(vhost_vsock_hash, vsock, hash, guest_cid) {
  69                 u32 other_cid = vsock->guest_cid;
  70 
  71                 /* Skip instances that have no CID yet */
  72                 if (other_cid == 0)
  73                         continue;
  74 
  75                 if (other_cid == guest_cid)
  76                         return vsock;
  77 
  78         }
  79 
  80         return NULL;
  81 }
  82 
  83 static void
  84 vhost_transport_do_send_pkt(struct vhost_vsock *vsock,
  85                             struct vhost_virtqueue *vq)
  86 {
  87         struct vhost_virtqueue *tx_vq = &vsock->vqs[VSOCK_VQ_TX];
  88         int pkts = 0, total_len = 0;
  89         bool added = false;
  90         bool restart_tx = false;
  91 
  92         mutex_lock(&vq->mutex);
  93 
  94         if (!vq->private_data)
  95                 goto out;
  96 
  97         /* Avoid further vmexits, we're already processing the virtqueue */
  98         vhost_disable_notify(&vsock->dev, vq);
  99 
 100         do {
 101                 struct virtio_vsock_pkt *pkt;
 102                 struct iov_iter iov_iter;
 103                 unsigned out, in;
 104                 size_t nbytes;
 105                 size_t iov_len, payload_len;
 106                 int head;
 107 
 108                 spin_lock_bh(&vsock->send_pkt_list_lock);
 109                 if (list_empty(&vsock->send_pkt_list)) {
 110                         spin_unlock_bh(&vsock->send_pkt_list_lock);
 111                         vhost_enable_notify(&vsock->dev, vq);
 112                         break;
 113                 }
 114 
 115                 pkt = list_first_entry(&vsock->send_pkt_list,
 116                                        struct virtio_vsock_pkt, list);
 117                 list_del_init(&pkt->list);
 118                 spin_unlock_bh(&vsock->send_pkt_list_lock);
 119 
 120                 head = vhost_get_vq_desc(vq, vq->iov, ARRAY_SIZE(vq->iov),
 121                                          &out, &in, NULL, NULL);
 122                 if (head < 0) {
 123                         spin_lock_bh(&vsock->send_pkt_list_lock);
 124                         list_add(&pkt->list, &vsock->send_pkt_list);
 125                         spin_unlock_bh(&vsock->send_pkt_list_lock);
 126                         break;
 127                 }
 128 
 129                 if (head == vq->num) {
 130                         spin_lock_bh(&vsock->send_pkt_list_lock);
 131                         list_add(&pkt->list, &vsock->send_pkt_list);
 132                         spin_unlock_bh(&vsock->send_pkt_list_lock);
 133 
 134                         /* We cannot finish yet if more buffers snuck in while
 135                          * re-enabling notify.
 136                          */
 137                         if (unlikely(vhost_enable_notify(&vsock->dev, vq))) {
 138                                 vhost_disable_notify(&vsock->dev, vq);
 139                                 continue;
 140                         }
 141                         break;
 142                 }
 143 
 144                 if (out) {
 145                         virtio_transport_free_pkt(pkt);
 146                         vq_err(vq, "Expected 0 output buffers, got %u\n", out);
 147                         break;
 148                 }
 149 
 150                 iov_len = iov_length(&vq->iov[out], in);
 151                 if (iov_len < sizeof(pkt->hdr)) {
 152                         virtio_transport_free_pkt(pkt);
 153                         vq_err(vq, "Buffer len [%zu] too small\n", iov_len);
 154                         break;
 155                 }
 156 
 157                 iov_iter_init(&iov_iter, READ, &vq->iov[out], in, iov_len);
 158                 payload_len = pkt->len - pkt->off;
 159 
 160                 /* If the packet is greater than the space available in the
 161                  * buffer, we split it using multiple buffers.
 162                  */
 163                 if (payload_len > iov_len - sizeof(pkt->hdr))
 164                         payload_len = iov_len - sizeof(pkt->hdr);
 165 
 166                 /* Set the correct length in the header */
 167                 pkt->hdr.len = cpu_to_le32(payload_len);
 168 
 169                 nbytes = copy_to_iter(&pkt->hdr, sizeof(pkt->hdr), &iov_iter);
 170                 if (nbytes != sizeof(pkt->hdr)) {
 171                         virtio_transport_free_pkt(pkt);
 172                         vq_err(vq, "Faulted on copying pkt hdr\n");
 173                         break;
 174                 }
 175 
 176                 nbytes = copy_to_iter(pkt->buf + pkt->off, payload_len,
 177                                       &iov_iter);
 178                 if (nbytes != payload_len) {
 179                         virtio_transport_free_pkt(pkt);
 180                         vq_err(vq, "Faulted on copying pkt buf\n");
 181                         break;
 182                 }
 183 
 184                 /* Deliver to monitoring devices all packets that we
 185                  * will transmit.
 186                  */
 187                 virtio_transport_deliver_tap_pkt(pkt);
 188 
 189                 vhost_add_used(vq, head, sizeof(pkt->hdr) + payload_len);
 190                 added = true;
 191 
 192                 pkt->off += payload_len;
 193                 total_len += payload_len;
 194 
 195                 /* If we didn't send all the payload we can requeue the packet
 196                  * to send it with the next available buffer.
 197                  */
 198                 if (pkt->off < pkt->len) {
 199                         spin_lock_bh(&vsock->send_pkt_list_lock);
 200                         list_add(&pkt->list, &vsock->send_pkt_list);
 201                         spin_unlock_bh(&vsock->send_pkt_list_lock);
 202                 } else {
 203                         if (pkt->reply) {
 204                                 int val;
 205 
 206                                 val = atomic_dec_return(&vsock->queued_replies);
 207 
 208                                 /* Do we have resources to resume tx
 209                                  * processing?
 210                                  */
 211                                 if (val + 1 == tx_vq->num)
 212                                         restart_tx = true;
 213                         }
 214 
 215                         virtio_transport_free_pkt(pkt);
 216                 }
 217         } while(likely(!vhost_exceeds_weight(vq, ++pkts, total_len)));
 218         if (added)
 219                 vhost_signal(&vsock->dev, vq);
 220 
 221 out:
 222         mutex_unlock(&vq->mutex);
 223 
 224         if (restart_tx)
 225                 vhost_poll_queue(&tx_vq->poll);
 226 }
 227 
 228 static void vhost_transport_send_pkt_work(struct vhost_work *work)
 229 {
 230         struct vhost_virtqueue *vq;
 231         struct vhost_vsock *vsock;
 232 
 233         vsock = container_of(work, struct vhost_vsock, send_pkt_work);
 234         vq = &vsock->vqs[VSOCK_VQ_RX];
 235 
 236         vhost_transport_do_send_pkt(vsock, vq);
 237 }
 238 
 239 static int
 240 vhost_transport_send_pkt(struct virtio_vsock_pkt *pkt)
 241 {
 242         struct vhost_vsock *vsock;
 243         int len = pkt->len;
 244 
 245         rcu_read_lock();
 246 
 247         /* Find the vhost_vsock according to guest context id  */
 248         vsock = vhost_vsock_get(le64_to_cpu(pkt->hdr.dst_cid));
 249         if (!vsock) {
 250                 rcu_read_unlock();
 251                 virtio_transport_free_pkt(pkt);
 252                 return -ENODEV;
 253         }
 254 
 255         if (pkt->reply)
 256                 atomic_inc(&vsock->queued_replies);
 257 
 258         spin_lock_bh(&vsock->send_pkt_list_lock);
 259         list_add_tail(&pkt->list, &vsock->send_pkt_list);
 260         spin_unlock_bh(&vsock->send_pkt_list_lock);
 261 
 262         vhost_work_queue(&vsock->dev, &vsock->send_pkt_work);
 263 
 264         rcu_read_unlock();
 265         return len;
 266 }
 267 
 268 static int
 269 vhost_transport_cancel_pkt(struct vsock_sock *vsk)
 270 {
 271         struct vhost_vsock *vsock;
 272         struct virtio_vsock_pkt *pkt, *n;
 273         int cnt = 0;
 274         int ret = -ENODEV;
 275         LIST_HEAD(freeme);
 276 
 277         rcu_read_lock();
 278 
 279         /* Find the vhost_vsock according to guest context id  */
 280         vsock = vhost_vsock_get(vsk->remote_addr.svm_cid);
 281         if (!vsock)
 282                 goto out;
 283 
 284         spin_lock_bh(&vsock->send_pkt_list_lock);
 285         list_for_each_entry_safe(pkt, n, &vsock->send_pkt_list, list) {
 286                 if (pkt->vsk != vsk)
 287                         continue;
 288                 list_move(&pkt->list, &freeme);
 289         }
 290         spin_unlock_bh(&vsock->send_pkt_list_lock);
 291 
 292         list_for_each_entry_safe(pkt, n, &freeme, list) {
 293                 if (pkt->reply)
 294                         cnt++;
 295                 list_del(&pkt->list);
 296                 virtio_transport_free_pkt(pkt);
 297         }
 298 
 299         if (cnt) {
 300                 struct vhost_virtqueue *tx_vq = &vsock->vqs[VSOCK_VQ_TX];
 301                 int new_cnt;
 302 
 303                 new_cnt = atomic_sub_return(cnt, &vsock->queued_replies);
 304                 if (new_cnt + cnt >= tx_vq->num && new_cnt < tx_vq->num)
 305                         vhost_poll_queue(&tx_vq->poll);
 306         }
 307 
 308         ret = 0;
 309 out:
 310         rcu_read_unlock();
 311         return ret;
 312 }
 313 
 314 static struct virtio_vsock_pkt *
 315 vhost_vsock_alloc_pkt(struct vhost_virtqueue *vq,
 316                       unsigned int out, unsigned int in)
 317 {
 318         struct virtio_vsock_pkt *pkt;
 319         struct iov_iter iov_iter;
 320         size_t nbytes;
 321         size_t len;
 322 
 323         if (in != 0) {
 324                 vq_err(vq, "Expected 0 input buffers, got %u\n", in);
 325                 return NULL;
 326         }
 327 
 328         pkt = kzalloc(sizeof(*pkt), GFP_KERNEL);
 329         if (!pkt)
 330                 return NULL;
 331 
 332         len = iov_length(vq->iov, out);
 333         iov_iter_init(&iov_iter, WRITE, vq->iov, out, len);
 334 
 335         nbytes = copy_from_iter(&pkt->hdr, sizeof(pkt->hdr), &iov_iter);
 336         if (nbytes != sizeof(pkt->hdr)) {
 337                 vq_err(vq, "Expected %zu bytes for pkt->hdr, got %zu bytes\n",
 338                        sizeof(pkt->hdr), nbytes);
 339                 kfree(pkt);
 340                 return NULL;
 341         }
 342 
 343         if (le16_to_cpu(pkt->hdr.type) == VIRTIO_VSOCK_TYPE_STREAM)
 344                 pkt->len = le32_to_cpu(pkt->hdr.len);
 345 
 346         /* No payload */
 347         if (!pkt->len)
 348                 return pkt;
 349 
 350         /* The pkt is too big */
 351         if (pkt->len > VIRTIO_VSOCK_MAX_PKT_BUF_SIZE) {
 352                 kfree(pkt);
 353                 return NULL;
 354         }
 355 
 356         pkt->buf = kmalloc(pkt->len, GFP_KERNEL);
 357         if (!pkt->buf) {
 358                 kfree(pkt);
 359                 return NULL;
 360         }
 361 
 362         pkt->buf_len = pkt->len;
 363 
 364         nbytes = copy_from_iter(pkt->buf, pkt->len, &iov_iter);
 365         if (nbytes != pkt->len) {
 366                 vq_err(vq, "Expected %u byte payload, got %zu bytes\n",
 367                        pkt->len, nbytes);
 368                 virtio_transport_free_pkt(pkt);
 369                 return NULL;
 370         }
 371 
 372         return pkt;
 373 }
 374 
 375 /* Is there space left for replies to rx packets? */
 376 static bool vhost_vsock_more_replies(struct vhost_vsock *vsock)
 377 {
 378         struct vhost_virtqueue *vq = &vsock->vqs[VSOCK_VQ_TX];
 379         int val;
 380 
 381         smp_rmb(); /* paired with atomic_inc() and atomic_dec_return() */
 382         val = atomic_read(&vsock->queued_replies);
 383 
 384         return val < vq->num;
 385 }
 386 
 387 static void vhost_vsock_handle_tx_kick(struct vhost_work *work)
 388 {
 389         struct vhost_virtqueue *vq = container_of(work, struct vhost_virtqueue,
 390                                                   poll.work);
 391         struct vhost_vsock *vsock = container_of(vq->dev, struct vhost_vsock,
 392                                                  dev);
 393         struct virtio_vsock_pkt *pkt;
 394         int head, pkts = 0, total_len = 0;
 395         unsigned int out, in;
 396         bool added = false;
 397 
 398         mutex_lock(&vq->mutex);
 399 
 400         if (!vq->private_data)
 401                 goto out;
 402 
 403         vhost_disable_notify(&vsock->dev, vq);
 404         do {
 405                 u32 len;
 406 
 407                 if (!vhost_vsock_more_replies(vsock)) {
 408                         /* Stop tx until the device processes already
 409                          * pending replies.  Leave tx virtqueue
 410                          * callbacks disabled.
 411                          */
 412                         goto no_more_replies;
 413                 }
 414 
 415                 head = vhost_get_vq_desc(vq, vq->iov, ARRAY_SIZE(vq->iov),
 416                                          &out, &in, NULL, NULL);
 417                 if (head < 0)
 418                         break;
 419 
 420                 if (head == vq->num) {
 421                         if (unlikely(vhost_enable_notify(&vsock->dev, vq))) {
 422                                 vhost_disable_notify(&vsock->dev, vq);
 423                                 continue;
 424                         }
 425                         break;
 426                 }
 427 
 428                 pkt = vhost_vsock_alloc_pkt(vq, out, in);
 429                 if (!pkt) {
 430                         vq_err(vq, "Faulted on pkt\n");
 431                         continue;
 432                 }
 433 
 434                 len = pkt->len;
 435 
 436                 /* Deliver to monitoring devices all received packets */
 437                 virtio_transport_deliver_tap_pkt(pkt);
 438 
 439                 /* Only accept correctly addressed packets */
 440                 if (le64_to_cpu(pkt->hdr.src_cid) == vsock->guest_cid &&
 441                     le64_to_cpu(pkt->hdr.dst_cid) ==
 442                     vhost_transport_get_local_cid())
 443                         virtio_transport_recv_pkt(pkt);
 444                 else
 445                         virtio_transport_free_pkt(pkt);
 446 
 447                 len += sizeof(pkt->hdr);
 448                 vhost_add_used(vq, head, len);
 449                 total_len += len;
 450                 added = true;
 451         } while(likely(!vhost_exceeds_weight(vq, ++pkts, total_len)));
 452 
 453 no_more_replies:
 454         if (added)
 455                 vhost_signal(&vsock->dev, vq);
 456 
 457 out:
 458         mutex_unlock(&vq->mutex);
 459 }
 460 
 461 static void vhost_vsock_handle_rx_kick(struct vhost_work *work)
 462 {
 463         struct vhost_virtqueue *vq = container_of(work, struct vhost_virtqueue,
 464                                                 poll.work);
 465         struct vhost_vsock *vsock = container_of(vq->dev, struct vhost_vsock,
 466                                                  dev);
 467 
 468         vhost_transport_do_send_pkt(vsock, vq);
 469 }
 470 
 471 static int vhost_vsock_start(struct vhost_vsock *vsock)
 472 {
 473         struct vhost_virtqueue *vq;
 474         size_t i;
 475         int ret;
 476 
 477         mutex_lock(&vsock->dev.mutex);
 478 
 479         ret = vhost_dev_check_owner(&vsock->dev);
 480         if (ret)
 481                 goto err;
 482 
 483         for (i = 0; i < ARRAY_SIZE(vsock->vqs); i++) {
 484                 vq = &vsock->vqs[i];
 485 
 486                 mutex_lock(&vq->mutex);
 487 
 488                 if (!vhost_vq_access_ok(vq)) {
 489                         ret = -EFAULT;
 490                         goto err_vq;
 491                 }
 492 
 493                 if (!vq->private_data) {
 494                         vq->private_data = vsock;
 495                         ret = vhost_vq_init_access(vq);
 496                         if (ret)
 497                                 goto err_vq;
 498                 }
 499 
 500                 mutex_unlock(&vq->mutex);
 501         }
 502 
 503         /* Some packets may have been queued before the device was started,
 504          * let's kick the send worker to send them.
 505          */
 506         vhost_work_queue(&vsock->dev, &vsock->send_pkt_work);
 507 
 508         mutex_unlock(&vsock->dev.mutex);
 509         return 0;
 510 
 511 err_vq:
 512         vq->private_data = NULL;
 513         mutex_unlock(&vq->mutex);
 514 
 515         for (i = 0; i < ARRAY_SIZE(vsock->vqs); i++) {
 516                 vq = &vsock->vqs[i];
 517 
 518                 mutex_lock(&vq->mutex);
 519                 vq->private_data = NULL;
 520                 mutex_unlock(&vq->mutex);
 521         }
 522 err:
 523         mutex_unlock(&vsock->dev.mutex);
 524         return ret;
 525 }
 526 
 527 static int vhost_vsock_stop(struct vhost_vsock *vsock)
 528 {
 529         size_t i;
 530         int ret;
 531 
 532         mutex_lock(&vsock->dev.mutex);
 533 
 534         ret = vhost_dev_check_owner(&vsock->dev);
 535         if (ret)
 536                 goto err;
 537 
 538         for (i = 0; i < ARRAY_SIZE(vsock->vqs); i++) {
 539                 struct vhost_virtqueue *vq = &vsock->vqs[i];
 540 
 541                 mutex_lock(&vq->mutex);
 542                 vq->private_data = NULL;
 543                 mutex_unlock(&vq->mutex);
 544         }
 545 
 546 err:
 547         mutex_unlock(&vsock->dev.mutex);
 548         return ret;
 549 }
 550 
 551 static void vhost_vsock_free(struct vhost_vsock *vsock)
 552 {
 553         kvfree(vsock);
 554 }
 555 
 556 static int vhost_vsock_dev_open(struct inode *inode, struct file *file)
 557 {
 558         struct vhost_virtqueue **vqs;
 559         struct vhost_vsock *vsock;
 560         int ret;
 561 
 562         /* This struct is large and allocation could fail, fall back to vmalloc
 563          * if there is no other way.
 564          */
 565         vsock = kvmalloc(sizeof(*vsock), GFP_KERNEL | __GFP_RETRY_MAYFAIL);
 566         if (!vsock)
 567                 return -ENOMEM;
 568 
 569         vqs = kmalloc_array(ARRAY_SIZE(vsock->vqs), sizeof(*vqs), GFP_KERNEL);
 570         if (!vqs) {
 571                 ret = -ENOMEM;
 572                 goto out;
 573         }
 574 
 575         vsock->guest_cid = 0; /* no CID assigned yet */
 576 
 577         atomic_set(&vsock->queued_replies, 0);
 578 
 579         vqs[VSOCK_VQ_TX] = &vsock->vqs[VSOCK_VQ_TX];
 580         vqs[VSOCK_VQ_RX] = &vsock->vqs[VSOCK_VQ_RX];
 581         vsock->vqs[VSOCK_VQ_TX].handle_kick = vhost_vsock_handle_tx_kick;
 582         vsock->vqs[VSOCK_VQ_RX].handle_kick = vhost_vsock_handle_rx_kick;
 583 
 584         vhost_dev_init(&vsock->dev, vqs, ARRAY_SIZE(vsock->vqs),
 585                        UIO_MAXIOV, VHOST_VSOCK_PKT_WEIGHT,
 586                        VHOST_VSOCK_WEIGHT);
 587 
 588         file->private_data = vsock;
 589         spin_lock_init(&vsock->send_pkt_list_lock);
 590         INIT_LIST_HEAD(&vsock->send_pkt_list);
 591         vhost_work_init(&vsock->send_pkt_work, vhost_transport_send_pkt_work);
 592         return 0;
 593 
 594 out:
 595         vhost_vsock_free(vsock);
 596         return ret;
 597 }
 598 
 599 static void vhost_vsock_flush(struct vhost_vsock *vsock)
 600 {
 601         int i;
 602 
 603         for (i = 0; i < ARRAY_SIZE(vsock->vqs); i++)
 604                 if (vsock->vqs[i].handle_kick)
 605                         vhost_poll_flush(&vsock->vqs[i].poll);
 606         vhost_work_flush(&vsock->dev, &vsock->send_pkt_work);
 607 }
 608 
 609 static void vhost_vsock_reset_orphans(struct sock *sk)
 610 {
 611         struct vsock_sock *vsk = vsock_sk(sk);
 612 
 613         /* vmci_transport.c doesn't take sk_lock here either.  At least we're
 614          * under vsock_table_lock so the sock cannot disappear while we're
 615          * executing.
 616          */
 617 
 618         /* If the peer is still valid, no need to reset connection */
 619         if (vhost_vsock_get(vsk->remote_addr.svm_cid))
 620                 return;
 621 
 622         /* If the close timeout is pending, let it expire.  This avoids races
 623          * with the timeout callback.
 624          */
 625         if (vsk->close_work_scheduled)
 626                 return;
 627 
 628         sock_set_flag(sk, SOCK_DONE);
 629         vsk->peer_shutdown = SHUTDOWN_MASK;
 630         sk->sk_state = SS_UNCONNECTED;
 631         sk->sk_err = ECONNRESET;
 632         sk->sk_error_report(sk);
 633 }
 634 
 635 static int vhost_vsock_dev_release(struct inode *inode, struct file *file)
 636 {
 637         struct vhost_vsock *vsock = file->private_data;
 638 
 639         mutex_lock(&vhost_vsock_mutex);
 640         if (vsock->guest_cid)
 641                 hash_del_rcu(&vsock->hash);
 642         mutex_unlock(&vhost_vsock_mutex);
 643 
 644         /* Wait for other CPUs to finish using vsock */
 645         synchronize_rcu();
 646 
 647         /* Iterating over all connections for all CIDs to find orphans is
 648          * inefficient.  Room for improvement here. */
 649         vsock_for_each_connected_socket(vhost_vsock_reset_orphans);
 650 
 651         vhost_vsock_stop(vsock);
 652         vhost_vsock_flush(vsock);
 653         vhost_dev_stop(&vsock->dev);
 654 
 655         spin_lock_bh(&vsock->send_pkt_list_lock);
 656         while (!list_empty(&vsock->send_pkt_list)) {
 657                 struct virtio_vsock_pkt *pkt;
 658 
 659                 pkt = list_first_entry(&vsock->send_pkt_list,
 660                                 struct virtio_vsock_pkt, list);
 661                 list_del_init(&pkt->list);
 662                 virtio_transport_free_pkt(pkt);
 663         }
 664         spin_unlock_bh(&vsock->send_pkt_list_lock);
 665 
 666         vhost_dev_cleanup(&vsock->dev);
 667         kfree(vsock->dev.vqs);
 668         vhost_vsock_free(vsock);
 669         return 0;
 670 }
 671 
 672 static int vhost_vsock_set_cid(struct vhost_vsock *vsock, u64 guest_cid)
 673 {
 674         struct vhost_vsock *other;
 675 
 676         /* Refuse reserved CIDs */
 677         if (guest_cid <= VMADDR_CID_HOST ||
 678             guest_cid == U32_MAX)
 679                 return -EINVAL;
 680 
 681         /* 64-bit CIDs are not yet supported */
 682         if (guest_cid > U32_MAX)
 683                 return -EINVAL;
 684 
 685         /* Refuse if CID is already in use */
 686         mutex_lock(&vhost_vsock_mutex);
 687         other = vhost_vsock_get(guest_cid);
 688         if (other && other != vsock) {
 689                 mutex_unlock(&vhost_vsock_mutex);
 690                 return -EADDRINUSE;
 691         }
 692 
 693         if (vsock->guest_cid)
 694                 hash_del_rcu(&vsock->hash);
 695 
 696         vsock->guest_cid = guest_cid;
 697         hash_add_rcu(vhost_vsock_hash, &vsock->hash, vsock->guest_cid);
 698         mutex_unlock(&vhost_vsock_mutex);
 699 
 700         return 0;
 701 }
 702 
 703 static int vhost_vsock_set_features(struct vhost_vsock *vsock, u64 features)
 704 {
 705         struct vhost_virtqueue *vq;
 706         int i;
 707 
 708         if (features & ~VHOST_VSOCK_FEATURES)
 709                 return -EOPNOTSUPP;
 710 
 711         mutex_lock(&vsock->dev.mutex);
 712         if ((features & (1 << VHOST_F_LOG_ALL)) &&
 713             !vhost_log_access_ok(&vsock->dev)) {
 714                 mutex_unlock(&vsock->dev.mutex);
 715                 return -EFAULT;
 716         }
 717 
 718         for (i = 0; i < ARRAY_SIZE(vsock->vqs); i++) {
 719                 vq = &vsock->vqs[i];
 720                 mutex_lock(&vq->mutex);
 721                 vq->acked_features = features;
 722                 mutex_unlock(&vq->mutex);
 723         }
 724         mutex_unlock(&vsock->dev.mutex);
 725         return 0;
 726 }
 727 
 728 static long vhost_vsock_dev_ioctl(struct file *f, unsigned int ioctl,
 729                                   unsigned long arg)
 730 {
 731         struct vhost_vsock *vsock = f->private_data;
 732         void __user *argp = (void __user *)arg;
 733         u64 guest_cid;
 734         u64 features;
 735         int start;
 736         int r;
 737 
 738         switch (ioctl) {
 739         case VHOST_VSOCK_SET_GUEST_CID:
 740                 if (copy_from_user(&guest_cid, argp, sizeof(guest_cid)))
 741                         return -EFAULT;
 742                 return vhost_vsock_set_cid(vsock, guest_cid);
 743         case VHOST_VSOCK_SET_RUNNING:
 744                 if (copy_from_user(&start, argp, sizeof(start)))
 745                         return -EFAULT;
 746                 if (start)
 747                         return vhost_vsock_start(vsock);
 748                 else
 749                         return vhost_vsock_stop(vsock);
 750         case VHOST_GET_FEATURES:
 751                 features = VHOST_VSOCK_FEATURES;
 752                 if (copy_to_user(argp, &features, sizeof(features)))
 753                         return -EFAULT;
 754                 return 0;
 755         case VHOST_SET_FEATURES:
 756                 if (copy_from_user(&features, argp, sizeof(features)))
 757                         return -EFAULT;
 758                 return vhost_vsock_set_features(vsock, features);
 759         default:
 760                 mutex_lock(&vsock->dev.mutex);
 761                 r = vhost_dev_ioctl(&vsock->dev, ioctl, argp);
 762                 if (r == -ENOIOCTLCMD)
 763                         r = vhost_vring_ioctl(&vsock->dev, ioctl, argp);
 764                 else
 765                         vhost_vsock_flush(vsock);
 766                 mutex_unlock(&vsock->dev.mutex);
 767                 return r;
 768         }
 769 }
 770 
 771 #ifdef CONFIG_COMPAT
 772 static long vhost_vsock_dev_compat_ioctl(struct file *f, unsigned int ioctl,
 773                                          unsigned long arg)
 774 {
 775         return vhost_vsock_dev_ioctl(f, ioctl, (unsigned long)compat_ptr(arg));
 776 }
 777 #endif
 778 
 779 static const struct file_operations vhost_vsock_fops = {
 780         .owner          = THIS_MODULE,
 781         .open           = vhost_vsock_dev_open,
 782         .release        = vhost_vsock_dev_release,
 783         .llseek         = noop_llseek,
 784         .unlocked_ioctl = vhost_vsock_dev_ioctl,
 785 #ifdef CONFIG_COMPAT
 786         .compat_ioctl   = vhost_vsock_dev_compat_ioctl,
 787 #endif
 788 };
 789 
 790 static struct miscdevice vhost_vsock_misc = {
 791         .minor = VHOST_VSOCK_MINOR,
 792         .name = "vhost-vsock",
 793         .fops = &vhost_vsock_fops,
 794 };
 795 
 796 static struct virtio_transport vhost_transport = {
 797         .transport = {
 798                 .get_local_cid            = vhost_transport_get_local_cid,
 799 
 800                 .init                     = virtio_transport_do_socket_init,
 801                 .destruct                 = virtio_transport_destruct,
 802                 .release                  = virtio_transport_release,
 803                 .connect                  = virtio_transport_connect,
 804                 .shutdown                 = virtio_transport_shutdown,
 805                 .cancel_pkt               = vhost_transport_cancel_pkt,
 806 
 807                 .dgram_enqueue            = virtio_transport_dgram_enqueue,
 808                 .dgram_dequeue            = virtio_transport_dgram_dequeue,
 809                 .dgram_bind               = virtio_transport_dgram_bind,
 810                 .dgram_allow              = virtio_transport_dgram_allow,
 811 
 812                 .stream_enqueue           = virtio_transport_stream_enqueue,
 813                 .stream_dequeue           = virtio_transport_stream_dequeue,
 814                 .stream_has_data          = virtio_transport_stream_has_data,
 815                 .stream_has_space         = virtio_transport_stream_has_space,
 816                 .stream_rcvhiwat          = virtio_transport_stream_rcvhiwat,
 817                 .stream_is_active         = virtio_transport_stream_is_active,
 818                 .stream_allow             = virtio_transport_stream_allow,
 819 
 820                 .notify_poll_in           = virtio_transport_notify_poll_in,
 821                 .notify_poll_out          = virtio_transport_notify_poll_out,
 822                 .notify_recv_init         = virtio_transport_notify_recv_init,
 823                 .notify_recv_pre_block    = virtio_transport_notify_recv_pre_block,
 824                 .notify_recv_pre_dequeue  = virtio_transport_notify_recv_pre_dequeue,
 825                 .notify_recv_post_dequeue = virtio_transport_notify_recv_post_dequeue,
 826                 .notify_send_init         = virtio_transport_notify_send_init,
 827                 .notify_send_pre_block    = virtio_transport_notify_send_pre_block,
 828                 .notify_send_pre_enqueue  = virtio_transport_notify_send_pre_enqueue,
 829                 .notify_send_post_enqueue = virtio_transport_notify_send_post_enqueue,
 830 
 831                 .set_buffer_size          = virtio_transport_set_buffer_size,
 832                 .set_min_buffer_size      = virtio_transport_set_min_buffer_size,
 833                 .set_max_buffer_size      = virtio_transport_set_max_buffer_size,
 834                 .get_buffer_size          = virtio_transport_get_buffer_size,
 835                 .get_min_buffer_size      = virtio_transport_get_min_buffer_size,
 836                 .get_max_buffer_size      = virtio_transport_get_max_buffer_size,
 837         },
 838 
 839         .send_pkt = vhost_transport_send_pkt,
 840 };
 841 
 842 static int __init vhost_vsock_init(void)
 843 {
 844         int ret;
 845 
 846         ret = vsock_core_init(&vhost_transport.transport);
 847         if (ret < 0)
 848                 return ret;
 849         return misc_register(&vhost_vsock_misc);
 850 };
 851 
 852 static void __exit vhost_vsock_exit(void)
 853 {
 854         misc_deregister(&vhost_vsock_misc);
 855         vsock_core_exit();
 856 };
 857 
 858 module_init(vhost_vsock_init);
 859 module_exit(vhost_vsock_exit);
 860 MODULE_LICENSE("GPL v2");
 861 MODULE_AUTHOR("Asias He");
 862 MODULE_DESCRIPTION("vhost transport for vsock ");
 863 MODULE_ALIAS_MISCDEV(VHOST_VSOCK_MINOR);
 864 MODULE_ALIAS("devname:vhost-vsock");

/* [<][>][^][v][top][bottom][index][help] */