* response filter

* rewrite implement based on the filter

* multi responses

* abs path

* code handling

* option to not use docker

* context

* eval_only -> raise_error

* notebook

* utils

* utils

* separate tests

* test

* test

* test

* test

* test

* test

* test

* test

* **config in test()

* test

* test

* filename
This commit is contained in:
Chi Wang 2023-05-21 15:22:29 -07:00 коммит произвёл GitHub
Родитель 7de4eb347d
Коммит e463146cb8
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
21 изменённых файлов: 2253 добавлений и 1820 удалений

29
.github/workflows/openai.yml поставляемый
Просмотреть файл

@ -7,10 +7,10 @@ on:
pull_request:
branches: ['main']
paths:
- 'flaml/integrations/oai/**'
- 'test/openai/**'
- 'notebook/integrate_openai.ipynb'
- 'notebook/integrate_chatgpt_math.ipynb'
- 'flaml/autogen/**'
- 'test/autogen/**'
- 'notebook/autogen_openai_completion.ipynb'
- 'notebook/autogen_chatgpt_gpt4.ipynb'
- '.github/workflows/openai.yml'
jobs:
@ -18,7 +18,7 @@ jobs:
strategy:
matrix:
os: [ubuntu-latest]
python-version: [3.9]
python-version: ["3.9", "3.10", "3.11"]
runs-on: ${{ matrix.os }}
environment: openai
steps:
@ -33,14 +33,27 @@ jobs:
python -m pip install --upgrade pip wheel
pip install -e .[autogen,blendsearch]
python -c "import flaml"
pip install coverage pytest datasets
- name: Coverage
if: matrix.python-version == '3.9'
env:
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
AZURE_OPENAI_API_KEY: ${{ secrets.AZURE_OPENAI_API_KEY }}
AZURE_OPENAI_API_BASE: ${{ secrets.AZURE_OPENAI_API_BASE }}
run: |
pip install coverage pytest datasets nbconvert nbformat ipykernel
coverage run -a -m pytest test/openai
coverage run -a -m pytest test/autogen
coverage xml
cat "$(pwd)/test/openai/executed_openai_notebook_output.txt"
- name: Coverage and check notebook outputs
if: matrix.python-version != '3.9'
env:
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
AZURE_OPENAI_API_KEY: ${{ secrets.AZURE_OPENAI_API_KEY }}
AZURE_OPENAI_API_BASE: ${{ secrets.AZURE_OPENAI_API_BASE }}
run: |
pip install nbconvert nbformat ipykernel
coverage run -a -m pytest test/autogen/oai/test_notebook.py
coverage xml
cat "$(pwd)/test/autogen/oai/executed_openai_notebook_output.txt"
- name: Upload coverage to Codecov
uses: codecov/codecov-action@v3
with:

Просмотреть файл

@ -37,7 +37,7 @@ class Agent:
def _receive(self, message, sender):
"""Receive a message from another agent."""
print("****", self.name, "received message from", sender.name, "****")
print("\n****", self.name, "received message from", sender.name, "****\n")
print(message)
self._conversations[sender.name].append({"content": message, "role": "user"})

Просмотреть файл

@ -3,8 +3,8 @@ from flaml.autogen.code_utils import extract_code, execute_code
from collections import defaultdict
class HumanProxyAgent(Agent):
"""(Experimental) A proxy agent for human, that can execute code and provide feedback to the other agents."""
class UserProxyAgent(Agent):
"""(Experimental) A proxy agent for the user, that can execute code and provide feedback to the other agents."""
MAX_CONSECUTIVE_AUTO_REPLY = 100 # maximum number of consecutive auto replies (subject to future change)
@ -16,6 +16,7 @@ class HumanProxyAgent(Agent):
human_input_mode="ALWAYS",
max_consecutive_auto_reply=None,
is_termination_msg=None,
use_docker=True,
**config,
):
"""
@ -51,22 +52,26 @@ class HumanProxyAgent(Agent):
max_consecutive_auto_reply if max_consecutive_auto_reply is not None else self.MAX_CONSECUTIVE_AUTO_REPLY
)
self._consecutive_auto_reply_counter = defaultdict(int)
self._use_docker = use_docker
def _execute_code(self, code, lang):
"""Execute the code and return the result."""
if lang == "bash":
assert code.startswith("python "), code
if lang in ["bash", "shell"]:
if not code.startswith("python "):
return 1, f"please do not suggest bash or shell commands like {code}"
file_name = code[len("python ") :]
exitcode, logs = execute_code(filename=file_name, work_dir=self._work_dir)
exitcode, logs = execute_code(filename=file_name, work_dir=self._work_dir, use_docker=self._use_docker)
logs = logs.decode("utf-8")
elif lang == "python":
if code.startswith("# filename: "):
filename = code[11 : code.find("\n")].strip()
else:
filename = None
exitcode, logs = execute_code(code, work_dir=self._work_dir, filename=filename)
exitcode, logs = execute_code(code, work_dir=self._work_dir, filename=filename, use_docker=self._use_docker)
logs = logs.decode("utf-8")
else:
# TODO: could this happen?
exitcode, logs = 1, "unknown language"
exitcode, logs = 1, f"unknown language {lang}"
# raise NotImplementedError
return exitcode, logs
@ -80,7 +85,7 @@ class HumanProxyAgent(Agent):
# try to execute the code
exitcode, logs = self._execute_code(code, lang)
exitcode2str = "execution succeeded" if exitcode == 0 else "execution failed"
self._send(f"exitcode: {exitcode} ({exitcode2str})\nCode output: {logs.decode('utf-8')}", sender)
self._send(f"exitcode: {exitcode} ({exitcode2str})\nCode output: {logs}", sender)
def receive(self, message, sender):
"""Receive a message from the sender agent.

Просмотреть файл

