diff mbox series

[v2,15/35] net/tcp: Wire TCP-AO to request sockets

Message ID 20220923201319.493208-16-dima@arista.com
State New
Headers show
Series [v2,01/35] crypto: Introduce crypto_pool | expand

Commit Message

Dmitry Safonov Sept. 23, 2022, 8:12 p.m. UTC
Now when the new request socket is created from the listening socket,
it's recorded what MKT was used by the peer. tcp_rsk_used_ao() is
a new helper for checking if TCP-AO option was used to create the
request socket.
tcp_ao_copy_all_matching() will copy all keys that match the peer on the
request socket, as well as preparing them for the usage (creating
traffic keys).

Co-developed-by: Francesco Ruggeri <fruggeri@arista.com>
Signed-off-by: Francesco Ruggeri <fruggeri@arista.com>
Co-developed-by: Salam Noureddine <noureddine@arista.com>
Signed-off-by: Salam Noureddine <noureddine@arista.com>
Signed-off-by: Dmitry Safonov <dima@arista.com>
---
 include/linux/tcp.h      |  18 +++++
 include/net/tcp.h        |   7 ++
 include/net/tcp_ao.h     |  16 +++++
 net/ipv4/tcp_ao.c        | 141 +++++++++++++++++++++++++++++++++++++++
 net/ipv4/tcp_input.c     |  19 +++++-
 net/ipv4/tcp_ipv4.c      |  63 +++++++++++++++--
 net/ipv4/tcp_minisocks.c |  10 +++
 net/ipv4/tcp_output.c    |  36 +++++++---
 net/ipv6/tcp_ao.c        |  21 ++++++
 net/ipv6/tcp_ipv6.c      |  78 ++++++++++++++++++----
 10 files changed, 379 insertions(+), 30 deletions(-)

Comments

kernel test robot Sept. 24, 2022, 5:23 a.m. UTC | #1
Hi Dmitry,

Thank you for the patch! Perhaps something to improve:

[auto build test WARNING on bf682942cd26ce9cd5e87f73ae099b383041e782]

url:    https://github.com/intel-lab-lkp/linux/commits/Dmitry-Safonov/net-tcp-Add-TCP-AO-support/20220924-042311
base:   bf682942cd26ce9cd5e87f73ae099b383041e782
config: um-x86_64_defconfig
compiler: gcc-11 (Debian 11.3.0-5) 11.3.0
reproduce (this is a W=1 build):
        # https://github.com/intel-lab-lkp/linux/commit/714d0dcced0422997aec848b06f2e8f57533e255
        git remote add linux-review https://github.com/intel-lab-lkp/linux
        git fetch --no-tags linux-review Dmitry-Safonov/net-tcp-Add-TCP-AO-support/20220924-042311
        git checkout 714d0dcced0422997aec848b06f2e8f57533e255
        # save the config file
        mkdir build_dir && cp config build_dir/.config
        make W=1 O=build_dir ARCH=um SUBARCH=x86_64 SHELL=/bin/bash net/ipv4/

If you fix the issue, kindly add following tag where applicable
| Reported-by: kernel test robot <lkp@intel.com>

All warnings (new ones prefixed by >>):

