From 353d86a7801ffd83d7aaaf8dc2a6284986af00df Mon Sep 17 00:00:00 2001 From: ShobithNandakumar Date: Tue, 27 Aug 2024 14:36:50 +0530 Subject: [PATCH] refactor --- .../distillation/src/common/validation.py | 36 ++++++++----------- .../distillation/src/validate_pipeline.py | 2 +- 2 files changed, 15 insertions(+), 23 deletions(-) diff --git a/assets/training/distillation/src/common/validation.py b/assets/training/distillation/src/common/validation.py index 6c2f9192e8..df20ec3d08 100644 --- a/assets/training/distillation/src/common/validation.py +++ b/assets/training/distillation/src/common/validation.py @@ -38,21 +38,31 @@ class ConversationModel(BaseModel): def validate_messages(self) -> Self: messages = self.messages if not messages or messages[0].role != ChatRoles.SYSTEM: - raise ValidationError( + raise ValueError( f"First message should be from {ChatRoles.SYSTEM}", ) if self.task_type == DataGenerationTaskType.CONVERSATION: - validate_task_type_conversation(messages) + if (len(messages) < 3) or (len(messages[1:]) % 2 != 0): + raise ValueError( + f"For task type {self.task_type}, \ + there must be a 'system' message followed by at least one pair of 'user' and 'assistant' roles." + ) + + # Applies to the rest of the task types i.e. NLI, NLU_QA. else: - validate_generic_chat(messages) + if len(messages) != 2: + raise ValueError( + f"For task type {self.task_type}, \ + there must be exactly two messages: 'system' followed by 'user'." + ) # Validate that the conversation is in the expected order. expected_input_roles = [ChatRoles.USER, ChatRoles.ASSISTANT] for i, message in enumerate(messages[1:], start=1): expected_role = expected_input_roles[(i - 1) % 2] if message.role != expected_role: - raise ValidationError( + raise ValueError( f"Message at index {i} should be from the '{expected_role}' role, but got '{message.role}'." ) @@ -171,21 +181,3 @@ def validate_min_endpoint_success_ratio(val: int): ), ) ) - - -def validate_task_type_conversation(messages: List[MessageModel]): - """Validate conversation task type.""" - if len(messages) < 3 or (len(messages[1:]) % 2 != 0): - raise ValidationError( - f"For task type {DataGenerationTaskType.CONVERSATION}, \ - there must be a 'system' message followed by at least one pair of 'user' and 'assistant' roles." - ) - - -def validate_generic_chat(messages: List[MessageModel]): - """Validate generic chat task type, applies to NLI and NLU_QA.""" - if len(messages) != 2: - raise ValueError( - f"For {DataGenerationTaskType.NLI} and {DataGenerationTaskType.NLU_QUESTION_ANSWERING}, \ - there must be exactly two messages: 'system' followed by 'user'." - ) diff --git a/assets/training/distillation/src/validate_pipeline.py b/assets/training/distillation/src/validate_pipeline.py index bc4a45b4e4..1083cb6047 100644 --- a/assets/training/distillation/src/validate_pipeline.py +++ b/assets/training/distillation/src/validate_pipeline.py @@ -198,7 +198,7 @@ class PipelineInputsValidator: ConversationModel.model_validate(data) except Exception as e: raise generic_validation_error( - f"Error validating record at index {idx}. Error: {str(e)}" + f"Error validating data-set record at index {idx}. Error: {str(e)}" ) self._validate_number_of_records(size=total_rows)