diff mbox

[Branch,~linaro-validation/lava-tool/trunk] Rev 153: When accessing and storing tokens in the keyring, use the full URL of the

Message ID 20110614230220.17492.25753.launchpad@loganberry.canonical.com
State Accepted
Headers show

Commit Message

Michael-Doyle Hudson June 14, 2011, 11:02 p.m. UTC
Merge authors:
  Michael Hudson-Doyle (mwhudson)
Related merge proposals:
  https://code.launchpad.net/~mwhudson/lava-tool/full-url-as-keyring-service-name/+merge/64484
  proposed by: Michael Hudson-Doyle (mwhudson)
  review: Approve - Zygmunt Krynicki (zkrynicki)
------------------------------------------------------------
revno: 153 [merge]
committer: Michael-Doyle Hudson <michael.hudson@linaro.org>
branch nick: trunk
timestamp: Wed 2011-06-15 11:00:32 +1200
message:
  When accessing and storing tokens in the keyring, use the full URL of the
  xml-rpc endpoint as the system name rather than just the hostname.
modified:
  lava_tool/authtoken.py
  lava_tool/commands/auth.py
  lava_tool/tests/test_auth_commands.py
  lava_tool/tests/test_authtoken.py


--
lp:lava-tool
https://code.launchpad.net/~linaro-validation/lava-tool/trunk

You are subscribed to branch lp:lava-tool.
To unsubscribe from this branch go to https://code.launchpad.net/~linaro-validation/lava-tool/trunk/+edit-subscription
diff mbox

Patch

=== modified file 'lava_tool/authtoken.py'
--- lava_tool/authtoken.py	2011-06-08 03:49:30 +0000
+++ lava_tool/authtoken.py	2011-06-14 03:15:20 +0000
@@ -24,67 +24,73 @@ 
 
 from lava_tool.interface import LavaCommandError
 
+
 class AuthBackend(object):
 
-    def add_token(self, username, hostname, token):
+    def add_token(self, username, endpoint_url, token):
         raise NotImplementedError
 
-    def get_token_for_host(self, user, host):
+    def get_token_for_endpoint(self, user, endpoint_url):
         raise NotImplementedError
 
 
 class KeyringAuthBackend(AuthBackend):
 
-    def add_token(self, username, hostname, token):
-        keyring.core.set_password("lava-tool-%s" % hostname, username, token)
+    def add_token(self, username, endpoint_url, token):
+        keyring.core.set_password(
+            "lava-tool-%s" % endpoint_url, username, token)
 
-    def get_token_for_host(self, username, hostname):
-        return keyring.core.get_password("lava-tool-%s" % hostname, username)
+    def get_token_for_host(self, username, endpoint_url):
+        return keyring.core.get_password(
+            "lava-tool-%s" % endpoint_url, username)
 
 
 class MemoryAuthBackend(AuthBackend):
 
-    def __init__(self, user_host_token_list):
+    def __init__(self, user_endpoint_token_list):
         self._tokens = {}
-        for user, host, token in user_host_token_list:
-            self._tokens[(user, host)] = token
-
-    def add_token(self, username, hostname, token):
-        self._tokens[(username, hostname)] = token
-
-    def get_token_for_host(self, username, host):
-        return self._tokens.get((username, host))
+        for user, endpoint, token in user_endpoint_token_list:
+            self._tokens[(user, endpoint)] = token
+
+    def add_token(self, username, endpoint_url, token):
+        self._tokens[(username, endpoint_url)] = token
+
+    def get_token_for_endpoint(self, username, endpoint_url):
+        return self._tokens.get((username, endpoint_url))
 
 
 class AuthenticatingTransportMixin:
 
+    def send_request(self, connection, handler, request_body):
+        xmlrpclib.Transport.send_request(
+            self, connection, handler, request_body)
+        auth, host = urllib.splituser(self._connection[0])
+        if auth is None:
+            return
+        user, token = urllib.splitpasswd(auth)
+        if token is None:
+            endpoint_url = '%s://%s%s' % (self._scheme, host, handler)
+            token = self.auth_backend.get_token_for_endpoint(
+                user, endpoint_url)
+            if token is None:
+                raise LavaCommandError(
+                    "Username provided but no token found.")
+        auth = base64.b64encode(urllib.unquote(user + ':' + token))
+        connection.putheader("Authorization", "Basic " + auth)
+
     def get_host_info(self, host):
-
+        # We override to never send any authorization header based soley on
+        # the host; we do all that in send_request above.
         x509 = {}
         if isinstance(host, tuple):
             host, x509 = host