>> net/ipv4/tcp_ao.c:55:20: warning: no previous prototype for 'tcp_ao_do_lookup_keyid' [-Wmissing-prototypes]
      55 | struct tcp_ao_key *tcp_ao_do_lookup_keyid(struct tcp_ao_info *ao,
         |                    ^~~~~~~~~~~~~~~~~~~~~~
>> net/ipv4/tcp_ao.c:212:20: warning: no previous prototype for 'tcp_ao_copy_key' [-Wmissing-prototypes]
     212 | struct tcp_ao_key *tcp_ao_copy_key(struct sock *sk, struct tcp_ao_key *key)
         |                    ^~~~~~~~~~~~~~~


vim +/tcp_ao_do_lookup_keyid +55 net/ipv4/tcp_ao.c

    54	
  > 55	struct tcp_ao_key *tcp_ao_do_lookup_keyid(struct tcp_ao_info *ao,
    56						  int sndid, int rcvid)
    57	{
    58		struct tcp_ao_key *key;
    59	
    60		hlist_for_each_entry_rcu(key, &ao->head, node) {
    61			if ((sndid >= 0 && key->sndid != sndid) ||
    62			    (rcvid >= 0 && key->rcvid != rcvid))
    63				continue;
    64			return key;
    65		}
    66	
    67		return NULL;
    68	}
    69	
    70	static struct tcp_ao_key *tcp_ao_do_lookup_rcvid(struct sock *sk, u8 keyid)
    71	{
    72		struct tcp_sock *tp = tcp_sk(sk);
    73		struct tcp_ao_key *key;
    74		struct tcp_ao_info *ao;
    75	
    76		ao = rcu_dereference_check(tp->ao_info, lockdep_sock_is_held(sk));
    77	
    78		if (!ao)
    79			return NULL;
    80	
    81		hlist_for_each_entry_rcu(key, &ao->head, node) {
    82			if (key->rcvid == keyid)
    83				return key;
    84		}
    85		return NULL;
    86	}
    87	
    88	struct tcp_ao_key *tcp_ao_do_lookup_sndid(const struct sock *sk, u8 keyid)
    89	{
    90		struct tcp_ao_key *key;
    91		struct tcp_ao_info *ao;
    92	
    93		if (sk->sk_state == TCP_TIME_WAIT)
    94			ao = rcu_dereference_check(tcp_twsk(sk)->ao_info,
    95						   lockdep_sock_is_held(sk));
    96		else
    97			ao = rcu_dereference_check(tcp_sk(sk)->ao_info,
    98						   lockdep_sock_is_held(sk));
    99	
   100		if (!ao)
   101			return NULL;
   102	
   103		hlist_for_each_entry_rcu(key, &ao->head, node) {
   104			if (key->sndid == keyid)
   105				return key;
   106		}
   107		return NULL;
   108	}
   109	
   110	static inline int ipv4_prefix_cmp(const struct in_addr *addr1,
   111					  const struct in_addr *addr2,
   112					  unsigned int prefixlen)
   113	{
   114		__be32 mask = inet_make_mask(prefixlen);
   115	
   116		if ((addr1->s_addr & mask) == (addr2->s_addr & mask))
   117			return 0;
   118		return ((addr1->s_addr & mask) > (addr2->s_addr & mask)) ? 1 : -1;
   119	}
   120	
   121	static int __tcp_ao_key_cmp(const struct tcp_ao_key *key,
   122				    const union tcp_ao_addr *addr, u8 prefixlen,
   123				    int family, int sndid, int rcvid, u16 port)
   124	{
   125		if (sndid >= 0 && key->sndid != sndid)
   126			return (key->sndid > sndid) ? 1 : -1;
   127		if (rcvid >= 0 && key->rcvid != rcvid)
   128			return (key->rcvid > rcvid) ? 1 : -1;
   129		if (port != 0 && key->port != 0 && port != key->port)
   130			return (key->port > port) ? 1 : -1;
   131	
   132		if (family == AF_UNSPEC)
   133			return 0;
   134		if (key->family != family)
   135			return (key->family > family) ? 1 : -1;
   136	
   137		if (family == AF_INET) {
   138			if (key->addr.a4.s_addr == INADDR_ANY)
   139				return 0;
   140			if (addr->a4.s_addr == INADDR_ANY)
   141				return 0;
   142			return ipv4_prefix_cmp(&key->addr.a4, &addr->a4, prefixlen);
   143	#if IS_ENABLED(CONFIG_IPV6)
   144		} else {
   145			if (ipv6_addr_any(&key->addr.a6) || ipv6_addr_any(&addr->a6))
   146				return 0;
   147			if (ipv6_prefix_equal(&key->addr.a6, &addr->a6, prefixlen))
   148				return 0;
   149			return memcmp(&key->addr.a6, &addr->a6, prefixlen);
   150	#endif
   151		}
   152		return -1;
   153	}
   154	
   155	static int tcp_ao_key_cmp(const struct tcp_ao_key *key,
   156				  const union tcp_ao_addr *addr, u8 prefixlen,
   157				  int family, int sndid, int rcvid, u16 port)
   158	{
   159	#if IS_ENABLED(CONFIG_IPV6)
   160		if (family == AF_INET6 && ipv6_addr_v4mapped(&addr->a6)) {
   161			__be32 addr4 = addr->a6.s6_addr32[3];
   162	
   163			return __tcp_ao_key_cmp(key, (union tcp_ao_addr *)&addr4,
   164						prefixlen, AF_INET, sndid, rcvid, port);
   165		}
   166	#endif
   167		return __tcp_ao_key_cmp(key, addr, prefixlen, family, sndid, rcvid, port);
   168	}
   169	
   170	struct tcp_ao_key *tcp_ao_do_lookup(const struct sock *sk,
   171					    const union tcp_ao_addr *addr,
   172					    int family, int sndid, int rcvid, u16 port)
   173	{
   174		struct tcp_ao_key *key;
   175		struct tcp_ao_info *ao;
   176	
   177		ao = rcu_dereference_check(tcp_sk(sk)->ao_info,
   178					   lockdep_sock_is_held(sk));
   179		if (!ao)
   180			return NULL;
   181	
   182		hlist_for_each_entry_rcu(key, &ao->head, node) {
   183			if (!tcp_ao_key_cmp(key, addr, key->prefixlen,
   184					    family, sndid, rcvid, port))
   185				return key;
   186		}
   187		return NULL;
   188	}
   189	EXPORT_SYMBOL(tcp_ao_do_lookup);
   190	
   191	static struct tcp_ao_info *tcp_ao_alloc_info(gfp_t flags,
   192			struct tcp_ao_info *cloned_from)
   193	{
   194		struct tcp_ao_info *ao;
   195	
   196		ao = kzalloc(sizeof(*ao), flags);
   197		if (!ao)
   198			return NULL;
   199		INIT_HLIST_HEAD(&ao->head);
   200		atomic_set(&ao->refcnt, 1);
   201	
   202		if (cloned_from)
   203			ao->ao_flags = cloned_from->ao_flags;
   204		return ao;
   205	}
   206	
   207	static void tcp_ao_link_mkt(struct tcp_ao_info *ao, struct tcp_ao_key *mkt)
   208	{
   209		hlist_add_head_rcu(&mkt->node, &ao->head);
   210	}
   211	
 > 212	struct tcp_ao_key *tcp_ao_copy_key(struct sock *sk, struct tcp_ao_key *key)
   213	{
   214		struct tcp_ao_key *new_key;
   215	
   216		new_key = sock_kmalloc(sk, tcp_ao_sizeof_key(key),
   217				       GFP_ATOMIC);
   218		if (!new_key)
   219			return NULL;
   220	
   221		*new_key = *key;
   222		INIT_HLIST_NODE(&new_key->node);
   223		crypto_pool_add(new_key->crypto_pool_id);
   224	
   225		return new_key;
   226	}
   227
diff mbox series

Patch

diff --git a/include/linux/tcp.h b/include/linux/tcp.h
index 8031995b58a2..0e0f07d12f7b 100644
--- a/include/linux/tcp.h
+++ b/include/linux/tcp.h
@@ -165,6 +165,11 @@  struct tcp_request_sock {
 						  * after data-in-SYN.
 						  */
 	u8				syn_tos;
+#ifdef CONFIG_TCP_AO
+	u8				ao_keyid;
+	u8				ao_rcv_next;
+	u8				maclen;
+#endif
 };
 
 static inline struct tcp_request_sock *tcp_rsk(const struct request_sock *req)
@@ -172,6 +177,19 @@  static inline struct tcp_request_sock *tcp_rsk(const struct request_sock *req)
 	return (struct tcp_request_sock *)req;
 }
 
