Fixes
This commit is contained in:
Родитель
82ec219d64
Коммит
2acd8fdc35
|
@ -1,7 +1,7 @@
|
|||
$schema: https://azuremlschemas.azureedge.net/latest/commandComponent.schema.json
|
||||
|
||||
name: convert_model_to_mlflow
|
||||
version: 0.0.33
|
||||
version: 0.0.34
|
||||
type: command
|
||||
|
||||
is_deterministic: True
|
||||
|
@ -19,7 +19,7 @@ command: |
|
|||
pip_pkg_str="${pip_pkgs[*]}"
|
||||
if [[ -n "$pip_pkg_str" ]]; then echo "Installing $pip_pkg_str"; pip install $pip_pkg_str; echo "pip installation completed. For any installation error please check above logs"; fi;
|
||||
echo "Running model conversion ... "
|
||||
python -u run_model_preprocess.py $[[--model-id ${{inputs.model_id}}]] $[[--task-name ${{inputs.task_name}}]] $[[--model-download-metadata ${{inputs.model_download_metadata}}]] $[[--license-file-path ${{inputs.license_file_path}}]] $[[--hf-config-args "${{inputs.hf_config_args}}"]] $[[--hf-tokenizer-args "${{inputs.hf_tokenizer_args}}"]] $[[--hf-model-args "${{inputs.hf_model_args}}"]] $[[--hf-pipeline-args "${{inputs.hf_pipeline_args}}"]] $[[--hf-config-class ${{inputs.hf_config_class}}]] $[[--hf-model-class ${{inputs.hf_model_class}}]] $[[--hf-tokenizer-class ${{inputs.hf_tokenizer_class}}]] $[[--hf-use-experimental-features ${{inputs.hf_use_experimental_features}}]] $[[--extra-pip-requirements "${{inputs.extra_pip_requirements}}"]] $[[--inference-base-image "${{inputs.inference_base_image}}"]] --vllm-enabled ${{inputs.vllm_enabled}} --model-framework ${{inputs.model_framework}} --model-path ${{inputs.model_path}} --mlflow-model-output-dir ${{outputs.mlflow_model_folder}} --model-flavor ${{inputs.model_flavor}}
|
||||
python -u run_model_preprocess.py $[[--model-id ${{inputs.model_id}}]] $[[--task-name ${{inputs.task_name}}]] $[[--model-download-metadata ${{inputs.model_download_metadata}}]] $[[--license-file-path ${{inputs.license_file_path}}]] $[[--hf-config-args "${{inputs.hf_config_args}}"]] $[[--hf-tokenizer-args "${{inputs.hf_tokenizer_args}}"]] $[[--hf-model-args "${{inputs.hf_model_args}}"]] $[[--hf-pipeline-args "${{inputs.hf_pipeline_args}}"]] $[[--hf-config-class ${{inputs.hf_config_class}}]] $[[--hf-model-class ${{inputs.hf_model_class}}]] $[[--hf-tokenizer-class ${{inputs.hf_tokenizer_class}}]] $[[--hf-use-experimental-features ${{inputs.hf_use_experimental_features}}]] $[[--extra-pip-requirements "${{inputs.extra_pip_requirements}}"]] $[[--inference-base-image "${{inputs.inference_base_image}}"]] --vllm-enabled ${{inputs.vllm_enabled}} --model-framework ${{inputs.model_framework}} $[[--model-path "${{inputs.model_path}}"]] $[[--model-path-mmd "${{inputs.model_path_mmd}}"]] --mlflow-model-output-dir ${{outputs.mlflow_model_folder}} --model-flavor ${{inputs.model_flavor}}
|
||||
echo "Completed model conversion ... "
|
||||
|
||||
inputs:
|
||||
|
@ -79,6 +79,7 @@ inputs:
|
|||
- mask-generation
|
||||
- video-multi-object-tracking
|
||||
- visual-question-answering
|
||||
- image-feature-extraction
|
||||
description: A Hugging face task on which model was trained on. A required parameter for transformers MLflow flavor. Can be provided as input here or in model_download_metadata JSON file.
|
||||
optional: true
|
||||
|
||||
|
@ -156,7 +157,13 @@ inputs:
|
|||
type: uri_folder
|
||||
description: Path to the model.
|
||||
mode: ro_mount
|
||||
optional: false
|
||||
optional: true
|
||||
|
||||
model_path_mmd:
|
||||
type: uri_folder
|
||||
description: Path to the MMD model.
|
||||
mode: ro_mount
|
||||
optional: true
|
||||
|
||||
license_file_path:
|
||||
type: uri_file
|
||||
|
|
|
@ -100,6 +100,7 @@ inputs:
|
|||
- mask-generation
|
||||
- video-multi-object-tracking
|
||||
- visual-question-answering
|
||||
- image-feature-extraction
|
||||
optional: true
|
||||
type: string
|
||||
|
||||
|
|
|
@ -67,6 +67,7 @@ inputs:
|
|||
- image-classification
|
||||
- text-to-image
|
||||
- chat-completion
|
||||
- image-feature-extraction
|
||||
optional: true
|
||||
type: string
|
||||
|
||||
|
|
|
@ -36,6 +36,7 @@ from azureml.model.mgmt.processors.pyfunc.convertors import (
|
|||
DinoV2MLFlowConvertor,
|
||||
LLaVAMLFlowConvertor,
|
||||
SegmentAnythingMLFlowConvertor,
|
||||
VirchowMLFlowConvertor
|
||||
)
|
||||
|
||||
|
||||
|
@ -84,6 +85,10 @@ def get_mlflow_convertor(model_framework, model_dir, output_dir, temp_dir, trans
|
|||
return SegmentAnythingMLflowConvertorFactory.create_mlflow_convertor(
|
||||
model_dir, output_dir, temp_dir, translate_params
|
||||
)
|
||||
elif task == PyFuncSupportedTasks.IMAGE_FEATURE_EXTRACTION.value:
|
||||
return VirchowMLflowConvertorFactory.create_mlflow_convertor(
|
||||
model_dir, output_dir, temp_dir, translate_params
|
||||
)
|
||||
else:
|
||||
raise Exception(
|
||||
f"Models from {model_framework} for task {task} and model {model_id} "
|
||||
|
@ -297,7 +302,18 @@ class SegmentAnythingMLflowConvertorFactory(MLflowConvertorFactoryInterface):
|
|||
temp_dir=temp_dir,
|
||||
translate_params=translate_params,
|
||||
)
|
||||
|
||||
class VirchowMLflowConvertorFactory(MLflowConvertorFactoryInterface):
|
||||
"""Factory class for segment anything Virchow model."""
|
||||
|
||||
def create_mlflow_convertor(model_dir, output_dir, temp_dir, translate_params):
|
||||
"""Create MLflow convertor for segment anything Virchow model."""
|
||||
return VirchowMLFlowConvertor(
|
||||
model_dir=model_dir,
|
||||
output_dir=output_dir,
|
||||
temp_dir=temp_dir,
|
||||
translate_params=translate_params,
|
||||
)
|
||||
|
||||
class MMLabTrackingMLflowConvertorFactory(MLflowConvertorFactoryInterface):
|
||||
"""Factory class for MMTrack video model family."""
|
||||
|
|
|
@ -56,6 +56,8 @@ class SupportedTasks(_CustomEnum):
|
|||
IMAGE_OBJECT_DETECTION = "image-object-detection"
|
||||
IMAGE_INSTANCE_SEGMENTATION = "image-instance-segmentation"
|
||||
|
||||
# Virchow
|
||||
IMAGE_FEATURE_EXTRACTION = "image-feature-extraction"
|
||||
|
||||
class ModelFamilyPrefixes(_CustomEnum):
|
||||
"""Prefixes for some of the models converted to PyFunc MLflow."""
|
||||
|
@ -65,3 +67,6 @@ class ModelFamilyPrefixes(_CustomEnum):
|
|||
|
||||
# DinoV2 model family.
|
||||
DINOV2 = "facebook/dinov2"
|
||||
|
||||
# Virchow model family.
|
||||
VIRCHOW = "paige-ai/Virchow"
|
||||
|
|
|
@ -41,6 +41,8 @@ from azureml.model.mgmt.processors.pyfunc.segment_anything.config import \
|
|||
MLflowSchemaLiterals as SegmentAnythingMLFlowSchemaLiterals, MLflowLiterals as SegmentAnythingMLflowLiterals
|
||||
from azureml.model.mgmt.processors.pyfunc.vision.config import \
|
||||
MLflowSchemaLiterals as VisionMLFlowSchemaLiterals, MMDetLiterals
|
||||
from azureml.model.mgmt.processors.pyfunc.virchow.config import \
|
||||
MLflowSchemaLiterals as VirchowMLFlowSchemaLiterals, MLflowLiterals as VirchowMLflowLiterals
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
@ -1136,3 +1138,95 @@ class AutoMLMLFlowConvertor(PyFuncMLFLowConvertor):
|
|||
conda_env=conda_env_file,
|
||||
code_path=None,
|
||||
)
|
||||
|
||||
|
||||
class VirchowMLFlowConvertor(PyFuncMLFLowConvertor):
|
||||
"""PyFunc MLfLow convertor for Virchow models."""
|
||||
|
||||
MODEL_DIR = os.path.join(os.path.dirname(__file__), "virchow")
|
||||
COMMON_DIR = os.path.join(os.path.dirname(os.path.dirname(__file__)), "common")
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
"""Initialize MLflow convertor for Virchow models."""
|
||||
super().__init__(**kwargs)
|
||||
if self._task not in \
|
||||
[SupportedTasks.IMAGE_FEATURE_EXTRACTION.value]:
|
||||
raise Exception("Unsupported task")
|
||||
|
||||
def get_model_signature(self) -> ModelSignature:
|
||||
"""Return MLflow model signature with input and output schema for the given input task.
|
||||
|
||||
:return: MLflow model signature.
|
||||
:rtype: mlflow.models.signature.ModelSignature
|
||||
"""
|
||||
input_schema = Schema(
|
||||
[
|
||||
ColSpec(VirchowMLFlowSchemaLiterals.INPUT_COLUMN_IMAGE_DATA_TYPE,
|
||||
VirchowMLFlowSchemaLiterals.INPUT_COLUMN_IMAGE),
|
||||
ColSpec(VirchowMLFlowSchemaLiterals.INPUT_COLUMN_TEXT_DATA_TYPE,
|
||||
VirchowMLFlowSchemaLiterals.INPUT_COLUMN_TEXT),
|
||||
]
|
||||
)
|
||||
|
||||
if self._task == SupportedTasks.IMAGE_FEATURE_EXTRACTION.value:
|
||||
output_schema = Schema(
|
||||
[
|
||||
ColSpec(VirchowMLFlowSchemaLiterals.OUTPUT_COLUMN_DATA_TYPE,
|
||||
VirchowMLFlowSchemaLiterals.OUTPUT_COLUMN_PROBS),
|
||||
ColSpec(VirchowMLFlowSchemaLiterals.OUTPUT_COLUMN_DATA_TYPE,
|
||||
VirchowMLFlowSchemaLiterals.OUTPUT_COLUMN_LABELS),
|
||||
ColSpec(VirchowMLFlowSchemaLiterals.OUTPUT_COLUMN_DATA_TYPE,
|
||||
VirchowMLFlowSchemaLiterals.OUTPUT_COLUMN_IMAGE_FEATURES),
|
||||
ColSpec(VirchowMLFlowSchemaLiterals.OUTPUT_COLUMN_DATA_TYPE,
|
||||
VirchowMLFlowSchemaLiterals.OUTPUT_COLUMN_TEXT_FEATURES),
|
||||
]
|
||||
)
|
||||
else:
|
||||
raise Exception("Unsupported task")
|
||||
|
||||
return ModelSignature(inputs=input_schema, outputs=output_schema)
|
||||
|
||||
def save_as_mlflow(self):
|
||||
"""Prepare model for save to MLflow."""
|
||||
sys.path.append(self.MODEL_DIR)
|
||||
|
||||
from virchow_mlflow_model_wrapper import VirchowModelWrapper
|
||||
mlflow_model_wrapper = VirchowModelWrapper()
|
||||
|
||||
artifacts_dict = self._prepare_artifacts_dict()
|
||||
conda_env_file = os.path.join(self.MODEL_DIR, "conda.yaml")
|
||||
code_path = self._get_code_path()
|
||||
|
||||
super()._save(
|
||||
mlflow_model_wrapper=mlflow_model_wrapper,
|
||||
artifacts_dict=artifacts_dict,
|
||||
conda_env=conda_env_file,
|
||||
code_path=code_path,
|
||||
)
|
||||
|
||||
def _get_code_path(self):
|
||||
"""Return code path for saving mlflow model depending on task type.
|
||||
|
||||
:return: code path
|
||||
:rtype: List[str]
|
||||
"""
|
||||
code_path = [
|
||||
os.path.join(self.MODEL_DIR, "virchow_mlflow_model_wrapper.py"),
|
||||
os.path.join(self.MODEL_DIR, "config.py"),
|
||||
os.path.join(self.COMMON_DIR, "vision_utils.py")
|
||||
]
|
||||
|
||||
return code_path
|
||||
|
||||
def _prepare_artifacts_dict(self) -> Dict:
|
||||
"""Prepare artifacts dict for MLflow model.
|
||||
|
||||
:return: artifacts dict
|
||||
:rtype: Dict
|
||||
"""
|
||||
artifacts_dict = {
|
||||
VirchowMLflowLiterals.MODEL_DIR: self._model_dir
|
||||
}
|
||||
return artifacts_dict
|
||||
|
||||
|
||||
|
|
|
@ -0,0 +1,17 @@
|
|||
channels:
|
||||
- conda-forge
|
||||
dependencies:
|
||||
- python=3.10.14
|
||||
- pip<=24.0
|
||||
- pip:
|
||||
- mlflow==2.13.2
|
||||
- cffi==1.16.0
|
||||
- cloudpickle==2.2.1
|
||||
- numpy==1.23.5
|
||||
- pandas==2.2.2
|
||||
- pyyaml==6.0.1
|
||||
- requests==2.32.3
|
||||
- timm==1.0.9,>=0.9.11
|
||||
- torch>2
|
||||
- pillow>=10
|
||||
name: mlflow-env
|
|
@ -0,0 +1,33 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
"""Common Config."""
|
||||
|
||||
from enum import Enum
|
||||
|
||||
from mlflow.types import DataType
|
||||
|
||||
|
||||
class _CustomEnum(Enum):
|
||||
@classmethod
|
||||
def has_value(cls, value):
|
||||
return value in cls._value2member_map_
|
||||
|
||||
class MLflowSchemaLiterals:
|
||||
"""MLflow model signature related schema."""
|
||||
|
||||
INPUT_COLUMN_IMAGE_DATA_TYPE = DataType.binary
|
||||
INPUT_COLUMN_IMAGE = "image"
|
||||
INPUT_COLUMN_TEXT_DATA_TYPE = DataType.string
|
||||
INPUT_COLUMN_TEXT = "text"
|
||||
OUTPUT_COLUMN_DATA_TYPE = DataType.string
|
||||
OUTPUT_COLUMN_PROBS = "probs"
|
||||
OUTPUT_COLUMN_LABELS = "labels"
|
||||
OUTPUT_COLUMN_IMAGE_FEATURES = "image_features"
|
||||
OUTPUT_COLUMN_TEXT_FEATURES = "text_features"
|
||||
|
||||
|
||||
class MLflowLiterals:
|
||||
"""MLflow export related literals."""
|
||||
|
||||
MODEL_DIR = "model_dir"
|
|
@ -0,0 +1,71 @@
|
|||
import json
|
||||
|
||||
import mlflow.pyfunc
|
||||
import timm
|
||||
import torch
|
||||
import pandas as pd
|
||||
import io
|
||||
from PIL import Image
|
||||
from timm.data import resolve_data_config
|
||||
from timm.data.transforms_factory import create_transform
|
||||
from timm.layers import SwiGLUPacked
|
||||
from config import MLflowSchemaLiterals
|
||||
import logging
|
||||
logger = logging.getLogger("mlflow") # Set log level to debugging
|
||||
logger.setLevel(logging.DEBUG)
|
||||
|
||||
class VirchowModelWrapper(mlflow.pyfunc.PythonModel):
|
||||
def load_context(self, context):
|
||||
config_path = context.artifacts["config_path"]
|
||||
checkpoint_path = context.artifacts["checkpoint_path"]
|
||||
# config = json.loads(config_path.read_text())
|
||||
with open(config_path) as f:
|
||||
config = json.load(f)
|
||||
self.model = timm.create_model(
|
||||
model_name="vit_huge_patch14_224",
|
||||
checkpoint_path=checkpoint_path,
|
||||
mlp_layer=SwiGLUPacked,
|
||||
act_layer=torch.nn.SiLU,
|
||||
pretrained_cfg=config["pretrained_cfg"],
|
||||
**config["model_args"]
|
||||
)
|
||||
self.model.eval()
|
||||
self.transforms = create_transform(
|
||||
**resolve_data_config(self.model.pretrained_cfg, model=self.model)
|
||||
)
|
||||
|
||||
# def predict(self, image_input_path: str, params: dict = None):
|
||||
def predict(self, context: mlflow.pyfunc.PythonModelContext, input_data: pd.DataFrame, params: pd.DataFrame) -> pd.DataFrame:
|
||||
|
||||
from vision_utils import process_image
|
||||
pil_images = [
|
||||
Image.open(io.BytesIO(process_image(image)))
|
||||
for image in input_data[MLflowSchemaLiterals.INPUT_COLUMN_IMAGE]
|
||||
]
|
||||
# image = input_data["image"]
|
||||
# pil_image = Image.open(io.BytesIO(process_image(pil_images[0])))
|
||||
pil_image = self.transforms(pil_images[0]).unsqueeze(0) # size: 1 x 3 x 224 x 224
|
||||
|
||||
device_type = params.get("device_type", "cuda")
|
||||
to_half_precision = params.get("to_half_precision", False)
|
||||
|
||||
with torch.inference_mode(), torch.autocast(
|
||||
device_type=device_type, dtype=torch.float16
|
||||
):
|
||||
output = self.model(pil_image) # size: 1 x 257 x 1280
|
||||
|
||||
class_token = output[:, 0] # size: 1 x 1280
|
||||
patch_tokens = output[:, 1:] # size: 1 x 256 x 1280
|
||||
|
||||
# use the class token only as the embedding
|
||||
# size: 1 x 1280
|
||||
embedding = class_token
|
||||
|
||||
# the model output will be fp32 because the final operation is a LayerNorm that is ran in mixed precision
|
||||
# optionally, you can convert the embedding to fp16 for efficiency in downstream use
|
||||
if to_half_precision:
|
||||
embedding = embedding.to(torch.float16)
|
||||
|
||||
df_result = pd.DataFrame()
|
||||
df_result['output'] = embedding.tolist()
|
||||
return df_result
|
|
@ -0,0 +1,234 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
"""Helper utils for vision Mlflow models."""
|
||||
|
||||
import base64
|
||||
import io
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import requests
|
||||
import uuid
|
||||
|
||||
import PIL
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from ast import literal_eval
|
||||
from PIL import Image, UnidentifiedImageError
|
||||
from typing import Union
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Uncomment the following line for mlflow debug mode
|
||||
# logging.getLogger("mlflow").setLevel(logging.DEBUG)
|
||||
|
||||
|
||||
def save_image(output_folder: str, img: PIL.Image.Image, format: str) -> str:
|
||||
"""
|
||||
Save image in a folder designated for batch output and return image file path.
|
||||
|
||||
:param output_folder: directory path where we need to save files
|
||||
:type output_folder: str
|
||||
:param img: image object
|
||||
:type img: PIL.Image.Image
|
||||
:param format: format to save image
|
||||
:type format: str
|
||||
:return: file name of image.
|
||||
:rtype: str
|
||||
"""
|
||||
filename = f"image_{uuid.uuid4()}.{format.lower()}"
|
||||
img.save(os.path.join(output_folder, filename), format=format)
|
||||
return filename
|
||||
|
||||
|
||||
def get_pil_image(image: bytes) -> PIL.Image.Image:
|
||||
"""
|
||||
Convert image bytes to PIL image.
|
||||
|
||||
:param image: image bytes
|
||||
:type image: bytes
|
||||
:return: PIL image object
|
||||
:rtype: PIL.Image.Image
|
||||
"""
|
||||
try:
|
||||
return Image.open(io.BytesIO(image))
|
||||
except UnidentifiedImageError as e:
|
||||
logger.error("Invalid image format. Please use base64 encoding for input images.")
|
||||
raise e
|
||||
|
||||
|
||||
def image_to_base64(img: PIL.Image.Image, format: str) -> str:
|
||||
"""
|
||||
Convert image into Base64 encoded string.
|
||||
|
||||
:param img: image object
|
||||
:type img: PIL.Image.Image
|
||||
:param format: image format
|
||||
:type format: str
|
||||
:return: base64 encoded string
|
||||
:rtype: str
|
||||
"""
|
||||
buffered = io.BytesIO()
|
||||
img.save(buffered, format=format)
|
||||
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
|
||||
return img_str
|
||||
|
||||
|
||||
def process_image(image: Union[str, bytes]) -> bytes:
|
||||
"""Process image.
|
||||
|
||||
If input image is in bytes format, return it as it is.
|
||||
If input image is in base64 string format, decode it to bytes.
|
||||
If input image is in url format, download it and return bytes.
|
||||
https://github.com/mlflow/mlflow/blob/master/examples/flower_classifier/image_pyfunc.py
|
||||
|
||||
:param image: image in base64 string format or url or bytes.
|
||||
:type image: string or bytes
|
||||
:return: decoded image.
|
||||
:rtype: bytes
|
||||
"""
|
||||
if isinstance(image, bytes):
|
||||
return image
|
||||
elif isinstance(image, str):
|
||||
if _is_valid_url(image):
|
||||
try:
|
||||
response = requests.get(image)
|
||||
response.raise_for_status() # Raise exception in case of unsuccessful response code.
|
||||
image = response.content
|
||||
return image
|
||||
except requests.exceptions.RequestException as ex:
|
||||
raise ValueError(f"Unable to retrieve image from url string due to exception: {ex}")
|
||||
else:
|
||||
try:
|
||||
return base64.b64decode(image)
|
||||
except ValueError:
|
||||
raise ValueError(
|
||||
"The provided image string cannot be decoded. " "Expected format is base64 string or url string."
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Image received in {type(image)} format which is not supported. "
|
||||
"Expected format is bytes, base64 string or url string."
|
||||
)
|
||||
|
||||
|
||||
def process_image_pandas_series(image_pandas_series: pd.Series) -> pd.Series:
|
||||
"""Process image in Pandas series form.
|
||||
|
||||
If input image is in bytes format, return it as it is.
|
||||
If input image is in base64 string format, decode it to bytes.
|
||||
If input image is in url format, download it and return bytes.
|
||||
https://github.com/mlflow/mlflow/blob/master/examples/flower_classifier/image_pyfunc.py
|
||||
|
||||
:param img: pandas series with image in base64 string format or url or bytes.
|
||||
:type img: pd.Series
|
||||
:return: decoded image in pandas series format.
|
||||
:rtype: Pandas Series
|
||||
"""
|
||||
image = image_pandas_series[0]
|
||||
return pd.Series(process_image(image))
|
||||
|
||||
|
||||
def _is_valid_url(text: str) -> bool:
|
||||
"""Check if text is url or base64 string.
|
||||
|
||||
:param text: text to validate
|
||||
:type text: str
|
||||
:return: True if url else false
|
||||
:rtype: bool
|
||||
"""
|
||||
regex = (
|
||||
"((http|https)://)(www.)?"
|
||||
+ "[a-zA-Z0-9@:%._\\+~#?&//=\\-]"
|
||||
+ "{2,256}\\.[a-z]"
|
||||
+ "{2,6}\\b([-a-zA-Z0-9@:%"
|
||||
+ "._\\+~#?&//=]*)"
|
||||
)
|
||||
p = re.compile(regex)
|
||||
|
||||
# If the string is empty
|
||||
# return false
|
||||
if str is None:
|
||||
return False
|
||||
|
||||
# Return if the string
|
||||
# matched the ReGex
|
||||
if re.search(p, text):
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
def get_current_device() -> torch.device:
|
||||
"""Get current cuda device.
|
||||
|
||||
:return: current device
|
||||
:rtype: torch.device
|
||||
"""
|
||||
# check if GPU is available
|
||||
if torch.cuda.is_available():
|
||||
try:
|
||||
# get the current device index
|
||||
device_idx = torch.distributed.get_rank()
|
||||
except RuntimeError as ex:
|
||||
if "Default process group has not been initialized".lower() in str(ex).lower():
|
||||
device_idx = 0
|
||||
else:
|
||||
logger.error(str(ex))
|
||||
raise ex
|
||||
return torch.device(type="cuda", index=device_idx)
|
||||
else:
|
||||
return torch.device(type="cpu")
|
||||
|
||||
|
||||
def string_to_nested_float_list(input_str: str) -> list:
|
||||
"""Convert string to nested list of floats.
|
||||
|
||||
:return: string converted to nested list of floats
|
||||
:rtype: list
|
||||
"""
|
||||
if input_str in ["null", "None", "", "nan", "NoneType", np.nan, None]:
|
||||
return None
|
||||
try:
|
||||
# Use ast.literal_eval to safely evaluate the string into a list
|
||||
nested_list = literal_eval(input_str)
|
||||
|
||||
# Recursive function to convert all numbers in the nested list to floats
|
||||
def to_floats(lst) -> list:
|
||||
"""
|
||||
Recursively convert all numbers in a nested list to floats.
|
||||
|
||||
:param lst: nested list
|
||||
:type lst: list
|
||||
:return: nested list of floats
|
||||
:rtype: list
|
||||
"""
|
||||
return [to_floats(item) if isinstance(item, list) else float(item) for item in lst]
|
||||
|
||||
# Use the recursive function to process the nested list
|
||||
return to_floats(nested_list)
|
||||
except (ValueError, SyntaxError) as e:
|
||||
# In case of an error during conversion, print an error message
|
||||
print(f"Invalid input {input_str}: {e}, ignoring.")
|
||||
return None
|
||||
|
||||
|
||||
def bool_array_to_pil_image(bool_array: np.ndarray) -> PIL.Image.Image:
|
||||
"""Convert boolean array to PIL Image.
|
||||
|
||||
:param bool_array: boolean array
|
||||
:type bool_array: np.array
|
||||
:return: PIL Image
|
||||
:rtype: PIL.Image.Image
|
||||
"""
|
||||
# Convert boolean array to uint8
|
||||
uint8_array = bool_array.astype(np.uint8) * 255
|
||||
|
||||
# Create a PIL Image
|
||||
pil_image = Image.fromarray(uint8_array)
|
||||
|
||||
return pil_image
|
Загрузка…
Ссылка в новой задаче