Address broken connections during open sessions (#466)

* reconnect if broken

* update naming

* retry for broken connection during request

* change name of SCRIPTAS_REQUEST

* expand node broken connection

* reconnect during refresh if needed

* lint fix

* test fix init

* none checks for tests

* fix connection during errors

* return after script request exception

* use mock connection for server in test

* put back try catch block for node expansion

* validate error

* mock server

* mock server

* Revert "mock server"

This reverts commit 2cdeeea68f.

* Revert "put back try catch block for node expansion"

This reverts commit a87a34b52e.

* retry state

* add back exception in expand node base

* lint error
This commit is contained in:
nasc17 2023-08-24 18:50:45 -04:00 коммит произвёл GitHub
Родитель a7f10d8b76
Коммит c3f1582088
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
12 изменённых файлов: 81 добавлений и 42 удалений

Просмотреть файл

@ -108,7 +108,7 @@ class ConnectionService:
# Get the connection for the given type and build a response if it is present, otherwise open the connection
connection = connection_info.get_connection(params.type)
if connection is not None:
if connection is not None and not connection.connection.broken:
return _build_connection_response(connection_info, params.type)
# The connection doesn't exist yet. Cancel any ongoing connection and set up a cancellation token
@ -167,7 +167,7 @@ class ConnectionService:
if connection_info is None:
raise ValueError('No connection associated with given owner URI')
if not connection_info.has_connection(connection_type):
if not connection_info.has_connection(connection_type) or not connection_info.get_connection(connection_type).open:
self.connect(ConnectRequestParams(connection_info.details, owner_uri, connection_type))
return connection_info.get_connection(connection_type)
@ -195,7 +195,7 @@ class ConnectionService:
"""Close a connection in response to an incoming disconnection request"""
request_context.send_response(self.disconnect(params.owner_uri, params.type))
def handle_list_databases(self, request_context: RequestContext, params: ListDatabasesParams):
def handle_list_databases(self, request_context: RequestContext, params: ListDatabasesParams, retry_state=False):
"""List all databases on the server that the given URI has a connection to"""
connection = None
try:
@ -209,9 +209,13 @@ class ConnectionService:
query_results = connection.list_databases()
except Exception as err:
if self._service_provider is not None and self._service_provider.logger is not None:
self._service_provider.logger.exception('Error listing databases')
request_context.send_error(str(err))
if connection is not None and connection.connection.broken and not retry_state:
self._service_provider.logger.warn('Server closed the connection unexpectedly. Attempting to reconnect...')
self.handle_list_databases(request_context, params, True)
else:
if self._service_provider is not None and self._service_provider.logger is not None:
self._service_provider.logger.exception('Error listing databases')
request_context.send_error(str(err))
return
database_names = [result[0] for result in query_results]

Просмотреть файл

@ -187,8 +187,8 @@ class ObjectExplorerService(object):
if session is None:
return
# Step 2: Start a task for expanding the node
try:
# Step 2: Start a task for expanding the node
key = params.node_path
if is_refresh:
task = session.refresh_tasks.get(key)
@ -206,18 +206,26 @@ class ObjectExplorerService(object):
session.refresh_tasks[key] = new_task
else:
session.expand_tasks[key] = new_task
except Exception as e:
self._expand_node_error(request_context, params, str(e))
def _expand_node_thread(self, is_refresh: bool, request_context: RequestContext, params: ExpandParameters, session: ObjectExplorerSession):
def _expand_node_thread(self, is_refresh: bool, request_context: RequestContext,
params: ExpandParameters, session: ObjectExplorerSession, retry_state=False):
try:
response = ExpandCompletedParameters(session.id, params.node_path)
response.nodes = self._route_request(is_refresh, session, params.node_path)
request_context.send_notification(EXPAND_COMPLETED_METHOD, response)
except Exception as e:
self._expand_node_error(request_context, params, str(e))
except BaseException as e:
if session.server.connection is not None and session.server.connection.connection.broken and not retry_state:
conn_service = self._service_provider[utils.constants.CONNECTION_SERVICE_NAME]
connection = conn_service.get_connection(session.id, ConnectionType.OBJECT_EXLPORER)
session.server.set_connection(connection)
session.server.refresh()
self._expand_node_thread(is_refresh, request_context, params, session, True)
return
else:
self._expand_node_error(request_context, params, str(e))
def _expand_node_error(self, request_context: RequestContext, params: ExpandParameters, message: str):
if self._service_provider.logger is not None:
@ -248,7 +256,7 @@ class ObjectExplorerService(object):
request_context.send_response(True)
return session
except Exception as e:
message = f'Failed to expand node: {str(e)}' # TODO: Localize
message = f'Failed to expand node base: {str(e)}' # TODO: Localize
if self._service_provider.logger is not None:
self._service_provider.logger.error(message)
request_context.send_error(message)

Просмотреть файл

@ -117,7 +117,7 @@ class Query:
def current_batch_index(self) -> int:
return self._current_batch_index
def execute(self, connection: ServerConnection):
def execute(self, connection: ServerConnection, retry_state=False):
"""
Execute the query using the given connection
@ -126,7 +126,7 @@ class Query:
:param batch_end_callback: A function to run after executing each batch
:raises RuntimeError: If the query was already executed
"""
if self._execution_state is ExecutionState.EXECUTED:
if self._execution_state is ExecutionState.EXECUTED and not retry_state:
raise RuntimeError('Cannot execute a query multiple times')
self._execution_state = ExecutionState.EXECUTING

Просмотреть файл

@ -348,7 +348,7 @@ class QueryExecutionService(object):
except BaseException as e:
raise e
def _execute_query_request_worker(self, worker_args: ExecuteRequestWorkerArgs):
def _execute_query_request_worker(self, worker_args: ExecuteRequestWorkerArgs, retry_state=False):
"""Worker method for 'handle execute query request' thread"""
_check_and_fire(worker_args.before_query_initialize, {})
@ -357,9 +357,15 @@ class QueryExecutionService(object):
# Wrap execution in a try/except block so that we can send an error if it fails
try:
query.execute(worker_args.connection)
query.execute(worker_args.connection, retry_state)
except Exception as e:
self._resolve_query_exception(e, query, worker_args)
if not retry_state and worker_args.connection.connection.broken:
self._resolve_query_exception(e, query, worker_args, False, True)
conn = self._get_connection(worker_args.owner_uri, ConnectionType.QUERY)
worker_args.connection = conn
self._execute_query_request_worker(worker_args, True)
else:
self._resolve_query_exception(e, query, worker_args)
finally:
# Send a query complete notification
batch_summaries = [batch.batch_summary for batch in query.batches]
@ -417,11 +423,14 @@ class QueryExecutionService(object):
# Then params must be an instance of ExecuteStringParams, which has the query as an attribute
return params.query
def _resolve_query_exception(self, e: Exception, query: Query, worker_args: ExecuteRequestWorkerArgs, is_rollback_error=False):
def _resolve_query_exception(self, e: Exception, query: Query, worker_args: ExecuteRequestWorkerArgs, is_rollback_error=False, retry_query=False):
utils.log.log_debug(self._service_provider.logger, f'Query execution failed for following query: {query.query_text}\n {e}')
if retry_query:
error_message = 'Server closed the connection unexpectedly. Attempting to reconnect...'
# If the error relates to the database, display the appropriate error message based on the provider
if isinstance(e, worker_args.connection.database_error) or isinstance(e, worker_args.connection.query_canceled_error):
elif isinstance(e, worker_args.connection.database_error) or isinstance(e, worker_args.connection.query_canceled_error):
# get_error_message may return None so ensure error_message is str type
error_message = str(worker_args.connection.get_error_message(e))
@ -438,12 +447,13 @@ class QueryExecutionService(object):
error_message = 'Error while rolling back open transaction due to previous failure: ' + error_message # TODO: Localize
# Send a message with the error to the client
result_message_params = self.build_message_params(query.owner_uri, query.batches[query.current_batch_index].id, error_message, True)
is_error_notification = not retry_query
result_message_params = self.build_message_params(query.owner_uri, query.batches[query.current_batch_index].id, error_message, is_error_notification)
_check_and_fire(worker_args.on_message_notification, result_message_params)
# If there was a failure in the middle of a transaction, roll it back.
# Note that conn.rollback() won't work since the connection is in autocommit mode
if not is_rollback_error and worker_args.connection.transaction_in_error and not worker_args.connection.user_transaction:
if not is_rollback_error and not retry_query and worker_args.connection.transaction_in_error and not worker_args.connection.user_transaction:
rollback_query = Query(query.owner_uri, 'ROLLBACK', QueryExecutionSettings(ExecutionPlanOptions(), None), QueryEvents())
try:
rollback_query.execute(worker_args.connection)

Просмотреть файл

@ -3,10 +3,10 @@
# Licensed under the MIT License. See License.txt in the project root for license information.
# --------------------------------------------------------------------------------------------
from ossdbtoolsservice.scripting.contracts.scriptas_request import (
ScriptAsParameters, ScriptAsResponse, SCRIPTAS_REQUEST, ScriptOperation)
from ossdbtoolsservice.scripting.contracts.script_as_request import (
ScriptAsParameters, ScriptAsResponse, SCRIPT_AS_REQUEST, ScriptOperation)
__all__ = [
'ScriptAsParameters', 'ScriptAsResponse', 'SCRIPTAS_REQUEST',
'ScriptAsParameters', 'ScriptAsResponse', 'SCRIPT_AS_REQUEST',
'ScriptOperation'
]

Просмотреть файл

@ -41,4 +41,4 @@ class ScriptAsResponse(Serializable):
self.script: str = script
SCRIPTAS_REQUEST = IncomingMessageConfiguration('scripting/script', ScriptAsParameters)
SCRIPT_AS_REQUEST = IncomingMessageConfiguration('scripting/script', ScriptAsParameters)

Просмотреть файл

@ -9,7 +9,7 @@ from ossdbtoolsservice.hosting import RequestContext, ServiceProvider
from ossdbtoolsservice.metadata.contracts.object_metadata import ObjectMetadata
from ossdbtoolsservice.scripting.scripter import Scripter
from ossdbtoolsservice.scripting.contracts import (
ScriptAsParameters, ScriptAsResponse, SCRIPTAS_REQUEST
ScriptAsParameters, ScriptAsResponse, SCRIPT_AS_REQUEST
)
from ossdbtoolsservice.connection.contracts import ConnectionType
import ossdbtoolsservice.utils as utils
@ -25,7 +25,7 @@ class ScriptingService(object):
self._service_provider = service_provider
# Register the request handlers with the server
self._service_provider.server.set_request_handler(SCRIPTAS_REQUEST, self._handle_scriptas_request)
self._service_provider.server.set_request_handler(SCRIPT_AS_REQUEST, self._handle_script_as_request)
# Find the provider type
self._provider: str = self._service_provider.provider
@ -43,10 +43,14 @@ class ScriptingService(object):
return object_metadata
# REQUEST HANDLERS #####################################################
def _handle_scriptas_request(self, request_context: RequestContext, params: ScriptAsParameters) -> None:
def _handle_script_as_request(self, request_context: RequestContext, params: ScriptAsParameters, retry_state=False) -> None:
try:
utils.validate.is_not_none('params', params)
except Exception as e:
self._request_error(request_context, params, str(e))
return
try:
scripting_operation = params.operation
connection_service = self._service_provider[utils.constants.CONNECTION_SERVICE_NAME]
connection = connection_service.get_connection(params.owner_uri, ConnectionType.QUERY)
@ -57,6 +61,13 @@ class ScriptingService(object):
script = scripter.script(scripting_operation, object_metadata)
request_context.send_response(ScriptAsResponse(params.owner_uri, script))
except Exception as e:
if self._service_provider.logger is not None:
self._service_provider.logger.exception('Scripting operation failed')
request_context.send_error(str(e), params)
if connection is not None and connection.connection.broken and not retry_state:
self._service_provider.logger.warn('Server closed the connection unexpectedly. Attempting to reconnect...')
self._handle_script_as_request(request_context, params, True)
else:
self._request_error(request_context, params, str(e))
def _request_error(self, request_context: RequestContext, params: ScriptAsParameters, message: str):
if self._service_provider.logger is not None:
self._service_provider.logger.exception('Scripting operation failed')
request_context.send_error(message, params)

Просмотреть файл

@ -283,6 +283,10 @@ class Server:
}
return object_map[object_type.capitalize()](metadata)
def set_connection(self, conn: ServerConnection) -> ServerConnection:
"""Reset connection to the server/db that this object will use"""
self._conn = conn
# IMPLEMENTATION DETAILS ###############################################
def _fetch_recovery_state(self) -> Dict[str, Optional[bool]]:

