@@ -873,6 +873,81 @@ static bool mptcp_frag_can_collapse_to(const struct mptcp_sock *msk,
df->data_seq + df->data_len == msk->write_seq;
}
+static int mptcp_wmem_with_overhead(int size)
+{
+ return size + ((sizeof(struct mptcp_data_frag) * size) >> PAGE_SHIFT);
+}
+
+static void __mptcp_wmem_reserve(struct sock *sk, int size)
+{
+ int amount = mptcp_wmem_with_overhead(size);
+ struct mptcp_sock *msk = mptcp_sk(sk);
+
+ WARN_ON_ONCE(msk->wmem_reserved);
+ if (amount <= sk->sk_forward_alloc)
+ goto reserve;
+
+ /* under memory pressure try to reserve at most a single page
+ * otherwise try to reserve the full estimate and fallback
+ * to a single page before entering the error path
+ */
+ if ((tcp_under_memory_pressure(sk) && amount > PAGE_SIZE) ||
+ !sk_wmem_schedule(sk, amount)) {
+ if (amount <= PAGE_SIZE)
+ goto nomem;
+
+ amount = PAGE_SIZE;
+ if (!sk_wmem_schedule(sk, amount))
+ goto nomem;
+ }
+
+reserve:
+ msk->wmem_reserved = amount;
+ sk->sk_forward_alloc -= amount;
+ return;
+
+nomem:
+ /* we will wait for memory on next allocation */
+ msk->wmem_reserved = -1;
+}
+
+static void __mptcp_update_wmem(struct sock *sk)
+{
+ struct mptcp_sock *msk = mptcp_sk(sk);
+
+ if (!msk->wmem_reserved)
+ return;
+
+ if (msk->wmem_reserved < 0)
+ msk->wmem_reserved = 0;
+ if (msk->wmem_reserved > 0) {
+ sk->sk_forward_alloc += msk->wmem_reserved;
+ msk->wmem_reserved = 0;
+ }
+}
+
+static bool mptcp_wmem_alloc(struct sock *sk, int size)
+{
+ struct mptcp_sock *msk = mptcp_sk(sk);
+
+ /* check for pre-existing error condition */
+ if (msk->wmem_reserved < 0)
+ return false;
+
+ if (msk->wmem_reserved >= size)
+ goto account;
+
+ if (!sk_wmem_schedule(sk, size))
+ return false;
+
+ sk->sk_forward_alloc -= size;
+ msk->wmem_reserved += size;
+
+account:
+ msk->wmem_reserved -= size;
+ return true;
+}
+
static void dfrag_uncharge(struct sock *sk, int len)
{
sk_mem_uncharge(sk, len);
@@ -930,7 +1005,7 @@ static void mptcp_clean_una(struct sock *sk)
}
out:
- if (cleaned)
+ if (cleaned && tcp_under_memory_pressure(sk))
sk_mem_reclaim_partial(sk);
}
@@ -1307,7 +1382,7 @@ static int mptcp_sendmsg(struct sock *sk, struct msghdr *msg, size_t len)
if (msg->msg_flags & ~(MSG_MORE | MSG_DONTWAIT | MSG_NOSIGNAL))
return -EOPNOTSUPP;
- lock_sock(sk);
+ mptcp_lock_sock(sk, __mptcp_wmem_reserve(sk, len));
timeo = sock_sndtimeo(sk, msg->msg_flags & MSG_DONTWAIT);
@@ -1356,11 +1431,12 @@ static int mptcp_sendmsg(struct sock *sk, struct msghdr *msg, size_t len)
offset = dfrag->offset + dfrag->data_len;
psize = pfrag->size - offset;
psize = min_t(size_t, psize, msg_data_left(msg));
- if (!sk_wmem_schedule(sk, psize + frag_truesize))
+ if (!mptcp_wmem_alloc(sk, psize + frag_truesize))
goto wait_for_memory;
if (copy_page_from_iter(dfrag->page, offset, psize,
&msg->msg_iter) != psize) {
+ msk->wmem_reserved += psize + frag_truesize;
ret = -EFAULT;
goto out;
}
@@ -1376,7 +1452,6 @@ static int mptcp_sendmsg(struct sock *sk, struct msghdr *msg, size_t len)
* Note: we charge such data both to sk and ssk
*/
sk_wmem_queued_add(sk, frag_truesize);
- sk->sk_forward_alloc -= frag_truesize;
if (!dfrag_collapsed) {
get_page(dfrag->page);
list_add_tail(&dfrag->list, &msk->rtx_queue);
@@ -2003,6 +2078,7 @@ static int __mptcp_init_sock(struct sock *sk)
INIT_WORK(&msk->work, mptcp_worker);
msk->out_of_order_queue = RB_ROOT;
msk->first_pending = NULL;
+ msk->wmem_reserved = 0;
msk->ack_hint = NULL;
msk->first = NULL;
@@ -2197,6 +2273,7 @@ static void __mptcp_destroy_sock(struct sock *sk)
sk->sk_prot->destroy(sk);
+ WARN_ON_ONCE(msk->wmem_reserved);
sk_stream_kill_queues(sk);
xfrm_sk_free_policy(sk);
sk_refcnt_debug_release(sk);
@@ -2542,13 +2619,14 @@ static int mptcp_getsockopt(struct sock *sk, int level, int optname,
#define MPTCP_DEFERRED_ALL (TCPF_WRITE_TIMER_DEFERRED)
-/* this is very alike tcp_release_cb() but we must handle differently a
- * different set of events
- */
+/* processes deferred events and flush wmem */
static void mptcp_release_cb(struct sock *sk)
{
unsigned long flags, nflags;
+ /* clear any wmem reservation and errors */
+ __mptcp_update_wmem(sk);
+
do {
flags = sk->sk_tsq_flags;
if (!(flags & MPTCP_DEFERRED_ALL))
@@ -218,6 +218,7 @@ struct mptcp_sock {
u64 ack_seq;
u64 rcv_wnd_sent;
u64 rcv_data_fin_seq;
+ int wmem_reserved;
struct sock *last_snd;
int snd_burst;
int old_wspace;