diff mbox series

[15/15] wifi: cfg80211: validate MLO connections better

Message ID 20240102213313.ff83c034cb9a.I9962db0bfa8c73b37b8d5b59a3fad7f02f2129ae@changeid
State New
Headers show
Series cfg80211/mac80211 patches from our internal tree 2024-01-02 | expand

Commit Message

Korenblit, Miriam Rachel Jan. 2, 2024, 7:35 p.m. UTC
From: Johannes Berg <johannes.berg@intel.com>

When going into an MLO connection, validate that the link IDs
match what userspace indicated, and that the AP MLD addresses
and capabilities are all matching between the links.

Signed-off-by: Johannes Berg <johannes.berg@intel.com>
Reviewed-by: Gregory Greenman <gregory.greenman@intel.com>
Signed-off-by: Miri Korenblit <miriam.rachel.korenblit@intel.com>
---
 include/linux/ieee80211.h |  24 +++++++
 net/wireless/core.h       |   3 +-
 net/wireless/mlme.c       | 131 ++++++++++++++++++++++++++++++++++----
 net/wireless/nl80211.c    |   3 +-
 net/wireless/sme.c        |   3 +-
 5 files changed, 148 insertions(+), 16 deletions(-)
diff mbox series

Patch

diff --git a/include/linux/ieee80211.h b/include/linux/ieee80211.h
index 61dbee51025d..189f08aefecd 100644
--- a/include/linux/ieee80211.h
+++ b/include/linux/ieee80211.h
@@ -4934,6 +4934,30 @@  static inline u8 ieee80211_mle_common_size(const u8 *data)
 	return sizeof(*mle) + common + mle->variable[0];
 }
 