Просмотреть файл

@ -104,6 +104,7 @@ class RequestFlowValidator:
'Expected additional messages: '
f'[{self._expected_messages[i].message_type}] '
f'{self._expected_messages[i].param_type}'
f'{self._expected_messages[i].message_method}'
)
received = self._received_messages[i]

Просмотреть файл

@ -195,7 +195,7 @@ class TestObjectExplorer(unittest.TestCase):
oe._provider = constants.PG_PROVIDER_NAME
# ... Patch the threading to throw
patch_mock = mock.MagicMock(side_effect=Exception('Boom!'))
patch_mock = mock.MagicMock(side_effect=Exception('Boom! Create Session Failed'))
patch_path = 'ossdbtoolsservice.object_explorer.object_explorer_service.threading.Thread'
with mock.patch(patch_path, patch_mock):
# If: I create a new session
@ -304,7 +304,7 @@ class TestObjectExplorer(unittest.TestCase):
# ... Create OE service with mock connection service that returns a failed connection response
cs = ConnectionService()
connect_response = ConnectionCompleteParams()
connect_response.error_message = 'Boom!'
connect_response.error_message = 'Boom! Init Session Failed'
cs.connect = mock.MagicMock(return_value=connect_response)
oe = ObjectExplorerService()
oe._service_provider = utils.get_mock_service_provider({constants.CONNECTION_SERVICE_NAME: cs})
@ -479,7 +479,7 @@ class TestObjectExplorer(unittest.TestCase):
oe, session, session_uri = self._preloaded_oe_service()
# ... Patch the threading to throw
patch_mock = mock.MagicMock(side_effect=Exception('Boom!'))
patch_mock = mock.MagicMock(side_effect=Exception('Boom! Thread Error Handling Failed'))
patch_path = 'ossdbtoolsservice.object_explorer.object_explorer_service.threading.Thread'
with mock.patch(patch_path, patch_mock):
# If: I expand a node (with threading that throws)
@ -502,11 +502,11 @@ class TestObjectExplorer(unittest.TestCase):
def _handle_er_exception_expanding(self, method: TEventHandler, get_tasks: TGetTask):
# Setup: Create an OE service with a session preloaded
oe, session, session_uri = self._preloaded_oe_service()
oe, session, session_uri = self._preloaded_oe_service(Server(MockPGServerConnection()))
# ... Patch the route_request to throw
# ... Patch the threading to throw
patch_mock = mock.MagicMock(side_effect=Exception('Boom!'))
patch_mock = mock.MagicMock(side_effect=Exception('Boom! Expand Error Handling Failed'))
patch_path = 'ossdbtoolsservice.object_explorer.object_explorer_service.ObjectExplorerService._route_request'
with mock.patch(patch_path, patch_mock):
# If: I expand a node (with route_request that throws)
@ -600,14 +600,14 @@ class TestObjectExplorer(unittest.TestCase):
testevent.set()
# IMPLEMENTATION DETAILS ###############################################
def _preloaded_oe_service(self) -> Tuple[ObjectExplorerService, ObjectExplorerSession, str]:
def _preloaded_oe_service(self, server=mock.Mock()) -> Tuple[ObjectExplorerService, ObjectExplorerSession, str]:
oe = ObjectExplorerService()
oe._service_provider = utils.get_mock_service_provider({})
oe._routing_table = PG_ROUTING_TABLE
conn_details, session_uri = _connection_details()
session = ObjectExplorerSession(session_uri, conn_details)
session.server = mock.Mock()
session.server = server
session.is_ready = True
oe._session_map[session_uri] = session