@ -37,8 +37,7 @@ def generate_code(pattern: str = CODE_BLOCK_PATTERN, **config) -> Tuple[str, flo
float: The cost of the generation.
"""
response = oai.Completion.create(**config)
cost = oai.Completion.cost(response)
return extract_code(oai.Completion.extract_text(response)[0], pattern), cost
return extract_code(oai.Completion.extract_text(response)[0], pattern), response["cost"]
_IMPROVE_FUNCTION_CONFIG = {
@ -59,8 +58,7 @@ def improve_function(file_name, func_name, objective, **config):
response = oai.Completion.create(
{"func_name": func_name, "objective": objective, "file_string": file_string}, **params
)
cost = oai.Completion.cost(response)
return oai.Completion.extract_text(response)[0], cost
return oai.Completion.extract_text(response)[0], response["cost"]
_IMPROVE_CODE_CONFIG = {
@ -97,8 +95,7 @@ def improve_code(files, objective, suggest_only=True, **config):
params = {**_IMPROVE_CODE_CONFIG, **config}
followup = "" if suggest_only else " followed by the improved code"
response = oai.Completion.create({"objective": objective, "code": code, "followup": followup}, **params)
cost = oai.Completion.cost(response)
return oai.Completion.extract_text(response)[0], cost
return oai.Completion.extract_text(response)[0], response["cost"]
def timeout_handler(signum, frame):
@ -281,9 +278,8 @@ def generate_assertions(definition: str, **config) -> Tuple[str, float]:
{"definition": definition},
**params,
)
cost = oai.Completion.cost(response)
assertions = oai.Completion.extract_text(response)[0]
return assertions, cost
return assertions, response["cost"]
def _remove_check(response):
@ -387,6 +383,23 @@ _IMPLEMENT_CONFIGS = [
]
class PassAssertionFilter:
def __init__(self, assertions):
self._assertions = assertions
self.cost = 0
self.metrics = self.responses = None
def pass_assertions(self, context, response, **_):
"""Check if the response passes the assertions."""
responses = oai.Completion.extract_text(response)
metrics = eval_function_completions(responses, context["definition"], assertions=self._assertions)
self._assertions = metrics["assertions"]
self.cost += metrics["gen_cost"]
self.metrics = metrics
self.responses = responses
return metrics["succeed_assertions"]
def implement(
definition: str,
configs: Optional[List[Dict]] = None,
@ -408,12 +421,19 @@ def implement(
configs = configs or _IMPLEMENT_CONFIGS
if len(configs) > 1 and callable(assertions):
assertions, cost = assertions(definition)
for i, config in enumerate(configs):
response = oai.Completion.create({"definition": definition}, **config)
cost += oai.Completion.cost(response)
responses = oai.Completion.extract_text(response)
metrics = eval_function_completions(responses, definition, assertions=assertions)
assertions = metrics["assertions"]
cost += metrics["gen_cost"]
if metrics["succeed_assertions"] or i == len(configs) - 1:
return responses[metrics["index_selected"]], cost, i
assertion_filter = PassAssertionFilter(assertions)
response = oai.Completion.create(
{"definition": definition}, config_list=configs, filter_func=assertion_filter.pass_assertions
)
cost += assertion_filter.cost + response["cost"]
return assertion_filter.responses[assertion_filter.metrics["index_selected"]], cost, response["config_id"]
# for i, config in enumerate(configs):
# response = oai.Completion.create({"definition": definition}, **config)
# cost += oai.Completion.cost(response)
# responses = oai.Completion.extract_text(response)
# metrics = eval_function_completions(responses, definition, assertions=assertions)
# assertions = metrics["assertions"]
# cost += metrics["gen_cost"]
# if metrics["succeed_assertions"] or i == len(configs) - 1:
# return responses[metrics["index_selected"]], cost, i

Просмотреть файл

@ -20,9 +20,8 @@ def solve_problem(problem: str, **config) -> str:
"""
params = {**_MATH_CONFIG, **config}
response = oai.Completion.create({"problem": problem}, **params)
cost = oai.Completion.cost(response)
results = eval_math_responses(oai.Completion.extract_text(response))
return results.get("voted_answer"), cost
return results.get("voted_answer"), response["cost"]
def remove_boxed(string: str) -> Optional[str]:

Просмотреть файл

@ -1,3 +1,4 @@
from flaml.autogen.oai.completion import Completion, ChatCompletion
from flaml.autogen.oai.openai_utils import get_config_list, config_list_gpt4_gpt35, config_list_openai_aoai
__all__ = ["Completion", "ChatCompletion"]
__all__ = ["Completion", "ChatCompletion", "get_config_list", "config_list_gpt4_gpt35", "config_list_openai_aoai"]

Просмотреть файл

@ -2,11 +2,13 @@ from time import sleep
import logging
import numpy as np
import time
from typing import List, Optional, Dict
from typing import List, Optional, Dict, Callable, Any
import sys
import json
import shutil
from flaml import tune, BlendSearch
from flaml.tune.space import is_constant
from flaml.automl.logger import logger_formatter
from .openai_utils import get_key
try:
import openai
@ -34,23 +36,6 @@ if not logger.handlers:
logger.addHandler(_ch)
def get_key(config):
"""Get a unique identifier of a configuration.
Args:
config (dict or list): A configuration.
Returns:
tuple: A unique identifier which can be used as a key for a dict.
"""
# if isinstance(config, dict):
# return tuple(get_key(x) for x in sorted(config.items()))
# if isinstance(config, list):
# return tuple(get_key(x) for x in config)
# return config
return json.dumps(config, sort_keys=True)
class Completion(openai_Completion):
"""A class for OpenAI completion API.
@ -123,7 +108,7 @@ class Completion(openai_Completion):
_history_dict = _count_create = None
@classmethod
def set_cache(cls, seed=41, cache_path=".cache"):
def set_cache(cls, seed: Optional[int] = 41, cache_path_root: Optional[str] = ".cache"):
"""Set cache path.
Args:
@ -133,11 +118,29 @@ class Completion(openai_Completion):
The complete cache path will be {cache_path}/{seed}.
"""
cls.seed = seed
cls.cache_path = f"{cache_path}/{seed}"
cls.cache_path = f"{cache_path_root}/{seed}"
@classmethod
def clear_cache(cls, seed: Optional[int] = None, cache_path_root: Optional[str] = ".cache"):
"""Clear cache.
Args:
seed (int, Optional): The integer identifier for the pseudo seed.
If omitted, all caches under cache_path_root will be cleared.
cache_path (str, Optional): The root path for the cache.
The complete cache path will be {cache_path}/{seed}.
"""
if seed is None:
shutil.rmtree(cache_path_root, ignore_errors=True)
return
with diskcache.Cache(f"{cache_path_root}/{seed}") as cache:
cache.clear()
@classmethod
def _book_keeping(cls, config: Dict, response):
"""Book keeping for the created completions."""
if response != -1 and "cost" not in response:
response["cost"] = cls.cost(response)
if cls._history_dict is None:
return
if cls._history_compact:
@ -154,7 +157,7 @@ class Completion(openai_Completion):
else:
key = get_key([config["prompt"]] + [choice.get("text") for choice in response["choices"]])
value["created_at"].append(cls._count_create)
value["cost"].append(cls.cost(response))
value["cost"].append(response["cost"])
cls._history_dict[key] = value
cls._count_create += 1
return
@ -165,7 +168,7 @@ class Completion(openai_Completion):
cls._count_create += 1
@classmethod
def _get_response(cls, config: Dict, eval_only=False, use_cache=True):
def _get_response(cls, config: Dict, raise_error=False, use_cache=True):
"""Get the response from the openai api call.
Try cache first. If not found, call the openai api. If the api call fails, retry after retry_time.
@ -175,7 +178,7 @@ class Completion(openai_Completion):
key = get_key(config)
if use_cache:
response = cls._cache.get(key, None)
if response is not None and (response != -1 or not eval_only):
if response is not None and (response != -1 or not raise_error):
# print("using cached response")
cls._book_keeping(config, response)
return response
@ -197,7 +200,7 @@ class Completion(openai_Completion):
APIConnectionError,
):
# transient error
logger.warning(f"retrying in {cls.retry_time} seconds...", exc_info=1)
logger.info(f"retrying in {cls.retry_time} seconds...", exc_info=1)
sleep(cls.retry_time)
except APIError as err:
error_code = err and err.json_body and err.json_body.get("error")
@ -205,7 +208,7 @@ class Completion(openai_Completion):
if error_code == "content_filter":
raise
# transient error
logger.warning(f"retrying in {cls.retry_time} seconds...", exc_info=1)
logger.info(f"retrying in {cls.retry_time} seconds...", exc_info=1)
sleep(cls.retry_time)
except (RateLimitError, Timeout) as err:
time_left = cls.retry_timeout - (time.time() - start_time + cls.retry_time)
@ -216,10 +219,16 @@ class Completion(openai_Completion):
and isinstance(err, Timeout)
):
logger.info(f"retrying in {cls.retry_time} seconds...", exc_info=1)
elif eval_only:
elif raise_error:
raise
else:
break
response = -1
if use_cache and isinstance(err, Timeout):
cls._cache.set(key, response)
logger.warning(
f"Failed to get response from openai api due to getting RateLimitError or Timeout for {cls.retry_timeout} seconds."
)
return response
if isinstance(err, Timeout):
if "request_timeout" in config:
raise
@ -237,13 +246,6 @@ class Completion(openai_Completion):
cls._cache.set(key, response)
cls._book_keeping(config, response)
return response
logger.warning(
f"Failed to get response from openai api due to getting RateLimitError or Timeout for {cls.retry_timeout} seconds."
)
response = -1
if use_cache:
cls._cache.set(key, response)
return response
@classmethod
def _get_max_valid_n(cls, key, max_tokens):
@ -266,6 +268,7 @@ class Completion(openai_Completion):
@classmethod
def _get_region_key(cls, config):
# get a key for the valid/invalid region corresponding to the given config
config = cls._pop_subspace(config, always_copy=False)
return (
config["model"],
config.get("prompt", config.get("messages")),
@ -282,31 +285,28 @@ class Completion(openai_Completion):
invalid_n[max_tokens] = min(num_completions, invalid_n.get(max_tokens, np.inf))
@classmethod
def _pop_subspace(cls, config):
def _pop_subspace(cls, config, always_copy=True):
if "subspace" in config:
config = config.copy()
config.update(config.pop("subspace"))
return config
return config.copy() if always_copy else config
@classmethod
def _get_prompt_messages_from_config(cls, model, config):
prompt, messages = None, None
if model in cls.chat_models or issubclass(cls, ChatCompletion):
# either "prompt" should be in config (for being compatible with non-chat models)
# or "messages" should be in config (for tuning chat models only)
prompt = config.get("prompt")
messages = config.get("messages")
# either prompt or messages should be in config, but not both
assert (prompt is None) != (
messages is None
), "Either prompt or messages should be in config for chat models."
if prompt is None:
messages = cls._messages[messages]
else:
prompt = cls._prompts[prompt]
def _get_params_for_create(cls, config: Dict) -> Dict:
"""Get the params for the openai api call from a config in the search space."""
params = cls._pop_subspace(config)
if cls._prompts:
params["prompt"] = cls._prompts[config["prompt"]]
else:
prompt = cls._prompts[config["prompt"]]
return prompt, messages
params["messages"] = cls._messages[config["messages"]]
if "stop" in params:
params["stop"] = cls._stops and cls._stops[params["stop"]]
temperature_or_top_p = params.pop("temperature_or_top_p", None)
if temperature_or_top_p:
params.update(temperature_or_top_p)
if cls._config_list and "config_list" not in params:
params["config_list"] = cls._config_list
return params
@classmethod
def _eval(cls, config: dict, prune=True, eval_only=False):
@ -315,7 +315,8 @@ class Completion(openai_Completion):
Args:
config (dict): Hyperparameter setting for the openai api call.
prune (bool, optional): Whether to enable pruning. Defaults to True.
eval_only (bool, optional): Whether to evaluate only (ignore the inference budget and no timeout).
eval_only (bool, optional): Whether to evaluate only
(ignore the inference budget and do not rasie error when a request fails).
Defaults to False.
Returns:
@ -323,18 +324,18 @@ class Completion(openai_Completion):
"""
cost = 0
data = cls.data
config = cls._pop_subspace(config)
model = config["model"]
params = cls._get_params_for_create(config)
model = params["model"]
data_length = len(data)
price = cls.price1K.get(model)
price_input, price_output = price if isinstance(price, tuple) else (price, price)
inference_budget = getattr(cls, "inference_budget", None)
prune_hp = getattr(cls, "_prune_hp", "n")
metric = cls._metric
config_n = config.get(prune_hp, 1) # default value in OpenAI is 1
max_tokens = config.get("max_tokens", np.inf if model in cls.chat_models else 16)
prompt, messages = cls._get_prompt_messages_from_config(model, config)
stop = cls._stops and cls._stops[config["stop"]]
config_n = params.get(prune_hp, 1) # default value in OpenAI is 1
max_tokens = params.get(
"max_tokens", np.inf if model in cls.chat_models or issubclass(cls, ChatCompletion) else 16
)
target_output_tokens = None
if not cls.avg_input_tokens:
input_tokens = [None] * data_length
@ -365,12 +366,6 @@ class Completion(openai_Completion):
else:
start_n = config_n
region_key = None
params = config.copy()
if "stop" in config:
params["stop"] = stop
temperature_or_top_p = params.pop("temperature_or_top_p", None)
if temperature_or_top_p:
params.update(temperature_or_top_p)
num_completions, previous_num_completions = start_n, 0
n_tokens_list, result, responses_list = [], {}, []
while True: # n <= config_n
@ -383,9 +378,9 @@ class Completion(openai_Completion):
for i in range(prev_data_limit, data_limit):
logger.debug(f"num_completions={num_completions}, data instance={i}")
data_i = data[i]
params = cls._construct_params(data_i, params, prompt, messages)
response = cls._get_response(params, eval_only)
if response == -1: # rate limit error, treat as invalid
# params = cls._construct_params(data_i, params, prompt, messages)
response = cls.create(data_i, raise_error=eval_only, **params)
if response == -1: # rate limit/timeout error, treat as invalid
cls._update_invalid_n(prune, region_key, max_tokens, num_completions)
result[metric] = 0
result["cost"] = cost
@ -398,7 +393,7 @@ class Completion(openai_Completion):
if not cls.avg_input_tokens and not input_tokens[i]:
# store the # input tokens
input_tokens[i] = n_input_tokens
query_cost = (price_input * n_input_tokens + price_output * n_output_tokens) / 1000
query_cost = response["cost"]
cls._total_cost += query_cost
cost += query_cost
if cls.optimization_budget and cls._total_cost >= cls.optimization_budget and not eval_only:
@ -489,15 +484,15 @@ class Completion(openai_Completion):
@classmethod
def tune(
cls,
data,
metric,
mode,
eval_func,
log_file_name=None,
inference_budget=None,
optimization_budget=None,
num_samples=1,
logging_level=logging.WARNING,
data: List[Dict],
metric: str,
mode: str,
eval_func: Callable,
log_file_name: Optional[str] = None,
inference_budget: Optional[float] = None,
optimization_budget: Optional[float] = None,
num_samples: Optional[int] = 1,
logging_level: Optional[int] = logging.WARNING,
**config,
):
"""Tune the parameters for the OpenAI API call.
@ -597,6 +592,11 @@ class Completion(openai_Completion):
if not (isinstance(cls._stops, list) and isinstance(cls._stops[0], list)):
cls._stops = [cls._stops]
space["stop"] = tune.choice(list(range(len(cls._stops))))
cls._config_list = space.get("config_list")
if cls._config_list is not None:
is_const = is_constant(cls._config_list)
if is_const:
space.pop("config_list")
cls._metric, cls._mode = metric, mode
cls._total_cost = 0 # total optimization cost
cls._eval_func = eval_func
@ -663,16 +663,9 @@ class Completion(openai_Completion):
verbose=3,
)
config = analysis.best_config
params = cls._pop_subspace(config)
if cls._prompts:
params["prompt"] = cls._prompts[config["prompt"]]
else:
params["messages"] = cls._messages[config["messages"]]
stop = cls._stops and cls._stops[config["stop"]]
params["stop"] = stop
temperature_or_top_p = params.pop("temperature_or_top_p", None)
if temperature_or_top_p:
params.update(temperature_or_top_p)
params = cls._get_params_for_create(config)
if cls._config_list is not None and is_const:
params.pop("config_list")
logger.setLevel(old_level)
return params, analysis
@ -681,14 +674,16 @@ class Completion(openai_Completion):
cls,
context: Optional[Dict] = None,
use_cache: Optional[bool] = True,
config_list: Optional[List] = None,
config_list: Optional[List[Dict]] = None,
filter_func: Optional[Callable[[Dict, Dict, Dict], bool]] = None,
raise_error: Optional[bool] = True,
**config,
):
"""Make a completion for a given context.
Args:
context (Dict, Optional): The context to instantiate the prompt.
It needs to contain keys that are used by the prompt template.
It needs to contain keys that are used by the prompt template or the filter function.
E.g., `prompt="Complete the following sentence: {prefix}, context={"prefix": "Today I feel"}`.
The actual prompt will be:
"Complete the following sentence: Today I feel".
@ -714,19 +709,28 @@ class Completion(openai_Completion):
"api_key": os.environ.get("OPENAI_API_KEY"),
"api_type": "open_ai",
"api_base": "https://api.openai.com/v1",
"api_version": None,
},
{
"model": "llama-7B",
"api_base": "http://127.0.0.1:8080",
"api_type": "open_ai",
"api_version": None,
}
],
prompt="Hi",
)
```
filter_func (Callable, Optional): A function that takes in the context, the config and the response and returns a boolean to indicate whether the response is valid. E.g.,
```python
def yes_or_no_filter(context, config, response):
return context.get("yes_or_no_choice", False) is False or any(
text in ["Yes.", "No."] for text in oai.Completion.extract_text(response)
)
```
raise_error (bool, Optional): Whether to raise error when all configs fail.
When set to False, -1 will be returned when all configs fail.
**config: Configuration for the completion.
Besides the parameters for the openai API call, it can also contain a seed (int) for the cache.
This is useful when implementing "controlled randomness" for the completion.
@ -739,28 +743,39 @@ class Completion(openai_Completion):
raise ERROR
if config_list:
retry_timeout = cls.retry_timeout
last = len(config_list) - 1
cost = 0
for i, each_config in enumerate(config_list):
base_config = config.copy()
base_config.update(each_config)
try:
cls.retry_timeout = 0 if i < len(config_list) - 1 else retry_timeout
# retry_timeout = 0 to avoid retrying
return cls.create(context, use_cache, **base_config)
cls.retry_timeout = 0 if i < last and filter_func is None else retry_timeout
# retry_timeout = 0 to avoid retrying when no filter is given
response = cls.create(context, use_cache, **base_config)
pass_filter = filter_func is None or filter_func(
context=context, base_config=config, response=response
)
if pass_filter or i == last:
response["cost"] = cost + response["cost"]
response["config_id"] = i
response["pass_filter"] = pass_filter
return response
cost += response["cost"]
except (AuthenticationError, RateLimitError, Timeout):
logger.info(f"failed with config {i}", exc_info=1)
if i == len(config_list) - 1:
logger.debug(f"failed with config {i}", exc_info=1)
if i == last:
raise
finally:
cls.retry_timeout = retry_timeout
params = cls._construct_params(context, config)
if not use_cache:
return cls._get_response(params, eval_only=True, use_cache=False)
return cls._get_response(params, raise_error=raise_error, use_cache=False)
seed = cls.seed
if "seed" in params:
cls.set_cache(params.pop("seed"))
with diskcache.Cache(cls.cache_path) as cls._cache:
cls.set_cache(seed)
return cls._get_response(params, eval_only=True)
return cls._get_response(params, raise_error=raise_error)
@classmethod
def _instantiate(cls, template: str, context: Optional[Dict] = None):
@ -810,18 +825,17 @@ class Completion(openai_Completion):
def test(
cls,
data,
config,
eval_func=None,
use_cache=True,
agg_method="avg",
return_responses_and_per_instance_result=False,
logging_level=logging.WARNING,
**config,
):
"""Evaluate the responses created with the config for the OpenAI API call.
Args:
data (list): The list of test data points.
config (dict): Hyperparameter setting for the openai api call.
eval_func (Callable): The evaluation function for responses per data instance.
The function should take a list of responses and a data point as input,
and return a dict of metrics. You need to either provide a valid callable
@ -867,6 +881,7 @@ class Completion(openai_Completion):
return_responses_and_per_instance_result (bool): Whether to also return responses
and per instance results in addition to the aggregated results.
logging_level (optional): logging level. Defaults to logging.WARNING.
**config (dict): parametes passed to the openai api call `create()`.
Returns:
None when no valid eval_func is provided in either test or tune;
@ -881,7 +896,7 @@ class Completion(openai_Completion):
for i, data_i in enumerate(data):
logger.info(f"evaluating data instance {i}")
response = cls.create(data_i, use_cache, **config)
cost += cls.cost(response)
cost += response["cost"]
# evaluate the quality of the responses
responses = cls.extract_text(response)
if eval_func is not None:
@ -947,11 +962,12 @@ class Completion(openai_Completion):
response (dict): The response from OpenAI API.
Returns:
The cost in USD.
The cost in USD. 0 if the model is not supported.
"""
model = response["model"]
if model not in cls.price1K:
raise ValueError(f"Unknown model: {model}")
return 0
# raise ValueError(f"Unknown model: {model}")
usage = response["usage"]
n_input_tokens = usage["prompt_tokens"]
n_output_tokens = usage.get("completion_tokens", 0)

Просмотреть файл

@ -0,0 +1,142 @@
import os
import json
from typing import List, Optional, Dict
import logging
NON_CACHE_KEY = ["api_key", "api_base", "api_type", "api_version"]
def get_key(config):
"""Get a unique identifier of a configuration.
Args:
config (dict or list): A configuration.
Returns:
tuple: A unique identifier which can be used as a key for a dict.
"""
copied = False
for key in NON_CACHE_KEY:
if key in config:
config, copied = config.copy() if not copied else config, True
config.pop(key)
# if isinstance(config, dict):
# return tuple(get_key(x) for x in sorted(config.items()))
# if isinstance(config, list):
# return tuple(get_key(x) for x in config)
# return config
return json.dumps(config, sort_keys=True)
def get_config_list(
api_keys: List, api_bases: Optional[List] = None, api_type: Optional[str] = None, api_version: Optional[str] = None
) -> List[Dict]:
"""Get a list of configs for openai api calls.
Args:
api_keys (list): The api keys for openai api calls.
api_bases (list, optional): The api bases for openai api calls.
api_type (str, optional): The api type for openai api calls.
api_version (str, optional): The api version for openai api calls.
"""
config_list = []
for i, api_key in enumerate(api_keys):
if not api_key.strip():
continue
config = {"api_key": api_key}
if api_bases:
config["api_base"] = api_bases[i]
if api_type:
config["api_type"] = api_type
if api_version:
config["api_version"] = api_version
config_list.append(config)
return config_list
def config_list_openai_aoai(
key_file_path: Optional[str] = ".",
openai_api_key_file: Optional[str] = "key_openai.txt",
aoai_api_key_file: Optional[str] = "key_aoai.txt",
aoai_api_base_file: Optional[str] = "base_aoai.txt",
) -> List[Dict]:
"""Get a list of configs for openai + azure openai api calls.
Args:
key_file_path (str, optional): The path to the key files.
openai_api_key_file (str, optional): The file name of the openai api key.
aoai_api_key_file (str, optional): The file name of the azure openai api key.
aoai_api_base_file (str, optional): The file name of the azure openai api base.
Returns:
list: A list of configs for openai api calls.
"""
if "OPENAI_API_KEY" not in os.environ:
try:
os.environ["OPENAI_API_KEY"] = open(f"{key_file_path}/{openai_api_key_file}").read().strip()
except FileNotFoundError:
logging.info(
"To use OpenAI API, please set OPENAI_API_KEY in os.environ "
"or create key_openai.txt in the specified path, or specify the api_key in config_list."
)
if "AZURE_OPENAI_API_KEY" not in os.environ:
try:
os.environ["AZURE_OPENAI_API_KEY"] = open(f"{key_file_path}/{aoai_api_key_file}").read().strip()
except FileNotFoundError:
logging.info(
"To use Azure OpenAI API, please set AZURE_OPENAI_API_KEY in os.environ "
"or create key_aoai.txt in the specified path, or specify the api_key in config_list."
)
if "AZURE_OPENAI_API_BASE" not in os.environ:
try:
os.environ["AZURE_OPENAI_API_BASE"] = open(f"{key_file_path}/{aoai_api_base_file}").read().strip()
except FileNotFoundError:
logging.info(
"To use Azure OpenAI API, please set AZURE_OPENAI_API_BASE in os.environ "
"or create base_aoai.txt in the specified path, or specify the api_base in config_list."
)
aoai_config = get_config_list(
# Assuming Azure OpenAI api keys in os.environ["AZURE_OPENAI_API_KEY"], in separated lines
api_keys=os.environ.get("AZURE_OPENAI_API_KEY", "").split("\n"),
# Assuming Azure OpenAI api bases in os.environ["AZURE_OPENAI_API_BASE"], in separated lines
api_bases=os.environ.get("AZURE_OPENAI_API_BASE", "").split("\n"),
api_type="azure",
api_version="2023-03-15-preview", # change if necessary
)
openai_config = get_config_list(
# Assuming OpenAI API_KEY in os.environ["OPENAI_API_KEY"]
api_keys=os.environ.get("OPENAI_API_KEY", "").split("\n"),
# "api_type": "open_ai",
# "api_base": "https://api.openai.com/v1",
)
config_list = openai_config + aoai_config
return config_list
def config_list_gpt4_gpt35(
key_file_path: Optional[str] = ".",
openai_api_key_file: Optional[str] = "key_openai.txt",
aoai_api_key_file: Optional[str] = "key_aoai.txt",
aoai_api_base_file: Optional[str] = "base_aoai.txt",
) -> List[Dict]:
"""Get a list of configs for gpt-4 followed by gpt-3.5 api calls.
Args:
key_file_path (str, optional): The path to the key files.
openai_api_key_file (str, optional): The file name of the openai api key.
aoai_api_key_file (str, optional): The file name of the azure openai api key.
aoai_api_base_file (str, optional): The file name of the azure openai api base.
Returns:
list: A list of configs for openai api calls.
"""
config_list = config_list_openai_aoai(
key_file_path,
openai_api_key_file,
aoai_api_key_file,
aoai_api_base_file,
)
return [{**config, "model": "gpt-4"} for config in config_list] + [
{**config, "model": "gpt-3.5-turbo"} for config in config_list
]

Просмотреть файл

@ -11,7 +11,7 @@ try:
except (ImportError, AssertionError):
from . import sample
from .searcher.variant_generator import generate_variants
from typing import Dict, Optional, Any, Tuple, Generator
from typing import Dict, Optional, Any, Tuple, Generator, List, Union
import numpy as np
import logging
@ -27,6 +27,29 @@ def generate_variants_compatible(
return generate_variants(unresolved_spec, constant_grid_search)
def is_constant(space: Union[Dict, List]) -> bool:
"""Whether the search space is all constant.
Returns:
A bool of whether the search space is all constant.
"""
if isinstance(space, dict):
for domain in space.values():
if isinstance(domain, (dict, list)):
if not is_constant(domain):
return False
continue
if isinstance(domain, sample.Domain):
return False
return True
elif isinstance(space, list):
for item in space:
if not is_constant(item):
return False
return True
return not isinstance(space, sample.Domain)
def define_by_run_func(trial, space: Dict, path: str = "") -> Optional[Dict[str, Any]]:
"""Define-by-run function to create the search space.