-
         auth, host = urllib.splituser(host)
-
-        if auth:
-            user, token = urllib.splitpasswd(auth)
-            if token is None:
-                token = self.auth_backend.get_token_for_host(user, host)
-                if token is None:
-                    raise LavaCommandError(
-                        "Username provided but no token found.")
-            auth = base64.b64encode(urllib.unquote(user + ':' + token))
-            extra_headers = [
-                ("Authorization", "Basic " + auth)
-                ]
-        else:
-            extra_headers = None
-
-        return host, extra_headers, x509
+        return host, None, x509
 
 
 class AuthenticatingTransport(
         AuthenticatingTransportMixin, xmlrpclib.Transport):
+    _scheme = 'http'
     def __init__(self, use_datetime=0, auth_backend=None):
         xmlrpclib.Transport.__init__(self, use_datetime)
         self.auth_backend = auth_backend
@@ -92,6 +98,7 @@ 
 
 class AuthenticatingSafeTransport(
         AuthenticatingTransportMixin, xmlrpclib.SafeTransport):
+    _scheme = 'https'
     def __init__(self, use_datetime=0, auth_backend=None):
         xmlrpclib.SafeTransport.__init__(self, use_datetime)
         self.auth_backend = auth_backend

=== modified file 'lava_tool/commands/auth.py'
--- lava_tool/commands/auth.py	2011-06-09 05:35:17 +0000
+++ lava_tool/commands/auth.py	2011-06-14 01:30:55 +0000
@@ -80,7 +80,8 @@ 
         if parsed_host.port:
             host += ':' + str(parsed_host.port)
 
-        uri = '%s://%s@%s/RPC2/' % (parsed_host.scheme, username, host)
+        uri = '%s://%s@%s%s' % (
+            parsed_host.scheme, username, host, parsed_host.path)
 
         if self.args.token_file:
             if parsed_host.password:
@@ -118,6 +119,7 @@ 
                     "whoami() returned %s rather than expected %s -- this is "
                     "a bug." % (token_user, username))
 
-        self.auth_backend.add_token(username, host, token)
+        userless_uri = '%s://%s%s' % (parsed_host.scheme, host, parsed_host.path)
+        self.auth_backend.add_token(username, userless_uri, token)
 
         print 'Token added successfully for user %s.' % username

=== modified file 'lava_tool/tests/test_auth_commands.py'
--- lava_tool/tests/test_auth_commands.py	2011-06-09 05:35:17 +0000
+++ lava_tool/tests/test_auth_commands.py	2011-06-14 00:51:14 +0000
@@ -57,10 +57,49 @@ 
     def test_token_taken_from_argument(self):
         auth_backend = MemoryAuthBackend([])
         cmd = self.make_command(
+            auth_backend, HOST='http://user:TOKEN@example.com/RPC2/',
+            no_check=True)
+        cmd.invoke()
+        self.assertEqual(
+            'TOKEN',
+            auth_backend.get_token_for_endpoint(
+                'user', 'http://example.com/RPC2/'))
+
+    def test_RPC2_implied(self):
+        auth_backend = MemoryAuthBackend([])
+        cmd = self.make_command(
             auth_backend, HOST='http://user:TOKEN@example.com', no_check=True)
         cmd.invoke()
         self.assertEqual(
-            'TOKEN', auth_backend.get_token_for_host('user', 'example.com'))
+            'TOKEN',
+            auth_backend.get_token_for_endpoint(
+                'user', 'http://example.com/RPC2/'))
+
+    def test_scheme_recorded(self):
+        auth_backend = MemoryAuthBackend([])
+        cmd = self.make_command(
+            auth_backend, HOST='https://user:TOKEN@example.com/RPC2/',
+            no_check=True)
+        cmd.invoke()
+        self.assertEqual(
+            None,
+            auth_backend.get_token_for_endpoint(
+                'user', 'http://example.com/RPC2/'))
+        self.assertEqual(
+            'TOKEN',
+            auth_backend.get_token_for_endpoint(
+                'user', 'https://example.com/RPC2/'))
+
+    def test_path_on_server_recorded(self):
+        auth_backend = MemoryAuthBackend([])
+        cmd = self.make_command(
+            auth_backend, HOST='https://user:TOKEN@example.com/path',
+            no_check=True)
+        cmd.invoke()
+        self.assertEqual(
+            'TOKEN',
+            auth_backend.get_token_for_endpoint(
+                'user', 'https://example.com/path/RPC2/'))
 
     def test_token_taken_from_getpass(self):
         mocked_getpass = self.mocker.replace('getpass.getpass', passthrough=False)
