diff --git a/ossdbtoolsservice/connection/connection_service.py b/ossdbtoolsservice/connection/connection_service.py index eeab09e9..a72adfb7 100644 --- a/ossdbtoolsservice/connection/connection_service.py +++ b/ossdbtoolsservice/connection/connection_service.py @@ -108,7 +108,7 @@ class ConnectionService: # Get the connection for the given type and build a response if it is present, otherwise open the connection connection = connection_info.get_connection(params.type) - if connection is not None: + if connection is not None and not connection.connection.broken: return _build_connection_response(connection_info, params.type) # The connection doesn't exist yet. Cancel any ongoing connection and set up a cancellation token @@ -167,7 +167,7 @@ class ConnectionService: if connection_info is None: raise ValueError('No connection associated with given owner URI') - if not connection_info.has_connection(connection_type): + if not connection_info.has_connection(connection_type) or not connection_info.get_connection(connection_type).open: self.connect(ConnectRequestParams(connection_info.details, owner_uri, connection_type)) return connection_info.get_connection(connection_type) @@ -195,7 +195,7 @@ class ConnectionService: """Close a connection in response to an incoming disconnection request""" request_context.send_response(self.disconnect(params.owner_uri, params.type)) - def handle_list_databases(self, request_context: RequestContext, params: ListDatabasesParams): + def handle_list_databases(self, request_context: RequestContext, params: ListDatabasesParams, retry_state=False): """List all databases on the server that the given URI has a connection to""" connection = None try: @@ -209,9 +209,13 @@ class ConnectionService: query_results = connection.list_databases() except Exception as err: - if self._service_provider is not None and self._service_provider.logger is not None: - self._service_provider.logger.exception('Error listing databases') - request_context.send_error(str(err)) + if connection is not None and connection.connection.broken and not retry_state: + self._service_provider.logger.warn('Server closed the connection unexpectedly. Attempting to reconnect...') + self.handle_list_databases(request_context, params, True) + else: + if self._service_provider is not None and self._service_provider.logger is not None: + self._service_provider.logger.exception('Error listing databases') + request_context.send_error(str(err)) return database_names = [result[0] for result in query_results] diff --git a/ossdbtoolsservice/object_explorer/object_explorer_service.py b/ossdbtoolsservice/object_explorer/object_explorer_service.py index 3679d8f6..17e44ae0 100644 --- a/ossdbtoolsservice/object_explorer/object_explorer_service.py +++ b/ossdbtoolsservice/object_explorer/object_explorer_service.py @@ -187,8 +187,8 @@ class ObjectExplorerService(object): if session is None: return - # Step 2: Start a task for expanding the node try: + # Step 2: Start a task for expanding the node key = params.node_path if is_refresh: task = session.refresh_tasks.get(key) @@ -206,18 +206,26 @@ class ObjectExplorerService(object): session.refresh_tasks[key] = new_task else: session.expand_tasks[key] = new_task - except Exception as e: self._expand_node_error(request_context, params, str(e)) - def _expand_node_thread(self, is_refresh: bool, request_context: RequestContext, params: ExpandParameters, session: ObjectExplorerSession): + def _expand_node_thread(self, is_refresh: bool, request_context: RequestContext, + params: ExpandParameters, session: ObjectExplorerSession, retry_state=False): try: response = ExpandCompletedParameters(session.id, params.node_path) response.nodes = self._route_request(is_refresh, session, params.node_path) request_context.send_notification(EXPAND_COMPLETED_METHOD, response) - except Exception as e: - self._expand_node_error(request_context, params, str(e)) + except BaseException as e: + if session.server.connection is not None and session.server.connection.connection.broken and not retry_state: + conn_service = self._service_provider[utils.constants.CONNECTION_SERVICE_NAME] + connection = conn_service.get_connection(session.id, ConnectionType.OBJECT_EXLPORER) + session.server.set_connection(connection) + session.server.refresh() + self._expand_node_thread(is_refresh, request_context, params, session, True) + return + else: + self._expand_node_error(request_context, params, str(e)) def _expand_node_error(self, request_context: RequestContext, params: ExpandParameters, message: str): if self._service_provider.logger is not None: @@ -248,7 +256,7 @@ class ObjectExplorerService(object): request_context.send_response(True) return session except Exception as e: - message = f'Failed to expand node: {str(e)}' # TODO: Localize + message = f'Failed to expand node base: {str(e)}' # TODO: Localize if self._service_provider.logger is not None: self._service_provider.logger.error(message) request_context.send_error(message) diff --git a/ossdbtoolsservice/query/query.py b/ossdbtoolsservice/query/query.py index d4106628..941ac0ba 100644 --- a/ossdbtoolsservice/query/query.py +++ b/ossdbtoolsservice/query/query.py @@ -117,7 +117,7 @@ class Query: def current_batch_index(self) -> int: return self._current_batch_index - def execute(self, connection: ServerConnection): + def execute(self, connection: ServerConnection, retry_state=False): """ Execute the query using the given connection @@ -126,7 +126,7 @@ class Query: :param batch_end_callback: A function to run after executing each batch :raises RuntimeError: If the query was already executed """ - if self._execution_state is ExecutionState.EXECUTED: + if self._execution_state is ExecutionState.EXECUTED and not retry_state: raise RuntimeError('Cannot execute a query multiple times') self._execution_state = ExecutionState.EXECUTING diff --git a/ossdbtoolsservice/query_execution/query_execution_service.py b/ossdbtoolsservice/query_execution/query_execution_service.py index 6bd447cb..502962ae 100644 --- a/ossdbtoolsservice/query_execution/query_execution_service.py +++ b/ossdbtoolsservice/query_execution/query_execution_service.py @@ -348,7 +348,7 @@ class QueryExecutionService(object): except BaseException as e: raise e - def _execute_query_request_worker(self, worker_args: ExecuteRequestWorkerArgs): + def _execute_query_request_worker(self, worker_args: ExecuteRequestWorkerArgs, retry_state=False): """Worker method for 'handle execute query request' thread""" _check_and_fire(worker_args.before_query_initialize, {}) @@ -357,9 +357,15 @@ class QueryExecutionService(object): # Wrap execution in a try/except block so that we can send an error if it fails try: - query.execute(worker_args.connection) + query.execute(worker_args.connection, retry_state) except Exception as e: - self._resolve_query_exception(e, query, worker_args) + if not retry_state and worker_args.connection.connection.broken: + self._resolve_query_exception(e, query, worker_args, False, True) + conn = self._get_connection(worker_args.owner_uri, ConnectionType.QUERY) + worker_args.connection = conn + self._execute_query_request_worker(worker_args, True) + else: + self._resolve_query_exception(e, query, worker_args) finally: # Send a query complete notification batch_summaries = [batch.batch_summary for batch in query.batches] @@ -417,11 +423,14 @@ class QueryExecutionService(object): # Then params must be an instance of ExecuteStringParams, which has the query as an attribute return params.query - def _resolve_query_exception(self, e: Exception, query: Query, worker_args: ExecuteRequestWorkerArgs, is_rollback_error=False): + def _resolve_query_exception(self, e: Exception, query: Query, worker_args: ExecuteRequestWorkerArgs, is_rollback_error=False, retry_query=False): utils.log.log_debug(self._service_provider.logger, f'Query execution failed for following query: {query.query_text}\n {e}') + if retry_query: + error_message = 'Server closed the connection unexpectedly. Attempting to reconnect...' + # If the error relates to the database, display the appropriate error message based on the provider - if isinstance(e, worker_args.connection.database_error) or isinstance(e, worker_args.connection.query_canceled_error): + elif isinstance(e, worker_args.connection.database_error) or isinstance(e, worker_args.connection.query_canceled_error): # get_error_message may return None so ensure error_message is str type error_message = str(worker_args.connection.get_error_message(e)) @@ -438,12 +447,13 @@ class QueryExecutionService(object): error_message = 'Error while rolling back open transaction due to previous failure: ' + error_message # TODO: Localize # Send a message with the error to the client - result_message_params = self.build_message_params(query.owner_uri, query.batches[query.current_batch_index].id, error_message, True) + is_error_notification = not retry_query + result_message_params = self.build_message_params(query.owner_uri, query.batches[query.current_batch_index].id, error_message, is_error_notification) _check_and_fire(worker_args.on_message_notification, result_message_params) # If there was a failure in the middle of a transaction, roll it back. # Note that conn.rollback() won't work since the connection is in autocommit mode - if not is_rollback_error and worker_args.connection.transaction_in_error and not worker_args.connection.user_transaction: + if not is_rollback_error and not retry_query and worker_args.connection.transaction_in_error and not worker_args.connection.user_transaction: rollback_query = Query(query.owner_uri, 'ROLLBACK', QueryExecutionSettings(ExecutionPlanOptions(), None), QueryEvents()) try: rollback_query.execute(worker_args.connection) diff --git a/ossdbtoolsservice/scripting/contracts/__init__.py b/ossdbtoolsservice/scripting/contracts/__init__.py index 31a5a17a..686ab90b 100644 --- a/ossdbtoolsservice/scripting/contracts/__init__.py +++ b/ossdbtoolsservice/scripting/contracts/__init__.py @@ -3,10 +3,10 @@ # Licensed under the MIT License. See License.txt in the project root for license information. # -------------------------------------------------------------------------------------------- -from ossdbtoolsservice.scripting.contracts.scriptas_request import ( - ScriptAsParameters, ScriptAsResponse, SCRIPTAS_REQUEST, ScriptOperation) +from ossdbtoolsservice.scripting.contracts.script_as_request import ( + ScriptAsParameters, ScriptAsResponse, SCRIPT_AS_REQUEST, ScriptOperation) __all__ = [ - 'ScriptAsParameters', 'ScriptAsResponse', 'SCRIPTAS_REQUEST', + 'ScriptAsParameters', 'ScriptAsResponse', 'SCRIPT_AS_REQUEST', 'ScriptOperation' ] diff --git a/ossdbtoolsservice/scripting/contracts/scriptas_request.py b/ossdbtoolsservice/scripting/contracts/script_as_request.py similarity index 94% rename from ossdbtoolsservice/scripting/contracts/scriptas_request.py rename to ossdbtoolsservice/scripting/contracts/script_as_request.py index 9eb78294..b884c51a 100644 --- a/ossdbtoolsservice/scripting/contracts/scriptas_request.py +++ b/ossdbtoolsservice/scripting/contracts/script_as_request.py @@ -41,4 +41,4 @@ class ScriptAsResponse(Serializable): self.script: str = script -SCRIPTAS_REQUEST = IncomingMessageConfiguration('scripting/script', ScriptAsParameters) +SCRIPT_AS_REQUEST = IncomingMessageConfiguration('scripting/script', ScriptAsParameters) diff --git a/ossdbtoolsservice/scripting/scripting_service.py b/ossdbtoolsservice/scripting/scripting_service.py index 1f6a812e..a42c6ec6 100644 --- a/ossdbtoolsservice/scripting/scripting_service.py +++ b/ossdbtoolsservice/scripting/scripting_service.py @@ -9,7 +9,7 @@ from ossdbtoolsservice.hosting import RequestContext, ServiceProvider from ossdbtoolsservice.metadata.contracts.object_metadata import ObjectMetadata from ossdbtoolsservice.scripting.scripter import Scripter from ossdbtoolsservice.scripting.contracts import ( - ScriptAsParameters, ScriptAsResponse, SCRIPTAS_REQUEST + ScriptAsParameters, ScriptAsResponse, SCRIPT_AS_REQUEST ) from ossdbtoolsservice.connection.contracts import ConnectionType import ossdbtoolsservice.utils as utils @@ -25,7 +25,7 @@ class ScriptingService(object): self._service_provider = service_provider # Register the request handlers with the server - self._service_provider.server.set_request_handler(SCRIPTAS_REQUEST, self._handle_scriptas_request) + self._service_provider.server.set_request_handler(SCRIPT_AS_REQUEST, self._handle_script_as_request) # Find the provider type self._provider: str = self._service_provider.provider @@ -43,10 +43,14 @@ class ScriptingService(object): return object_metadata # REQUEST HANDLERS ##################################################### - def _handle_scriptas_request(self, request_context: RequestContext, params: ScriptAsParameters) -> None: + def _handle_script_as_request(self, request_context: RequestContext, params: ScriptAsParameters, retry_state=False) -> None: try: utils.validate.is_not_none('params', params) + except Exception as e: + self._request_error(request_context, params, str(e)) + return + try: scripting_operation = params.operation connection_service = self._service_provider[utils.constants.CONNECTION_SERVICE_NAME] connection = connection_service.get_connection(params.owner_uri, ConnectionType.QUERY) @@ -57,6 +61,13 @@ class ScriptingService(object): script = scripter.script(scripting_operation, object_metadata) request_context.send_response(ScriptAsResponse(params.owner_uri, script)) except Exception as e: - if self._service_provider.logger is not None: - self._service_provider.logger.exception('Scripting operation failed') - request_context.send_error(str(e), params) + if connection is not None and connection.connection.broken and not retry_state: + self._service_provider.logger.warn('Server closed the connection unexpectedly. Attempting to reconnect...') + self._handle_script_as_request(request_context, params, True) + else: + self._request_error(request_context, params, str(e)) + + def _request_error(self, request_context: RequestContext, params: ScriptAsParameters, message: str): + if self._service_provider.logger is not None: + self._service_provider.logger.exception('Scripting operation failed') + request_context.send_error(message, params) diff --git a/pgsmo/objects/server/server.py b/pgsmo/objects/server/server.py index dbc7928b..e24e8bdf 100644 --- a/pgsmo/objects/server/server.py +++ b/pgsmo/objects/server/server.py @@ -283,6 +283,10 @@ class Server: } return object_map[object_type.capitalize()](metadata) + def set_connection(self, conn: ServerConnection) -> ServerConnection: + """Reset connection to the server/db that this object will use""" + self._conn = conn + # IMPLEMENTATION DETAILS ############################################### def _fetch_recovery_state(self) -> Dict[str, Optional[bool]]: diff --git a/tests/mock_request_validation.py b/tests/mock_request_validation.py index 0f5f727e..7036df90 100644 --- a/tests/mock_request_validation.py +++ b/tests/mock_request_validation.py @@ -104,6 +104,7 @@ class RequestFlowValidator: 'Expected additional messages: ' f'[{self._expected_messages[i].message_type}] ' f'{self._expected_messages[i].param_type}' + f'{self._expected_messages[i].message_method}' ) received = self._received_messages[i] diff --git a/tests/object_explorer/test_object_explorer_service.py b/tests/object_explorer/test_object_explorer_service.py index 89b68a8c..908cf2f1 100644 --- a/tests/object_explorer/test_object_explorer_service.py +++ b/tests/object_explorer/test_object_explorer_service.py @@ -195,7 +195,7 @@ class TestObjectExplorer(unittest.TestCase): oe._provider = constants.PG_PROVIDER_NAME # ... Patch the threading to throw - patch_mock = mock.MagicMock(side_effect=Exception('Boom!')) + patch_mock = mock.MagicMock(side_effect=Exception('Boom! Create Session Failed')) patch_path = 'ossdbtoolsservice.object_explorer.object_explorer_service.threading.Thread' with mock.patch(patch_path, patch_mock): # If: I create a new session @@ -304,7 +304,7 @@ class TestObjectExplorer(unittest.TestCase): # ... Create OE service with mock connection service that returns a failed connection response cs = ConnectionService() connect_response = ConnectionCompleteParams() - connect_response.error_message = 'Boom!' + connect_response.error_message = 'Boom! Init Session Failed' cs.connect = mock.MagicMock(return_value=connect_response) oe = ObjectExplorerService() oe._service_provider = utils.get_mock_service_provider({constants.CONNECTION_SERVICE_NAME: cs}) @@ -479,7 +479,7 @@ class TestObjectExplorer(unittest.TestCase): oe, session, session_uri = self._preloaded_oe_service() # ... Patch the threading to throw - patch_mock = mock.MagicMock(side_effect=Exception('Boom!')) + patch_mock = mock.MagicMock(side_effect=Exception('Boom! Thread Error Handling Failed')) patch_path = 'ossdbtoolsservice.object_explorer.object_explorer_service.threading.Thread' with mock.patch(patch_path, patch_mock): # If: I expand a node (with threading that throws) @@ -502,11 +502,11 @@ class TestObjectExplorer(unittest.TestCase): def _handle_er_exception_expanding(self, method: TEventHandler, get_tasks: TGetTask): # Setup: Create an OE service with a session preloaded - oe, session, session_uri = self._preloaded_oe_service() + oe, session, session_uri = self._preloaded_oe_service(Server(MockPGServerConnection())) # ... Patch the route_request to throw # ... Patch the threading to throw - patch_mock = mock.MagicMock(side_effect=Exception('Boom!')) + patch_mock = mock.MagicMock(side_effect=Exception('Boom! Expand Error Handling Failed')) patch_path = 'ossdbtoolsservice.object_explorer.object_explorer_service.ObjectExplorerService._route_request' with mock.patch(patch_path, patch_mock): # If: I expand a node (with route_request that throws) @@ -600,14 +600,14 @@ class TestObjectExplorer(unittest.TestCase): testevent.set() # IMPLEMENTATION DETAILS ############################################### - def _preloaded_oe_service(self) -> Tuple[ObjectExplorerService, ObjectExplorerSession, str]: + def _preloaded_oe_service(self, server=mock.Mock()) -> Tuple[ObjectExplorerService, ObjectExplorerSession, str]: oe = ObjectExplorerService() oe._service_provider = utils.get_mock_service_provider({}) oe._routing_table = PG_ROUTING_TABLE conn_details, session_uri = _connection_details() session = ObjectExplorerSession(session_uri, conn_details) - session.server = mock.Mock() + session.server = server session.is_ready = True oe._session_map[session_uri] = session diff --git a/tests/scripting/test_scripting_service.py b/tests/scripting/test_scripting_service.py index d3b29b00..9957cb11 100644 --- a/tests/scripting/test_scripting_service.py +++ b/tests/scripting/test_scripting_service.py @@ -10,7 +10,7 @@ import tests.utils as utils from ossdbtoolsservice.connection import ConnectionService from ossdbtoolsservice.connection.contracts import ConnectionCompleteParams from ossdbtoolsservice.hosting import JSONRPCServer, ServiceProvider -from ossdbtoolsservice.scripting.contracts.scriptas_request import ( +from ossdbtoolsservice.scripting.contracts.script_as_request import ( ScriptAsParameters, ScriptAsResponse, ScriptOperation) from ossdbtoolsservice.scripting.scripter import Scripter from ossdbtoolsservice.scripting.scripting_service import ScriptingService @@ -63,7 +63,7 @@ class TestScriptingService(unittest.TestCase): # If: I make a scripting request missing params rc: RequestFlowValidator = RequestFlowValidator() rc.add_expected_error(type(None), RequestFlowValidator.basic_error_validation) - ss._handle_scriptas_request(rc.request_context, None) + ss._handle_script_as_request(rc.request_context, None) # Then: # ... I should get an error response @@ -81,7 +81,7 @@ class TestScriptingService(unittest.TestCase): # If: I create an OE session with missing params rc: RequestFlowValidator = RequestFlowValidator() rc.add_expected_error(type(None), RequestFlowValidator.basic_error_validation) - ss._handle_scriptas_request(rc.request_context, None) + ss._handle_script_as_request(rc.request_context, None) # Then: # ... I should get an error response @@ -129,7 +129,7 @@ class TestScriptingService(unittest.TestCase): 'scripting_objects': [scripting_object] }) - ss._handle_scriptas_request(rc.request_context, params) + ss._handle_script_as_request(rc.request_context, params) # Then: # ... The request should have been handled correctly diff --git a/tests/utils.py b/tests/utils.py index ac5d9446..c4867bc5 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -119,6 +119,7 @@ class MockPsycopgConnection(object): self.commit = mock.Mock() self.pgconn = mock.Mock() self.info = MockConnectionInfo(dsn_parameters, self.server_version) + self.broken = False self._adapters: Optional[AdaptersMap] = mock.Mock() self.notice_handlers: List[NoticeHandler] = []