diff --git a/include/uapi/linux/rds.h b/include/uapi/linux/rds.h index e71d4491f225..12e3bca32cad 100644 --- a/include/uapi/linux/rds.h +++ b/include/uapi/linux/rds.h @@ -103,6 +103,7 @@ #define RDS_CMSG_MASKED_ATOMIC_FADD 8 #define RDS_CMSG_MASKED_ATOMIC_CSWP 9 #define RDS_CMSG_RXPATH_LATENCY 11 +#define RDS_CMSG_ZCOPY_COOKIE 12 #define RDS_INFO_FIRST 10000 #define RDS_INFO_COUNTERS 10000 diff --git a/net/rds/message.c b/net/rds/message.c index bf1a656b198a..651834513481 100644 --- a/net/rds/message.c +++ b/net/rds/message.c @@ -341,12 +341,14 @@ struct rds_message *rds_message_map_pages(unsigned long *page_addrs, unsigned in return rm; } -int rds_message_copy_from_user(struct rds_message *rm, struct iov_iter *from) +int rds_message_copy_from_user(struct rds_message *rm, struct iov_iter *from, + bool zcopy) { unsigned long to_copy, nbytes; unsigned long sg_off; struct scatterlist *sg; int ret = 0; + int length = iov_iter_count(from); rm->m_inc.i_hdr.h_len = cpu_to_be32(iov_iter_count(from)); @@ -356,6 +358,53 @@ int rds_message_copy_from_user(struct rds_message *rm, struct iov_iter *from) sg = rm->data.op_sg; sg_off = 0; /* Dear gcc, sg->page will be null from kzalloc. */ + if (zcopy) { + int total_copied = 0; + struct sk_buff *skb; + + skb = alloc_skb(SO_EE_ORIGIN_MAX_ZCOOKIES * sizeof(u32), + GFP_KERNEL); + if (!skb) + return -ENOMEM; + rm->data.op_mmp_znotifier = RDS_ZCOPY_SKB(skb); + if (mm_account_pinned_pages(&rm->data.op_mmp_znotifier->z_mmp, + length)) { + ret = -ENOMEM; + goto err; + } + while (iov_iter_count(from)) { + struct page *pages; + size_t start; + ssize_t copied; + + copied = iov_iter_get_pages(from, &pages, PAGE_SIZE, + 1, &start); + if (copied < 0) { + struct mmpin *mmp; + int i; + + for (i = 0; i < rm->data.op_nents; i++) + put_page(sg_page(&rm->data.op_sg[i])); + mmp = &rm->data.op_mmp_znotifier->z_mmp; + mm_unaccount_pinned_pages(mmp); + ret = -EFAULT; + goto err; + } + total_copied += copied; + iov_iter_advance(from, copied); + length -= copied; + sg_set_page(sg, pages, copied, start); + rm->data.op_nents++; + sg++; + } + WARN_ON_ONCE(length != 0); + return ret; +err: + consume_skb(skb); + rm->data.op_mmp_znotifier = NULL; + return ret; + } /* zcopy */ + while (iov_iter_count(from)) { if (!sg_page(sg)) { ret = rds_page_remainder_alloc(sg, iov_iter_count(from), diff --git a/net/rds/rds.h b/net/rds/rds.h index 24576bc4a5e9..31cd38852050 100644 --- a/net/rds/rds.h +++ b/net/rds/rds.h @@ -785,7 +785,8 @@ rds_conn_connecting(struct rds_connection *conn) /* message.c */ struct rds_message *rds_message_alloc(unsigned int nents, gfp_t gfp); struct scatterlist *rds_message_alloc_sgs(struct rds_message *rm, int nents); -int rds_message_copy_from_user(struct rds_message *rm, struct iov_iter *from); +int rds_message_copy_from_user(struct rds_message *rm, struct iov_iter *from, + bool zcopy); struct rds_message *rds_message_map_pages(unsigned long *page_addrs, unsigned int total_len); void rds_message_populate_header(struct rds_header *hdr, __be16 sport, __be16 dport, u64 seq); diff --git a/net/rds/send.c b/net/rds/send.c index e8f3ff471b15..028ab598ac1b 100644 --- a/net/rds/send.c +++ b/net/rds/send.c @@ -875,12 +875,13 @@ out: * rds_message is getting to be quite complicated, and we'd like to allocate * it all in one go. This figures out how big it needs to be up front. */ -static int rds_rm_size(struct msghdr *msg, int data_len) +static int rds_rm_size(struct msghdr *msg, int num_sgs) { struct cmsghdr *cmsg; int size = 0; int cmsg_groups = 0; int retval; + bool zcopy_cookie = false; for_each_cmsghdr(cmsg, msg) { if (!CMSG_OK(msg, cmsg)) @@ -899,6 +900,8 @@ static int rds_rm_size(struct msghdr *msg, int data_len) break; + case RDS_CMSG_ZCOPY_COOKIE: + zcopy_cookie = true; case RDS_CMSG_RDMA_DEST: case RDS_CMSG_RDMA_MAP: cmsg_groups |= 2; @@ -919,7 +922,10 @@ static int rds_rm_size(struct msghdr *msg, int data_len) } - size += ceil(data_len, PAGE_SIZE) * sizeof(struct scatterlist); + if ((msg->msg_flags & MSG_ZEROCOPY) && !zcopy_cookie) + return -EINVAL; + + size += num_sgs * sizeof(struct scatterlist); /* Ensure (DEST, MAP) are never used with (ARGS, ATOMIC) */ if (cmsg_groups == 3) @@ -928,6 +934,18 @@ static int rds_rm_size(struct msghdr *msg, int data_len) return size; } +static int rds_cmsg_zcopy(struct rds_sock *rs, struct rds_message *rm, + struct cmsghdr *cmsg) +{ + u32 *cookie; + + if (cmsg->cmsg_len < CMSG_LEN(sizeof(*cookie))) + return -EINVAL; + cookie = CMSG_DATA(cmsg); + rm->data.op_mmp_znotifier->z_cookie = *cookie; + return 0; +} + static int rds_cmsg_send(struct rds_sock *rs, struct rds_message *rm, struct msghdr *msg, int *allocated_mr) { @@ -970,6 +988,10 @@ static int rds_cmsg_send(struct rds_sock *rs, struct rds_message *rm, ret = rds_cmsg_atomic(rs, rm, cmsg); break; + case RDS_CMSG_ZCOPY_COOKIE: + ret = rds_cmsg_zcopy(rs, rm, cmsg); + break; + default: return -EINVAL; } @@ -1040,10 +1062,13 @@ int rds_sendmsg(struct socket *sock, struct msghdr *msg, size_t payload_len) long timeo = sock_sndtimeo(sk, nonblock); struct rds_conn_path *cpath; size_t total_payload_len = payload_len, rdma_payload_len = 0; + bool zcopy = ((msg->msg_flags & MSG_ZEROCOPY) && + sock_flag(rds_rs_to_sk(rs), SOCK_ZEROCOPY)); + int num_sgs = ceil(payload_len, PAGE_SIZE); /* Mirror Linux UDP mirror of BSD error message compatibility */ /* XXX: Perhaps MSG_MORE someday */ - if (msg->msg_flags & ~(MSG_DONTWAIT | MSG_CMSG_COMPAT)) { + if (msg->msg_flags & ~(MSG_DONTWAIT | MSG_CMSG_COMPAT | MSG_ZEROCOPY)) { ret = -EOPNOTSUPP; goto out; } @@ -1087,8 +1112,15 @@ int rds_sendmsg(struct socket *sock, struct msghdr *msg, size_t payload_len) goto out; } + if (zcopy) { + if (rs->rs_transport->t_type != RDS_TRANS_TCP) { + ret = -EOPNOTSUPP; + goto out; + } + num_sgs = iov_iter_npages(&msg->msg_iter, INT_MAX); + } /* size of rm including all sgs */ - ret = rds_rm_size(msg, payload_len); + ret = rds_rm_size(msg, num_sgs); if (ret < 0) goto out; @@ -1100,12 +1132,12 @@ int rds_sendmsg(struct socket *sock, struct msghdr *msg, size_t payload_len) /* Attach data to the rm */ if (payload_len) { - rm->data.op_sg = rds_message_alloc_sgs(rm, ceil(payload_len, PAGE_SIZE)); + rm->data.op_sg = rds_message_alloc_sgs(rm, num_sgs); if (!rm->data.op_sg) { ret = -ENOMEM; goto out; } - ret = rds_message_copy_from_user(rm, &msg->msg_iter); + ret = rds_message_copy_from_user(rm, &msg->msg_iter, zcopy); if (ret) goto out; }