@@ -72,7 +111,9 @@ 
             auth_backend, HOST='http://user@example.com', no_check=True)
         cmd.invoke()
         self.assertEqual(
-            'TOKEN', auth_backend.get_token_for_host('user', 'example.com'))
+            'TOKEN',
+            auth_backend.get_token_for_endpoint(
+                'user', 'http://example.com/RPC2/'))
 
     def test_token_taken_from_file(self):
         auth_backend = MemoryAuthBackend([])
@@ -84,7 +125,9 @@ 
             token_file=token_file.name)
         cmd.invoke()
         self.assertEqual(
-            'TOKEN', auth_backend.get_token_for_host('user', 'example.com'))
+            'TOKEN',
+            auth_backend.get_token_for_endpoint(
+                'user', 'http://example.com/RPC2/'))
 
     def test_token_file_and_in_url_conflict(self):
         auth_backend = MemoryAuthBackend([])
@@ -114,7 +157,9 @@ 
             token_file=token_file.name)
         cmd.invoke()
         self.assertEqual(
-            'TOKEN', auth_backend.get_token_for_host('user', 'example.com'))
+            'TOKEN',
+            auth_backend.get_token_for_endpoint(
+                'user', 'http://example.com/RPC2/'))
 
     def test_port_included(self):
         auth_backend = MemoryAuthBackend([])
