Address broken connections during open sessions (#466)
* reconnect if broken * update naming * retry for broken connection during request * change name of SCRIPTAS_REQUEST * expand node broken connection * reconnect during refresh if needed * lint fix * test fix init * none checks for tests * fix connection during errors * return after script request exception * use mock connection for server in test * put back try catch block for node expansion * validate error * mock server * mock server * Revert "mock server" This reverts commit2cdeeea68f
. * Revert "put back try catch block for node expansion" This reverts commita87a34b52e
. * retry state * add back exception in expand node base * lint error
This commit is contained in:
Родитель
a7f10d8b76
Коммит
c3f1582088
|
@ -108,7 +108,7 @@ class ConnectionService:
|
|||
|
||||
# Get the connection for the given type and build a response if it is present, otherwise open the connection
|
||||
connection = connection_info.get_connection(params.type)
|
||||
if connection is not None:
|
||||
if connection is not None and not connection.connection.broken:
|
||||
return _build_connection_response(connection_info, params.type)
|
||||
|
||||
# The connection doesn't exist yet. Cancel any ongoing connection and set up a cancellation token
|
||||
|
@ -167,7 +167,7 @@ class ConnectionService:
|
|||
if connection_info is None:
|
||||
raise ValueError('No connection associated with given owner URI')
|
||||
|
||||
if not connection_info.has_connection(connection_type):
|
||||
if not connection_info.has_connection(connection_type) or not connection_info.get_connection(connection_type).open:
|
||||
self.connect(ConnectRequestParams(connection_info.details, owner_uri, connection_type))
|
||||
return connection_info.get_connection(connection_type)
|
||||
|
||||
|
@ -195,7 +195,7 @@ class ConnectionService:
|
|||
"""Close a connection in response to an incoming disconnection request"""
|
||||
request_context.send_response(self.disconnect(params.owner_uri, params.type))
|
||||
|
||||
def handle_list_databases(self, request_context: RequestContext, params: ListDatabasesParams):
|
||||
def handle_list_databases(self, request_context: RequestContext, params: ListDatabasesParams, retry_state=False):
|
||||
"""List all databases on the server that the given URI has a connection to"""
|
||||
connection = None
|
||||
try:
|
||||
|
@ -209,9 +209,13 @@ class ConnectionService:
|
|||
query_results = connection.list_databases()
|
||||
|
||||
except Exception as err:
|
||||
if self._service_provider is not None and self._service_provider.logger is not None:
|
||||
self._service_provider.logger.exception('Error listing databases')
|
||||
request_context.send_error(str(err))
|
||||
if connection is not None and connection.connection.broken and not retry_state:
|
||||
self._service_provider.logger.warn('Server closed the connection unexpectedly. Attempting to reconnect...')
|
||||
self.handle_list_databases(request_context, params, True)
|
||||
else:
|
||||
if self._service_provider is not None and self._service_provider.logger is not None:
|
||||
self._service_provider.logger.exception('Error listing databases')
|
||||
request_context.send_error(str(err))
|
||||
return
|
||||
|
||||
database_names = [result[0] for result in query_results]
|
||||
|
|
|
@ -187,8 +187,8 @@ class ObjectExplorerService(object):
|
|||
if session is None:
|
||||
return
|
||||
|
||||
# Step 2: Start a task for expanding the node
|
||||
try:
|
||||
# Step 2: Start a task for expanding the node
|
||||
key = params.node_path
|
||||
if is_refresh:
|
||||
task = session.refresh_tasks.get(key)
|
||||
|
@ -206,18 +206,26 @@ class ObjectExplorerService(object):
|
|||
session.refresh_tasks[key] = new_task
|
||||
else:
|
||||
session.expand_tasks[key] = new_task
|
||||
|
||||
except Exception as e:
|
||||
self._expand_node_error(request_context, params, str(e))
|
||||
|
||||
def _expand_node_thread(self, is_refresh: bool, request_context: RequestContext, params: ExpandParameters, session: ObjectExplorerSession):
|
||||
def _expand_node_thread(self, is_refresh: bool, request_context: RequestContext,
|
||||
params: ExpandParameters, session: ObjectExplorerSession, retry_state=False):
|
||||
try:
|
||||
response = ExpandCompletedParameters(session.id, params.node_path)
|
||||
response.nodes = self._route_request(is_refresh, session, params.node_path)
|
||||
|
||||
request_context.send_notification(EXPAND_COMPLETED_METHOD, response)
|
||||
except Exception as e:
|
||||
self._expand_node_error(request_context, params, str(e))
|
||||
except BaseException as e:
|
||||
if session.server.connection is not None and session.server.connection.connection.broken and not retry_state:
|
||||
conn_service = self._service_provider[utils.constants.CONNECTION_SERVICE_NAME]
|
||||
connection = conn_service.get_connection(session.id, ConnectionType.OBJECT_EXLPORER)
|
||||
session.server.set_connection(connection)
|
||||
session.server.refresh()
|
||||
self._expand_node_thread(is_refresh, request_context, params, session, True)
|
||||
return
|
||||
else:
|
||||
self._expand_node_error(request_context, params, str(e))
|
||||
|
||||
def _expand_node_error(self, request_context: RequestContext, params: ExpandParameters, message: str):
|
||||
if self._service_provider.logger is not None:
|
||||
|
@ -248,7 +256,7 @@ class ObjectExplorerService(object):
|
|||
request_context.send_response(True)
|
||||
return session
|
||||
except Exception as e:
|
||||
message = f'Failed to expand node: {str(e)}' # TODO: Localize
|
||||
message = f'Failed to expand node base: {str(e)}' # TODO: Localize
|
||||
if self._service_provider.logger is not None:
|
||||
self._service_provider.logger.error(message)
|
||||
request_context.send_error(message)
|
||||
|
|
|
@ -117,7 +117,7 @@ class Query:
|
|||
def current_batch_index(self) -> int:
|
||||
return self._current_batch_index
|
||||
|
||||
def execute(self, connection: ServerConnection):
|
||||
def execute(self, connection: ServerConnection, retry_state=False):
|
||||
"""
|
||||
Execute the query using the given connection
|
||||
|
||||
|
@ -126,7 +126,7 @@ class Query:
|
|||
:param batch_end_callback: A function to run after executing each batch
|
||||
:raises RuntimeError: If the query was already executed
|
||||
"""
|
||||
if self._execution_state is ExecutionState.EXECUTED:
|
||||
if self._execution_state is ExecutionState.EXECUTED and not retry_state:
|
||||
raise RuntimeError('Cannot execute a query multiple times')
|
||||
|
||||
self._execution_state = ExecutionState.EXECUTING
|
||||
|
|
|
@ -348,7 +348,7 @@ class QueryExecutionService(object):
|
|||
except BaseException as e:
|
||||
raise e
|
||||
|
||||
def _execute_query_request_worker(self, worker_args: ExecuteRequestWorkerArgs):
|
||||
def _execute_query_request_worker(self, worker_args: ExecuteRequestWorkerArgs, retry_state=False):
|
||||
"""Worker method for 'handle execute query request' thread"""
|
||||
|
||||
_check_and_fire(worker_args.before_query_initialize, {})
|
||||
|
@ -357,9 +357,15 @@ class QueryExecutionService(object):
|
|||
|
||||
# Wrap execution in a try/except block so that we can send an error if it fails
|
||||
try:
|
||||
query.execute(worker_args.connection)
|
||||
query.execute(worker_args.connection, retry_state)
|
||||
except Exception as e:
|
||||
self._resolve_query_exception(e, query, worker_args)
|
||||
if not retry_state and worker_args.connection.connection.broken:
|
||||
self._resolve_query_exception(e, query, worker_args, False, True)
|
||||
conn = self._get_connection(worker_args.owner_uri, ConnectionType.QUERY)
|
||||
worker_args.connection = conn
|
||||
self._execute_query_request_worker(worker_args, True)
|
||||
else:
|
||||
self._resolve_query_exception(e, query, worker_args)
|
||||
finally:
|
||||
# Send a query complete notification
|
||||
batch_summaries = [batch.batch_summary for batch in query.batches]
|
||||
|
@ -417,11 +423,14 @@ class QueryExecutionService(object):
|
|||
# Then params must be an instance of ExecuteStringParams, which has the query as an attribute
|
||||
return params.query
|
||||
|
||||
def _resolve_query_exception(self, e: Exception, query: Query, worker_args: ExecuteRequestWorkerArgs, is_rollback_error=False):
|
||||
def _resolve_query_exception(self, e: Exception, query: Query, worker_args: ExecuteRequestWorkerArgs, is_rollback_error=False, retry_query=False):
|
||||
utils.log.log_debug(self._service_provider.logger, f'Query execution failed for following query: {query.query_text}\n {e}')
|
||||
|
||||
if retry_query:
|
||||
error_message = 'Server closed the connection unexpectedly. Attempting to reconnect...'
|
||||
|
||||
# If the error relates to the database, display the appropriate error message based on the provider
|
||||
if isinstance(e, worker_args.connection.database_error) or isinstance(e, worker_args.connection.query_canceled_error):
|
||||
elif isinstance(e, worker_args.connection.database_error) or isinstance(e, worker_args.connection.query_canceled_error):
|
||||
# get_error_message may return None so ensure error_message is str type
|
||||
error_message = str(worker_args.connection.get_error_message(e))
|
||||
|
||||
|
@ -438,12 +447,13 @@ class QueryExecutionService(object):
|
|||
error_message = 'Error while rolling back open transaction due to previous failure: ' + error_message # TODO: Localize
|
||||
|
||||
# Send a message with the error to the client
|
||||
result_message_params = self.build_message_params(query.owner_uri, query.batches[query.current_batch_index].id, error_message, True)
|
||||
is_error_notification = not retry_query
|
||||
result_message_params = self.build_message_params(query.owner_uri, query.batches[query.current_batch_index].id, error_message, is_error_notification)
|
||||
_check_and_fire(worker_args.on_message_notification, result_message_params)
|
||||
|
||||
# If there was a failure in the middle of a transaction, roll it back.
|
||||
# Note that conn.rollback() won't work since the connection is in autocommit mode
|
||||
if not is_rollback_error and worker_args.connection.transaction_in_error and not worker_args.connection.user_transaction:
|
||||
if not is_rollback_error and not retry_query and worker_args.connection.transaction_in_error and not worker_args.connection.user_transaction:
|
||||
rollback_query = Query(query.owner_uri, 'ROLLBACK', QueryExecutionSettings(ExecutionPlanOptions(), None), QueryEvents())
|
||||
try:
|
||||
rollback_query.execute(worker_args.connection)
|
||||
|
|
|
@ -3,10 +3,10 @@
|
|||
# Licensed under the MIT License. See License.txt in the project root for license information.
|
||||
# --------------------------------------------------------------------------------------------
|
||||
|
||||
from ossdbtoolsservice.scripting.contracts.scriptas_request import (
|
||||
ScriptAsParameters, ScriptAsResponse, SCRIPTAS_REQUEST, ScriptOperation)
|
||||
from ossdbtoolsservice.scripting.contracts.script_as_request import (
|
||||
ScriptAsParameters, ScriptAsResponse, SCRIPT_AS_REQUEST, ScriptOperation)
|
||||
|
||||
__all__ = [
|
||||
'ScriptAsParameters', 'ScriptAsResponse', 'SCRIPTAS_REQUEST',
|
||||
'ScriptAsParameters', 'ScriptAsResponse', 'SCRIPT_AS_REQUEST',
|
||||
'ScriptOperation'
|
||||
]
|
||||
|
|
|
@ -41,4 +41,4 @@ class ScriptAsResponse(Serializable):
|
|||
self.script: str = script
|
||||
|
||||
|
||||
SCRIPTAS_REQUEST = IncomingMessageConfiguration('scripting/script', ScriptAsParameters)
|
||||
SCRIPT_AS_REQUEST = IncomingMessageConfiguration('scripting/script', ScriptAsParameters)
|
|
@ -9,7 +9,7 @@ from ossdbtoolsservice.hosting import RequestContext, ServiceProvider
|
|||
from ossdbtoolsservice.metadata.contracts.object_metadata import ObjectMetadata
|
||||
from ossdbtoolsservice.scripting.scripter import Scripter
|
||||
from ossdbtoolsservice.scripting.contracts import (
|
||||
ScriptAsParameters, ScriptAsResponse, SCRIPTAS_REQUEST
|
||||
ScriptAsParameters, ScriptAsResponse, SCRIPT_AS_REQUEST
|
||||
)
|
||||
from ossdbtoolsservice.connection.contracts import ConnectionType
|
||||
import ossdbtoolsservice.utils as utils
|
||||
|
@ -25,7 +25,7 @@ class ScriptingService(object):
|
|||
self._service_provider = service_provider
|
||||
|
||||
# Register the request handlers with the server
|
||||
self._service_provider.server.set_request_handler(SCRIPTAS_REQUEST, self._handle_scriptas_request)
|
||||
self._service_provider.server.set_request_handler(SCRIPT_AS_REQUEST, self._handle_script_as_request)
|
||||
|
||||
# Find the provider type
|
||||
self._provider: str = self._service_provider.provider
|
||||
|
@ -43,10 +43,14 @@ class ScriptingService(object):
|
|||
return object_metadata
|
||||
|
||||
# REQUEST HANDLERS #####################################################
|
||||
def _handle_scriptas_request(self, request_context: RequestContext, params: ScriptAsParameters) -> None:
|
||||
def _handle_script_as_request(self, request_context: RequestContext, params: ScriptAsParameters, retry_state=False) -> None:
|
||||
try:
|
||||
utils.validate.is_not_none('params', params)
|
||||
except Exception as e:
|
||||
self._request_error(request_context, params, str(e))
|
||||
return
|
||||
|
||||
try:
|
||||
scripting_operation = params.operation
|
||||
connection_service = self._service_provider[utils.constants.CONNECTION_SERVICE_NAME]
|
||||
connection = connection_service.get_connection(params.owner_uri, ConnectionType.QUERY)
|
||||
|
@ -57,6 +61,13 @@ class ScriptingService(object):
|
|||
script = scripter.script(scripting_operation, object_metadata)
|
||||
request_context.send_response(ScriptAsResponse(params.owner_uri, script))
|
||||
except Exception as e:
|
||||
if self._service_provider.logger is not None:
|
||||
self._service_provider.logger.exception('Scripting operation failed')
|
||||
request_context.send_error(str(e), params)
|
||||
if connection is not None and connection.connection.broken and not retry_state:
|
||||
self._service_provider.logger.warn('Server closed the connection unexpectedly. Attempting to reconnect...')
|
||||
self._handle_script_as_request(request_context, params, True)
|
||||
else:
|
||||
self._request_error(request_context, params, str(e))
|
||||
|
||||
def _request_error(self, request_context: RequestContext, params: ScriptAsParameters, message: str):
|
||||
if self._service_provider.logger is not None:
|
||||
self._service_provider.logger.exception('Scripting operation failed')
|
||||
request_context.send_error(message, params)
|
||||
|
|
|
@ -283,6 +283,10 @@ class Server:
|
|||
}
|
||||
return object_map[object_type.capitalize()](metadata)
|
||||
|
||||
def set_connection(self, conn: ServerConnection) -> ServerConnection:
|
||||
"""Reset connection to the server/db that this object will use"""
|
||||
self._conn = conn
|
||||
|
||||
# IMPLEMENTATION DETAILS ###############################################
|
||||
|
||||
def _fetch_recovery_state(self) -> Dict[str, Optional[bool]]:
|
||||
|
|
|
@ -104,6 +104,7 @@ class RequestFlowValidator:
|
|||
'Expected additional messages: '
|
||||
f'[{self._expected_messages[i].message_type}] '
|
||||
f'{self._expected_messages[i].param_type}'
|
||||
f'{self._expected_messages[i].message_method}'
|
||||
)
|
||||
received = self._received_messages[i]
|
||||
|
||||
|
|
|
@ -195,7 +195,7 @@ class TestObjectExplorer(unittest.TestCase):
|
|||
oe._provider = constants.PG_PROVIDER_NAME
|
||||
|
||||
# ... Patch the threading to throw
|
||||
patch_mock = mock.MagicMock(side_effect=Exception('Boom!'))
|
||||
patch_mock = mock.MagicMock(side_effect=Exception('Boom! Create Session Failed'))
|
||||
patch_path = 'ossdbtoolsservice.object_explorer.object_explorer_service.threading.Thread'
|
||||
with mock.patch(patch_path, patch_mock):
|
||||
# If: I create a new session
|
||||
|
@ -304,7 +304,7 @@ class TestObjectExplorer(unittest.TestCase):
|
|||
# ... Create OE service with mock connection service that returns a failed connection response
|
||||
cs = ConnectionService()
|
||||
connect_response = ConnectionCompleteParams()
|
||||
connect_response.error_message = 'Boom!'
|
||||
connect_response.error_message = 'Boom! Init Session Failed'
|
||||
cs.connect = mock.MagicMock(return_value=connect_response)
|
||||
oe = ObjectExplorerService()
|
||||
oe._service_provider = utils.get_mock_service_provider({constants.CONNECTION_SERVICE_NAME: cs})
|
||||
|
@ -479,7 +479,7 @@ class TestObjectExplorer(unittest.TestCase):
|
|||
oe, session, session_uri = self._preloaded_oe_service()
|
||||
|
||||
# ... Patch the threading to throw
|
||||
patch_mock = mock.MagicMock(side_effect=Exception('Boom!'))
|
||||
patch_mock = mock.MagicMock(side_effect=Exception('Boom! Thread Error Handling Failed'))
|
||||
patch_path = 'ossdbtoolsservice.object_explorer.object_explorer_service.threading.Thread'
|
||||
with mock.patch(patch_path, patch_mock):
|
||||
# If: I expand a node (with threading that throws)
|
||||
|
@ -502,11 +502,11 @@ class TestObjectExplorer(unittest.TestCase):
|
|||
|
||||
def _handle_er_exception_expanding(self, method: TEventHandler, get_tasks: TGetTask):
|
||||
# Setup: Create an OE service with a session preloaded
|
||||
oe, session, session_uri = self._preloaded_oe_service()
|
||||
oe, session, session_uri = self._preloaded_oe_service(Server(MockPGServerConnection()))
|
||||
|
||||
# ... Patch the route_request to throw
|
||||
# ... Patch the threading to throw
|
||||
patch_mock = mock.MagicMock(side_effect=Exception('Boom!'))
|
||||
patch_mock = mock.MagicMock(side_effect=Exception('Boom! Expand Error Handling Failed'))
|
||||
patch_path = 'ossdbtoolsservice.object_explorer.object_explorer_service.ObjectExplorerService._route_request'
|
||||
with mock.patch(patch_path, patch_mock):
|
||||
# If: I expand a node (with route_request that throws)
|
||||
|
@ -600,14 +600,14 @@ class TestObjectExplorer(unittest.TestCase):
|
|||
testevent.set()
|
||||
|
||||
# IMPLEMENTATION DETAILS ###############################################
|
||||
def _preloaded_oe_service(self) -> Tuple[ObjectExplorerService, ObjectExplorerSession, str]:
|
||||
def _preloaded_oe_service(self, server=mock.Mock()) -> Tuple[ObjectExplorerService, ObjectExplorerSession, str]:
|
||||
oe = ObjectExplorerService()
|
||||
oe._service_provider = utils.get_mock_service_provider({})
|
||||
oe._routing_table = PG_ROUTING_TABLE
|
||||
|
||||
conn_details, session_uri = _connection_details()
|
||||
session = ObjectExplorerSession(session_uri, conn_details)
|
||||
session.server = mock.Mock()
|
||||
session.server = server
|
||||
session.is_ready = True
|
||||
oe._session_map[session_uri] = session
|
||||
|
||||
|
|
|
@ -10,7 +10,7 @@ import tests.utils as utils
|
|||
from ossdbtoolsservice.connection import ConnectionService
|
||||
from ossdbtoolsservice.connection.contracts import ConnectionCompleteParams
|
||||
from ossdbtoolsservice.hosting import JSONRPCServer, ServiceProvider
|
||||
from ossdbtoolsservice.scripting.contracts.scriptas_request import (
|
||||
from ossdbtoolsservice.scripting.contracts.script_as_request import (
|
||||
ScriptAsParameters, ScriptAsResponse, ScriptOperation)
|
||||
from ossdbtoolsservice.scripting.scripter import Scripter
|
||||
from ossdbtoolsservice.scripting.scripting_service import ScriptingService
|
||||
|
@ -63,7 +63,7 @@ class TestScriptingService(unittest.TestCase):
|
|||
# If: I make a scripting request missing params
|
||||
rc: RequestFlowValidator = RequestFlowValidator()
|
||||
rc.add_expected_error(type(None), RequestFlowValidator.basic_error_validation)
|
||||
ss._handle_scriptas_request(rc.request_context, None)
|
||||
ss._handle_script_as_request(rc.request_context, None)
|
||||
|
||||
# Then:
|
||||
# ... I should get an error response
|
||||
|
@ -81,7 +81,7 @@ class TestScriptingService(unittest.TestCase):
|
|||
# If: I create an OE session with missing params
|
||||
rc: RequestFlowValidator = RequestFlowValidator()
|
||||
rc.add_expected_error(type(None), RequestFlowValidator.basic_error_validation)
|
||||
ss._handle_scriptas_request(rc.request_context, None)
|
||||
ss._handle_script_as_request(rc.request_context, None)
|
||||
|
||||
# Then:
|
||||
# ... I should get an error response
|
||||
|
@ -129,7 +129,7 @@ class TestScriptingService(unittest.TestCase):
|
|||
'scripting_objects': [scripting_object]
|
||||
})
|
||||
|
||||
ss._handle_scriptas_request(rc.request_context, params)
|
||||
ss._handle_script_as_request(rc.request_context, params)
|
||||
|
||||
# Then:
|
||||
# ... The request should have been handled correctly
|
||||
|
|
|
@ -119,6 +119,7 @@ class MockPsycopgConnection(object):
|
|||
self.commit = mock.Mock()
|
||||
self.pgconn = mock.Mock()
|
||||
self.info = MockConnectionInfo(dsn_parameters, self.server_version)
|
||||
self.broken = False
|
||||
|
||||
self._adapters: Optional[AdaptersMap] = mock.Mock()
|
||||
self.notice_handlers: List[NoticeHandler] = []
|
||||
|
|
Загрузка…
Ссылка в новой задаче