@@ -135,7 +135,7 @@ static int __bpf_set_link_xdp_fd_replace(int ifindex, int fd, int old_fd,
__u32 flags)
{
int sock, seq = 0, ret;
- struct nlattr *nla, *nla_xdp;
+ struct nlattr *nla;
struct {
struct nlmsghdr nh;
struct ifinfomsg ifinfo;
@@ -157,36 +157,31 @@ static int __bpf_set_link_xdp_fd_replace(int ifindex, int fd, int old_fd,
req.ifinfo.ifi_index = ifindex;
/* started nested attribute for XDP */
- nla = (struct nlattr *)(((char *)&req)
- + NLMSG_ALIGN(req.nh.nlmsg_len));
- nla->nla_type = NLA_F_NESTED | IFLA_XDP;
- nla->nla_len = NLA_HDRLEN;
+ nla = nlattr_begin_nested(&req.nh, sizeof(req), IFLA_XDP);
+ if (!nla) {
+ ret = -EMSGSIZE;
+ goto cleanup;
+ }
/* add XDP fd */
- nla_xdp = (struct nlattr *)((char *)nla + nla->nla_len);
- nla_xdp->nla_type = IFLA_XDP_FD;
- nla_xdp->nla_len = NLA_HDRLEN + sizeof(int);
- memcpy((char *)nla_xdp + NLA_HDRLEN, &fd, sizeof(fd));
- nla->nla_len += nla_xdp->nla_len;
+ ret = nlattr_add(&req.nh, sizeof(req), IFLA_XDP_FD, &fd, sizeof(fd));
+ if (ret < 0)
+ goto cleanup;
/* if user passed in any flags, add those too */
if (flags) {
- nla_xdp = (struct nlattr *)((char *)nla + nla->nla_len);
- nla_xdp->nla_type = IFLA_XDP_FLAGS;
- nla_xdp->nla_len = NLA_HDRLEN + sizeof(flags);
- memcpy((char *)nla_xdp + NLA_HDRLEN, &flags, sizeof(flags));
- nla->nla_len += nla_xdp->nla_len;
+ ret = nlattr_add(&req.nh, sizeof(req), IFLA_XDP_FLAGS, &flags, sizeof(flags));
+ if (ret < 0)
+ goto cleanup;
}
if (flags & XDP_FLAGS_REPLACE) {
- nla_xdp = (struct nlattr *)((char *)nla + nla->nla_len);
- nla_xdp->nla_type = IFLA_XDP_EXPECTED_FD;
- nla_xdp->nla_len = NLA_HDRLEN + sizeof(old_fd);
- memcpy((char *)nla_xdp + NLA_HDRLEN, &old_fd, sizeof(old_fd));
- nla->nla_len += nla_xdp->nla_len;
+ ret = nlattr_add(&req.nh, sizeof(req), IFLA_XDP_EXPECTED_FD, &flags, sizeof(flags));
+ if (ret < 0)
+ goto cleanup;
}
- req.nh.nlmsg_len += NLA_ALIGN(nla->nla_len);
+ nlattr_end_nested(&req.nh, nla);
if (send(sock, &req, req.nh.nlmsg_len, 0) < 0) {
ret = -errno;
@@ -10,7 +10,10 @@
#define __LIBBPF_NLATTR_H
#include <stdint.h>
+#include <string.h>
+#include <errno.h>
#include <linux/netlink.h>
+
/* avoid multiple definition of netlink features */
#define __LINUX_NETLINK_H
@@ -103,4 +106,49 @@ int libbpf_nla_parse_nested(struct nlattr *tb[], int maxtype,
int libbpf_nla_dump_errormsg(struct nlmsghdr *nlh);
+static inline struct nlattr *nla_data(struct nlattr *nla)
+{
+ return (struct nlattr *)((char *)nla + NLA_HDRLEN);
+}
+
+static inline struct nlattr *nh_tail(struct nlmsghdr *nh)
+{
+ return (struct nlattr *)((char *)nh + NLMSG_ALIGN(nh->nlmsg_len));
+}
+
+static inline int nlattr_add(struct nlmsghdr *nh, size_t maxsz, int type,
+ const void *data, int len)
+{
+ struct nlattr *nla;
+
+ if (NLMSG_ALIGN(nh->nlmsg_len) + NLA_ALIGN(NLA_HDRLEN + len) > maxsz)
+ return -EMSGSIZE;
+ if ((!data && len) || (data && !len))
+ return -EINVAL;
+
+ nla = nh_tail(nh);
+ nla->nla_type = type;
+ nla->nla_len = NLA_HDRLEN + len;
+ if (data)
+ memcpy(nla_data(nla), data, len);
+ nh->nlmsg_len = NLMSG_ALIGN(nh->nlmsg_len) + NLA_ALIGN(nla->nla_len);
+ return 0;
+}
+
+static inline struct nlattr *nlattr_begin_nested(struct nlmsghdr *nh,
+ size_t maxsz, int type)
+{
+ struct nlattr *tail;
+
+ tail = nh_tail(nh);
+ if (nlattr_add(nh, maxsz, type | NLA_F_NESTED, NULL, 0))
+ return NULL;
+ return tail;
+}
+
+static inline void nlattr_end_nested(struct nlmsghdr *nh, struct nlattr *tail)
+{
+ tail->nla_len = (char *)nh_tail(nh) - (char *)tail;
+}
+
#endif /* __LIBBPF_NLATTR_H */