diff --git a/api/login_api.py b/api/login_api.py index c7503b74..69a19e7e 100644 --- a/api/login_api.py +++ b/api/login_api.py @@ -17,10 +17,9 @@ import logging from google.oauth2 import id_token from google.auth.transport import requests -from flask import session from framework import basehandlers -from framework import xsrf +from framework import users import settings @@ -37,11 +36,7 @@ class LoginAPI(basehandlers.APIHandler): idinfo = id_token.verify_oauth2_token( token, requests.Request(), settings.GOOGLE_SIGN_IN_CLIENT_ID) - user_info = { - 'email': idinfo['email'], - } - signature = xsrf.generate_token(str(user_info)) - session['signed_user_info'] = user_info, signature + users.add_signed_user_info_to_session(idinfo['email']) message = "Done" # print(idinfo['email'], file=sys.stderr) except ValueError: diff --git a/api/login_api_test.py b/api/login_api_test.py index 38e64eb5..7e15b52d 100644 --- a/api/login_api_test.py +++ b/api/login_api_test.py @@ -46,25 +46,18 @@ class LoginAPITest(testing_config.CustomTestCase): """We reject login requests that have an invalid credential_token.""" params = {'credential': 'fake bad token'} with test_app.test_request_context(self.request_path, json=params): - session['something else'] = 'some other aspect of the session' + session.clear() actual_response = self.handler.do_post() self.assertEqual({'message': 'Invalid token'}, actual_response) - self.assertEqual(1, len(session)) + self.assertNotIn('signed_user_info', session) @mock.patch('google.oauth2.id_token.verify_oauth2_token') def test_post__normal(self, mock_verify): """We log in the user if they provide a good credential_token.""" mock_verify.return_value = {'email': 'user@example.com'} - params = {'credential': 'fake bad token'} + params = {'credential': 'fake good token'} with test_app.test_request_context(self.request_path, json=params): session.clear() - session['something else'] = 'some other aspect of the session' actual_response = self.handler.do_post() self.assertEqual({'message': 'Done'}, actual_response) - self.assertEqual(2, len(session)) - user_info, signature = session['signed_user_info'] - self.assertEqual({'email': 'user@example.com'}, user_info) - xsrf.validate_token( - signature, - str(user_info), - timeout=xsrf.REFRESH_TOKEN_TIMEOUT_SEC) + self.assertIn('signed_user_info', session) diff --git a/api/token_refresh_api.py b/api/token_refresh_api.py index e3dd8fa8..4f43ec0b 100644 --- a/api/token_refresh_api.py +++ b/api/token_refresh_api.py @@ -13,14 +13,11 @@ # See the License for the specific language governing permissions and # limitations under the License. - - - import logging from framework import basehandlers from framework import xsrf -from internals import models +from framework import users class TokenRefreshAPI(basehandlers.APIHandler): @@ -41,8 +38,9 @@ class TokenRefreshAPI(basehandlers.APIHandler): # Note: we use only POST instead of GET to avoid attacks that use GETs. def do_post(self): - """Return a new XSRF token for the current user.""" + """Refresh the session and return a new XSRF token for the current user.""" user = self.get_current_user() + users.refresh_user_session() result = { 'token': xsrf.generate_token(user.email()), 'token_expires_sec': xsrf.token_expires_sec(), diff --git a/api/token_refresh_api_test.py b/api/token_refresh_api_test.py index c5b045d6..b9efe2f3 100644 --- a/api/token_refresh_api_test.py +++ b/api/token_refresh_api_test.py @@ -15,6 +15,7 @@ import testing_config # Must be imported before the module under test. import flask +from flask import session from unittest import mock import werkzeug.exceptions # Flask HTTP stuff. @@ -22,6 +23,7 @@ from api import token_refresh_api from framework import xsrf test_app = flask.Flask(__name__) +test_app.secret_key = 'testing secret' class TokenRefreshAPITest(testing_config.CustomTestCase): @@ -73,9 +75,12 @@ class TokenRefreshAPITest(testing_config.CustomTestCase): def test_do_post__OK(self): """If the request is accepted, we return a new token.""" - testing_config.sign_in('user@example.com', 111) params = {'token': 'checked in base class'} with test_app.test_request_context(self.request_path, json=params): + session.clear() + testing_config.sign_in('user@example.com', 111) actual = self.handler.do_post() - self.assertIn('token', actual) - self.assertIn('token_expires_sec', actual) + + self.assertIn('signed_user_info', session) + self.assertIn('token', actual) + self.assertIn('token_expires_sec', actual) diff --git a/framework/basehandlers.py b/framework/basehandlers.py index f080f37d..5f5faa4f 100644 --- a/framework/basehandlers.py +++ b/framework/basehandlers.py @@ -29,6 +29,7 @@ from framework import csp from framework import permissions from framework import ramcache from framework import secrets +from framework import users from framework import utils from framework import xsrf from internals import approval_defs @@ -40,7 +41,6 @@ import django from google.auth.transport import requests from flask import session import sys -from framework import users # Initialize django so that it'll function when run as a standalone script. # https://django.readthedocs.io/en/latest/releases/1.7.html#standalone-scripts @@ -328,6 +328,7 @@ class FlaskHandler(BaseHandler): ramcache.check_for_distributed_invalidation() handler_data = self.get_template_data(*args, **kwargs) + users.refresh_user_session() if self.JSONIFY and type(handler_data) in (dict, list): headers = self.get_headers() diff --git a/framework/users.py b/framework/users.py index a53eb157..48b56701 100644 --- a/framework/users.py +++ b/framework/users.py @@ -228,3 +228,19 @@ def get_current_user(): def is_current_user_admin(): return False + + +def add_signed_user_info_to_session(email): + """Create and sign the user info in the Flask session.""" + user_info = { + 'email': email, + } + signature = xsrf.generate_token(str(user_info)) + session['signed_user_info'] = user_info, signature + + +def refresh_user_session(): + """If the user is signed in, update the signed user info with a new date.""" + user = get_current_user() + if user: + add_signed_user_info_to_session(user.email()) diff --git a/framework/users_test.py b/framework/users_test.py index 954475a5..a6acc88c 100644 --- a/framework/users_test.py +++ b/framework/users_test.py @@ -98,3 +98,17 @@ class UsersTest(testing_config.CustomTestCase): """We never consider a user an admin based on old GAE auth info.""" actual = users.is_current_user_admin() self.assertFalse(actual) + + def test_add_signed_user_info_to_session(self): + """We log in the user by adding a signed user_info to the session.""" + with test_app.test_request_context('/any/path'): + session.clear() + session['something else'] = 'some other aspect of the session' + users.add_signed_user_info_to_session('user@example.com') + self.assertEqual(2, len(session)) + user_info, signature = session['signed_user_info'] + self.assertEqual({'email': 'user@example.com'}, user_info) + xsrf.validate_token( + signature, + str(user_info), + timeout=xsrf.REFRESH_TOKEN_TIMEOUT_SEC) diff --git a/static/js-src/cs-client.js b/static/js-src/cs-client.js index e6660700..637e5a38 100644 --- a/static/js-src/cs-client.js +++ b/static/js-src/cs-client.js @@ -31,7 +31,7 @@ class ChromeStatusClient { const refreshResponse = await this.doFetch( '/currentuser/token', 'POST', null); this.token = refreshResponse.token; - this.tokenExpiresSec = refreshResponse.tokenExpiresSec; + this.tokenExpiresSec = refreshResponse.token_expires_sec; } }