diff --git a/client/client.go b/client/client.go index 62c5a466..782e17ea 100644 --- a/client/client.go +++ b/client/client.go @@ -702,20 +702,13 @@ func (cli Client) LoaderEntryStatus(le mig.LoaderEntry, status bool) (err error) } // Change the key on an existing loader entry -func (cli Client) LoaderEntryKey(le mig.LoaderEntry, key string) (err error) { +func (cli Client) LoaderEntryKey(le mig.LoaderEntry) (newle mig.LoaderEntry, err error) { defer func() { if e := recover(); e != nil { err = fmt.Errorf("LoaderEntryKey() -> %v", e) } }() - if key == "" { - panic("invalid loader key specified") - } - err = mig.ValidateLoaderKey(key) - if err != nil { - panic(err) - } - data := url.Values{"loaderid": {fmt.Sprintf("%.0f", le.ID)}, "loaderkey": {key}} + data := url.Values{"loaderid": {fmt.Sprintf("%.0f", le.ID)}} r, err := http.NewRequest("POST", cli.Conf.API.URL+"loader/key/", strings.NewReader(data.Encode())) if err != nil { @@ -743,16 +736,25 @@ func (cli Client) LoaderEntryKey(le mig.LoaderEntry, key string) (err error) { resp.StatusCode, resource.Collection.Error.Message, resource.Collection.Error.Code) panic(err) } + newle, err = ValueToLoaderEntry(resource.Collection.Items[0].Data[0].Value) + if err != nil { + panic(err) + } return } // Post a new loader entry for storage through the API -func (cli Client) PostNewLoader(le mig.LoaderEntry) (err error) { +func (cli Client) PostNewLoader(le mig.LoaderEntry) (newle mig.LoaderEntry, err error) { defer func() { if e := recover(); e != nil { err = fmt.Errorf("PostNewLoader() -> %v", e) } }() + // When adding a new loader entry, the prefix and key values should + // be "", since the server will be generating them. + if le.Prefix != "" || le.Key != "" { + panic("loader key and prefix must be unset") + } lebuf, err := json.Marshal(le) if err != nil { panic(err) @@ -785,6 +787,10 @@ func (cli Client) PostNewLoader(le mig.LoaderEntry) (err error) { resp.StatusCode, resource.Collection.Error.Message, resource.Collection.Error.Code) panic(err) } + newle, err = ValueToLoaderEntry(resource.Collection.Items[0].Data[0].Value) + if err != nil { + panic(err) + } return } diff --git a/client/mig-console/loader.go b/client/mig-console/loader.go index 49339d58..779721ff 100644 --- a/client/mig-console/loader.go +++ b/client/mig-console/loader.go @@ -124,24 +124,20 @@ r refresh the loader entry (get latest version from database) } fmt.Printf("%v\n", string(jsonle)) case "key": - fmt.Printf("New key component must be %v alphanumeric characters long, or type 'generate' to generate one\n", mig.LoaderKeyLength) - lkey, err := readline.String("New key for loader> ") + var nle mig.LoaderEntry + input, err := readline.String("generate new key for loader? (y/n)> ") if err != nil { panic(err) } - if lkey == "" { - panic("invalid key specified") + if input != "y" { + break } - if lkey == "generate" { - lkey = mig.GenerateLoaderKey() - fmt.Printf("New key will be set to %v\n", lkey) - } - fmt.Printf("New key including prefix to supply to client will be %q\n", le.Prefix+lkey) - err = cli.LoaderEntryKey(le, lkey) + nle, err = cli.LoaderEntryKey(le) if err != nil { panic(err) } - fmt.Println("Loader key changed") + fmt.Print("Loader key changed\n") + fmt.Printf("Loader key including prefix to supply to client will be %q\n", nle.Prefix+nle.Key) case "r": reloadfunc() case "": @@ -180,10 +176,6 @@ func loaderCreator(cli client.Client) (err error) { if err != nil { panic(err) } - fmt.Println("Generating loader prefix...") - newle.Prefix = mig.GenerateLoaderPrefix() - fmt.Println("Generating loader key...") - newle.Key = mig.GenerateLoaderKey() // Validate the new loader entry before sending it to the API err = newle.Validate() if err != nil { @@ -194,7 +186,7 @@ func loaderCreator(cli client.Client) (err error) { panic(err) } fmt.Printf("%s\n", jsonle) - fmt.Printf("Loader key including prefix to supply to client will be %q\n", newle.Prefix+newle.Key) + fmt.Print("Server will assign prefix and key on creation\n") input, err := readline.String("create loader entry? (y/n)> ") if err != nil { panic(err) @@ -203,10 +195,11 @@ func loaderCreator(cli client.Client) (err error) { fmt.Println("abort") return } - err = cli.PostNewLoader(newle) + createdle, err := cli.PostNewLoader(newle) if err != nil { panic(err) } - fmt.Println("New entry successfully created but is disabled") + fmt.Printf("Loader key including prefix to supply to client will be %q\n", createdle.Prefix+createdle.Key) + fmt.Printf("New entry successfully created (id %v) but is disabled\n", createdle.ID) return } diff --git a/database/loader.go b/database/loader.go index e4c155b6..8657194e 100644 --- a/database/loader.go +++ b/database/loader.go @@ -256,17 +256,22 @@ func (db *DB) LoaderUpdateKey(lid float64, hashkey []byte, salt []byte) (err err // Add a new loader entry to the database; the hashed loader key should // be provided as hashkey -func (db *DB) LoaderAdd(le mig.LoaderEntry, hashkey []byte, salt []byte) (err error) { +func (db *DB) LoaderAdd(le mig.LoaderEntry, hashkey []byte, salt []byte) (newle mig.LoaderEntry, err error) { var eval sql.NullString if le.ExpectEnv != "" { eval.String = le.ExpectEnv eval.Valid = true } - _, err = db.c.Exec(`INSERT INTO loaders + err = db.c.QueryRow(`INSERT INTO loaders (loadername, keyprefix, loaderkey, salt, lastseen, enabled, expectenv) VALUES - ($1, $2, $3, $4, now(), FALSE, $5)`, le.Name, - le.Prefix, hashkey, salt, eval) + ($1, $2, $3, $4, now(), FALSE, $5) + RETURNING id`, le.Name, + le.Prefix, hashkey, salt, eval).Scan(&le.ID) + if err != nil { + return + } + newle = le return } diff --git a/mig-api/apikey.go b/mig-api/apikey.go index 1f102175..68dbd839 100644 --- a/mig-api/apikey.go +++ b/mig-api/apikey.go @@ -23,6 +23,10 @@ const APIHashedKeyLength = 32 const APISaltLength = 16 func hashAPIKey(key string, salt []byte, keylen int, saltlen int) (ret []byte, retsalt []byte, err error) { + if key == "" { + err = fmt.Errorf("loader key cannot be zero length") + return + } if salt == nil { retsalt = make([]byte, saltlen) _, err = rand.Read(retsalt) diff --git a/mig-api/manifest_endpoints.go b/mig-api/manifest_endpoints.go index fea98fff..ea221095 100644 --- a/mig-api/manifest_endpoints.go +++ b/mig-api/manifest_endpoints.go @@ -578,15 +578,7 @@ func keyLoader(respWriter http.ResponseWriter, request *http.Request) { if err != nil { panic(err) } - lkey := request.FormValue("loaderkey") - if lkey == "" { - // bad request, return 400 - resource.SetError(cljs.Error{ - Code: fmt.Sprintf("%.0f", opid), - Message: "Invalid key specified"}) - respond(http.StatusBadRequest, resource, respWriter, request) - return - } + lkey := mig.GenerateLoaderKey() err = mig.ValidateLoaderKey(lkey) if err != nil { panic(err) @@ -599,6 +591,16 @@ func keyLoader(respWriter http.ResponseWriter, request *http.Request) { if err != nil { panic(err) } + le, err := ctx.DB.GetLoaderFromID(loaderid) + if err != nil { + panic(err) + } + le.Key = lkey + li, err := loaderEntryToItem(le, ctx) + if err != nil { + panic(err) + } + resource.AddItem(li) respond(http.StatusOK, resource, respWriter, request) } @@ -637,16 +639,36 @@ func newLoader(respWriter http.ResponseWriter, request *http.Request) { if err != nil { panic(err) } + // Generate the prefix and key will we use for this new loader entry + le.Prefix = mig.GenerateLoaderPrefix() + le.Key = mig.GenerateLoaderKey() + err = mig.ValidateLoaderPrefixAndKey(le.Prefix + le.Key) + if err != nil { + panic(err) + } // Hash the loader key to provide it to LoaderAdd hkey, salt, err := hashAPIKey(le.Key, nil, mig.LoaderHashedKeyLength, mig.LoaderSaltLength) if err != nil { panic(err) } - err = ctx.DB.LoaderAdd(le, hkey, salt) + createle, err := ctx.DB.LoaderAdd(le, hkey, salt) if err != nil { panic(err) } + // Retain the original loader entry rather than using createle directly here, + // since we will be echoing the new key back to the client and it is omitted from + // the returned LoaderEntry. We want to get the ID that was used for the new loader + // though. + le.ID = createle.ID + li, err := loaderEntryToItem(le, ctx) + if err != nil { + panic(err) + } + err = resource.AddItem(li) + if err != nil { + panic(err) + } respond(http.StatusCreated, resource, respWriter, request) }