diff mbox series

[RFCv3,04/15] selftests: tcp_authopt: Initial sockopt manipulation

Message ID e5d1d887325815300c92953575b2b5272e5bc0e5.1629840814.git.cdleonard@gmail.com
State New
Headers show
Series [RFCv3,01/15] tcp: authopt: Initial support and key management | expand

Commit Message

Leonard Crestez Aug. 24, 2021, 9:34 p.m. UTC
Signed-off-by: Leonard Crestez <cdleonard@gmail.com>
---
 .../tcp_authopt/tcp_authopt_test/conftest.py  |  21 ++
 .../tcp_authopt_test/linux_tcp_authopt.py     | 188 ++++++++++++++++++
 .../tcp_authopt/tcp_authopt_test/sockaddr.py  | 101 ++++++++++
 .../tcp_authopt_test/test_sockopt.py          |  74 +++++++
 4 files changed, 384 insertions(+)
 create mode 100644 tools/testing/selftests/tcp_authopt/tcp_authopt_test/conftest.py
 create mode 100644 tools/testing/selftests/tcp_authopt/tcp_authopt_test/linux_tcp_authopt.py
 create mode 100644 tools/testing/selftests/tcp_authopt/tcp_authopt_test/sockaddr.py
 create mode 100644 tools/testing/selftests/tcp_authopt/tcp_authopt_test/test_sockopt.py
diff mbox series

Patch