+static inline bool tcp_rsk_used_ao(const struct request_sock *req)
+{
+	/* The real length of MAC is saved in the request socket,
+	 * signing anything with zero-length makes no sense, so here is
+	 * a little hack..
+	 */
+#ifndef CONFIG_TCP_AO
+	return false;
+#else
+	return tcp_rsk(req)->maclen != 0;
+#endif
+}
+
 struct tcp_sock {
 	/* inet_connection_sock has to be the first member of tcp_sock */
 	struct inet_connection_sock	inet_conn;
diff --git a/include/net/tcp.h b/include/net/tcp.h
index 061fe8471bfc..f21898bb31bd 100644
--- a/include/net/tcp.h
+++ b/include/net/tcp.h
@@ -2103,6 +2103,13 @@  struct tcp_request_sock_ops {
 					  const struct sock *sk,
 					  const struct sk_buff *skb);
 #endif
+#ifdef CONFIG_TCP_AO
+	struct tcp_ao_key	*(*ao_lookup)(const struct sock *sk,
+					       struct request_sock *req,
+					       int sndid, int rcvid);
+	int			(*ao_calc_key)(struct tcp_ao_key *mkt, u8 *key,
+						struct request_sock *sk);
+#endif
 #ifdef CONFIG_SYN_COOKIES
 	__u32 (*cookie_init_seq)(const struct sk_buff *skb,
 				 __u16 *mss);
diff --git a/include/net/tcp_ao.h b/include/net/tcp_ao.h
index af82b4aeef11..a1d26e1a0b82 100644
--- a/include/net/tcp_ao.h
+++ b/include/net/tcp_ao.h
@@ -119,6 +119,9 @@  int tcp_ao_hash_skb(unsigned short int family,
 int tcp_parse_ao(struct sock *sk, int cmd, unsigned short int family,
 		 sockptr_t optval, int optlen);
 struct tcp_ao_key *tcp_ao_do_lookup_sndid(const struct sock *sk, u8 keyid);
+int tcp_ao_copy_all_matching(const struct sock *sk, struct sock *newsk,
+			     struct request_sock *req, struct sk_buff *skb,
+			     int family);
 int tcp_ao_calc_traffic_key(struct tcp_ao_key *mkt, u8 *key, void *ctx,
 			    unsigned int len);
 void tcp_ao_destroy_sock(struct sock *sk, bool twsk);
@@ -142,6 +145,11 @@  struct tcp_ao_key *tcp_v4_ao_lookup(const struct sock *sk, struct sock *addr_sk,
 int tcp_v4_ao_calc_key_sk(struct tcp_ao_key *mkt, u8 *key,
 			  const struct sock *sk,
 			  __be32 sisn, __be32 disn, bool send);
+int tcp_v4_ao_calc_key_rsk(struct tcp_ao_key *mkt, u8 *key,
+			   struct request_sock *req);
+struct tcp_ao_key *tcp_v4_ao_lookup_rsk(const struct sock *sk,
+					struct request_sock *req,
+					int sndid, int rcvid);
 int tcp_v4_ao_hash_skb(char *ao_hash, struct tcp_ao_key *key,
 		       const struct sock *sk, const struct sk_buff *skb,
 		       const u8 *tkey, int hash_offset, u32 sne);
@@ -152,9 +160,17 @@  int tcp_v6_ao_hash_pseudoheader(struct crypto_pool_ahash *hp,
 int tcp_v6_ao_calc_key_sk(struct tcp_ao_key *mkt, u8 *key,
 				 const struct sock *sk, __be32 sisn,
 				 __be32 disn, bool send);
+int tcp_v6_ao_calc_key_rsk(struct tcp_ao_key *mkt, u8 *key,
+			   struct request_sock *req);
+struct tcp_ao_key *tcp_v6_ao_do_lookup(const struct sock *sk,
+				       const struct in6_addr *addr,
+				       int sndid, int rcvid);
 struct tcp_ao_key *tcp_v6_ao_lookup(const struct sock *sk,
 				    struct sock *addr_sk,
 				    int sndid, int rcvid);
+struct tcp_ao_key *tcp_v6_ao_lookup_rsk(const struct sock *sk,
+					struct request_sock *req,
+					int sndid, int rcvid);
 int tcp_v6_ao_hash_skb(char *ao_hash, struct tcp_ao_key *key,
 		       const struct sock *sk, const struct sk_buff *skb,
 		       const u8 *tkey, int hash_offset, u32 sne);
diff --git a/net/ipv4/tcp_ao.c b/net/ipv4/tcp_ao.c
index 1086fa1ed2fd..a90b950781f9 100644
--- a/net/ipv4/tcp_ao.c
+++ b/net/ipv4/tcp_ao.c
@@ -52,6 +52,21 @@  int tcp_ao_calc_traffic_key(struct tcp_ao_key *mkt, u8 *key, void *ctx,
 	return 1;
 }
 
+struct tcp_ao_key *tcp_ao_do_lookup_keyid(struct tcp_ao_info *ao,
+					  int sndid, int rcvid)
+{
+	struct tcp_ao_key *key;
+
+	hlist_for_each_entry_rcu(key, &ao->head, node) {
+		if ((sndid >= 0 && key->sndid != sndid) ||
+		    (rcvid >= 0 && key->rcvid != rcvid))
+			continue;
+		return key;
+	}
+
+	return NULL;
+}
+
 static struct tcp_ao_key *tcp_ao_do_lookup_rcvid(struct sock *sk, u8 keyid)
 {
 	struct tcp_sock *tp = tcp_sk(sk);
@@ -194,6 +209,22 @@  static void tcp_ao_link_mkt(struct tcp_ao_info *ao, struct tcp_ao_key *mkt)
 	hlist_add_head_rcu(&mkt->node, &ao->head);
 }
 
+struct tcp_ao_key *tcp_ao_copy_key(struct sock *sk, struct tcp_ao_key *key)
+{
+	struct tcp_ao_key *new_key;
+
+	new_key = sock_kmalloc(sk, tcp_ao_sizeof_key(key),
+			       GFP_ATOMIC);
+	if (!new_key)
+		return NULL;
+
+	*new_key = *key;
+	INIT_HLIST_NODE(&new_key->node);
+	crypto_pool_add(new_key->crypto_pool_id);
+
+	return new_key;
+}
+
 static void tcp_ao_key_free_rcu(struct rcu_head *head)
 {
 	struct tcp_ao_key *key = container_of(head, struct tcp_ao_key, rcu);
@@ -291,6 +322,18 @@  int tcp_v4_ao_calc_key_sk(struct tcp_ao_key *mkt, u8 *key,
 					  htons(sk->sk_num), disn, sisn);
 }
 
+int tcp_v4_ao_calc_key_rsk(struct tcp_ao_key *mkt, u8 *key,
+			   struct request_sock *req)
+{
+	struct inet_request_sock *ireq = inet_rsk(req);
+
+	return tcp_v4_ao_calc_key(mkt, key,
+				  ireq->ir_loc_addr, ireq->ir_rmt_addr,
+				  htons(ireq->ir_num), ireq->ir_rmt_port,
+				  htonl(tcp_rsk(req)->snt_isn),
+				  htonl(tcp_rsk(req)->rcv_isn));
+}
+
 static int tcp_v4_ao_hash_pseudoheader(struct crypto_pool_ahash *hp,
 				       __be32 daddr, __be32 saddr,
 				       int nbytes)
@@ -560,6 +603,16 @@  int tcp_v4_ao_hash_skb(char *ao_hash, struct tcp_ao_key *key,
 			       tkey, hash_offset, sne);
 }
 
+struct tcp_ao_key *tcp_v4_ao_lookup_rsk(const struct sock *sk,
+					struct request_sock *req,
+					int sndid, int rcvid)
+{
+	union tcp_ao_addr *addr =
+			(union tcp_ao_addr *)&inet_rsk(req)->ir_rmt_addr;
+
+	return tcp_ao_do_lookup(sk, addr, AF_INET, sndid, rcvid, 0);
+}
+
 struct tcp_ao_key *tcp_v4_ao_lookup(const struct sock *sk, struct sock *addr_sk,
 				    int sndid, int rcvid)
 {
@@ -665,6 +718,94 @@  void tcp_ao_finish_connect(struct sock *sk, struct sk_buff *skb)
 	}
 }
 
+int tcp_ao_copy_all_matching(const struct sock *sk, struct sock *newsk,
+			     struct request_sock *req, struct sk_buff *skb,
+			     int family)
+{
+	struct tcp_ao_info *ao;
+	struct tcp_ao_info *new_ao;
+	struct tcp_ao_key *key, *new_key, *first_key;
+	struct hlist_node *n;
+	union tcp_ao_addr *addr;
+	bool match = false;
+
+	ao = rcu_dereference(tcp_sk(sk)->ao_info);
+	if (!ao)
+		return 0;
+
+	/* New socket without TCP-AO on it */
+	if (!tcp_rsk_used_ao(req))
+		return 0;
+
+	new_ao = tcp_ao_alloc_info(GFP_ATOMIC, ao);
+	if (!new_ao)
+		return -ENOMEM;
+	new_ao->lisn = htonl(tcp_rsk(req)->snt_isn);
+	new_ao->risn = htonl(tcp_rsk(req)->rcv_isn);
+
+	if (family == AF_INET)
+		addr = (union tcp_ao_addr *)&newsk->sk_daddr;
+#if IS_ENABLED(CONFIG_IPV6)
+	else if (family == AF_INET6)
+		addr = (union tcp_ao_addr *)&newsk->sk_v6_daddr;
+#endif
+	else
+		return -EOPNOTSUPP;
+
+	hlist_for_each_entry_rcu(key, &ao->head, node) {
+		if (tcp_ao_key_cmp(key, addr, key->prefixlen, family,
+				   -1, -1, 0))
+			continue;
+
+		new_key = tcp_ao_copy_key(newsk, key);
+		if (!new_key)
+			goto free_and_exit;
+
+		tcp_ao_cache_traffic_keys(newsk, new_ao, new_key);
+		tcp_ao_link_mkt(new_ao, new_key);
+		match = true;
+	}
+
+	if (match) {
+		struct hlist_node *key_head;
+
+		key_head = rcu_dereference(hlist_first_rcu(&new_ao->head));
+		first_key = hlist_entry_safe(key_head, struct tcp_ao_key, node);
+
+		/* set current_key */
+		key = tcp_ao_do_lookup_keyid(new_ao, tcp_rsk(req)->ao_keyid, -1);
+		if (key)
+			new_ao->current_key = key;
+		else
+			new_ao->current_key = first_key;
+
+		/* set rnext_key */
+		key = tcp_ao_do_lookup_keyid(new_ao, -1, tcp_rsk(req)->ao_rcv_next);
+		if (key)
+			new_ao->rnext_key = key;
+		else
+			new_ao->rnext_key = first_key;
+
+		new_ao->snd_sne_seq = tcp_rsk(req)->snt_isn;
+		new_ao->rcv_sne_seq = tcp_rsk(req)->rcv_isn;
+
+		sk_gso_disable(newsk);
+		rcu_assign_pointer(tcp_sk(newsk)->ao_info, new_ao);
+	}
+
+	return 0;
+
+free_and_exit:
+	hlist_for_each_entry_safe(key, n, &new_ao->head, node) {
+		hlist_del(&key->node);
+		crypto_pool_release(key->crypto_pool_id);
+		atomic_sub(tcp_ao_sizeof_key(key), &newsk->sk_omem_alloc);
+		kfree(key);
+	}
+	kfree(new_ao);
+	return -ENOMEM;
+}
+
 static int tcp_ao_current_rnext(struct sock *sk, u16 tcpa_flags,
 				u8 tcpa_sndid, u8 tcpa_rcvid)
 {
diff --git a/net/ipv4/tcp_input.c b/net/ipv4/tcp_input.c
index c8eac064eab6..59d4e7b246a9 100644
--- a/net/ipv4/tcp_input.c
+++ b/net/ipv4/tcp_input.c
@@ -6933,6 +6933,14 @@  int tcp_conn_request(struct request_sock_ops *rsk_ops,
 	struct dst_entry *dst;
 	struct flowi fl;
 	u8 syncookies;
+	const struct tcp_ao_hdr *aoh = NULL;
+
+#ifdef CONFIG_TCP_AO
+	/* TODO: Add an option to require TCP-AO signature */
+	/* Packet was already validated in tcp_v[46]_rcv */
+	if (tcp_parse_auth_options(tcp_hdr(skb), NULL, &aoh))
+		goto drop; /* Invalid TCP options */
+#endif
 
 	syncookies = READ_ONCE(net->ipv4.sysctl_tcp_syncookies);
 
@@ -6940,7 +6948,7 @@  int tcp_conn_request(struct request_sock_ops *rsk_ops,
 	 * limitations, they conserve resources and peer is
 	 * evidently real one.
 	 */
-	if ((syncookies == 2 || inet_csk_reqsk_queue_is_full(sk)) && !isn) {
+	if (!aoh && (syncookies == 2 || inet_csk_reqsk_queue_is_full(sk)) && !isn) {
 		want_cookie = tcp_syn_flood_action(sk, rsk_ops->slab_name);
 		if (!want_cookie)
 			goto drop;
@@ -7019,6 +7027,15 @@  int tcp_conn_request(struct request_sock_ops *rsk_ops,
 			inet_rsk(req)->ecn_ok = 0;
 	}
 
+#ifdef CONFIG_TCP_AO
+	if (aoh) {
+		tcp_rsk(req)->maclen = aoh->length - sizeof(struct tcp_ao_hdr);
+		tcp_rsk(req)->ao_rcv_next = aoh->keyid;
+		tcp_rsk(req)->ao_keyid = aoh->rnext_keyid;
+	} else {
+		tcp_rsk(req)->maclen = 0;
+	}
+#endif
 	tcp_rsk(req)->snt_isn = isn;
 	tcp_rsk(req)->txhash = net_tx_rndhash();
 	tcp_rsk(req)->syn_tos = TCP_SKB_CB(skb)->ip_dsfield;
diff --git a/net/ipv4/tcp_ipv4.c b/net/ipv4/tcp_ipv4.c
index 7aa02d228fc3..cbe428957686 100644
--- a/net/ipv4/tcp_ipv4.c
+++ b/net/ipv4/tcp_ipv4.c
@@ -1064,30 +1064,73 @@  static void tcp_v4_timewait_ack(struct sock *sk, struct sk_buff *skb)
 static void tcp_v4_reqsk_send_ack(const struct sock *sk, struct sk_buff *skb,
 				  struct request_sock *req)
 {
+	struct tcp_md5sig_key *md5_key = NULL;
+	struct tcp_ao_key *ao_key = NULL;
 	const union tcp_md5_addr *addr;
-	int l3index;
+	u8 keyid = 0;
+#ifdef CONFIG_TCP_AO
+	u8 traffic_key[TCP_AO_MAX_HASH_SIZE] __tcp_ao_key_align;
+	const struct tcp_ao_hdr *aoh;
+#else
+	u8 *traffic_key = NULL;
+#endif
 
 	/* sk->sk_state == TCP_LISTEN -> for regular TCP_SYN_RECV
 	 * sk->sk_state == TCP_SYN_RECV -> for Fast Open.
 	 */
 	u32 seq = (sk->sk_state == TCP_LISTEN) ? tcp_rsk(req)->snt_isn + 1 :
 					     tcp_sk(sk)->snd_nxt;
+	addr = (union tcp_md5_addr *)&ip_hdr(skb)->saddr;
+
+	if (tcp_rsk_used_ao(req)) {
+#ifdef CONFIG_TCP_AO
+		/* Invalid TCP option size or twice included auth */
+		if (tcp_parse_auth_options(tcp_hdr(skb), NULL, &aoh))
+			return;
+
+		if (!aoh)
+			return;
+
+		ao_key = tcp_ao_do_lookup(sk, addr, AF_INET,
+					  aoh->rnext_keyid, -1, 0);
+		if (unlikely(!ao_key)) {
+			/* Send ACK with any matching MKT for the peer */
+			ao_key = tcp_ao_do_lookup(sk, addr,
+						  AF_INET, -1, -1, 0);
+			/* Matching key disappeared (user removed the key?)
+			 * let the handshake timeout.
+			 */
+			if (!ao_key) {
+				net_info_ratelimited("TCP-AO key for (%pI4, %d)->(%pI4, %d) suddenly disappeared, won't ACK new connection\n",
+						addr,
+						ntohs(tcp_hdr(skb)->source),
+						&ip_hdr(skb)->daddr,
+						ntohs(tcp_hdr(skb)->dest));
+				return;
+			}
+		}
 
+		keyid = aoh->keyid;
+		tcp_v4_ao_calc_key_rsk(ao_key, traffic_key, req);
+#endif
+	} else {
+		int l3index;
+
+		l3index = tcp_v4_sdif(skb) ? inet_iif(skb) : 0;
+		md5_key = tcp_md5_do_lookup(sk, l3index, addr, AF_INET);
+	}
 	/* RFC 7323 2.3
 	 * The window field (SEG.WND) of every outgoing segment, with the
 	 * exception of <SYN> segments, MUST be right-shifted by
 	 * Rcv.Wind.Shift bits:
 	 */
-	addr = (union tcp_md5_addr *)&ip_hdr(skb)->saddr;
-	l3index = tcp_v4_sdif(skb) ? inet_iif(skb) : 0;
 	tcp_v4_send_ack(sk, skb, seq,
 			tcp_rsk(req)->rcv_nxt,
 			req->rsk_rcv_wnd >> inet_rsk(req)->rcv_wscale,
 			tcp_time_stamp_raw() + tcp_rsk(req)->ts_off,
 			req->ts_recent,
 			0,
-			tcp_md5_do_lookup(sk, l3index, addr, AF_INET),
-			NULL, NULL, 0, 0,
+			md5_key, ao_key, traffic_key, keyid, 0,
 			inet_rsk(req)->no_srccheck ? IP_REPLY_ARG_NOSRCCHECK : 0,
 			ip_hdr(skb)->tos);
 }
@@ -1608,6 +1651,10 @@  const struct tcp_request_sock_ops tcp_request_sock_ipv4_ops = {
 	.req_md5_lookup	=	tcp_v4_md5_lookup,
 	.calc_md5_hash	=	tcp_v4_md5_hash_skb,
 #endif
+#ifdef CONFIG_TCP_AO
+	.ao_lookup	=	tcp_v4_ao_lookup_rsk,
+	.ao_calc_key	=	tcp_v4_ao_calc_key_rsk,
+#endif
 #ifdef CONFIG_SYN_COOKIES
 	.cookie_init_seq =	cookie_v4_init_sequence,
 #endif
@@ -1709,7 +1756,7 @@  struct sock *tcp_v4_syn_recv_sock(const struct sock *sk, struct sk_buff *skb,
 	/* Copy over the MD5 key from the original socket */
 	addr = (union tcp_md5_addr *)&newinet->inet_daddr;
 	key = tcp_md5_do_lookup(sk, l3index, addr, AF_INET);
-	if (key) {
+	if (key && !tcp_rsk_used_ao(req)) {
 		/*
 		 * We're using one, so create a matching key
 		 * on the newsk structure. If we fail to get
@@ -1720,6 +1767,10 @@  struct sock *tcp_v4_syn_recv_sock(const struct sock *sk, struct sk_buff *skb,
 		sk_gso_disable(newsk);
 	}
 #endif
+#ifdef CONFIG_TCP_AO
+	if (tcp_ao_copy_all_matching(sk, newsk, req, skb, AF_INET))
+		goto put_and_exit; /* OOM, release back memory */
+#endif
 
 	if (__inet_inherit_port(sk, newsk) < 0)
 		goto put_and_exit;
diff --git a/net/ipv4/tcp_minisocks.c b/net/ipv4/tcp_minisocks.c
index 94012a015bd0..03f1b809866b 100644
--- a/net/ipv4/tcp_minisocks.c
+++ b/net/ipv4/tcp_minisocks.c
@@ -473,6 +473,9 @@  struct sock *tcp_create_openreq_child(const struct sock *sk,
 	struct inet_connection_sock *newicsk;
 	struct tcp_sock *oldtp, *newtp;
 	u32 seq;
+#ifdef CONFIG_TCP_AO
+	struct tcp_ao_key *ao_key;
+#endif
 
 	if (!newsk)
 		return NULL;
@@ -551,6 +554,13 @@  struct sock *tcp_create_openreq_child(const struct sock *sk,
 	if (treq->af_specific->req_md5_lookup(sk, req_to_sk(req)))
 		newtp->tcp_header_len += TCPOLEN_MD5SIG_ALIGNED;
 #endif
+#ifdef CONFIG_TCP_AO
+	newtp->ao_info = NULL;
+	ao_key = treq->af_specific->ao_lookup(sk, req,
+				tcp_rsk(req)->ao_keyid, -1);
+	if (ao_key)
+		newtp->tcp_header_len += tcp_ao_len(ao_key);
+ #endif
 	if (skb->len >= TCP_MSS_DEFAULT + newtp->tcp_header_len)
 		newicsk->icsk_ack.last_seg_size = skb->len - newtp->tcp_header_len;
 	newtp->rx_opt.mss_clamp = req->mss;
diff --git a/net/ipv4/tcp_output.c b/net/ipv4/tcp_output.c
index 5130e42e3dbf..be92fe85e82c 100644
--- a/net/ipv4/tcp_output.c
+++ b/net/ipv4/tcp_output.c
@@ -606,6 +606,7 @@  static void bpf_skops_write_hdr_opt(struct sock *sk, struct sk_buff *skb,
  * (but it may well be that other scenarios fail similarly).
  */
 static void tcp_options_write(struct tcphdr *th, struct tcp_sock *tp,
+			      const struct tcp_request_sock *tcprsk,
 			      struct tcp_out_options *opts,
 			      struct tcp_ao_key *ao_key)
 {
@@ -620,19 +621,32 @@  static void tcp_options_write(struct tcphdr *th, struct tcp_sock *tp,
 		ptr += 4;
 	}
 #ifdef CONFIG_TCP_AO
-	if (unlikely(OPTION_AO & options) && tp) {
-		struct tcp_ao_info *ao_info;
+	if (unlikely(OPTION_AO & options)) {
 		u8 maclen;
 
-		ao_info = rcu_dereference_check(tp->ao_info,
+		if (tp) {
+			struct tcp_ao_info *ao_info;
+
+			ao_info = rcu_dereference_check(tp->ao_info,
 				lockdep_sock_is_held(&tp->inet_conn.icsk_inet.sk));
-		if (WARN_ON_ONCE(!ao_key || !ao_info || !ao_info->rnext_key))
+			if (WARN_ON_ONCE(!ao_key || !ao_info || !ao_info->rnext_key))
+				goto out_ao;
+			maclen = tcp_ao_maclen(ao_key);
+			*ptr++ = htonl((TCPOPT_AO << 24) |
+				       (tcp_ao_len(ao_key) << 16) |
+				       (ao_key->sndid << 8) |
+				       (ao_info->rnext_key->rcvid));
+		} else if (tcprsk) {
+			u8 aolen = tcprsk->maclen + sizeof(struct tcp_ao_hdr);
+
+			maclen = tcprsk->maclen;
+			*ptr++ = htonl((TCPOPT_AO << 24) | (aolen << 16) |
+				       (tcprsk->ao_keyid << 8) |
+				       (tcprsk->ao_rcv_next));
+		} else {
+			WARN_ON_ONCE(1);
 			goto out_ao;
-		maclen = tcp_ao_maclen(ao_key);
-		*ptr++ = htonl((TCPOPT_AO << 24) |
-				(tcp_ao_len(ao_key) << 16) |
-				(ao_key->sndid << 8) |
-				(ao_info->rnext_key->rcvid));
+		}
 		opts->hash_location = (__u8 *)ptr;
 		ptr += maclen / sizeof(*ptr);
 		if (unlikely(maclen % sizeof(*ptr))) {
@@ -1411,7 +1425,7 @@  static int __tcp_transmit_skb(struct sock *sk, struct sk_buff *skb,
 		th->window	= htons(min(tp->rcv_wnd, 65535U));
 	}
 
-	tcp_options_write(th, tp, &opts, ao_key);
+	tcp_options_write(th, tp, NULL, &opts, ao_key);
 
 #ifdef CONFIG_TCP_MD5SIG
 	/* Calculate the MD5 hash, as we have all we need now */
@@ -3697,7 +3711,7 @@  struct sk_buff *tcp_make_synack(const struct sock *sk, struct dst_entry *dst,
 
 	/* RFC1323: The window in SYN & SYN/ACK segments is never scaled. */
 	th->window = htons(min(req->rsk_rcv_wnd, 65535U));
-	tcp_options_write(th, NULL, &opts, NULL);
+	tcp_options_write(th, NULL, NULL, &opts, NULL);
 	th->doff = (tcp_header_size >> 2);
 	__TCP_INC_STATS(sock_net(sk), TCP_MIB_OUTSEGS);
 
diff --git a/net/ipv6/tcp_ao.c b/net/ipv6/tcp_ao.c
index 7fd31c60488a..31ae504af8e6 100644
--- a/net/ipv6/tcp_ao.c
+++ b/net/ipv6/tcp_ao.c
@@ -53,6 +53,18 @@  int tcp_v6_ao_calc_key_sk(struct tcp_ao_key *mkt, u8 *key,
 					  htons(sk->sk_num), disn, sisn);
 }
 
+int tcp_v6_ao_calc_key_rsk(struct tcp_ao_key *mkt, u8 *key,
+			   struct request_sock *req)
+{
+	struct inet_request_sock *ireq = inet_rsk(req);
+
+	return tcp_v6_ao_calc_key(mkt, key,
+			&ireq->ir_v6_loc_addr, &ireq->ir_v6_rmt_addr,
+			htons(ireq->ir_num), ireq->ir_rmt_port,
+			htonl(tcp_rsk(req)->snt_isn),
+			htonl(tcp_rsk(req)->rcv_isn));
+}
+
 struct tcp_ao_key *tcp_v6_ao_do_lookup(const struct sock *sk,
 				       const struct in6_addr *addr,
 				       int sndid, int rcvid)
@@ -70,6 +82,15 @@  struct tcp_ao_key *tcp_v6_ao_lookup(const struct sock *sk,
 	return tcp_v6_ao_do_lookup(sk, addr, sndid, rcvid);
 }
 
+struct tcp_ao_key *tcp_v6_ao_lookup_rsk(const struct sock *sk,
+					struct request_sock *req,
+					int sndid, int rcvid)
+{
+	struct in6_addr *addr = &inet_rsk(req)->ir_v6_rmt_addr;
+
+	return tcp_v6_ao_do_lookup(sk, addr, sndid, rcvid);
+}
+
 int tcp_v6_ao_hash_pseudoheader(struct crypto_pool_ahash *hp,
 				const struct in6_addr *daddr,
 				const struct in6_addr *saddr, int nbytes)
diff --git a/net/ipv6/tcp_ipv6.c b/net/ipv6/tcp_ipv6.c
index bab4a1883b3c..7610fdf5d519 100644
--- a/net/ipv6/tcp_ipv6.c
+++ b/net/ipv6/tcp_ipv6.c
@@ -835,6 +835,10 @@  const struct tcp_request_sock_ops tcp_request_sock_ipv6_ops = {
 	.req_md5_lookup	=	tcp_v6_md5_lookup,
 	.calc_md5_hash	=	tcp_v6_md5_hash_skb,
 #endif
+#ifdef CONFIG_TCP_AO
+	.ao_lookup	=	tcp_v6_ao_lookup_rsk,
+	.ao_calc_key	=	tcp_v6_ao_calc_key_rsk,
+#endif
 #ifdef CONFIG_SYN_COOKIES
 	.cookie_init_seq =	cookie_v6_init_sequence,
 #endif
@@ -1220,9 +1224,51 @@  static void tcp_v6_timewait_ack(struct sock *sk, struct sk_buff *skb)
 static void tcp_v6_reqsk_send_ack(const struct sock *sk, struct sk_buff *skb,
 				  struct request_sock *req)
 {
+	struct tcp_md5sig_key *md5_key = NULL;
+	struct tcp_ao_key *ao_key = NULL;
+	const struct in6_addr *addr;
+	u8 keyid = 0;
+#ifdef CONFIG_TCP_AO
+	char traffic_key[TCP_AO_MAX_HASH_SIZE] __tcp_ao_key_align;
+	const struct tcp_ao_hdr *aoh;
+#else
+	u8 *traffic_key = NULL;
+#endif
 	int l3index;
 
 	l3index = tcp_v6_sdif(skb) ? tcp_v6_iif_l3_slave(skb) : 0;
+	addr = &ipv6_hdr(skb)->saddr;
+
+	if (tcp_rsk_used_ao(req)) {
+#ifdef CONFIG_TCP_AO
+		/* Invalid TCP option size or twice included auth */
+		if (tcp_parse_auth_options(tcp_hdr(skb), NULL, &aoh))
+			return;
+		if (!aoh)
+			return;
+		ao_key = tcp_v6_ao_do_lookup(sk, addr, aoh->rnext_keyid, -1);
+		if (unlikely(!ao_key)) {
+			/* Send ACK with any matching MKT for the peer */
+			ao_key = tcp_v6_ao_do_lookup(sk, addr, -1, -1);
+			/* Matching key disappeared (user removed the key?)
+			 * let the handshake timeout.
+			 */
+			if (!ao_key) {
+				net_info_ratelimited("TCP-AO key for (%pI6, %d)->(%pI6, %d) suddenly disappeared, won't ACK new connection\n",
+						addr,
+						ntohs(tcp_hdr(skb)->source),
+						&ipv6_hdr(skb)->daddr,
+						ntohs(tcp_hdr(skb)->dest));
+				return;
+			}
+		}
+
+		keyid = aoh->keyid;
+		tcp_v6_ao_calc_key_rsk(ao_key, traffic_key, req);
+#endif
+	} else {
+		md5_key = tcp_v6_md5_do_lookup(sk, addr, l3index);
+	}
 
 	/* sk->sk_state == TCP_LISTEN -> for regular TCP_SYN_RECV
 	 * sk->sk_state == TCP_SYN_RECV -> for Fast Open.
@@ -1238,9 +1284,9 @@  static void tcp_v6_reqsk_send_ack(const struct sock *sk, struct sk_buff *skb,
 			req->rsk_rcv_wnd >> inet_rsk(req)->rcv_wscale,
 			tcp_time_stamp_raw() + tcp_rsk(req)->ts_off,
 			req->ts_recent, sk->sk_bound_dev_if,
-			tcp_v6_md5_do_lookup(sk, &ipv6_hdr(skb)->saddr, l3index),
+			md5_key,
 			ipv6_get_dsfield(ipv6_hdr(skb)), 0, sk->sk_priority,
-			NULL, NULL, 0, 0);
+			ao_key, traffic_key, keyid, 0);
 }
 
 
@@ -1470,18 +1516,26 @@  static struct sock *tcp_v6_syn_recv_sock(const struct sock *sk, struct sk_buff *
 #ifdef CONFIG_TCP_MD5SIG
 	l3index = l3mdev_master_ifindex_by_index(sock_net(sk), ireq->ir_iif);
 
-	/* Copy over the MD5 key from the original socket */
-	key = tcp_v6_md5_do_lookup(sk, &newsk->sk_v6_daddr, l3index);
-	if (key) {
-		/* We're using one, so create a matching key
-		 * on the newsk structure. If we fail to get
-		 * memory, then we end up not copying the key
-		 * across. Shucks.
-		 */
-		tcp_md5_key_copy(newsk, (union tcp_md5_addr *)&newsk->sk_v6_daddr,
-			       AF_INET6, 128, l3index, key);
+	if (!tcp_rsk_used_ao(req)) {
+		const struct in6_addr *daddr = &newsk->sk_v6_daddr;
+		/* Copy over the MD5 key from the original socket */
+		key = tcp_v6_md5_do_lookup(sk, daddr, l3index);
+		if (key) {
+			/* We're using one, so create a matching key
+			 * on the newsk structure. If we fail to get
+			 * memory, then we end up not copying the key
+			 * across. Shucks.
+			 */
+			tcp_md5_key_copy(newsk, (union tcp_md5_addr *)daddr,
+					 AF_INET6, 128, l3index, key);
+		}
 	}
 #endif
+#ifdef CONFIG_TCP_AO
+	/* Copy over tcp_ao_info if any */
+	if (tcp_ao_copy_all_matching(sk, newsk, req, skb, AF_INET6))
+		goto out; /* OOM */
+#endif
 
 	if (__inet_inherit_port(sk, newsk) < 0) {
 		inet_csk_prepare_forced_close(newsk);