diff --git a/net/vmw_vsock/vmci_transport.c b/net/vmw_vsock/vmci_transport.c index cb332adb84cd..c361ce782412 100644 --- a/net/vmw_vsock/vmci_transport.c +++ b/net/vmw_vsock/vmci_transport.c @@ -263,6 +263,31 @@ vmci_transport_send_control_pkt_bh(struct sockaddr_vm *src, false); } +static int +vmci_transport_alloc_send_control_pkt(struct sockaddr_vm *src, + struct sockaddr_vm *dst, + enum vmci_transport_packet_type type, + u64 size, + u64 mode, + struct vmci_transport_waiting_info *wait, + u16 proto, + struct vmci_handle handle) +{ + struct vmci_transport_packet *pkt; + int err; + + pkt = kmalloc(sizeof(*pkt), GFP_KERNEL); + if (!pkt) + return -ENOMEM; + + err = __vmci_transport_send_control_pkt(pkt, src, dst, type, size, + mode, wait, proto, handle, + true); + kfree(pkt); + + return err; +} + static int vmci_transport_send_control_pkt(struct sock *sk, enum vmci_transport_packet_type type, @@ -272,9 +297,7 @@ vmci_transport_send_control_pkt(struct sock *sk, u16 proto, struct vmci_handle handle) { - struct vmci_transport_packet *pkt; struct vsock_sock *vsk; - int err; vsk = vsock_sk(sk); @@ -284,17 +307,10 @@ vmci_transport_send_control_pkt(struct sock *sk, if (!vsock_addr_bound(&vsk->remote_addr)) return -EINVAL; - pkt = kmalloc(sizeof(*pkt), GFP_KERNEL); - if (!pkt) - return -ENOMEM; - - err = __vmci_transport_send_control_pkt(pkt, &vsk->local_addr, - &vsk->remote_addr, type, size, - mode, wait, proto, handle, - true); - kfree(pkt); - - return err; + return vmci_transport_alloc_send_control_pkt(&vsk->local_addr, + &vsk->remote_addr, + type, size, mode, + wait, proto, handle); } static int vmci_transport_send_reset_bh(struct sockaddr_vm *dst, @@ -312,12 +328,29 @@ static int vmci_transport_send_reset_bh(struct sockaddr_vm *dst, static int vmci_transport_send_reset(struct sock *sk, struct vmci_transport_packet *pkt) { + struct sockaddr_vm *dst_ptr; + struct sockaddr_vm dst; + struct vsock_sock *vsk; + if (pkt->type == VMCI_TRANSPORT_PACKET_TYPE_RST) return 0; - return vmci_transport_send_control_pkt(sk, - VMCI_TRANSPORT_PACKET_TYPE_RST, - 0, 0, NULL, VSOCK_PROTO_INVALID, - VMCI_INVALID_HANDLE); + + vsk = vsock_sk(sk); + + if (!vsock_addr_bound(&vsk->local_addr)) + return -EINVAL; + + if (vsock_addr_bound(&vsk->remote_addr)) { + dst_ptr = &vsk->remote_addr; + } else { + vsock_addr_init(&dst, pkt->dg.src.context, + pkt->src_port); + dst_ptr = &dst; + } + return vmci_transport_alloc_send_control_pkt(&vsk->local_addr, dst_ptr, + VMCI_TRANSPORT_PACKET_TYPE_RST, + 0, 0, NULL, VSOCK_PROTO_INVALID, + VMCI_INVALID_HANDLE); } static int vmci_transport_send_negotiate(struct sock *sk, size_t size)