Разница между файлами не показана из-за своего большого размера Загрузить разницу

Разница между файлами не показана из-за своего большого размера Загрузить разницу

Просмотреть файл

@ -757,7 +757,7 @@
}
],
"source": [
"# result = oai.Completion.test(test_data, config)\n",
"# result = oai.Completion.test(test_data, **config)\n",
"# print(\"performance on test data with the tuned config:\", result)"
]
},

Просмотреть файл

@ -313,7 +313,7 @@
"source": [
"### Evaluate the success rate on the test data\n",
"\n",
"You can use flaml's `oai.ChatCompletion.test` to evaluate the performance of an entire dataset with the tuned config."
"You can use flaml's `oai.ChatCompletion.test` to evaluate the performance of an entire dataset with a config."
]
},
{
@ -325,7 +325,7 @@
"import logging\n",
"\n",
"config_n1 = {\"model\": 'gpt-4', \"prompt\": prompt, \"max_tokens\": 600, \"n\": 1}\n",
"n1_result = oai.ChatCompletion.test(test_data[:50], config_n1, eval_math_responses)\n",
"n1_result = oai.ChatCompletion.test(test_data[:50], eval_math_responses, **config_n1)\n",
"print(n1_result)"
]
},
@ -336,8 +336,8 @@
"outputs": [],
"source": [
"oai.ChatCompletion.request_timeout = 120\n",
"config_n10 = {\"model\": 'gpt-4', \"prompt\": prompts[0], \"max_tokens\": 600, \"n\": 10}\n",
"n10_result = oai.ChatCompletion.test(test_data[:50], config_n10, eval_math_responses, logging_level=logging.INFO)\n",
"config_n10 = {\"model\": 'gpt-4', \"prompt\": prompt, \"max_tokens\": 600, \"n\": 10}\n",
"n10_result = oai.ChatCompletion.test(test_data[:50], eval_math_responses, logging_level=logging.INFO, **config_n10)\n",
"print(n10_result)"
]
},
@ -347,8 +347,8 @@
"metadata": {},
"outputs": [],
"source": [
"config_n30 = {\"model\": 'gpt-4', \"prompt\": prompts[0], \"max_tokens\": 600, \"n\": 30}\n",
"n30_result = oai.ChatCompletion.test(test_data[:50], config_n30, eval_math_responses, logging_level=logging.INFO)\n",
"config_n30 = {\"model\": 'gpt-4', \"prompt\": prompt, \"max_tokens\": 600, \"n\": 30}\n",
"n30_result = oai.ChatCompletion.test(test_data[:50], eval_math_responses, logging_level=logging.INFO, **config_n30)\n",
"print(n30_result)"
]
},

