extract code from text; solve_problem; request_timeout in config; improve code (#999)

* extract code from text

* solve_problem; request_timeout in config

* improve

* move import statement

* improve code

* generate assertions

* constant

* configs for implement; voting

* doc

* execute code in docker

* success indicator of code executation in docker

* success indicator

* execute code

* strip n

* add cost in generate_code

* add docstr

* filename

* bytes

* check docker version

* print log

* python test

* remove api key address

* rename exit code

* success exit code

* datasets

* exit code

* recover openai tests

* cache and pattern match

* wait

* wait

* cache and test

* timeout test

* python image name and skip macos

* windows image

* docker images

* volume path and yaml

* win path -> posix

* extensions

* path

* path

* path

* path

* path

* path

* path

* path

* path

* path

* path

* skip windows

* path

* timeout in windows

* use_docker

* use_docker

* hot fix from #1000

---------

Co-authored-by: Qingyun Wu <qingyun.wu@psu.edu>
This commit is contained in:
Chi Wang 2023-04-23 04:50:29 -07:00 коммит произвёл GitHub
Родитель 7114b8f742
Коммит fa5ccea862
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
16 изменённых файлов: 486 добавлений и 70 удалений

4
.github/workflows/openai.yml поставляемый
Просмотреть файл

@ -29,10 +29,10 @@ jobs:
python-version: ${{ matrix.python-version }}
- name: Install packages and dependencies
run: |
docker --version
python -m pip install --upgrade pip wheel
pip install -e .
pip install -e .[autogen,blendsearch]
python -c "import flaml"
pip install -e .[openai]
- name: Coverage
env:
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}

Просмотреть файл

