This commit is contained in:
ShobithNandakumar 2024-08-27 14:36:50 +05:30
Родитель 4e087250be
Коммит 353d86a780
2 изменённых файлов: 15 добавлений и 23 удалений

Просмотреть файл

@ -38,21 +38,31 @@ class ConversationModel(BaseModel):
def validate_messages(self) -> Self:
messages = self.messages
if not messages or messages[0].role != ChatRoles.SYSTEM:
raise ValidationError(
raise ValueError(
f"First message should be from {ChatRoles.SYSTEM}",
)
if self.task_type == DataGenerationTaskType.CONVERSATION:
validate_task_type_conversation(messages)
if (len(messages) < 3) or (len(messages[1:]) % 2 != 0):
raise ValueError(
f"For task type {self.task_type}, \
there must be a 'system' message followed by at least one pair of 'user' and 'assistant' roles."
)
# Applies to the rest of the task types i.e. NLI, NLU_QA.
else:
validate_generic_chat(messages)
if len(messages) != 2:
raise ValueError(
f"For task type {self.task_type}, \
there must be exactly two messages: 'system' followed by 'user'."
)
# Validate that the conversation is in the expected order.
expected_input_roles = [ChatRoles.USER, ChatRoles.ASSISTANT]
for i, message in enumerate(messages[1:], start=1):
expected_role = expected_input_roles[(i - 1) % 2]
if message.role != expected_role:
raise ValidationError(
raise ValueError(
f"Message at index {i} should be from the '{expected_role}' role, but got '{message.role}'."
)
@ -171,21 +181,3 @@ def validate_min_endpoint_success_ratio(val: int):
),
)
)
def validate_task_type_conversation(messages: List[MessageModel]):
"""Validate conversation task type."""
if len(messages) < 3 or (len(messages[1:]) % 2 != 0):
raise ValidationError(
f"For task type {DataGenerationTaskType.CONVERSATION}, \
there must be a 'system' message followed by at least one pair of 'user' and 'assistant' roles."
)
def validate_generic_chat(messages: List[MessageModel]):
"""Validate generic chat task type, applies to NLI and NLU_QA."""
if len(messages) != 2:
raise ValueError(
f"For {DataGenerationTaskType.NLI} and {DataGenerationTaskType.NLU_QUESTION_ANSWERING}, \
there must be exactly two messages: 'system' followed by 'user'."
)

Просмотреть файл

@ -198,7 +198,7 @@ class PipelineInputsValidator:
ConversationModel.model_validate(data)
except Exception as e:
raise generic_validation_error(
f"Error validating record at index {idx}. Error: {str(e)}"
f"Error validating data-set record at index {idx}. Error: {str(e)}"
)
self._validate_number_of_records(size=total_rows)