Просмотреть файл

@ -4,6 +4,7 @@ import numpy as np
import pytest
from functools import partial
import os
import json
from flaml import oai
from flaml.autogen.code_utils import (
eval_function_completions,
@ -17,6 +18,54 @@ from flaml.autogen.code_utils import (
)
from flaml.autogen.math_utils import eval_math_responses, solve_problem
KEY_LOC = "test/autogen"
here = os.path.abspath(os.path.dirname(__file__))
def yes_or_no_filter(context, response, **_):
return context.get("yes_or_no_choice", False) is False or any(
text in ["Yes.", "No."] for text in oai.Completion.extract_text(response)
)
def valid_json_filter(response, **_):
for text in oai.Completion.extract_text(response):
try:
json.loads(text)
return True
except ValueError:
pass
return False
def test_filter():
try:
import openai
except ImportError as exc:
print(exc)
return
response = oai.Completion.create(
context={"yes_or_no_choice": True},
config_list=[{"model": "text-ada-001"}, {"model": "gpt-3.5-turbo"}, {"model": "text-davinci-003"}],
prompt="Is 37 a prime number? Please answer 'Yes.' or 'No.'",
filter_func=yes_or_no_filter,
)
assert oai.Completion.extract_text(response)[0] in ["Yes.", "No."]
response = oai.Completion.create(
context={"yes_or_no_choice": False},
config_list=[{"model": "text-ada-001"}, {"model": "gpt-3.5-turbo"}, {"model": "text-davinci-003"}],
prompt="Is 37 a prime number?",
filter_func=yes_or_no_filter,
)
assert response["model"] == "text-ada-001"
response = oai.Completion.create(
config_list=[{"model": "text-ada-001"}, {"model": "gpt-3.5-turbo"}, {"model": "text-davinci-003"}],
prompt="How to construct a json request to Bing API to search for 'latest AI news'? Return the JSON request.",
filter_func=valid_json_filter,
)
assert response["config_id"] == 2 or response["pass_filter"], "the response must pass filter unless all fail"
assert not response["pass_filter"] or json.loads(oai.Completion.extract_text(response)[0])
def test_chatcompletion():
params = oai.ChatCompletion._construct_params(
@ -46,36 +95,7 @@ def test_multi_model():
print(exc)
return
response = oai.Completion.create(
config_list=[
{
"model": "gpt-4",
"api_key": os.environ.get("OPENAI_API_KEY"),
"api_type": "open_ai",
"api_base": "https://api.openai.com/v1",
"api_version": None,
},
{
"model": "gpt-4",
"api_key": os.environ.get("AZURE_OPENAI_API_KEY"),
"api_type": "azure",
"api_base": os.environ.get("AZURE_OPENAI_API_BASE"),
"api_version": "2023-03-15-preview",
},
{
"model": "gpt-3.5-turbo",
"api_key": os.environ.get("OPENAI_API_KEY"),
"api_type": "open_ai",
"api_base": "https://api.openai.com/v1",
"api_version": None,
},
{
"model": "gpt-3.5-turbo",
"api_key": os.environ.get("AZURE_OPENAI_API_KEY"),
"api_type": "azure",
"api_base": os.environ.get("AZURE_OPENAI_API_BASE"),
"api_version": "2023-03-15-preview",
},
],
config_list=oai.config_list_gpt4_gpt35(KEY_LOC),
prompt="Hi",
)
print(response)
@ -96,7 +116,7 @@ def test_execute_code():
# read a file
print(execute_code("with open('tmp/codetest.py', 'r') as f: a=f.read()"))
# create a file
print(execute_code("with open('tmp/codetest.py', 'w') as f: f.write('b=1')", work_dir="test/openai/my_tmp"))
print(execute_code("with open('tmp/codetest.py', 'w') as f: f.write('b=1')", work_dir=f"{here}/my_tmp"))
# execute code in a file
print(execute_code(filename="tmp/codetest.py"))
# execute code for assertion error
@ -116,25 +136,29 @@ def test_improve():
except ImportError as exc:
print(exc)
return
config_list = oai.config_list_openai_aoai(KEY_LOC)
improved, _ = improve_function(
"flaml/autogen/math_utils.py",
"solve_problem",
"Solve math problems accurately, by avoiding calculation errors and reduce reasoning errors.",
config_list=config_list,
)
with open("test/openai/math_utils.py.improved", "w") as f:
with open(f"{here}/math_utils.py.improved", "w") as f:
f.write(improved)
suggestion, _ = improve_code(
["flaml/autogen/code_utils.py", "flaml/autogen/math_utils.py"],
"leverage generative AI smartly and cost-effectively",
config_list=config_list,
)
print(suggestion)
improvement, cost = improve_code(
["flaml/autogen/code_utils.py", "flaml/autogen/math_utils.py"],
"leverage generative AI smartly and cost-effectively",
suggest_only=False,
config_list=config_list,
)
print(cost)
with open("test/openai/suggested_improvement.txt", "w") as f:
with open(f"{here}/suggested_improvement.txt", "w") as f:
f.write(improvement)
@ -196,7 +220,7 @@ print(f"Text: {text}")
"""
)
print(code)
solution, cost = solve_problem("1+1=")
solution, cost = solve_problem("1+1=", config_list=oai.config_list_gpt4_gpt35(KEY_LOC))
print(solution, cost)
@ -226,6 +250,7 @@ def test_humaneval(num_samples=1):
}
for x in range(n_tune_data, len(data))
]
oai.Completion.clear_cache(cache_path_root="{here}/cache")
oai.Completion.set_cache(seed)
try:
import openai
@ -233,6 +258,7 @@ def test_humaneval(num_samples=1):
except ImportError as exc:
print(exc)
return
oai.Completion.clear_cache(400)
# a minimal tuning example
config, _ = oai.Completion.tune(
data=tune_data,
@ -254,7 +280,8 @@ def test_humaneval(num_samples=1):
prompt="{definition}",
)
responses = oai.Completion.create(context=test_data[0], **config)
# a minimal tuning example for tuning chat completion models using the Completion class
# a minimal tuning example for tuning chat completion models using the ChatCompletion class
config_list = oai.config_list_openai_aoai(KEY_LOC)
config, _ = oai.ChatCompletion.tune(
data=tune_data,
metric="expected_success",
@ -262,12 +289,14 @@ def test_humaneval(num_samples=1):
eval_func=eval_function_completions,
n=1,
messages=[{"role": "user", "content": "{definition}"}],
config_list=config_list,
)
responses = oai.ChatCompletion.create(context=test_data[0], **config)
responses = oai.ChatCompletion.create(context=test_data[0], config_list=config_list, **config)
print(responses)
code, cost, _ = implement(tune_data[1], [config])
code, cost, selected = implement(tune_data[1], [{**config_list[-1], **config}])
print(code)
print(cost)
assert selected == 0
print(eval_function_completions([code], **tune_data[1]))
# a more comprehensive tuning example
config2, analysis = oai.Completion.tune(
@ -295,9 +324,11 @@ def test_humaneval(num_samples=1):
oai.Completion.data = test_data[:num_samples]
result = oai.Completion._eval(analysis.best_config, prune=False, eval_only=True)
print("result without pruning", result)
result = oai.Completion.test(test_data[:num_samples], config=config2)
result = oai.Completion.test(test_data[:num_samples], **config2)
print(result)
code, cost, selected = implement(tune_data[1], [config2, config])
print(code)
print(cost)
print(selected)
print(eval_function_completions([code], **tune_data[1]))
@ -352,12 +383,12 @@ def test_math(num_samples=-1):
"stop": "###",
}
test_data_sample = test_data[0:3]
result = oai.ChatCompletion.test(test_data_sample, vanilla_config, eval_math_responses)
result = oai.ChatCompletion.test(test_data_sample, eval_math_responses, **vanilla_config)
result = oai.ChatCompletion.test(
test_data_sample,
vanilla_config,
eval_math_responses,
agg_method="median",
**vanilla_config,
)
def my_median(results):
@ -368,13 +399,12 @@ def test_math(num_samples=-1):
result = oai.ChatCompletion.test(
test_data_sample,
vanilla_config,
eval_math_responses,
agg_method=my_median,
**vanilla_config,
)
result = oai.ChatCompletion.test(
test_data_sample,
vanilla_config,
eval_math_responses,
agg_method={
"expected_success": my_median,
@ -382,6 +412,7 @@ def test_math(num_samples=-1):
"success_vote": my_average,
"votes": np.mean,
},
**vanilla_config,
)
print(result)
@ -399,7 +430,7 @@ def test_math(num_samples=-1):
stop="###", # the stop sequence
)
print("tuned config", config)
result = oai.ChatCompletion.test(test_data_sample, config)
result = oai.ChatCompletion.test(test_data_sample, config_list=oai.config_list_openai_aoai(KEY_LOC), **config)
print("result from tuned config:", result)
print("empty responses", eval_math_responses([], None))
@ -407,13 +438,15 @@ def test_math(num_samples=-1):
if __name__ == "__main__":
import openai
openai.api_key = os.environ["OPENAI_API_KEY"] = open("test/openai/key.txt").read().strip()
os.environ["AZURE_OPENAI_API_KEY"] = open("test/openai/key_azure.txt").read().strip()
os.environ["AZURE_OPENAI_API_BASE"] = open("test/openai/base_azure.txt").read().strip()
test_chatcompletion()
config_list = oai.config_list_openai_aoai(KEY_LOC)
assert len(config_list) >= 3, config_list
openai.api_key = os.environ["OPENAI_API_KEY"]
# test_filter()
# test_chatcompletion()
# test_multi_model()
# test_execute_code()
# test_improve()
# test_nocontext()
# test_humaneval(1)
test_humaneval(1)
# test_math(1)

Просмотреть файл

@ -0,0 +1,64 @@
import sys
import os
import pytest
try:
import openai
skip = False
except ImportError:
skip = True
here = os.path.abspath(os.path.dirname(__file__))
def run_notebook(input_nb, output_nb="executed_openai_notebook.ipynb", save=False):
import nbformat
from nbconvert.preprocessors import ExecutePreprocessor
from nbconvert.preprocessors import CellExecutionError
try:
file_path = os.path.join(here, os.pardir, os.pardir, os.pardir, "notebook", input_nb)
with open(file_path) as nb_file:
nb = nbformat.read(nb_file, as_version=4)
preprocessor = ExecutePreprocessor(timeout=4800, kernel_name="python3")
preprocessor.preprocess(nb, {"metadata": {"path": here}})
output_file_name = "executed_openai_notebook_output.txt"
output_file = os.path.join(here, output_file_name)
with open(output_file, "a") as nb_output_file:
for cell in nb.cells:
if cell.cell_type == "code" and "outputs" in cell:
for output in cell.outputs:
if "text" in output:
nb_output_file.write(output["text"].strip() + "\n")
elif "data" in output and "text/plain" in output["data"]:
nb_output_file.write(output["data"]["text/plain"].strip() + "\n")
except CellExecutionError:
raise
finally:
if save:
with open(os.path.join(here, output_nb), "w", encoding="utf-8") as nb_executed_file:
nbformat.write(nb, nb_executed_file)
@pytest.mark.skipif(
skip or not sys.version.startswith("3.10"),
reason="do not run openai test if openai is not installed or py!=3.10",
)
def test_autogen_openai_completion(save=False):
run_notebook("autogen_openai_completion.ipynb", save=save)
@pytest.mark.skipif(
skip or not sys.version.startswith("3.11"),
reason="do not run openai test if openai is not installed or py!=3.11",
)
def test_autogen_chatgpt_gpt4(save=False):
run_notebook("autogen_chatgpt_gpt4.ipynb", save=save)
if __name__ == "__main__":
test_autogen_chatgpt_gpt4(save=True)
test_autogen_openai_completion(save=True)

Просмотреть файл

@ -1,6 +1,10 @@
import os
from flaml.autogen.code_utils import extract_code
from flaml import oai
KEY_LOC = "test/autogen"
here = os.path.abspath(os.path.dirname(__file__))
def test_extract_code():
print(extract_code("```bash\npython temp.py\n```"))
@ -12,12 +16,13 @@ def test_coding_agent(human_input_mode="NEVER", max_consecutive_auto_reply=10):
except ImportError:
return
from flaml.autogen.agent.coding_agent import PythonAgent
from flaml.autogen.agent.human_proxy_agent import HumanProxyAgent
from flaml.autogen.agent.user_proxy_agent import UserProxyAgent
config_list = oai.config_list_gpt4_gpt35(key_file_path=KEY_LOC)
conversations = {}
oai.ChatCompletion.start_logging(conversations)
agent = PythonAgent("coding_agent", request_timeout=600, seed=42)
user = HumanProxyAgent(
agent = PythonAgent("coding_agent", request_timeout=600, seed=42, config_list=config_list)
user = UserProxyAgent(
"user",
human_input_mode=human_input_mode,
max_consecutive_auto_reply=max_consecutive_auto_reply,
@ -48,8 +53,9 @@ def test_tsp(human_input_mode="NEVER", max_consecutive_auto_reply=10):
except ImportError:
return
from flaml.autogen.agent.coding_agent import PythonAgent
from flaml.autogen.agent.human_proxy_agent import HumanProxyAgent
from flaml.autogen.agent.user_proxy_agent import UserProxyAgent
config_list = oai.config_list_openai_aoai(key_file_path=KEY_LOC)
hard_questions = [
"What if we must go from node 1 to node 2?",
"Can we double all distances?",
@ -57,14 +63,14 @@ def test_tsp(human_input_mode="NEVER", max_consecutive_auto_reply=10):
]
oai.ChatCompletion.start_logging()
agent = PythonAgent("coding_agent", temperature=0)
user = HumanProxyAgent(
agent = PythonAgent("coding_agent", temperature=0, config_list=config_list)
user = UserProxyAgent(
"user",
work_dir="test/autogen",
work_dir=f"{here}",
human_input_mode=human_input_mode,
max_consecutive_auto_reply=max_consecutive_auto_reply,
)
with open("test/autogen/tsp_prompt.txt", "r") as f:
with open(f"{here}/tsp_prompt.txt", "r") as f:
prompt = f.read()
# agent.receive(prompt.format(question=hard_questions[0]), user)
# agent.receive(prompt.format(question=hard_questions[1]), user)
@ -74,14 +80,6 @@ def test_tsp(human_input_mode="NEVER", max_consecutive_auto_reply=10):
if __name__ == "__main__":
import openai
openai.api_key_path = "test/openai/key.txt"
# if you use Azure OpenAI, comment the above line and uncomment the following lines
# openai.api_type = "azure"
# openai.api_base = "https://<your_endpoint>.openai.azure.com/"
# openai.api_version = "2023-03-15-preview" # change if necessary
# openai.api_key = "<your_api_key>"
# test_extract_code()
test_coding_agent(human_input_mode="TERMINATE")
# when GPT-4, i.e., the DEFAULT_MODEL, is used, conversation in the following test

Просмотреть файл

@ -1,5 +1,7 @@
from flaml import oai
KEY_LOC = "test/autogen"
def test_human_agent():
try:
@ -7,12 +9,12 @@ def test_human_agent():
except ImportError:
return
from flaml.autogen.agent.chat_agent import ChatAgent
from flaml.autogen.agent.human_proxy_agent import HumanProxyAgent
from flaml.autogen.agent.user_proxy_agent import UserProxyAgent
conversations = {}
oai.ChatCompletion.start_logging(conversations)
agent = ChatAgent("chat_agent")
user = HumanProxyAgent("human_user", human_input_mode="NEVER", max_consecutive_auto_reply=2)
agent = ChatAgent("chat_agent", config_list=oai.config_list_gpt4_gpt35(key_file_path=KEY_LOC))
user = UserProxyAgent("human_user", human_input_mode="NEVER", max_consecutive_auto_reply=2)
agent.receive(
"""Write python code to solve the equation x^3=125. You must write code in the following format. You must always print the result.
Wait for me to return the result.
@ -27,13 +29,4 @@ def test_human_agent():
if __name__ == "__main__":
import openai
openai.api_key_path = "test/openai/key.txt"
# if you use Azure OpenAI, comment the above line and uncomment the following lines
# openai.api_type = "azure"
# openai.api_base = "https://<your_endpoint>.openai.azure.com/"
# openai.api_version = "2023-03-15-preview" # change if necessary
# openai.api_key = "<your_api_key>"
# test_extract_code()
test_human_agent()

Просмотреть файл

@ -1,62 +0,0 @@
import nbformat
from nbconvert.preprocessors import ExecutePreprocessor
from nbconvert.preprocessors import CellExecutionError
import os
import pytest
try:
import openai
skip = False
except ImportError:
skip = True
here = os.path.abspath(os.path.dirname(__file__))
def run_notebook(input_nb, output_nb="executed_openai_notebook.ipynb", save=False):
try:
file_path = os.path.join(here, os.pardir, os.pardir, "notebook", input_nb)
with open(file_path) as f:
nb = nbformat.read(f, as_version=4)
ep = ExecutePreprocessor(timeout=4800, kernel_name="python3")
ep.preprocess(nb, {"metadata": {"path": here}})
output_file_name = "executed_openai_notebook_output.txt"
output_file = os.path.join(here, output_file_name)
with open(output_file, "a") as f:
for cell in nb.cells:
if cell.cell_type == "code" and "outputs" in cell:
for output in cell.outputs:
if "text" in output:
f.write(output["text"].strip() + "\n")
elif "data" in output and "text/plain" in output["data"]:
f.write(output["data"]["text/plain"].strip() + "\n")
except CellExecutionError:
raise
finally:
if save:
with open(os.path.join(here, output_nb), "w", encoding="utf-8") as f:
nbformat.write(nb, f)
@pytest.mark.skipif(
skip,
reason="do not run openai test if openai is not installed",
)
def test_autogen_openai(save=False):
run_notebook("autogen_openai.ipynb", save=save)
@pytest.mark.skipif(
skip,
reason="do not run openai test if openai is not installed",
)
def test_autogen_chatgpt(save=False):
run_notebook("autogen_chatgpt.ipynb", save=save)
if __name__ == "__main__":
test_autogen_chatgpt(save=True)
test_autogen_openai(save=True)

Просмотреть файл

@ -38,6 +38,6 @@ We invite contributions from anyone interested in this topic and look forward to
## For Further Reading
* [Documentation about `flaml.autogen`](/docs/Use-Cases/Auto-Generation)
* [Code Example: Tune chatGPT for Math Problem Solving with FLAML](https://github.com/microsoft/FLAML/blob/main/notebook/autogen_chatgpt.ipynb)
* [Code Example: Tune chatGPT for Math Problem Solving with FLAML](https://github.com/microsoft/FLAML/blob/main/notebook/autogen_chatgpt_gpt4.ipynb)
*Do you have any experience to share about LLM applications? Do you like to see more support or research of LLMOps? Please join our [Discord](https://discord.gg/Cppx2vSPVP) server for discussion.*

Просмотреть файл

@ -128,10 +128,10 @@ print(eval_with_generated_assertions(oai.Completion.extract_text(response), **tu
You can use flaml's `oai.Completion.test` to evaluate the performance of an entire dataset with the tuned config.
```python
result = oai.Completion.test(test_data, config)
result = oai.Completion.test(test_data, **config)
print("performance on test data with the tuned config:", result)
```
The result will vary with the inference budget and optimization budget.
[Link to notebook](https://github.com/microsoft/FLAML/blob/main/notebook/autogen_openai.ipynb) | [Open in colab](https://colab.research.google.com/github/microsoft/FLAML/blob/main/notebook/autogen_openai.ipynb)
[Link to notebook](https://github.com/microsoft/FLAML/blob/main/notebook/autogen_openai_completion.ipynb) | [Open in colab](https://colab.research.google.com/github/microsoft/FLAML/blob/main/notebook/autogen_openai_completion.ipynb)

Просмотреть файл

@ -26,6 +26,9 @@ There are also complex interactions among subsets of the hyperparameters. For ex
the temperature and top_p are not recommended to be altered from their default values together because they both control the randomness of the generated text, and changing both at the same time can result in conflicting effects; n and best_of are rarely tuned together because if the application can process multiple outputs, filtering on the server side causes unnecessary information loss; both n and max_tokens will affect the total number of tokens generated, which in turn will affect the cost of the request.
These interactions and trade-offs make it difficult to manually determine the optimal hyperparameter settings for a given text generation task.
*Do the choices matter? Check this [blog post](/blog/2023/04/21/LLM-tuning-math) for a case study.*
## Tune Hyperparameters
The tuning can be performed with the following information:
@ -46,8 +49,9 @@ The evaluation function should take a list of responses, and other keyword argum
```python
def eval_math_responses(responses: List[str], solution: str, **args) -> Dict:
# select a response from the list of responses
answer = voted_answer(responses)
# check whether the answer is correct
return {"success": True or False}
return {"success": is_equivalent(answer, solution)}
```
[`flaml.autogen.code_utils`](../reference/autogen/code_utils) and [`flaml.autogen.math_utils`](../reference/autogen/math_utils) offer some example evaluation functions for code generation and math problem solving.
@ -100,6 +104,8 @@ The returned `config` contains the optimized configuration and `analysis` contai
The tuend config can be used to perform inference.
*Refer to this [page](../Examples/AutoGen-OpenAI) for a full example.*
## Perform Inference
One can use [`flaml.oai.Completion.create`](../reference/autogen/oai/completion#create) to perform inference.
@ -120,6 +126,8 @@ API call results are cached locally and reused when the same request is issued.
### Error handling
#### Runtime error
It is easy to hit error when calling OpenAI APIs, due to connection, rate limit, or timeout. Some of the errors are transient. `flaml.oai.Completion.create` deals with the transient errors and retries automatically. Initial request timeout, retry timeout and retry time interval can be configured via `flaml.oai.request_timeout`, `flaml.oai.retry_timeout` and `flaml.oai.retry_time`.
Moreover, one can pass a list of configurations of different models/endpoints to mitigate the rate limits. For example,
@ -155,6 +163,29 @@ response = oai.Completion.create(
It will try querying Azure OpenAI gpt-4, OpenAI gpt-3.5-turbo, and a locally hosted llama-7B one by one, ignoring AuthenticationError, RateLimitError and Timeout,
until a valid result is returned. This can speed up the development process where the rate limit is a bottleneck. An error will be raised if the last choice fails. So make sure the last choice in the list has the best availability.
#### Logic error
Another type of error is that the returned response does not satisfy a requirement. For example, if the response is required to be a valid json string, one would like to filter the responses that are not. This can be achieved by providing a list of configurations and a filter function. For example,
```python
def valid_json_filter(context, config, response):
for text in oai.Completion.extract_text(response):
try:
json.loads(text)
return True
except ValueError:
pass
return False
response = oai.Completion.create(
config_list=[{"model": "text-ada-001"}, {"model": "gpt-3.5-turbo"}, {"model": "text-davinci-003"}],
prompt="How to construct a json request to Bing API to search for 'latest AI news'? Return the JSON request.",
filter_func=valid_json_filter,
)
```
The example above will try to use text-ada-001, gpt-3.5-turbo, and text-davinci-003 iteratively, until a valid json string is returned or the last config is used. One can also repeat the same model in the list for multiple times to try one model multiple times for increasing the robustness of the final response.
### Templating
If the provided prompt or message is a template, it will be automatically materialized with a given context. For example,
@ -359,5 +390,5 @@ The compact history is more efficient and the individual API call history contai
*Interested in trying it yourself? Please check the following notebook examples:*
* [Optimize for Code Gen](https://github.com/microsoft/FLAML/blob/main/notebook/autogen_openai.ipynb)
* [Optimize for Math](https://github.com/microsoft/FLAML/blob/main/notebook/autogen_chatgpt.ipynb)
* [Optimize for Code Gen](https://github.com/microsoft/FLAML/blob/main/notebook/autogen_openai_completion.ipynb)
* [Optimize for Math](https://github.com/microsoft/FLAML/blob/main/notebook/autogen_chatgpt_gpt4.ipynb)