Adds redirect URL to login handler

This commit is contained in:
Aaron Louie 2020-02-24 16:59:16 -05:00
parent 1ac9180304
commit 002207cbca
3 changed files with 28 additions and 13 deletions

View File

@ -53,6 +53,11 @@ paths:
required: false required: false
schema: schema:
type: string type: string
- name: redirect_url
in: query
required: false
schema:
type: string
tags: tags:
- Users - Users
responses: responses:

View File

@ -32,10 +32,11 @@ def get_current_user():
@sso.login_handler @sso.login_handler
def sso_login(user_info): def sso_login(user_info):
# TODO: Get redirect URL from Shibboleth request header
_handle_login(user_info) _handle_login(user_info)
def _handle_login(user_info): def _handle_login(user_info, redirect_url=app.config['FRONTEND_AUTH_CALLBACK']):
"""On successful login, adds user to database if the user is not already in the system, """On successful login, adds user to database if the user is not already in the system,
then returns the frontend auth callback URL, with auth token appended. then returns the frontend auth callback URL, with auth token appended.
@ -50,6 +51,7 @@ def _handle_login(user_info):
last_name: Optional[str], last_name: Optional[str],
title: Optional[str], title: Optional[str],
}): Dictionary of user attributes }): Dictionary of user attributes
redirect_url: Optional[str]
Returns: Returns:
Response. 302 - Redirects to the frontend auth callback URL, with auth token appended. Response. 302 - Redirects to the frontend auth callback URL, with auth token appended.
@ -79,8 +81,7 @@ def _handle_login(user_info):
# Return the frontend auth callback URL, with auth token appended. # Return the frontend auth callback URL, with auth token appended.
auth_token = user.encode_auth_token().decode() auth_token = user.encode_auth_token().decode()
response_url = ('%s/%s' % (app.config['FRONTEND_AUTH_CALLBACK'], auth_token)) return redirect('%s/%s' % (redirect_url, auth_token))
return redirect(response_url)
def backdoor( def backdoor(
@ -91,7 +92,8 @@ def backdoor(
eppn=None, eppn=None,
first_name=None, first_name=None,
last_name=None, last_name=None,
title=None title=None,
redirect_url=None,
): ):
"""A backdoor for end-to-end system testing that allows the system to simulate logging in as a specific user. """A backdoor for end-to-end system testing that allows the system to simulate logging in as a specific user.
Only works if the application is running in a non-production environment. Only works if the application is running in a non-production environment.
@ -105,6 +107,7 @@ def backdoor(
first_name: Optional[str] first_name: Optional[str]
last_name: Optional[str] last_name: Optional[str]
title: Optional[str] title: Optional[str]
redirect_url: Optional[str]
Returns: Returns:
str. If not on production, returns the frontend auth callback URL, with auth token appended. str. If not on production, returns the frontend auth callback URL, with auth token appended.
@ -118,6 +121,6 @@ def backdoor(
if key in connexion.request.args: if key in connexion.request.args:
user_info[key] = connexion.request.args[key] user_info[key] = connexion.request.args[key]
return _handle_login(user_info) return _handle_login(user_info, redirect_url)
else: else:
raise ApiError('404', 'unknown') raise ApiError('404', 'unknown')

View File

@ -8,7 +8,7 @@ from tests.base_test import BaseTest
class TestAuthentication(BaseTest): class TestAuthentication(BaseTest):
test_uid = "dhf8r" test_uid = "dhf8r"
def logged_in_headers(self, user=None): def logged_in_headers(self, user=None, redirect_url='http://some/frontend/url'):
if user is None: if user is None:
uid = self.test_uid uid = self.test_uid
user_info = {'uid': self.test_uid, 'first_name': 'Daniel', 'last_name': 'Funk', user_info = {'uid': self.test_uid, 'first_name': 'Daniel', 'last_name': 'Funk',
@ -18,9 +18,11 @@ class TestAuthentication(BaseTest):
user_info = {'uid': user.uid, 'first_name': user.first_name, 'last_name': user.last_name, user_info = {'uid': user.uid, 'first_name': user.first_name, 'last_name': user.last_name,
'email_address': user.email_address} 'email_address': user.email_address}
query_string = self.user_info_to_query_string(user_info) query_string = self.user_info_to_query_string(user_info, redirect_url)
rv = self.app.get("/v1.0/sso_backdoor%s" % query_string, follow_redirects=True, rv = self.app.get("/v1.0/sso_backdoor%s" % query_string, follow_redirects=False)
content_type="application/json") self.assertTrue(rv.status_code == 302)
self.assertTrue(str.startswith(rv.location, redirect_url))
user_model = UserModel.query.filter_by(uid=uid).first() user_model = UserModel.query.filter_by(uid=uid).first()
self.assertIsNotNone(user_model.display_name) self.assertIsNotNone(user_model.display_name)
return dict(Authorization='Bearer ' + user_model.encode_auth_token().decode()) return dict(Authorization='Bearer ' + user_model.encode_auth_token().decode())
@ -39,10 +41,12 @@ class TestAuthentication(BaseTest):
user_info = {'uid': self.test_uid, 'first_name': 'Daniel', 'last_name': 'Funk', user_info = {'uid': self.test_uid, 'first_name': 'Daniel', 'last_name': 'Funk',
'email_address': 'dhf8r@virginia.edu'} 'email_address': 'dhf8r@virginia.edu'}
query_string = self.user_info_to_query_string(user_info) redirect_url = 'http://worlds.best.website/admin'
query_string = self.user_info_to_query_string(user_info, redirect_url)
url = '/v1.0/sso_backdoor%s' % query_string url = '/v1.0/sso_backdoor%s' % query_string
rv_1 = self.app.get(url, follow_redirects=False) rv_1 = self.app.get(url, follow_redirects=False)
self.assertTrue(rv_1.status_code == 302) self.assertTrue(rv_1.status_code == 302)
self.assertTrue(str.startswith(rv_1.location, redirect_url))
user = db.session.query(UserModel).filter(UserModel.uid == self.test_uid).first() user = db.session.query(UserModel).filter(UserModel.uid == self.test_uid).first()
self.assertIsNotNone(user) self.assertIsNotNone(user)
@ -51,7 +55,8 @@ class TestAuthentication(BaseTest):
# Hitting the same endpoint again with the same info should not cause an error # Hitting the same endpoint again with the same info should not cause an error
rv_2 = self.app.get(url, follow_redirects=False) rv_2 = self.app.get(url, follow_redirects=False)
self.assertTrue(rv_1.status_code == 302) self.assertTrue(rv_2.status_code == 302)
self.assertTrue(str.startswith(rv_2.location, redirect_url))
def test_current_user_status(self): def test_current_user_status(self):
self.load_example_data() self.load_example_data()
@ -62,15 +67,17 @@ class TestAuthentication(BaseTest):
self.assert_success(rv) self.assert_success(rv)
user = UserModel(uid="ajl2j", first_name='Aaron', last_name='Louie', email_address='ajl2j@virginia.edu') user = UserModel(uid="ajl2j", first_name='Aaron', last_name='Louie', email_address='ajl2j@virginia.edu')
rv = self.app.get('/v1.0/user', headers=self.logged_in_headers(user)) rv = self.app.get('/v1.0/user', headers=self.logged_in_headers(user, redirect_url='http://omg.edu/lolwut'))
self.assert_success(rv) self.assert_success(rv)
def user_info_to_query_string(self, user_info): def user_info_to_query_string(self, user_info, redirect_url):
query_string_list = [] query_string_list = []
items = user_info.items() items = user_info.items()
for key, value in items: for key, value in items:
query_string_list.append('%s=%s' % (key, urllib.parse.quote(value))) query_string_list.append('%s=%s' % (key, urllib.parse.quote(value)))
query_string_list.append('redirect_url=%s' % redirect_url)
return '?%s' % '&'.join(query_string_list) return '?%s' % '&'.join(query_string_list)