// Copyright 2013 The Go Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. package imports import ( "flag" "go/build" "io/ioutil" "os" "path/filepath" "sync" "testing" ) var only = flag.String("only", "", "If non-empty, the fix test to run") var tests = []struct { name string in, out string }{ // Adding an import to an existing parenthesized import { name: "factored_imports_add", in: `package foo import ( "fmt" ) func bar() { var b bytes.Buffer fmt.Println(b.String()) } `, out: `package foo import ( "bytes" "fmt" ) func bar() { var b bytes.Buffer fmt.Println(b.String()) } `, }, // Adding an import to an existing parenthesized import, // verifying it goes into the first section. { name: "factored_imports_add_first_sec", in: `package foo import ( "fmt" "appengine" ) func bar() { var b bytes.Buffer _ = appengine.IsDevServer fmt.Println(b.String()) } `, out: `package foo import ( "bytes" "fmt" "appengine" ) func bar() { var b bytes.Buffer _ = appengine.IsDevServer fmt.Println(b.String()) } `, }, // Adding an import to an existing parenthesized import, // verifying it goes into the first section. (test 2) { name: "factored_imports_add_first_sec_2", in: `package foo import ( "fmt" "appengine" ) func bar() { _ = math.NaN _ = fmt.Sprintf _ = appengine.IsDevServer } `, out: `package foo import ( "fmt" "math" "appengine" ) func bar() { _ = math.NaN _ = fmt.Sprintf _ = appengine.IsDevServer } `, }, // Adding a new import line, without parens { name: "add_import_section", in: `package foo func bar() { var b bytes.Buffer } `, out: `package foo import "bytes" func bar() { var b bytes.Buffer } `, }, // Adding two new imports, which should make a parenthesized import decl. { name: "add_import_paren_section", in: `package foo func bar() { _, _ := bytes.Buffer, zip.NewReader } `, out: `package foo import ( "archive/zip" "bytes" ) func bar() { _, _ := bytes.Buffer, zip.NewReader } `, }, // Make sure we don't add things twice { name: "no_double_add", in: `package foo func bar() { _, _ := bytes.Buffer, bytes.NewReader } `, out: `package foo import "bytes" func bar() { _, _ := bytes.Buffer, bytes.NewReader } `, }, // Remove unused imports, 1 of a factored block { name: "remove_unused_1_of_2", in: `package foo import ( "bytes" "fmt" ) func bar() { _, _ := bytes.Buffer, bytes.NewReader } `, out: `package foo import "bytes" func bar() { _, _ := bytes.Buffer, bytes.NewReader } `, }, // Remove unused imports, 2 of 2 { name: "remove_unused_2_of_2", in: `package foo import ( "bytes" "fmt" ) func bar() { } `, out: `package foo func bar() { } `, }, // Remove unused imports, 1 of 1 { name: "remove_unused_1_of_1", in: `package foo import "fmt" func bar() { } `, out: `package foo func bar() { } `, }, // Don't remove empty imports. { name: "dont_remove_empty_imports", in: `package foo import ( _ "image/png" _ "image/jpeg" ) `, out: `package foo import ( _ "image/jpeg" _ "image/png" ) `, }, // Don't remove dot imports. { name: "dont_remove_dot_imports", in: `package foo import ( . "foo" . "bar" ) `, out: `package foo import ( . "bar" . "foo" ) `, }, // Skip refs the parser can resolve. { name: "skip_resolved_refs", in: `package foo func f() { type t struct{ Println func(string) } fmt := t{Println: func(string) {}} fmt.Println("foo") } `, out: `package foo func f() { type t struct{ Println func(string) } fmt := t{Println: func(string) {}} fmt.Println("foo") } `, }, // Do not add a package we already have a resolution for. { name: "skip_template", in: `package foo import "html/template" func f() { t = template.New("sometemplate") } `, out: `package foo import "html/template" func f() { t = template.New("sometemplate") } `, }, // Don't touch cgo { name: "cgo", in: `package foo /* #include */ import "C" `, out: `package foo /* #include */ import "C" `, }, // Put some things in their own section { name: "make_sections", in: `package foo import ( "os" ) func foo () { _, _ = os.Args, fmt.Println _, _ = appengine.FooSomething, user.Current } `, out: `package foo import ( "fmt" "os" "appengine" "appengine/user" ) func foo() { _, _ = os.Args, fmt.Println _, _ = appengine.FooSomething, user.Current } `, }, // Delete existing empty import block { name: "delete_empty_import_block", in: `package foo import () `, out: `package foo `, }, // Use existing empty import block { name: "use_empty_import_block", in: `package foo import () func f() { _ = fmt.Println } `, out: `package foo import "fmt" func f() { _ = fmt.Println } `, }, // Blank line before adding new section. { name: "blank_line_before_new_group", in: `package foo import ( "fmt" "net" ) func f() { _ = net.Dial _ = fmt.Printf _ = snappy.Foo } `, out: `package foo import ( "fmt" "net" "code.google.com/p/snappy-go/snappy" ) func f() { _ = net.Dial _ = fmt.Printf _ = snappy.Foo } `, }, // Blank line between standard library and third-party stuff. { name: "blank_line_separating_std_and_third_party", in: `package foo import ( "code.google.com/p/snappy-go/snappy" "fmt" "net" ) func f() { _ = net.Dial _ = fmt.Printf _ = snappy.Foo } `, out: `package foo import ( "fmt" "net" "code.google.com/p/snappy-go/snappy" ) func f() { _ = net.Dial _ = fmt.Printf _ = snappy.Foo } `, }, // golang.org/issue/6884 { name: "issue 6884", in: `package main // A comment func main() { fmt.Println("Hello, world") } `, out: `package main import "fmt" // A comment func main() { fmt.Println("Hello, world") } `, }, // golang.org/issue/7132 { name: "issue 7132", in: `package main import ( "fmt" "gu" "github.com/foo/bar" ) var ( a = bar.a b = gu.a c = fmt.Printf ) `, out: `package main import ( "fmt" "gu" "github.com/foo/bar" ) var ( a = bar.a b = gu.a c = fmt.Printf ) `, }, { name: "renamed package", in: `package main var _ = str.HasPrefix `, out: `package main import str "strings" var _ = str.HasPrefix `, }, { name: "fragment with main", in: `func main(){fmt.Println("Hello, world")}`, out: `package main import "fmt" func main() { fmt.Println("Hello, world") } `, }, { name: "fragment without main", in: `func notmain(){fmt.Println("Hello, world")}`, out: `import "fmt" func notmain() { fmt.Println("Hello, world") }`, }, } func TestFixImports(t *testing.T) { simplePkgs := map[string]string{ "fmt": "fmt", "os": "os", "math": "math", "appengine": "appengine", "user": "appengine/user", "zip": "archive/zip", "bytes": "bytes", "snappy": "code.google.com/p/snappy-go/snappy", "str": "strings", } findImport = func(pkgName string, symbols map[string]bool) (string, bool, error) { return simplePkgs[pkgName], pkgName == "str", nil } options := &Options{ TabWidth: 8, TabIndent: true, Comments: true, Fragment: true, } for _, tt := range tests { if *only != "" && tt.name != *only { continue } buf, err := Process(tt.name+".go", []byte(tt.in), options) if err != nil { t.Errorf("error on %q: %v", tt.name, err) continue } if got := string(buf); got != tt.out { t.Errorf("results diff on %q\nGOT:\n%s\nWANT:\n%s\n", tt.name, got, tt.out) } } } func TestFindImportGoPath(t *testing.T) { goroot, err := ioutil.TempDir("", "goimports-") if err != nil { t.Fatal(err) } defer os.RemoveAll(goroot) pkgIndexOnce = sync.Once{} origStdlib := stdlib defer func() { stdlib = origStdlib }() stdlib = nil // Test against imaginary bits/bytes package in std lib bytesDir := filepath.Join(goroot, "src", "pkg", "bits", "bytes") if err := os.MkdirAll(bytesDir, 0755); err != nil { t.Fatal(err) } bytesSrcPath := filepath.Join(bytesDir, "bytes.go") bytesPkgPath := "bits/bytes" bytesSrc := []byte(`package bytes type Buffer2 struct {} `) if err := ioutil.WriteFile(bytesSrcPath, bytesSrc, 0775); err != nil { t.Fatal(err) } oldGOROOT := build.Default.GOROOT oldGOPATH := build.Default.GOPATH build.Default.GOROOT = goroot build.Default.GOPATH = "" defer func() { build.Default.GOROOT = oldGOROOT build.Default.GOPATH = oldGOPATH }() got, rename, err := findImportGoPath("bytes", map[string]bool{"Buffer2": true}) if err != nil { t.Fatal(err) } if got != bytesPkgPath || rename { t.Errorf(`findImportGoPath("bytes", Buffer2 ...)=%q, %t, want "%s", false`, got, rename, bytesPkgPath) } got, rename, err = findImportGoPath("bytes", map[string]bool{"Missing": true}) if err != nil { t.Fatal(err) } if got != "" || rename { t.Errorf(`findImportGoPath("bytes", Missing ...)=%q, %t, want "", false`, got, rename) } } func TestFindImportStdlib(t *testing.T) { tests := []struct { pkg string symbols []string want string }{ {"http", []string{"Get"}, "net/http"}, {"http", []string{"Get", "Post"}, "net/http"}, {"http", []string{"Get", "Foo"}, ""}, {"bytes", []string{"Buffer"}, "bytes"}, {"ioutil", []string{"Discard"}, "io/ioutil"}, } for _, tt := range tests { got, rename, ok := findImportStdlib(tt.pkg, strSet(tt.symbols)) if (got != "") != ok { t.Error("findImportStdlib return value inconsistent") } if got != tt.want || rename { t.Errorf("findImportStdlib(%q, %q) = %q, %t; want %q, false", tt.pkg, tt.symbols, got, rename, tt.want) } } } func strSet(ss []string) map[string]bool { m := make(map[string]bool) for _, s := range ss { m[s] = true } return m }