+/**
+ * ieee80211_mle_get_link_id - returns the link ID
+ * @data: the basic multi link element
+ *
+ * The element is assumed to be of the correct type (BASIC) and big enough,
+ * this must be checked using ieee80211_mle_type_ok().
+ *
+ * If the BSS link ID can't be found, -1 will be returned
+ */
+static inline int ieee80211_mle_get_link_id(const u8 *data)
+{
+	const struct ieee80211_multi_link_elem *mle = (const void *)data;
+	u16 control = le16_to_cpu(mle->control);
+	const u8 *common = mle->variable;
+
+	/* common points now at the beginning of ieee80211_mle_basic_common_info */
+	common += sizeof(struct ieee80211_mle_basic_common_info);
+
+	if (!(control & IEEE80211_MLC_BASIC_PRES_LINK_ID))
+		return -1;
+
+	return *common;
+}
+
 /**
  * ieee80211_mle_get_bss_param_ch_cnt - returns the BSS parameter change count
  * @mle: the basic multi link element
diff --git a/net/wireless/core.h b/net/wireless/core.h
index 1963958263d2..956ebec51bd6 100644
--- a/net/wireless/core.h
+++ b/net/wireless/core.h
@@ -362,7 +362,8 @@  int cfg80211_mlme_auth(struct cfg80211_registered_device *rdev,
 		       struct cfg80211_auth_request *req);
 int cfg80211_mlme_assoc(struct cfg80211_registered_device *rdev,
 			struct net_device *dev,
-			struct cfg80211_assoc_request *req);
+			struct cfg80211_assoc_request *req,
+			struct netlink_ext_ack *extack);
 int cfg80211_mlme_deauth(struct cfg80211_registered_device *rdev,
 			 struct net_device *dev, const u8 *bssid,
 			 const u8 *ie, int ie_len, u16 reason,
diff --git a/net/wireless/mlme.c b/net/wireless/mlme.c
index f635a8b6ca2e..3f90bbe7753d 100644
--- a/net/wireless/mlme.c
+++ b/net/wireless/mlme.c
@@ -325,28 +325,133 @@  void cfg80211_oper_and_vht_capa(struct ieee80211_vht_cap *vht_capa,
 		p1[i] &= p2[i];
 }
 
-/* Note: caller must cfg80211_put_bss() regardless of result */
-int cfg80211_mlme_assoc(struct cfg80211_registered_device *rdev,
-			struct net_device *dev,
-			struct cfg80211_assoc_request *req)
+static int
+cfg80211_mlme_check_mlo_compat(const struct ieee80211_multi_link_elem *mle_a,
+			       const struct ieee80211_multi_link_elem *mle_b,
+			       struct netlink_ext_ack *extack)
 {
-	struct wireless_dev *wdev = dev->ieee80211_ptr;
-	int err, i, j;
+	const struct ieee80211_mle_basic_common_info *common_a, *common_b;
 
-	lockdep_assert_wiphy(wdev->wiphy);
+	common_a = (const void *)mle_a->variable;
+	common_b = (const void *)mle_b->variable;
+
+	if (memcmp(common_a->mld_mac_addr, common_b->mld_mac_addr, ETH_ALEN)) {
+		NL_SET_ERR_MSG(extack, "AP MLD address mismatch");
+		return -EINVAL;
+	}
+
+	if (ieee80211_mle_get_eml_med_sync_delay((const u8 *)mle_a) !=
+	    ieee80211_mle_get_eml_med_sync_delay((const u8 *)mle_b)) {
+		NL_SET_ERR_MSG(extack, "link EML medium sync delay mismatch");
+		return -EINVAL;
+	}
+
+	if (ieee80211_mle_get_eml_cap((const u8 *)mle_a) !=
+	    ieee80211_mle_get_eml_cap((const u8 *)mle_b)) {
+		NL_SET_ERR_MSG(extack, "link EML capabilities mismatch");
+		return -EINVAL;
+	}
+
+	if (ieee80211_mle_get_mld_capa_op((const u8 *)mle_a) !=
+	    ieee80211_mle_get_mld_capa_op((const u8 *)mle_b)) {
+		NL_SET_ERR_MSG(extack, "link MLD capabilities/ops mismatch");
+		return -EINVAL;
+	}
+
+	return 0;
+}
+
+static int cfg80211_mlme_check_mlo(struct net_device *dev,
+				   struct cfg80211_assoc_request *req,
+				   struct netlink_ext_ack *extack)
+{
+	const struct ieee80211_multi_link_elem *mles[ARRAY_SIZE(req->links)] = {};
+	int i;
+
+	if (req->link_id < 0)
+		return 0;
+
+	if (!req->links[req->link_id].bss) {
+		NL_SET_ERR_MSG(extack, "no BSS for assoc link");
+		return -EINVAL;
+	}
+
+	rcu_read_lock();
+	for (i = 0; i < ARRAY_SIZE(req->links); i++) {
+		const struct cfg80211_bss_ies *ies;
+		const struct element *ml;
 
-	for (i = 1; i < ARRAY_SIZE(req->links); i++) {
 		if (!req->links[i].bss)
 			continue;
-		for (j = 0; j < i; j++) {
-			if (req->links[i].bss == req->links[j].bss)
-				return -EINVAL;
+
+		if (ether_addr_equal(req->links[i].bss->bssid, dev->dev_addr)) {
+			NL_SET_ERR_MSG(extack, "BSSID must not be our address");
+			req->links[i].error = -EINVAL;
+			goto error;
 		}
 
-		if (ether_addr_equal(req->links[i].bss->bssid, dev->dev_addr))
-			return -EINVAL;
+		ies = rcu_dereference(req->links[i].bss->ies);
+		ml = cfg80211_find_ext_elem(WLAN_EID_EXT_EHT_MULTI_LINK,
+					    ies->data, ies->len);
+		if (!ml) {
+			NL_SET_ERR_MSG(extack, "MLO BSS w/o ML element");
+			req->links[i].error = -EINVAL;
+			goto error;
+		}
+
+		if (!ieee80211_mle_type_ok(ml->data + 1,
+					   IEEE80211_ML_CONTROL_TYPE_BASIC,
+					   ml->datalen - 1)) {
+			NL_SET_ERR_MSG(extack, "BSS with invalid ML element");
+			req->links[i].error = -EINVAL;
+			goto error;
+		}
+
+		mles[i] = (const void *)(ml->data + 1);
+
+		if (ieee80211_mle_get_link_id((const u8 *)mles[i]) != i) {
+			NL_SET_ERR_MSG(extack, "link ID mismatch");
+			req->links[i].error = -EINVAL;
+			goto error;
+		}
+	}
+
+	if (WARN_ON(!mles[req->link_id]))
+		goto error;
+
+	for (i = 0; i < ARRAY_SIZE(req->links); i++) {
+		if (i == req->link_id || !req->links[i].bss)
+			continue;
+
+		if (WARN_ON(!mles[i]))
+			goto error;
+
+		if (cfg80211_mlme_check_mlo_compat(mles[req->link_id], mles[i],
+						   extack)) {
+			req->links[i].error = -EINVAL;
+			goto error;
+		}
 	}
 
+	rcu_read_unlock();
+	return 0;
+error:
+	rcu_read_unlock();
+	return -EINVAL;
+}
+
+/* Note: caller must cfg80211_put_bss() regardless of result */
+int cfg80211_mlme_assoc(struct cfg80211_registered_device *rdev,
+			struct net_device *dev,
+			struct cfg80211_assoc_request *req,
+			struct netlink_ext_ack *extack)
+{
+	struct wireless_dev *wdev = dev->ieee80211_ptr;
+	int err;
+
+	lockdep_assert_wiphy(wdev->wiphy);
+
+	err = cfg80211_mlme_check_mlo(dev, req, extack);
 	if (wdev->connected &&
 	    (!req->prev_bssid ||
 	     !ether_addr_equal(wdev->u.client.connected_addr, req->prev_bssid)))
diff --git a/net/wireless/nl80211.c b/net/wireless/nl80211.c
index 5640ac8c72ad..cfabf847024e 100644
--- a/net/wireless/nl80211.c
+++ b/net/wireless/nl80211.c
@@ -11244,7 +11244,8 @@  static int nl80211_associate(struct sk_buff *skb, struct genl_info *info)
 		struct nlattr *link;
 		int rem = 0;
 
-		err = cfg80211_mlme_assoc(rdev, dev, &req);
+		err = cfg80211_mlme_assoc(rdev, dev, &req,
+					  info->extack);
 
 		if (!err && info->attrs[NL80211_ATTR_SOCKET_OWNER]) {
 			dev->ieee80211_ptr->conn_owner_nlportid =
diff --git a/net/wireless/sme.c b/net/wireless/sme.c
index 195c8532734b..82e3ce42206c 100644
--- a/net/wireless/sme.c
+++ b/net/wireless/sme.c
@@ -209,7 +209,8 @@  static int cfg80211_conn_do_work(struct wireless_dev *wdev,
 		if (!req.bss) {
 			err = -ENOENT;
 		} else {
-			err = cfg80211_mlme_assoc(rdev, wdev->netdev, &req);
+			err = cfg80211_mlme_assoc(rdev, wdev->netdev,
+						  &req, NULL);
 			cfg80211_put_bss(&rdev->wiphy, req.bss);
 		}