Просмотреть файл

@ -10,7 +10,7 @@ import tests.utils as utils
from ossdbtoolsservice.connection import ConnectionService
from ossdbtoolsservice.connection.contracts import ConnectionCompleteParams
from ossdbtoolsservice.hosting import JSONRPCServer, ServiceProvider
from ossdbtoolsservice.scripting.contracts.scriptas_request import (
from ossdbtoolsservice.scripting.contracts.script_as_request import (
ScriptAsParameters, ScriptAsResponse, ScriptOperation)
from ossdbtoolsservice.scripting.scripter import Scripter
from ossdbtoolsservice.scripting.scripting_service import ScriptingService
@ -63,7 +63,7 @@ class TestScriptingService(unittest.TestCase):
# If: I make a scripting request missing params
rc: RequestFlowValidator = RequestFlowValidator()
rc.add_expected_error(type(None), RequestFlowValidator.basic_error_validation)
ss._handle_scriptas_request(rc.request_context, None)
ss._handle_script_as_request(rc.request_context, None)
# Then:
# ... I should get an error response
@ -81,7 +81,7 @@ class TestScriptingService(unittest.TestCase):
# If: I create an OE session with missing params
rc: RequestFlowValidator = RequestFlowValidator()
rc.add_expected_error(type(None), RequestFlowValidator.basic_error_validation)
ss._handle_scriptas_request(rc.request_context, None)
ss._handle_script_as_request(rc.request_context, None)
# Then:
# ... I should get an error response
@ -129,7 +129,7 @@ class TestScriptingService(unittest.TestCase):
'scripting_objects': [scripting_object]
})
ss._handle_scriptas_request(rc.request_context, params)
ss._handle_script_as_request(rc.request_context, params)
# Then:
# ... The request should have been handled correctly

Просмотреть файл

@ -119,6 +119,7 @@ class MockPsycopgConnection(object):
self.commit = mock.Mock()
self.pgconn = mock.Mock()
self.info = MockConnectionInfo(dsn_parameters, self.server_version)
self.broken = False
self._adapters: Optional[AdaptersMap] = mock.Mock()
self.notice_handlers: List[NoticeHandler] = []