diff --git a/tools/testing/selftests/tcp_authopt/tcp_authopt_test/conftest.py b/tools/testing/selftests/tcp_authopt/tcp_authopt_test/conftest.py
new file mode 100644
index 000000000000..c17c8ea2a943
--- /dev/null
+++ b/tools/testing/selftests/tcp_authopt/tcp_authopt_test/conftest.py
@@ -0,0 +1,21 @@ 
+# SPDX-License-Identifier: GPL-2.0
+from tcp_authopt_test.linux_tcp_authopt import has_tcp_authopt
+import pytest
+import logging
+from contextlib import ExitStack
+
+logger = logging.getLogger(__name__)
+
+skipif_missing_tcp_authopt = pytest.mark.skipif(
+    not has_tcp_authopt(), reason="Need CONFIG_TCP_AUTHOPT"
+)
+
+
+@pytest.fixture
+def exit_stack():
+    """Return a contextlib.ExitStack as a pytest fixture
+
+    This reduces indentation making code more readable
+    """
+    with ExitStack() as exit_stack:
+        yield exit_stack
diff --git a/tools/testing/selftests/tcp_authopt/tcp_authopt_test/linux_tcp_authopt.py b/tools/testing/selftests/tcp_authopt/tcp_authopt_test/linux_tcp_authopt.py
new file mode 100644
index 000000000000..41374f9851aa
--- /dev/null
+++ b/tools/testing/selftests/tcp_authopt/tcp_authopt_test/linux_tcp_authopt.py
@@ -0,0 +1,188 @@ 
+# SPDX-License-Identifier: GPL-2.0
+"""Python wrapper around linux TCP_AUTHOPT ABI"""
+
+from dataclasses import dataclass
+from ipaddress import IPv4Address, IPv6Address, ip_address
+import socket
+import errno
+import logging
+from .sockaddr import sockaddr_in, sockaddr_in6, sockaddr_storage, sockaddr_unpack
+import typing
+import struct
+
+logger = logging.getLogger(__name__)
+
+
+def BIT(x):
+    return 1 << x
+
+
+TCP_AUTHOPT = 38
+TCP_AUTHOPT_KEY = 39
+
+TCP_AUTHOPT_MAXKEYLEN = 80
+
+TCP_AUTHOPT_FLAG_REJECT_UNEXPECTED = BIT(2)
+
+TCP_AUTHOPT_KEY_DEL = BIT(0)
+TCP_AUTHOPT_KEY_EXCLUDE_OPTS = BIT(1)
+TCP_AUTHOPT_KEY_BIND_ADDR = BIT(2)
+
+TCP_AUTHOPT_ALG_HMAC_SHA_1_96 = 1
+TCP_AUTHOPT_ALG_AES_128_CMAC_96 = 2
+
+
+@dataclass
+class tcp_authopt:
+    """Like linux struct tcp_authopt"""
+
+    flags: int = 0
+    sizeof = 4
+
+    def pack(self) -> bytes:
+        return struct.pack(
+            "I",
+            self.flags,
+        )
+
+    def __bytes__(self):
+        return self.pack()
+
+    @classmethod
+    def unpack(cls, b: bytes):
+        tup = struct.unpack("I", b)
+        return cls(*tup)
+
+
+def set_tcp_authopt(sock, opt: tcp_authopt):
+    return sock.setsockopt(socket.IPPROTO_TCP, TCP_AUTHOPT, bytes(opt))
+
+
+def get_tcp_authopt(sock: socket.socket) -> tcp_authopt:
+    b = sock.getsockopt(socket.IPPROTO_TCP, TCP_AUTHOPT, tcp_authopt.sizeof)
+    return tcp_authopt.unpack(b)
+
+
+class tcp_authopt_key:
+    """Like linux struct tcp_authopt_key"""
+
+    def __init__(
+        self,
+        flags: int = 0,
+        send_id: int = 0,
+        recv_id: int = 0,
+        alg=TCP_AUTHOPT_ALG_HMAC_SHA_1_96,
+        key: bytes = b"",
+        addr: bytes = b"",
+        include_options=None,
+    ):
+        self.flags = flags
+        self.send_id = send_id
+        self.recv_id = recv_id
+        self.alg = alg
+        self.key = key
+        self.addr = addr
+        if include_options is not None:
+            self.include_options = include_options
+
+    def pack(self):
+        if len(self.key) > TCP_AUTHOPT_MAXKEYLEN:
+            raise ValueError(f"Max key length is {TCP_AUTHOPT_MAXKEYLEN}")
+        data = struct.pack(
+            "IBBBB80s",
+            self.flags,
+            self.send_id,
+            self.recv_id,
+            self.alg,
+            len(self.key),
+            self.key,
+        )
+        data += bytes(self.addrbuf.ljust(sockaddr_storage.sizeof, b"\x00"))
+        return data
+
+    def __bytes__(self):
+        return self.pack()
+
+    @property
+    def key(self) -> bytes:
+        return self._key
+
+    @key.setter
+    def key(self, val: typing.Union[bytes, str]) -> bytes:
+        if isinstance(val, str):
+            val = val.encode("utf-8")
+        if len(val) > TCP_AUTHOPT_MAXKEYLEN:
+            raise ValueError(f"Max key length is {TCP_AUTHOPT_MAXKEYLEN}")
+        self._key = val
+        return val
+
+    @property
+    def addr(self):
+        if not self.addrbuf:
+            return None
+        else:
+            return sockaddr_unpack(bytes(self.addrbuf))
+
+    @addr.setter
+    def addr(self, val):
+        if isinstance(val, bytes):
+            if len(val) > sockaddr_storage.sizeof:
+                raise ValueError(f"Must be up to {sockaddr_storage.sizeof}")
+            self.addrbuf = val
+        elif val is None:
+            self.addrbuf = b""
+        elif isinstance(val, str):
+            self.addr = ip_address(val)
+        elif isinstance(val, IPv4Address):
+            self.addr = sockaddr_in(addr=val)
+        elif isinstance(val, IPv6Address):
+            self.addr = sockaddr_in6(addr=val)
+        elif (
+            isinstance(val, sockaddr_in)
+            or isinstance(val, sockaddr_in6)
+            or isinstance(val, sockaddr_storage)
+        ):
+            self.addr = bytes(val)
+        else:
+            raise TypeError(f"Can't handle addr {val}")
+        return self.addr
+
+    @property
+    def include_options(self) -> bool:
+        return (self.flags & TCP_AUTHOPT_KEY_EXCLUDE_OPTS) == 0
+
+    @include_options.setter
+    def include_options(self, value) -> bool:
+        if value:
+            self.flags &= ~TCP_AUTHOPT_KEY_EXCLUDE_OPTS
+        else:
+            self.flags |= TCP_AUTHOPT_KEY_EXCLUDE_OPTS
+
+    @property
+    def delete_flag(self) -> bool:
+        return bool(self.flags & TCP_AUTHOPT_KEY_DEL)
+
+    @delete_flag.setter
+    def delete_flag(self, value) -> bool:
+        if value:
+            self.flags |= TCP_AUTHOPT_KEY_DEL
+        else:
+            self.flags &= ~TCP_AUTHOPT_KEY_DEL
+
+
+def set_tcp_authopt_key(sock, key: tcp_authopt_key):
+    return sock.setsockopt(socket.IPPROTO_TCP, TCP_AUTHOPT_KEY, bytes(key))
+
+
+def has_tcp_authopt() -> bool:
+    """Check is TCP_AUTHOPT is implemented by the OS"""
+    with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
+        try:
+            optbuf = bytes(4)
+            sock.setsockopt(socket.IPPROTO_TCP, TCP_AUTHOPT, optbuf)
+            return True
+        except OSError as e:
+            if e.errno == errno.ENOPROTOOPT:
+                return False
+            else:
+                raise
diff --git a/tools/testing/selftests/tcp_authopt/tcp_authopt_test/sockaddr.py b/tools/testing/selftests/tcp_authopt/tcp_authopt_test/sockaddr.py
new file mode 100644
index 000000000000..f61d0f190a0c
--- /dev/null
+++ b/tools/testing/selftests/tcp_authopt/tcp_authopt_test/sockaddr.py
@@ -0,0 +1,101 @@ 
+# SPDX-License-Identifier: GPL-2.0
+"""pack/unpack wrappers for sockaddr"""
+import socket
+import struct
+from dataclasses import dataclass
+from ipaddress import IPv4Address, IPv6Address
+
+
+@dataclass
+class sockaddr_in:
+    port: int
+    addr: IPv4Address
+    sizeof = 8
+
+    def __init__(self, port=0, addr=None):
+        self.port = port
+        if addr is None:
+            addr = IPv4Address(0)
+        self.addr = IPv4Address(addr)
+
+    def pack(self):
+        return struct.pack("HH4s", socket.AF_INET, self.port, self.addr.packed)
+
+    @classmethod
+    def unpack(cls, buffer):
+        family, port, addr_packed = struct.unpack("HH4s", buffer[:8])
+        if family != socket.AF_INET:
+            raise ValueError(f"Must be AF_INET not {family}")
+        return cls(port, addr_packed)
+
+    def __bytes__(self):
+        return self.pack()
+
+
+@dataclass
+class sockaddr_in6:
+    """Like sockaddr_in6 but for python. Always contains scope_id"""
+
+    port: int
+    addr: IPv6Address
+    flowinfo: int
+    scope_id: int
+    sizeof = 28
+
+    def __init__(self, port=0, addr=None, flowinfo=0, scope_id=0):
+        self.port = port
+        if addr is None:
+            addr = IPv6Address(0)
+        self.addr = IPv6Address(addr)
+        self.flowinfo = flowinfo
+        self.scope_id = scope_id
+
+    def pack(self):
+        return struct.pack(
+            "HHI16sI",
+            socket.AF_INET6,
+            self.port,
+            self.flowinfo,
+            self.addr.packed,
+            self.scope_id,
+        )
+
+    @classmethod
+    def unpack(cls, buffer):
+        family, port, flowinfo, addr_packed, scope_id = struct.unpack(
+            "HHI16sI", buffer[:28]
+        )
+        if family != socket.AF_INET6:
+            raise ValueError(f"Must be AF_INET6 not {family}")
+        return cls(port, addr_packed, flowinfo=flowinfo, scope_id=scope_id)
+
+    def __bytes__(self):
+        return self.pack()
+
+
+@dataclass
+class sockaddr_storage:
+    family: int
+    data: bytes
+    sizeof = 128
+
+    def pack(self):
+        return struct.pack("H126s", self.family, self.data)
+
+    def __bytes__(self):
+        return self.pack()
+
+    @classmethod
+    def unpack(cls, buffer):
+        return cls(*struct.unpack("H126s", buffer))
+
+
+def sockaddr_unpack(buffer: bytes):
+    """Unpack based on family"""
+    family = struct.unpack("H", buffer[:2])[0]
+    if family == socket.AF_INET:
+        return sockaddr_in.unpack(buffer)
+    elif family == socket.AF_INET6:
+        return sockaddr_in6.unpack(buffer)
+    else:
+        return sockaddr_storage.unpack(buffer)
diff --git a/tools/testing/selftests/tcp_authopt/tcp_authopt_test/test_sockopt.py b/tools/testing/selftests/tcp_authopt/tcp_authopt_test/test_sockopt.py
new file mode 100644
index 000000000000..06a05bf8aeec
--- /dev/null
+++ b/tools/testing/selftests/tcp_authopt/tcp_authopt_test/test_sockopt.py
@@ -0,0 +1,74 @@ 
+# SPDX-License-Identifier: GPL-2.0
+"""Test TCP_AUTHOPT sockopt API"""
+import errno
+import socket
+import struct
+from ipaddress import IPv4Address, IPv6Address
+
+import pytest
+
+from .linux_tcp_authopt import (
+    set_tcp_authopt,
+    set_tcp_authopt_key,
+    tcp_authopt,
+    tcp_authopt_key,
+)
+from .sockaddr import sockaddr_unpack
+from .conftest import skipif_missing_tcp_authopt
+
+pytestmark = skipif_missing_tcp_authopt
+
+
+def test_authopt_key_pack_noaddr():
+    b = bytes(tcp_authopt_key(key=b"a\x00b"))
+    assert b[7] == 3
+    assert b[8:13] == b"a\x00b\x00\x00"
+
+
+def test_authopt_key_pack_addr():
+    b = bytes(tcp_authopt_key(key=b"a\x00b", addr="10.0.0.1"))
+    assert struct.unpack("H", b[88:90])[0] == socket.AF_INET
+    assert sockaddr_unpack(b[88:]).addr == IPv4Address("10.0.0.1")
+
+
+def test_authopt_key_pack_addr6():
+    b = bytes(tcp_authopt_key(key=b"abc", addr="fd00::1"))
+    assert struct.unpack("H", b[88:90])[0] == socket.AF_INET6
+    assert sockaddr_unpack(b[88:]).addr == IPv6Address("fd00::1")
+
+
+def test_tcp_authopt_key_del_without_active(exit_stack):
+    sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+    exit_stack.push(sock)
+
+    # nothing happens:
+    key = tcp_authopt_key()
+    assert key.delete_flag is False
+    key.delete_flag = True
+    assert key.delete_flag is True
+    with pytest.raises(OSError) as e:
+        set_tcp_authopt_key(sock, key)
+    assert e.value.errno in [errno.EINVAL, errno.ENOENT]
+
+
+def test_tcp_authopt_key_setdel(exit_stack):
+    sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+    exit_stack.push(sock)
+    set_tcp_authopt(sock, tcp_authopt())
+
+    # delete returns ENOENT
+    key = tcp_authopt_key()
+    key.delete_flag = True
+    with pytest.raises(OSError) as e:
+        set_tcp_authopt_key(sock, key)
+    assert e.value.errno == errno.ENOENT
+
+    key = tcp_authopt_key(send_id=1, recv_id=2)
+    set_tcp_authopt_key(sock, key)
+    # First delete works fine:
+    key.delete_flag = True
+    set_tcp_authopt_key(sock, key)
+    # Duplicate delete returns ENOENT
+    with pytest.raises(OSError) as e:
+        set_tcp_authopt_key(sock, key)
+    assert e.value.errno == errno.ENOENT