зеркало из https://github.com/golang/vulndb.git
216 строки
5.1 KiB
Go
216 строки
5.1 KiB
Go
// Copyright 2023 The Go Authors. All rights reserved.
|
|
// Use of this source code is governed by a BSD-style
|
|
// license that can be found in the LICENSE file.
|
|
|
|
package genai
|
|
|
|
import (
|
|
"encoding/csv"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"os"
|
|
"path/filepath"
|
|
"strings"
|
|
|
|
"github.com/google/go-cmp/cmp"
|
|
)
|
|
|
|
// Examples represents a list of input/output pairs to be used
|
|
// as examples to follow by the Suggest function.
|
|
// Can be used to create few-shot prompts or fine-tune a model.
|
|
type Examples []*Example
|
|
|
|
type Example struct {
|
|
Input `json:"Input"`
|
|
Suggestion `json:"Suggestion"`
|
|
}
|
|
|
|
const (
|
|
dataFolder = "data"
|
|
promptFile = "example_prompt.txt"
|
|
csvFile = "examples.csv"
|
|
jsonFile = "examples.json"
|
|
)
|
|
|
|
// WriteFiles writes the examples to the given folder in the following formats:
|
|
// - <folder>/data/examples.json: a JSON array of the examples
|
|
// (used directly by the Suggest function)
|
|
// - <folder>/data/examples.csv: a CSV-formatted list of each example,
|
|
// where the input and output are comma-separated JSON objects.
|
|
// Can be used as an input to a Makersuite data prompt.
|
|
// - <folder>/data/prompt.txt: the prompt that will be used by Suggest
|
|
// (with placeholder data for the final input)
|
|
func (es Examples) WriteFiles(folder string) error {
|
|
dir := filepath.Join(folder, dataFolder)
|
|
if err := os.MkdirAll(dir, 0755); err != nil {
|
|
return fmt.Errorf("failed to create directory %q: %s", dir, err)
|
|
}
|
|
|
|
for filename, write := range map[string]func(w io.Writer) error{
|
|
promptFile: es.writePrompt,
|
|
csvFile: es.writeCSV,
|
|
jsonFile: es.writeJSON,
|
|
} {
|
|
f, err := os.Create(filepath.Join(folder, dataFolder, filename))
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer f.Close()
|
|
if err := write(f); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func checkFiles(folder string, minExamples int) error {
|
|
fpath := func(fname string) string {
|
|
return filepath.Join(folder, dataFolder, fname)
|
|
}
|
|
|
|
// Check the JSON file.
|
|
jsonExamples, err := readFile(fpath(jsonFile))
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if len(jsonExamples) < minExamples {
|
|
return fmt.Errorf("%s has fewer than %d examples", jsonFile, minExamples)
|
|
}
|
|
|
|
// Check the CSV file.
|
|
csvExamples, err := readFile(fpath(csvFile))
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if diff := cmp.Diff(csvExamples, jsonExamples); diff != "" {
|
|
return fmt.Errorf("CSV and JSON examples don't match (-csv +json):\n%s", diff)
|
|
}
|
|
|
|
// Check the example prompt.
|
|
b, err := os.ReadFile(fpath(promptFile))
|
|
if err != nil {
|
|
return fmt.Errorf("could not read %s: %v", promptFile, err)
|
|
}
|
|
gotPrompt := string(b)
|
|
if !strings.HasPrefix(gotPrompt, defaultPreamble) {
|
|
return fmt.Errorf("prompt in %s does not start with default preamble", promptFile)
|
|
}
|
|
wantPrompt, err := jsonExamples.placeholderPrompt()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if diff := cmp.Diff(wantPrompt, gotPrompt); diff != "" {
|
|
return fmt.Errorf("prompt mismatch (-want +got):\n%s", diff)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (es *Examples) ReadJSON(r io.Reader) error {
|
|
d := json.NewDecoder(r)
|
|
return d.Decode(es)
|
|
}
|
|
|
|
func readFile(filename string) (Examples, error) {
|
|
file, err := os.Open(filename)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer file.Close()
|
|
var examples Examples
|
|
switch filepath.Ext(filename) {
|
|
case ".csv":
|
|
if err := examples.readCSV(file); err != nil {
|
|
return nil, err
|
|
}
|
|
case ".json":
|
|
if err := examples.ReadJSON(file); err != nil {
|
|
return nil, err
|
|
}
|
|
default:
|
|
return nil, fmt.Errorf("unsupported file type: %s", filename)
|
|
}
|
|
return examples, nil
|
|
}
|
|
|
|
var placeholderInput = &Input{
|
|
Module: "INPUT MODULE",
|
|
Description: "INPUT DESCRIPTION",
|
|
}
|
|
|
|
func (es Examples) placeholderPrompt() (string, error) {
|
|
return newPrompt(placeholderInput, defaultPreamble, es, defaultMaxExamples)
|
|
}
|
|
|
|
func (es Examples) writePrompt(w io.Writer) error {
|
|
prompt, err := es.placeholderPrompt()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
_, err = w.Write([]byte(prompt))
|
|
return err
|
|
}
|
|
|
|
func (es Examples) writeCSV(w io.Writer) error {
|
|
cw := csv.NewWriter(w)
|
|
for _, e := range es {
|
|
in, out, err := e.marshal()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if err := cw.Write([]string{in, out}); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
cw.Flush()
|
|
return cw.Error()
|
|
}
|
|
|
|
func (es *Examples) readCSV(r io.Reader) error {
|
|
cr := csv.NewReader(r)
|
|
records, err := cr.ReadAll()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
for _, record := range records {
|
|
if len(record) != 2 {
|
|
return fmt.Errorf("unexpected CSV record: %v", record)
|
|
}
|
|
e, err := unmarshal(record[0], record[1])
|
|
if err != nil {
|
|
return err
|
|
}
|
|
*es = append(*es, e)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (es Examples) writeJSON(w io.Writer) error {
|
|
e := json.NewEncoder(w)
|
|
return e.Encode(es)
|
|
}
|
|
|
|
func unmarshal(input string, output string) (*Example, error) {
|
|
var e Example
|
|
if err := json.Unmarshal([]byte(input), &e.Input); err != nil {
|
|
return nil, err
|
|
}
|
|
if err := json.Unmarshal([]byte(output), &e.Suggestion); err != nil {
|
|
return nil, err
|
|
}
|
|
return &e, nil
|
|
}
|
|
|
|
func (e *Example) marshal() (input string, output string, err error) {
|
|
ib, err := json.Marshal(e.Input)
|
|
if err != nil {
|
|
return "", "", err
|
|
}
|
|
ob, err := json.Marshal(e.Suggestion)
|
|
if err != nil {
|
|
return "", "", err
|
|
}
|
|
return string(ib), string(ob), nil
|
|
}
|