@ -3,8 +3,8 @@
[![Build](https://github.com/microsoft/FLAML/actions/workflows/python-package.yml/badge.svg)](https://github.com/microsoft/FLAML/actions/workflows/python-package.yml)
![Python Version](https://img.shields.io/badge/3.7%20%7C%203.8%20%7C%203.9%20%7C%203.10-blue)
[![Downloads](https://pepy.tech/badge/flaml)](https://pepy.tech/project/flaml)
<!-- [![Join the chat at https://gitter.im/FLAMLer/community](https://badges.gitter.im/FLAMLer/community.svg)](https://gitter.im/FLAMLer/community?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge) -->
[![](https://img.shields.io/discord/1025786666260111483?logo=discord&style=flat)](https://discord.gg/Cppx2vSPVP)
<!-- [![Join the chat at https://gitter.im/FLAMLer/community](https://badges.gitter.im/FLAMLer/community.svg)](https://gitter.im/FLAMLer/community?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge) -->
# A Fast Library for Automated Machine Learning & Tuning

Просмотреть файл

@ -0,0 +1,2 @@
DEFAULT_MODEL = "gpt-4"
FAST_MODEL = "gpt-3.5-turbo"

Просмотреть файл

@ -1,56 +1,287 @@
import signal
import subprocess
import sys
import os
import pathlib
from typing import List, Dict, Tuple, Optional, Union, Callable
from flaml import oai
import re
import time
from flaml.autogen import oai, DEFAULT_MODEL, FAST_MODEL
# Regular expression for finding a code block
CODE_BLOCK_PATTERN = r"```\w*\n(.*?)\n```"
WORKING_DIR = os.path.join(os.path.dirname(os.path.realpath(__file__)), "extensions")
def extract_code(text: str, pattern: str = CODE_BLOCK_PATTERN) -> str:
# Use a regular expression to find the code block
match = re.search(pattern, text, flags=re.DOTALL)
# If a match is found, return the code
if match:
return match.group(1)
# If no code block is found, return the whole text
return text
def generate_code(pattern: str = CODE_BLOCK_PATTERN, **config) -> Tuple[str, float]:
"""Generate code.
Args:
pattern (Optional, str): The regular expression pattern for finding the code block.
The default pattern is for finding a code block in a markdown file.
config (Optional, dict): The configuration for the API call.
Returns:
str: The generated code.
float: The cost of the generation.
"""
response = oai.Completion.create(**config)
cost = oai.Completion.cost(config["model"], response)
return extract_code(oai.Completion.extract_text(response)[0], pattern), cost
_IMPROVE_FUNCTION_CONFIG = {
"prompt": """Improve the function '{func_name}' to achieve the objective '{objective}'.
The current implementation of the function is as follows:
{file_string}""",
"model": DEFAULT_MODEL,
"request_timeout": 300,
}
def improve_function(file_name, func_name, objective, **config):
"""(work in progress) Improve the function to achieve the objective."""
params = {**_IMPROVE_FUNCTION_CONFIG, **config}
# read the entire file into a str
with open(file_name, "r") as f:
file_string = f.read()
response = oai.Completion.create(
{"func_name": func_name, "objective": objective, "file_string": file_string}, **params
)
cost = oai.Completion.cost(params["model"], response)
return oai.Completion.extract_text(response)[0], cost
_IMPROVE_CODE_CONFIG = {
"prompt": """Analyze the code in the following files and return a list of suggestions for improvement{followup}, to achieve the objective of '{objective}'.
{code}
""",
"model": DEFAULT_MODEL,
"request_timeout": 900,
}
def improve_code(files, objective, suggest_only=True, **config):
"""Improve the code to achieve a given objective.
Args:
files (list): A list of file names containing the source code.
objective (str): The objective to achieve.
suggest_only (bool): Whether to return only the suggestions or the improved code.
config (Optional, dict): The configuration for the API call.
Returns:
str: The improved code if suggest_only=False; a list of suggestions if suggest_only=True (default).
float: The cost of the generation.
"""
code = ""
for file_name in files:
# read the entire file into a string
with open(file_name, "r") as f:
file_string = f.read()
code += f"""{file_name}:
{file_string}
"""
params = {**_IMPROVE_CODE_CONFIG, **config}
followup = "" if suggest_only else " followed by the improved code"
response = oai.Completion.create({"objective": objective, "code": code, "followup": followup}, **params)
cost = oai.Completion.cost(params["model"], response)
return oai.Completion.extract_text(response)[0], cost
def timeout_handler(signum, frame):
raise TimeoutError("Timed out!")
def execute_code(code: str, max_exec_time: Optional[int] = 3):
signal.signal(signal.SIGALRM, timeout_handler)
code = code.strip()
with open("codetest.py", "w") as fout:
fout.write(code)
try:
signal.alarm(max_exec_time)
result = subprocess.run(
[sys.executable, "codetest.py"],
stdout=subprocess.DEVNULL,
stderr=subprocess.PIPE,
)
signal.alarm(0)
except TimeoutError:
return 0
return int(result.returncode == 0)
def execute_code(
code: Optional[str] = None,
timeout: Optional[int] = 600,
filename: Optional[str] = None,
work_dir: Optional[str] = None,
use_docker: Optional[bool] = True,
) -> Tuple[int, bytes]:
"""Execute code in a docker container.
This function is not tested on MacOS.
Args:
code (Optional, str): The code to execute.
If None, the code from the file specified by filename will be executed.
Either code or filename must be provided.
timeout (Optional, int): The maximum execution time in seconds.
filename (Optional, str): The file name to save the code or where the code is stored when `code` is None.
If None, a file with a randomly generated name will be created.
The randomly generated file will be deleted after execution.
The file name must be a relative path. Relative paths are relative to the working directory.
work_dir (Optional, str): The working directory for the code execution.
If None, a default working directory will be used.
The default working directory is the "extensions" directory under
"xxx/flaml/autogen", where "xxx" is the path to the flaml package.
use_docker (Optional, bool): Whether to use a docker container for code execution.
If True, the code will be executed in a docker container.
If False, the code will be executed in the current environment.
Default is True. If the code is executed in the current environment,
the code must be trusted.
Returns:
int: 0 if the code executes successfully.
bytes: The error message if the code fails to execute; the stdout otherwise.
"""
assert code is not None or filename is not None, "Either code or filename must be provided."
original_filename = filename
if filename is None:
code_hash = hash(code)
# create a file with a automatically generated name
filename = f"tmp_code_{code_hash}.py"
if work_dir is None:
work_dir = WORKING_DIR
filepath = os.path.join(work_dir, filename)
file_dir = os.path.dirname(filepath)
os.makedirs(file_dir, exist_ok=True)
if code is not None:
code = code.strip()
with open(filepath, "w") as fout:
fout.write(code)
# check if already running in a docker container
in_docker_container = os.path.exists("/.dockerenv")
if not use_docker or in_docker_container:
# already running in a docker container
signal.signal(signal.SIGALRM, timeout_handler)
try:
signal.alarm(timeout)
# run the code in a subprocess in the current docker container in the working directory
result = subprocess.run(
[sys.executable, filename],
cwd=work_dir,
capture_output=True,
)
signal.alarm(0)
except TimeoutError:
if original_filename is None:
os.remove(filepath)
return 1, "Timeout"
if original_filename is None:
os.remove(filepath)
return result.returncode, result.stderr if result.returncode else result.stdout
import docker
from requests.exceptions import ReadTimeout, ConnectionError
# create a docker client
client = docker.from_env()
image_list = ["python:3-alpine", "python:3", "python:3-windowsservercore"]
for image in image_list:
# check if the image exists
try:
client.images.get(image)
break
except docker.errors.ImageNotFound:
# pull the image
print("Pulling image", image)
try:
client.images.pull(image)
break
except docker.errors.DockerException:
print("Failed to pull image", image)
# get a randomized str based on current time to wrap the exit code
exit_code_str = f"exitcode{time.time()}"
abs_path = pathlib.Path(work_dir).absolute()
# if sys.platform == "win32":
# abs_path = str(abs_path).replace("\\", "/")
# abs_path = f"/{abs_path[0].lower()}{abs_path[2:]}"
# create a docker container
container = client.containers.run(
image,
command=[
"sh",
"-c",
f"python {filename}; exit_code=$?; echo -n {exit_code_str}; echo -n $exit_code; echo {exit_code_str}",
],
working_dir="/workspace",
detach=True,
# get absolute path to the working directory
volumes={abs_path: {"bind": "/workspace", "mode": "rw"}},
)
start_time = time.time()
while container.status != "exited" and time.time() - start_time < timeout:
# Reload the container object
container.reload()
if container.status != "exited":
container.stop()
container.remove()
if original_filename is None:
os.remove(filepath)
return 1, "Timeout"
# try:
# container.wait(timeout=timeout)
# except (ReadTimeout, ConnectionError):
# container.stop()
# container.remove()
# if original_filename is None:
# os.remove(filepath)
# return 1, "Timeout"
# get the container logs
logs = container.logs().decode("utf-8").rstrip()
# remove the container
container.remove()
# check if the code executed successfully
exit_code = container.attrs["State"]["ExitCode"]
if exit_code == 0:
# extract the exit code from the logs
pattern = re.compile(f"{exit_code_str}(\\d+){exit_code_str}")
match = pattern.search(logs)
exit_code = int(match.group(1))
# remove the exit code from the logs
logs = pattern.sub("", logs)
logs = bytes(logs, "utf-8")
if original_filename is None:
os.remove(filepath)
# return the exit code and logs
return exit_code, logs
def generate_assertions(definition: str, model: Optional[str] = "gpt-3.5-turbo") -> Tuple[str, float]:
_GENERATE_ASSERTIONS_CONFIG = {
"prompt": """Given the signature and docstring, write the exactly same number of assertion(s) for the provided example(s) in the docstring, without assertion messages.
func signature:
{definition}
assertions:""",
"model": FAST_MODEL,
"max_tokens": 256,
"stop": "\n\n",
}
def generate_assertions(definition: str, **config) -> Tuple[str, float]:
"""Generate assertions for a function.
Args:
definition (str): The function definition, including the signature and docstr.
model (str): The model used for generation.
config (Optional, dict): The configuration for the API call.
Returns:
str: The generated assertions.
float: The cost of the generation.
"""
prompt = """Given the signature and docstring, write the exactly same number of assertion(s) for the provided example(s) in the docstring, without assertion messages.
func signature:
{definition}
assertions:"""
params = {**_GENERATE_ASSERTIONS_CONFIG, **config}
response = oai.Completion.create(
{"definition": definition},
model=model,
prompt=prompt,
max_tokens=256,
stop="\n\n",
**params,
)
cost = oai.Completion.cost(model, response)
cost = oai.Completion.cost(params["model"], response)
assertions = oai.Completion.extract_text(response)[0]
return assertions, cost
@ -70,6 +301,8 @@ def eval_function_completions(
test: Optional[str] = None,
entry_point: Optional[str] = None,
assertions: Optional[Union[str, Callable[[str], Tuple[str, float]]]] = None,
timeout: Optional[float] = 3,
use_docker: Optional[bool] = True,
) -> Dict:
"""Select a response from a list of responses for the function completion task (using generated assertions), and/or evaluate if the task is successful using a gold test.
@ -80,6 +313,7 @@ def eval_function_completions(
entry_point (Optional, str): The name of the function.
assertions (Optional, str or Callable): The assertion code which serves as a filter of the responses, or an assertion generator.
When provided, only the responses that pass the assertions will be considered for the actual test (if provided).
timeout (Optional, float): The timeout for executing the code.
Returns:
dict: The success metrics.
@ -95,7 +329,7 @@ def eval_function_completions(
if response.startswith("def")
else f"{definition}{response}\n{test}\ncheck({entry_point})"
)
success = execute_code(code)
success = execute_code(code, timeout=timeout, use_docker=use_docker)[0] == 0
success_list.append(success)
return {
"expected_success": 1 - pow(1 - sum(success_list) / n, n),
@ -112,7 +346,7 @@ def eval_function_completions(
code = (
f"{response}\n{assertions}" if response.startswith("def") else f"{definition}{response}\n{assertions}"
)
succeed_assertions = execute_code(code)
succeed_assertions = execute_code(code, timeout=timeout, use_docker=use_docker)[0] == 0
if succeed_assertions:
break
else:
@ -132,7 +366,7 @@ def eval_function_completions(
if response.startswith("def")
else f"{definition}{response}\n{test}\ncheck({entry_point})"
)
success = execute_code(code_test)
success = execute_code(code_test, timeout=timeout, use_docker=use_docker)[0] == 0
return {
"index_selected": i,
"succeed_assertions": succeed_assertions,
@ -142,9 +376,20 @@ def eval_function_completions(
}
_FUNC_COMPLETION_PROMPT = "# Python 3{definition}"
_FUNC_COMPLETION_STOP = ["\nclass", "\ndef", "\nif", "\nprint"]
_IMPLEMENT_CONFIGS = [
{"model": FAST_MODEL, "prompt": _FUNC_COMPLETION_PROMPT, "temperature": 0, "seed": 0},
{"model": FAST_MODEL, "prompt": _FUNC_COMPLETION_PROMPT, "stop": _FUNC_COMPLETION_STOP, "n": 7, "seed": 0},
{"model": DEFAULT_MODEL, "prompt": _FUNC_COMPLETION_PROMPT, "temperature": 0, "seed": 1},
{"model": DEFAULT_MODEL, "prompt": _FUNC_COMPLETION_PROMPT, "stop": _FUNC_COMPLETION_STOP, "n": 2, "seed": 2},
{"model": DEFAULT_MODEL, "prompt": _FUNC_COMPLETION_PROMPT, "stop": _FUNC_COMPLETION_STOP, "n": 1, "seed": 2},
]
def implement(
definition: str,
configs: List[Dict],
configs: Optional[List[Dict]] = None,
assertions: Optional[Union[str, Callable[[str], Tuple[str, float]]]] = generate_assertions,
) -> Tuple[str, float]:
"""Implement a function from a definition.
@ -160,6 +405,7 @@ def implement(
int: The index of the configuration which generates the implementation.
"""
cost = 0
configs = configs or _IMPLEMENT_CONFIGS
if len(configs) > 1 and callable(assertions):
assertions, cost = assertions(definition)
for i, config in enumerate(configs):

Просмотреть файл

Просмотреть файл

@ -1,4 +1,28 @@
from typing import Optional
from flaml.autogen import oai, DEFAULT_MODEL
_MATH_PROMPT = "{problem} Solve the problem carefully. Simplify your answer as much as possible. Put the final answer in \\boxed{{}}."
_MATH_CONFIG = {
"model": DEFAULT_MODEL,
"prompt": _MATH_PROMPT,
}
def solve_problem(problem: str, **config) -> str:
"""(work in progress) Solve the math problem.
Args:
problem (str): The problem statement.
config (Optional, dict): The configuration for the API call.
Returns:
str: The solution to the problem.
"""
params = {**_MATH_CONFIG, **config}
response = oai.Completion.create({"problem": problem}, **params)
cost = oai.Completion.cost(params["model"], response)
results = eval_math_responses(oai.Completion.extract_text(response))
return results.get("voted_answer"), cost
def remove_boxed(string: str) -> Optional[str]:

Просмотреть файл

@ -145,9 +145,10 @@ class Completion:
request_timeout = cls.request_timeout
while True:
try:
response = openai_completion.create(request_timeout=request_timeout, **config)
cls._cache.set(key, response)
return response
if "request_timeout" in config:
response = openai_completion.create(**config)
else:
response = openai_completion.create(request_timeout=request_timeout, **config)
except (
ServiceUnavailableError,
APIError,
@ -170,6 +171,8 @@ class Completion:
else:
break
if isinstance(e, Timeout):
if "request_timeout" in config:
raise
request_timeout <<= 1
request_timeout = min(request_timeout, time_left)
sleep(cls.retry_time)
@ -180,11 +183,16 @@ class Completion:
config["engine"] = config.pop("model").replace("gpt-3.5-turbo", "gpt-35-turbo")
else:
raise
else:
if use_cache:
cls._cache.set(key, response)
return response
logger.warning(
f"Failed to get response from openai api due to getting RateLimitError or Timeout for {cls.retry_timeout} seconds."
)
response = -1
cls._cache.set(key, response)
if use_cache:
cls._cache.set(key, response)
return response
@classmethod

Просмотреть файл

@ -1 +1 @@
__version__ = "1.2.1"
__version__ = "1.2.2"

Просмотреть файл

@ -21,9 +21,9 @@
"\n",
"## Requirements\n",
"\n",
"FLAML requires `Python>=3.7`. To run this notebook example, please install flaml with the [openai] option:\n",
"FLAML requires `Python>=3.7`. To run this notebook example, please install flaml with the [openai,blendsearch] option:\n",
"```bash\n",
"pip install flaml[openai]==1.2.0\n",
"pip install flaml[openai,blendsearch]==1.2.1\n",
"```"
]
},
@ -40,7 +40,7 @@
},
"outputs": [],
"source": [
"# %pip install flaml[openai]==1.2.0 datasets"
"# %pip install flaml[openai,blendsearch]==1.2.1 datasets"
]
},
{

Просмотреть файл

@ -21,9 +21,9 @@
"\n",
"## Requirements\n",
"\n",
"FLAML requires `Python>=3.7`. To run this notebook example, please install flaml with the [openai] option:\n",
"FLAML requires `Python>=3.7`. To run this notebook example, please install flaml with the [autogen,blendsearch] option:\n",
"```bash\n",
"pip install flaml[openai]==1.2.0\n",
"pip install flaml[autogen,blendsearch]==1.2.1\n",
"```"
]
},
@ -40,7 +40,7 @@
},
"outputs": [],
"source": [
"# %pip install flaml[openai]==1.2.0 datasets"
"# %pip install flaml[autogen,blendsearch]==1.2.1 datasets"
]
},
{
@ -297,7 +297,13 @@
"from functools import partial\n",
"from flaml.autogen.code_utils import eval_function_completions, generate_assertions\n",
"\n",
"eval_with_generated_assertions = partial(eval_function_completions, assertions=generate_assertions)"
"eval_with_generated_assertions = partial(\n",
" eval_function_completions,\n",
" assertions=generate_assertions,\n",
" use_docker=False,\n",
" # Please set use_docker=True if you have docker available to run the generated code.\n",
" # Using docker is safer than running the generated code directly.\n",
")\n"
]
},
{

Просмотреть файл

@ -19,9 +19,9 @@
"\n",
"## Requirements\n",
"\n",
"FLAML requires `Python>=3.7`. To run this notebook example, please install flaml with the [openai] option:\n",
"FLAML requires `Python>=3.7`. To run this notebook example, please install flaml with the [autogen] option:\n",
"```bash\n",
"pip install flaml[openai]==1.2.0\n",
"pip install flaml[autogen]==1.2.1\n",
"```"
]
},
@ -38,7 +38,7 @@
},
"outputs": [],
"source": [
"# %pip install flaml[openai]==1.2.0 datasets"
"# %pip install flaml[autogen]==1.2.1 datasets"
]
},
{
@ -381,7 +381,7 @@
"success = 0\n",
"for i, d in enumerate(data):\n",
" response, cost_i, j = implement(d[\"definition\"], configs)\n",
" metrics = eval_function_completions(responses=[response], **d)\n",
" metrics = eval_function_completions(responses=[response], use_docker=False, **d)\n",
" success += metrics[\"success\"]\n",
" cost += cost_i\n",
" print(f\"Example {i}, config {j}, success {success}\")\n",

Просмотреть файл

@ -21,7 +21,7 @@
"\n",
"FLAML requires `Python>=3.7`. To run this notebook example, please install flaml with the [openai] option:\n",
"```bash\n",
"pip install flaml[openai]==1.2.0\n",
"pip install flaml[openai]==1.2.1\n",
"```"
]
},
@ -38,7 +38,7 @@
},
"outputs": [],
"source": [
"# %pip install flaml[openai]==1.2.0 datasets"
"# %pip install flaml[openai]==1.2.1 datasets"
]
},
{

Просмотреть файл

@ -120,7 +120,8 @@ setuptools.setup(
"pytorch-forecasting>=0.9.0",
],
"benchmark": ["catboost>=0.26", "psutil==5.8.0", "xgboost==1.3.3"],
"openai": ["openai==0.27.4", "diskcache", "optuna==2.8.0"],
"openai": ["openai==0.27.4", "diskcache"],
"autogen": ["openai==0.27.4", "diskcache", "docker"],
"synapse": ["joblibspark>=0.5.0", "optuna==2.8.0", "pyspark>=3.2.0"],
},
classifiers=[

Просмотреть файл

@ -8,8 +8,70 @@ from flaml.autogen.code_utils import (
eval_function_completions,
generate_assertions,
implement,
generate_code,
extract_code,
improve_function,
improve_code,
execute_code,
)
from flaml.autogen.math_utils import eval_math_responses
from flaml.autogen.math_utils import eval_math_responses, solve_problem
@pytest.mark.skipif(
sys.platform in ["darwin", "win32"],
reason="do not run on MacOS or windows",
)
def test_execute_code():
try:
import docker
except ImportError as exc:
print(exc)
return
exitcode, msg = execute_code("print('hello world')", filename="tmp/codetest.py")
assert exitcode == 0 and msg == b"hello world\n", msg
# read a file
print(execute_code("with open('tmp/codetest.py', 'r') as f: a=f.read()"))
# create a file
print(execute_code("with open('tmp/codetest.py', 'w') as f: f.write('b=1')", work_dir="test/openai/my_tmp"))
# execute code in a file
print(execute_code(filename="tmp/codetest.py"))
# execute code for assertion error
exit_code, msg = execute_code("assert 1==2")
assert exit_code, msg
# execute code which takes a long time
exit_code, error = execute_code("import time; time.sleep(2)", timeout=1)
assert exit_code and error == "Timeout"
exit_code, error = execute_code("import time; time.sleep(2)", timeout=1, use_docker=False)
assert exit_code and error == "Timeout"
def test_improve():
try:
import openai
import diskcache
except ImportError as exc:
print(exc)
return
improved, _ = improve_function(
"flaml/autogen/math_utils.py",
"solve_problem",
"Solve math problems accurately, by avoiding calculation errors and reduce reasoning errors.",
)
with open("test/openai/math_utils.py.improved", "w") as f:
f.write(improved)
suggestion, _ = improve_code(
["flaml/autogen/code_utils.py", "flaml/autogen/math_utils.py"],
"leverage generative AI smartly and cost-effectively",
)
print(suggestion)
improvement, cost = improve_code(
["flaml/autogen/code_utils.py", "flaml/autogen/math_utils.py"],
"leverage generative AI smartly and cost-effectively",
suggest_only=False,
)
print(cost)
with open("test/openai/suggested_improvement.txt", "w") as f:
f.write(improvement)
def test_nocontext():
@ -19,8 +81,59 @@ def test_nocontext():
except ImportError as exc:
print(exc)
return
response = oai.Completion.create(model="text-ada-001", prompt="1+1=", max_tokens=1)
response = oai.Completion.create(
model="text-ada-001", prompt="1+1=", max_tokens=1, use_cache=False, request_timeout=10
)
print(response)
code, _ = generate_code(
model="gpt-3.5-turbo",
messages=[
{
"role": "system",
"content": "You want to become a better assistant by learning new skills and improving your existing ones.",
},
{
"role": "user",
"content": "Write reusable code to use web scraping to get information from websites.",
},
],
)
print(code)
# test extract_code from markdown
code = extract_code(
"""
Example:
```
print("hello extract code")
```
"""
)
print(code)
code = extract_code(
"""
Example:
```python
def scrape(url):
import requests
from bs4 import BeautifulSoup
response = requests.get(url)
soup = BeautifulSoup(response.text, "html.parser")
title = soup.find("title").text
text = soup.find("div", {"id": "bodyContent"}).text
return title, text
```
Test:
```python
url = "https://en.wikipedia.org/wiki/Web_scraping"
title, text = scrape(url)
print(f"Title: {title}")
print(f"Text: {text}")
"""
)
print(code)
solution, cost = solve_problem("1+1=")
print(solution, cost)
@pytest.mark.skipif(
@ -102,6 +215,7 @@ def test_humaneval(num_samples=1):
inference_budget=0.002,
optimization_budget=2,
num_samples=num_samples,
# logging_level=logging.INFO,
prompt=[
"{definition}",
"# Python 3{definition}",
@ -175,12 +289,10 @@ def test_math(num_samples=-1):
}
test_data_sample = test_data[0:3]
result = oai.ChatCompletion.test(test_data_sample, vanilla_config, eval_math_responses)
test_data_sample = test_data[3:6]
result = oai.ChatCompletion.test(
test_data_sample,
vanilla_config,
eval_math_responses,
use_cache=False,
agg_method="median",
)
@ -194,14 +306,12 @@ def test_math(num_samples=-1):
test_data_sample,
vanilla_config,
eval_math_responses,
use_cache=False,
agg_method=my_median,
)
result = oai.ChatCompletion.test(
test_data_sample,
vanilla_config,
eval_math_responses,
use_cache=False,
agg_method={
"expected_success": my_median,
"success": my_average,
@ -231,9 +341,11 @@ def test_math(num_samples=-1):
if __name__ == "__main__":
import openai
# import openai
openai.api_key_path = "test/openai/key.txt"
test_nocontext()
test_humaneval(1)
test_math(1)
# openai.api_key_path = "test/openai/key.txt"
test_execute_code()
# test_improve()
# test_nocontext()
# test_humaneval(1)
# test_math(1)

Просмотреть файл

@ -5,9 +5,9 @@ In this example, we will tune several hyperparameters for the OpenAI's completio
### Prerequisites
Install the [openai] option. The OpenAI integration is in preview.
Install the [autogen,blendsearch] option. The OpenAI integration is in preview.
```bash
pip install "flaml[openai]==1.2.0"
pip install "flaml[autogen,blendsearch]==1.2.1 datasets"
```
Setup your OpenAI key:

Просмотреть файл

@ -126,12 +126,29 @@ response = oai.Completion.create(problme=problem, prompt="{problem} Solve the pr
```
## Other utilities
`flaml.oai.Completion` also offers some additional utilities, such as:
### Completion
[`flaml.oai.Completion`](../reference/autogen/oai/completion) also offers some additional utilities, such as:
- a [`cost`](../reference/autogen/oai/completion#cost) function to calculate the cost of an API call.
- a [`test`](../reference/autogen/oai/completion#test) function to conveniently evaluate the configuration over test data.
- a [`extract_text`](../reference/autogen/oai/completion#extract_text) function to extract the text from a completion or chat response.
- a [`set_cache`](../reference/autogen/oai/completion#extract_text) function to set the seed and cache path. The caching is introduced in the section above, with the benefit of cost saving, reproducibility, and controlled randomness.
Interested in trying it yourself? Please check the following notebook examples:
### Code
[`flaml.autogen.code_utils`](../reference/autogen/code_utils) offers code-related utilities, such as:
- a [`improve_code`](../reference/autogen/code_utils#improve_code) function to improve code for a given objective.
- a [`generate_assertions`](../reference/autogen/code_utils#generate_assertions) function to generate assertion statements from function signature and docstr.
- a [`implement`](../reference/autogen/code_utils#implement) function to implement a function from a definition.
- a [`eval_function_completions`](../reference/autogen/code_utils#eval_function_completions) function to evaluate the success of a function completion task, or select a response from a list of responses using generated assertions.
### Math
[`flaml.autogen.math_utils`](../reference/autogen/math_utils) offers utilities for math problems, such as:
- a [eval_math_responses](../reference/autogen/math_utils#eval_math_responses) function to select a response using voting, and check if the final answer is correct if the canonical solution is provided.
*Interested in trying it yourself? Please check the following notebook examples:*
* [Optimize for Code Gen](https://github.com/microsoft/FLAML/blob/main/notebook/autogen_openai.ipynb)
* [Optimize for Math](https://github.com/microsoft/FLAML/blob/main/notebook/autogen_chatgpt.ipynb)