refactor
This commit is contained in:
Родитель
4e087250be
Коммит
353d86a780
|
@ -38,21 +38,31 @@ class ConversationModel(BaseModel):
|
|||
def validate_messages(self) -> Self:
|
||||
messages = self.messages
|
||||
if not messages or messages[0].role != ChatRoles.SYSTEM:
|
||||
raise ValidationError(
|
||||
raise ValueError(
|
||||
f"First message should be from {ChatRoles.SYSTEM}",
|
||||
)
|
||||
|
||||
if self.task_type == DataGenerationTaskType.CONVERSATION:
|
||||
validate_task_type_conversation(messages)
|
||||
if (len(messages) < 3) or (len(messages[1:]) % 2 != 0):
|
||||
raise ValueError(
|
||||
f"For task type {self.task_type}, \
|
||||
there must be a 'system' message followed by at least one pair of 'user' and 'assistant' roles."
|
||||
)
|
||||
|
||||
# Applies to the rest of the task types i.e. NLI, NLU_QA.
|
||||
else:
|
||||
validate_generic_chat(messages)
|
||||
if len(messages) != 2:
|
||||
raise ValueError(
|
||||
f"For task type {self.task_type}, \
|
||||
there must be exactly two messages: 'system' followed by 'user'."
|
||||
)
|
||||
|
||||
# Validate that the conversation is in the expected order.
|
||||
expected_input_roles = [ChatRoles.USER, ChatRoles.ASSISTANT]
|
||||
for i, message in enumerate(messages[1:], start=1):
|
||||
expected_role = expected_input_roles[(i - 1) % 2]
|
||||
if message.role != expected_role:
|
||||
raise ValidationError(
|
||||
raise ValueError(
|
||||
f"Message at index {i} should be from the '{expected_role}' role, but got '{message.role}'."
|
||||
)
|
||||
|
||||
|
@ -171,21 +181,3 @@ def validate_min_endpoint_success_ratio(val: int):
|
|||
),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def validate_task_type_conversation(messages: List[MessageModel]):
|
||||
"""Validate conversation task type."""
|
||||
if len(messages) < 3 or (len(messages[1:]) % 2 != 0):
|
||||
raise ValidationError(
|
||||
f"For task type {DataGenerationTaskType.CONVERSATION}, \
|
||||
there must be a 'system' message followed by at least one pair of 'user' and 'assistant' roles."
|
||||
)
|
||||
|
||||
|
||||
def validate_generic_chat(messages: List[MessageModel]):
|
||||
"""Validate generic chat task type, applies to NLI and NLU_QA."""
|
||||
if len(messages) != 2:
|
||||
raise ValueError(
|
||||
f"For {DataGenerationTaskType.NLI} and {DataGenerationTaskType.NLU_QUESTION_ANSWERING}, \
|
||||
there must be exactly two messages: 'system' followed by 'user'."
|
||||
)
|
||||
|
|
|
@ -198,7 +198,7 @@ class PipelineInputsValidator:
|
|||
ConversationModel.model_validate(data)
|
||||
except Exception as e:
|
||||
raise generic_validation_error(
|
||||
f"Error validating record at index {idx}. Error: {str(e)}"
|
||||
f"Error validating data-set record at index {idx}. Error: {str(e)}"
|
||||
)
|
||||
|
||||
self._validate_number_of_records(size=total_rows)
|
||||
|
|
Загрузка…
Ссылка в новой задаче