@@ -122,7 +167,9 @@ 
             auth_backend, HOST='http://user:TOKEN@example.com:1234', no_check=True)
         cmd.invoke()
         self.assertEqual(
-            'TOKEN', auth_backend.get_token_for_host('user', 'example.com:1234'))
+            'TOKEN',
+            auth_backend.get_token_for_endpoint(
+                'user', 'http://example.com:1234/RPC2/'))
 
     def test_check_made(self):
         mocked_AuthenticatingServerProxy = self.mocker.replace(
@@ -136,10 +183,12 @@ 
         self.mocker.replay()
         auth_backend = MemoryAuthBackend([])
         cmd = self.make_command(
-            auth_backend, HOST='http://user:TOKEN@example.com:1234', no_check=False)
+            auth_backend, HOST='http://user:TOKEN@example.com', no_check=False)
         cmd.invoke()
         self.assertEqual(
-            'TOKEN', auth_backend.get_token_for_host('user', 'example.com:1234'))
+            'TOKEN',
+            auth_backend.get_token_for_endpoint(
+                'user', 'http://example.com/RPC2/'))
 
     def test_check_auth_failure_reported_nicely(self):
         mocked_AuthenticatingServerProxy = self.mocker.replace(
@@ -169,7 +218,9 @@ 
             auth_backend, HOST='http://user:TOKEN@example.com', no_check=False)
         self.assertRaises(LavaCommandError, cmd.invoke)
         self.assertEqual(
-            None, auth_backend.get_token_for_host('user', 'example.com'))
+            None,
+            auth_backend.get_token_for_endpoint(
+                'user', 'http://example.com/RPC2/'))
 
     def test_check_other_http_failure_just_raised(self):
         mocked_AuthenticatingServerProxy = self.mocker.replace(

=== modified file 'lava_tool/tests/test_authtoken.py'
--- lava_tool/tests/test_authtoken.py	2011-06-09 05:30:32 +0000
+++ lava_tool/tests/test_authtoken.py	2011-06-14 23:00:09 +0000
@@ -21,31 +21,75 @@ 
 """
 
 import base64
+import StringIO
 from unittest import TestCase
+import urlparse
+import xmlrpclib
 
 from lava_tool.authtoken import (
-    AuthenticatingTransportMixin,
+    AuthenticatingServerProxy,
     MemoryAuthBackend,
     )
 from lava_tool.interface import LavaCommandError
-
-
-class TestAuthenticatingTransportMixin(TestCase):
-
-    def headers_for_host(self, host, auth_backend):
-        a = AuthenticatingTransportMixin()
-        a.auth_backend = auth_backend
-        _, headers, _ = a.get_host_info(host)
-        return headers
-
-    def user_and_password_from_headers(self, headers):
-        if len(headers) != 1:
-            self.fail("expected exactly 1 header, got %r" % headers)
-        [(name, value)] = headers
-        if name != 'Authorization':
-            self.fail("non-authorization header found in %r" % headers)
+from lava_tool.mocker import ARGS, KWARGS, Mocker
+
+
+class TestAuthenticatingServerProxy(TestCase):
+
+    def auth_headers_for_method_call_on(self, url, auth_backend):
+        parsed = urlparse.urlparse(url)
+        expected_host = parsed.hostname
+        if parsed.port:
+            expected_host += ':' + str(parsed.port)
+        server_proxy = AuthenticatingServerProxy(
+            url, auth_backend=auth_backend)
+        mocker = Mocker()
+        if url.startswith('https'):
+            cls_name = 'httplib.HTTPSConnection'
+            expected_constructor_args = (expected_host, None)
+        else:
+            cls_name = 'httplib.HTTPConnection'
+            expected_constructor_args = (expected_host,)
+        mocked_HTTPConnection = mocker.replace(cls_name, passthrough=False)
+        mocked_connection = mocked_HTTPConnection(*expected_constructor_args)
+        # nospec() is required because of
+        # https://bugs.launchpad.net/mocker/+bug/794351
+        mocker.nospec()
+        auth_data = []
+        mocked_connection.putrequest(ARGS, KWARGS)
+
+        def match_header(header, *values):
+            if header.lower() == 'authorization':
+                if len(values) != 1:
+                    self.fail(
+                        'more than one value for '
+                        'putheader("Authorization", ...)')
+                auth_data.append(values[0])
+        mocked_connection.putheader(ARGS)
+        mocker.call(match_header)
+        mocker.count(1, None)
+
+        mocked_connection.endheaders(ARGS, KWARGS)
+
+        mocked_connection.getresponse(ARGS, KWARGS)
+        s = StringIO.StringIO(xmlrpclib.dumps((1,), methodresponse=True))
+        s.status = 200
+        mocker.result(s)
+
+        mocked_connection.close()
+        mocker.count(0, 1)
+
+        with mocker:
+            server_proxy.method()
+
+        return auth_data
+
+    def user_and_password_from_auth_data(self, auth_data):
+        if len(auth_data) != 1:
+            self.fail("expected exactly 1 header, got %r" % len(auth_data))
+        [value] = auth_data
         if not value.startswith("Basic "):
-            self.fail("non-basic auth header found in %r" % headers)
+            self.fail("non-basic auth header found in %r" % auth_data)
         auth = base64.b64decode(value[len("Basic "):])
         if ':' in auth:
             return tuple(auth.split(':', 1))
@@ -53,17 +97,38 @@ 
             return (auth, None)
 
     def test_no_user_no_auth(self):
-        headers = self.headers_for_host('example.com', MemoryAuthBackend([]))
-        self.assertEqual(None, headers)
+        auth_headers = self.auth_headers_for_method_call_on(
+            'http://localhost/RPC2/', MemoryAuthBackend([]))
+        self.assertEqual([], auth_headers)
+
+    def test_token_used_for_auth_http(self):
+        auth_headers = self.auth_headers_for_method_call_on(
+            'http://user@localhost/RPC2/',
+            MemoryAuthBackend([('user', 'http://localhost/RPC2/', 'TOKEN')]))
+        self.assertEqual(
+            ('user', 'TOKEN'),
+            self.user_and_password_from_auth_data(auth_headers))
+
+    def test_token_used_for_auth_https(self):
+        auth_headers = self.auth_headers_for_method_call_on(
+            'https://user@localhost/RPC2/',
+            MemoryAuthBackend([('user', 'https://localhost/RPC2/', 'TOKEN')]))
+        self.assertEqual(
+            ('user', 'TOKEN'),
+            self.user_and_password_from_auth_data(auth_headers))
+
+    def test_port_included(self):
+        auth_headers = self.auth_headers_for_method_call_on(
+            'http://user@localhost:1234/RPC2/',
+            MemoryAuthBackend(
+                [('user', 'http://localhost:1234/RPC2/', 'TOKEN')]))
+        self.assertEqual(
+            ('user', 'TOKEN'),
+            self.user_and_password_from_auth_data(auth_headers))
 
     def test_error_when_user_but_no_token(self):
         self.assertRaises(
             LavaCommandError,
-            self.headers_for_host, 'user@example.com', MemoryAuthBackend([]))
-
-    def test_token_used_for_auth(self):
-        headers = self.headers_for_host(
-            'user@example.com',
-            MemoryAuthBackend([('user', 'example.com', "TOKEN")]))
-        self.assertEqual(
-            ('user', 'TOKEN'), self.user_and_password_from_headers(headers))
+            self.auth_headers_for_method_call_on,
+            'http://user@localhost/RPC2/',
+            MemoryAuthBackend([]))