Initial version of go-sqlcmd (#1)

* commandline and variables

* add displayname to build tasks

* implement quit command

* initial version of batch parsing

* move variables

* new package for variables

* add vscode helpers for debug

* fix go and quit processing

* implement Out and Error commands

* fix custom batch separator

* move connectionString to sqlcmd

* Add sql connection and print column headers

* add row and error processing

* remove unused package

* fix binary rendering and screen fitting

* rewrite decodeBinary for performance

* fix test pipeline

* remove password command line param

* exit on ctrl-c

* implement -q and -Q

* de-lint and update readme

* add lint for PRs

* separate pkg and cmd folders

* fix pipeline for new folder layout

* de-lint and reduce public surface

* more de-linting

* hopefully last round of de-linting
This commit is contained in:
David Shiflet 2021-08-26 10:34:14 -04:00 коммит произвёл GitHub
Родитель 3d4fc056ea
Коммит 76685e94af
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
28 изменённых файлов: 4180 добавлений и 10 удалений

17
.github/workflows/golangci-lint.yml поставляемый Normal file
Просмотреть файл

@ -0,0 +1,17 @@
name: golangci-lint
on:
push:
branches:
- main
pull_request:
jobs:
golangci-pr:
name: lint-pr-changes
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- name: golangci-lint
uses: golangci/golangci-lint-action@v2
with:
version: v1.42.0
only-new-issues: true

5
.gitignore поставляемый
Просмотреть файл

@ -13,3 +13,8 @@
# Dependency directories (remove the comment below to include it)
# vendor/
coverage.json
coverage.txt
coverage.xml
testresults.xml

Просмотреть файл

@ -0,0 +1,80 @@
pool:
vmImage: 'ubuntu-latest'
steps:
- task: GoTool@0
inputs:
version: '1.16.5'
- task: Go@0
displayName: 'Go: get dependencies'
inputs:
command: 'get'
arguments: '-d'
workingDirectory: '$(Build.SourcesDirectory)/cmd/sqlcmd'
- task: Go@0
displayName: 'Go: install gotest.tools/gotestsum'
inputs:
command: 'custom'
customCommand: 'install'
arguments: 'gotest.tools/gotestsum@latest'
workingDirectory: '$(System.DefaultWorkingDirectory)'
- task: Go@0
displayName: 'Go: install github.com/axw/gocov/gocov'
inputs:
command: 'custom'
customCommand: 'install'
arguments: 'github.com/axw/gocov/gocov@latest'
workingDirectory: '$(System.DefaultWorkingDirectory)'
- task: Go@0
displayName: 'Go: install github.com/axw/gocov/gocov'
inputs:
command: 'custom'
customCommand: 'install'
arguments: 'github.com/AlekSi/gocov-xml@latest'
workingDirectory: '$(System.DefaultWorkingDirectory)'
#Your build pipeline references an undefined variables named SQLPASSWORD.
#Create or edit the build pipeline for this YAML file, define the variable on the Variables tab. See https://go.microsoft.com/fwlink/?linkid=865972
- task: Docker@2
displayName: 'Run SQL 2017 docker image'
inputs:
command: run
arguments: '-m 2GB -e ACCEPT_EULA=1 -d --name sql2017 -p:1433:1433 -e SA_PASSWORD=$(SQLPASSWORD) mcr.microsoft.com/mssql/server:2017-latest'
- script: |
~/go/bin/gotestsum --junitfile testresults.xml -- ./... -coverprofile=coverage.txt -covermode count
~/go/bin/gocov convert coverage.txt > coverage.json
~/go/bin/gocov-xml < coverage.json > coverage.xml
mkdir coverage
workingDirectory: '$(Build.SourcesDirectory)'
displayName: 'run tests'
env:
SQLPASSWORD: $(SQLPASSWORD)
SQLCMDUSER: sa
SQLCMDPASSWORD: $(SQLPASSWORD)
continueOnError: true
- task: PublishTestResults@2
displayName: "Publish junit-style results"
inputs:
testResultsFiles: 'testresults.xml'
testResultsFormat: JUnit
searchFolder: '$(Build.SourcesDirectory)'
testRunTitle: 'SQL 2017 - $(Build.SourceBranchName)'
condition: always()
continueOnError: true
- task: PublishCodeCoverageResults@1
inputs:
codeCoverageTool: Cobertura
pathToSources: '$(Build.SourcesDirectory)'
summaryFileLocation: $(Build.SourcesDirectory)/**/coverage.xml
reportDirectory: $(Build.SourcesDirectory)/**/coverage
failIfCoverageEmpty: true
condition: always()
continueOnError: true

27
.vscode/launch.json поставляемый Normal file
Просмотреть файл

@ -0,0 +1,27 @@
{
// Use IntelliSense to learn about possible attributes.
// Hover to view descriptions of existing attributes.
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
"version": "0.2.0",
"configurations": [
{
"name": "Attach using delve",
"type": "go",
"request": "attach",
"preLaunchTask": "delve",
"mode": "remote",
"remotePath": "${workspaceFolder}",
"port" : 23456,
"host" : "127.0.0.1",
"cwd" : "${workspaceFolder}",
},
{
"name" : "Run query and exit",
"type" : "go",
"request": "launch",
"mode" : "auto",
"program": "${fileDirname}",
"args" : ["-Q", "\"select 100 as Count\""],
}
]
}

36
.vscode/tasks.json поставляемый Normal file
Просмотреть файл

@ -0,0 +1,36 @@
{
// See https://go.microsoft.com/fwlink/?LinkId=733558
// for the documentation about the tasks.json format
"version": "2.0.0",
"tasks": [
{
"label": "delve",
"type": "shell",
"command": "dlv debug --headless --listen=:23456 --api-version=2 \"${workspaceFolder}\"",
"isBackground": true,
"presentation": {
"focus": true,
"panel": "dedicated",
"clear": false
},
"group": {
"kind": "build",
"isDefault": true
},
"problemMatcher": {
"pattern": {
"regexp": ""
},
"background": {
"activeOnStart": true,
"beginsPattern": {
"regexp": ".*"
},
"endsPattern": {
"regexp": ".*server listening.*"
}
}
}
}
]
}

Просмотреть файл

@ -1,14 +1,29 @@
# Project
# SQL Utilities - Go edition
> This repo has been populated by an initial template to help get you started. Please
> make sure to update the content to build a great experience for community-building.
This repo contains command line tools and go packages for working with Microsoft SQL Server, Azure SQL Database, and Azure Synapse.
As the maintainer of this project, please make a few updates:
## Sqlcmd
- Improving this README.MD file to provide a great experience
- Updating SUPPORT.MD with content about this project's support experience
- Understanding the security reporting process in SECURITY.MD
- Remove this section from the README
The `sqlcmd` project aims to be a complete port of the native sqlcmd to the `go` language, utilizing the [go-mssqldb](https://github.com/denisenkom/go-mssqldb) driver. For full documentation of the tool, see https://docs.microsoft.com/sql/tools/sqlcmd-utility
### Breaking changes
We will be implementing as many command line switches and behaviors as possible over time. Several switches and behaviors are expected to change in this implementation.
- `-P` switch will be removed. Passwords for SQL authentication can only be provided through these mechanisms:
-The `SQLCMDPASSWORD` environment variable
-The `:CONNECT` command
-When prompted, the user can type the password to complete a connection
- `-R` switch will be removed. The go runtime does not provide access to user locale information, and it's not readily available through syscall on all supported platforms.
- Some behaviors that were kept to maintain compatibility with `OSQL` may be changed, such as alignment of column headers for some data types.
### Packages
#### sqlcmd
#### batch
## Contributing
@ -26,8 +41,9 @@ contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with any additio
## Trademarks
This project may contain trademarks or logos for projects, products, or services. Authorized use of Microsoft
trademarks or logos is subject to and must follow
This project may contain trademarks or logos for projects, products, or services. Authorized use of Microsoft
trademarks or logos is subject to and must follow
[Microsoft's Trademark & Brand Guidelines](https://www.microsoft.com/en-us/legal/intellectualproperty/trademarks/usage/general).
Use of Microsoft trademarks or logos in modified versions of this project must not cause confusion or imply Microsoft sponsorship.
Any use of third-party trademarks or logos are subject to those third-party's policies.

129
cmd/sqlcmd/main.go Normal file
Просмотреть файл

@ -0,0 +1,129 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
package main
import (
"fmt"
"os"
"github.com/alecthomas/kong"
"github.com/microsoft/go-sqlcmd/pkg/sqlcmd"
"github.com/xo/usql/rline"
)
// SQLCmdArguments defines the command line arguments for sqlcmd
// The exhaustive list is at https://docs.microsoft.com/sql/tools/sqlcmd-utility?view=sql-server-ver15
type SQLCmdArguments struct {
// Which batch terminator to use. Default is GO
BatchTerminator string `short:"c" default:"GO" arghelp:"Specifies the batch terminator. The default value is GO."`
// Whether to trust the server certificate on an encrypted connection
TrustServerCertificate bool `short:"C" help:"Implicitly trust the server certificate without validation."`
DatabaseName string `short:"d" help:"This option sets the sqlcmd scripting variable SQLCMDDBNAME. This parameter specifies the initial database. The default is your login's default-database property. If the database does not exist, an error message is generated and sqlcmd exits."`
UseTrustedConnection bool `short:"E" xor:"uid" help:"Uses a trusted connection instead of using a user name and password to sign in to SQL Server, ignoring any any environment variables that define user name and password."`
UserName string `short:"U" xor:"uid" help:"The login name or contained database user name. For contained database users, you must provide the database name option"`
// Files from which to read query text
InputFile []string `short:"i" xor:"input1, input2" type:"existingFile" help:"Identifies one or more files that contain batches of SQL statements. If one or more files do not exist, sqlcmd will exit. Mutually exclusive with -Q/-q."`
OutputFile string `short:"o" type:"path" help:"Identifies the file that receives output from sqlcmd."`
// First query to run in interactive mode
InitialQuery string `short:"q" xor:"input1" help:"Executes a query when sqlcmd starts, but does not exit sqlcmd when the query has finished running. Multiple-semicolon-delimited queries can be executed."`
// Query to run then exit
Query string `short:"Q" xor:"input2" help:"Executes a query when sqlcmd starts and then immediately exits sqlcmd. Multiple-semicolon-delimited queries can be executed."`
Server string `short:"S" help:"[tcp:]server[\\instance_name][,port]Specifies the instance of SQL Server to which to connect. It sets the sqlcmd scripting variable SQLCMDSERVER."`
// Disable syscommands with a warning
DisableCmdAndWarn bool `short:"X" xor:"syscmd" help:"Disables commands that might compromise system security. Sqlcmd issues a warning and continues."`
}
// Breaking changes in command line are listed here.
// Any switch not listed in breaking changes and not also included in SqlCmdArguments just has not been implemented yet
// 1. -P: Passwords have to be provided through SQLCMDPASSWORD environment variable or typed when prompted
// 2. -R: Go runtime doesn't expose user locale information and syscall would only enable it on Windows, so we won't try to implement it
var args SQLCmdArguments
func main() {
kong.Parse(&args)
vars := sqlcmd.InitializeVariables(!args.DisableCmdAndWarn)
setVars(vars, &args)
exitCode, err := run(vars)
if err != nil {
fmt.Println(err.Error())
}
os.Exit(exitCode)
}
// Initializes scripting variables from command line arguments
func setVars(vars *sqlcmd.Variables, args *SQLCmdArguments) {
varmap := map[string]func(*SQLCmdArguments) string{
sqlcmd.SQLCMDDBNAME: func(a *SQLCmdArguments) string { return a.DatabaseName },
sqlcmd.SQLCMDLOGINTIMEOUT: func(a *SQLCmdArguments) string { return "" },
sqlcmd.SQLCMDUSEAAD: func(a *SQLCmdArguments) string { return "" },
sqlcmd.SQLCMDWORKSTATION: func(a *SQLCmdArguments) string { return "" },
sqlcmd.SQLCMDSERVER: func(a *SQLCmdArguments) string { return a.Server },
sqlcmd.SQLCMDERRORLEVEL: func(a *SQLCmdArguments) string { return "" },
sqlcmd.SQLCMDPACKETSIZE: func(a *SQLCmdArguments) string { return "" },
sqlcmd.SQLCMDUSER: func(a *SQLCmdArguments) string { return a.UserName },
sqlcmd.SQLCMDSTATTIMEOUT: func(a *SQLCmdArguments) string { return "" },
sqlcmd.SQLCMDHEADERS: func(a *SQLCmdArguments) string { return "" },
sqlcmd.SQLCMDCOLSEP: func(a *SQLCmdArguments) string { return "" },
sqlcmd.SQLCMDCOLWIDTH: func(a *SQLCmdArguments) string { return "" },
sqlcmd.SQLCMDMAXVARTYPEWIDTH: func(a *SQLCmdArguments) string { return "" },
sqlcmd.SQLCMDMAXFIXEDTYPEWIDTH: func(a *SQLCmdArguments) string { return "" },
}
for varname, set := range varmap {
val := set(args)
if val != "" {
vars.Set(varname, val)
}
}
}
func run(vars *sqlcmd.Variables) (exitcode int, err error) {
wd, err := os.Getwd()
if err != nil {
return 1, err
}
if args.BatchTerminator != "GO" {
err = sqlcmd.SetBatchTerminator(args.BatchTerminator)
if err != nil {
err = fmt.Errorf("invalid batch terminator '%s'", args.BatchTerminator)
}
}
if err != nil {
return 1, err
}
iactive := args.InputFile == nil
line, err := rline.New(!iactive, "", "")
if err != nil {
return 1, err
}
defer line.Close()
s := sqlcmd.New(line, wd, vars)
s.Connect.UseTrustedConnection = args.UseTrustedConnection
s.Connect.TrustServerCertificate = args.TrustServerCertificate
s.Format = sqlcmd.NewSQLCmdDefaultFormatter(false)
if args.OutputFile != "" {
err = s.RunCommand(sqlcmd.Commands["OUT"], []string{args.OutputFile})
if err != nil {
return 1, err
}
}
once := false
if args.InitialQuery != "" {
s.Query = args.InitialQuery
} else if args.Query != "" {
once = true
s.Query = args.Query
}
err = s.ConnectDb("", "", "", !iactive)
if err != nil {
return 1, err
}
if iactive {
err = s.Run(once)
}
return s.Exitcode, err
}

91
cmd/sqlcmd/main_test.go Normal file
Просмотреть файл

@ -0,0 +1,91 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
package main
import (
"strings"
"testing"
"github.com/alecthomas/kong"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func newKong(t *testing.T, cli interface{}, options ...kong.Option) *kong.Kong {
t.Helper()
options = append([]kong.Option{
kong.Name("test"),
kong.Exit(func(int) {
t.Helper()
t.Fatalf("unexpected exit()")
}),
}, options...)
parser, err := kong.New(cli, options...)
require.NoError(t, err)
return parser
}
func TestValidCommandLineToArgsConversion(t *testing.T) {
type cmdLineTest struct {
commandLine []string
check func(SQLCmdArguments) bool
}
// These tests only cover compatibility with the native sqlcmd, which only supports the short flags
// The long flag names are up for debate.
commands := []cmdLineTest{
{[]string{}, func(args SQLCmdArguments) bool {
return args.Server == "" && !args.UseTrustedConnection && args.UserName == ""
}},
{[]string{"-c", "MYGO", "-C", "-E", "-i", "file1", "-o", "outfile", "-i", "file2"}, func(args SQLCmdArguments) bool {
return args.BatchTerminator == "MYGO" && args.TrustServerCertificate && len(args.InputFile) == 2 && strings.HasSuffix(args.OutputFile, "outfile")
}},
{[]string{"-U", "someuser", "-d", "somedatabase", "-S", "someserver"}, func(args SQLCmdArguments) bool {
return args.BatchTerminator == "GO" && !args.TrustServerCertificate && args.UserName == "someuser" && args.DatabaseName == "somedatabase" && args.Server == "someserver"
}},
// native sqlcmd allows both -q and -Q but only runs the -Q query and exits. We could make them mutually exclusive if desired.
{[]string{"-q", "select 1", "-Q", "select 2"}, func(args SQLCmdArguments) bool {
return args.Server == "" && args.InitialQuery == "select 1" && args.Query == "select 2"
}},
{[]string{"-S", "someserver/someinstance"}, func(args SQLCmdArguments) bool {
return args.Server == "someserver/someinstance"
}},
{[]string{"-S", "tcp:someserver,10245"}, func(args SQLCmdArguments) bool {
return args.Server == "tcp:someserver,10245"
}},
{[]string{"-X"}, func(args SQLCmdArguments) bool {
return args.DisableCmdAndWarn
}},
}
for _, test := range commands {
arguments := &SQLCmdArguments{}
parser := newKong(t, arguments)
_, err := parser.Parse(test.commandLine)
msg := ""
if err != nil {
msg = err.Error()
}
if assert.Nil(t, err, "Unable to parse commandLine:%v\n%s", test.commandLine, msg) {
assert.True(t, test.check(*arguments), "Unexpected SqlCmdArguments from: %v\n%+v", test.commandLine, *arguments)
}
}
}
func TestInvalidCommandLine(t *testing.T) {
type cmdLineTest struct {
commandLine []string
errorMessage string
}
commands := []cmdLineTest{
{[]string{"-E", "-U", "someuser"}, "--use-trusted-connection and --user-name can't be used together"},
}
for _, test := range commands {
arguments := &SQLCmdArguments{}
parser := newKong(t, arguments)
_, err := parser.Parse(test.commandLine)
assert.EqualError(t, err, test.errorMessage, "Command line:%v", test.commandLine)
}
}

11
go.mod Normal file
Просмотреть файл

@ -0,0 +1,11 @@
module github.com/microsoft/go-sqlcmd
go 1.16
require (
github.com/alecthomas/kong v0.2.17
github.com/denisenkom/go-mssqldb v0.10.0
github.com/google/uuid v1.2.0
github.com/stretchr/testify v1.7.0
github.com/xo/usql v0.9.1
)

1513
go.sum Normal file

Разница между файлами не показана из-за своего большого размера Загрузить разницу

182
pkg/sqlcmd/batch.go Normal file
Просмотреть файл

@ -0,0 +1,182 @@
package sqlcmd
const minCapIncrease = 512
// lineend is the slice to use when appending a line.
var lineend = []rune{'\n'}
// Batch provides the query text to run
type Batch struct {
// read provides the next chunk of runes
read func() ([]rune, error)
// Buffer is the current batch text
Buffer []rune
// Length is the length of the statement
Length int
// raw is the unprocessed runes
raw []rune
// rawlen is the number of unprocessed runes
rawlen int
// quote indicates currently processing a quoted string
quote rune
// comment is the state of multi-line comment processing
comment bool
// batchline is the 1-based index of the next line.
// Used for the prompt in interactive mode
batchline int
// linecount is the total number of batch lines processed in the session
linecount uint
}
// NewBatch creates a Batch which converts runes provided by reader into SQL batches
func NewBatch(reader func() ([]rune, error)) *Batch {
b := &Batch{
read: reader,
}
b.Reset(nil)
return b
}
// String returns the current SQL batch text
func (b *Batch) String() string {
return string(b.Buffer)
}
// Reset clears the current batch text and replaces it with new runes
func (b *Batch) Reset(r []rune) {
b.Buffer, b.Length = nil, 0
b.quote = 0
b.comment = false
b.batchline = 1
if r != nil {
b.raw, b.rawlen = r, len(r)
}
}
// Next processes the next chunk of input and sets the Batch state accordingly.
// If the input contains a command to run, Next returns the Command and its
// parameters.
// Upon exit from Next, the caller can use the State method to determine if
// it represents a runnable SQL batch text.
func (b *Batch) Next() (*Command, []string, error) {
var err error
var i int
if b.rawlen == 0 {
b.raw, err = b.read()
if err != nil {
return nil, nil, err
}
b.rawlen = len(b.raw)
}
var command *Command
var args []string
var ok bool
var scannedCommand bool
parse:
for ; i < b.rawlen; i++ {
c, next := b.raw[i], grab(b.raw, i+1, b.rawlen)
switch {
// we're in a quoted string
case b.quote != 0:
i, ok = readString(b.raw, i, b.rawlen, b.quote)
if ok {
b.quote = 0
}
// inside a multiline comment
case b.comment:
i, ok = readMultilineComment(b.raw, i, b.rawlen)
b.comment = !ok
// start of a string
case c == '\'' || c == '"':
b.quote = c
// inline sql comment, skip to end of line
case c == '-' && next == '-':
i = b.rawlen
// start a multi-line comment
case c == '/' && next == '*':
b.comment = true
i++
// continue processing quoted string or multiline comment
case b.quote != 0 || b.comment:
// Commands have to be alone on the line
case !scannedCommand:
var cend int
scannedCommand = true
command, args, cend = readCommand(b.raw, i, b.rawlen)
if command != nil {
// remove the command from raw
b.raw = append(b.raw[:i], b.raw[cend:]...)
break parse
}
}
}
i = min(i, b.rawlen)
empty := isEmptyLine(b.raw, 0, i)
appendLine := b.quote != 0 || b.comment || !empty
if !b.comment && command != nil && empty {
appendLine = false
}
if appendLine {
// skip leading space when empty
st := 0
if b.Length == 0 {
st, _ = findNonSpace(b.raw, 0, i)
}
// log.Printf(">> appending: `%s`", string(r[st:i]))
b.append(b.raw[st:i], lineend)
b.batchline++
}
b.raw = b.raw[i:]
b.rawlen = len(b.raw)
b.linecount++
return command, args, nil
}
// append appends r to b.Buffer separated by sep when b.Buffer is not already empty.
//
// Dynamically grows b.Buf as necessary to accommodate r and the separator.
// Specifically, when b.Buf is not empty, b.Buf will grow by increments of
// MinCapIncrease.
//
// After a call to append, b.Len will be len(b.Buf)+len(sep)+len(r). Call Reset
// to reset the Buf.
func (b *Batch) append(r, sep []rune) {
rlen := len(r)
// initial
if b.Buffer == nil {
b.Buffer, b.Length = r, rlen
return
}
blen, seplen := b.Length, len(sep)
tlen := blen + rlen + seplen
// grow
if bcap := cap(b.Buffer); tlen > bcap {
n := tlen + 2*rlen
n += minCapIncrease - (n % minCapIncrease)
z := make([]rune, blen, n)
copy(z, b.Buffer)
b.Buffer = z
}
b.Buffer = b.Buffer[:tlen]
copy(b.Buffer[blen:], sep)
copy(b.Buffer[blen+seplen:], r)
b.Length = tlen
}
// State returns a string representing the state of statement parsing.
// * Is in the middle of a multi-line comment
// - Has a non-empty batch ready to run
// = Is empty
// ' " Is in the middle of a multi-line quoted string
func (b *Batch) State() string {
switch {
case b.quote != 0:
return string(b.quote)
case b.comment:
return "*"
case b.Length != 0:
return "-"
}
return "="
}

73
pkg/sqlcmd/batch_test.go Normal file
Просмотреть файл

@ -0,0 +1,73 @@
package sqlcmd
import (
"io"
"strings"
"testing"
"github.com/stretchr/testify/assert"
)
func TestBatchNextReset(t *testing.T) {
tests := []struct {
s string
stmts []string
cmds []string
state string
}{
{"", nil, nil, "="},
{"select 1", []string{"select 1"}, nil, "-"},
{"select 1\nquit", []string{"select 1"}, []string{"QUIT"}, "="},
{"select 1\nquite", []string{"select 1\nquite"}, nil, "-"},
{"select 1\nquit\nselect 2", []string{"select 1", "select 2"}, []string{"QUIT"}, "-"},
{"select '1\n", []string{"select '1\n"}, nil, "'"},
{"select 1 /* comment\nGO", []string{"select 1 /* comment\nGO"}, nil, "*"},
{"select '1\n00' \n/* comm\nent*/\nGO 4", []string{"select '1\n00' \n/* comm\nent*/"}, []string{"GO"}, "="},
}
for _, test := range tests {
b := NewBatch(sp(test.s, "\n"))
var stmts, cmds []string
loop:
for {
cmd, _, err := b.Next()
switch {
case err == io.EOF:
// if we get EOF before a command we will try to run
// whatever is in the buffer
if s := b.String(); s != "" {
stmts = append(stmts, s)
}
break loop
case err != nil:
t.Fatalf("test %s did not expect error, got: %v", test.s, err)
}
// resetting the buffer for every command purely for test purposes
if cmd != nil {
stmts = append(stmts, b.String())
cmds = append(cmds, cmd.name)
b.Reset(nil)
}
}
assert.Equal(t, test.stmts, stmts, "Statements for %s", test.s)
assert.Equal(t, test.state, b.State(), "State for %s", test.s)
assert.Equal(t, test.cmds, cmds, "Commands for %s", test.s)
b.Reset(nil)
assert.Zero(t, b.Length, "Length after Reset")
assert.Zero(t, len(b.Buffer), "len(Buffer) after Reset")
assert.Zero(t, b.quote, "quote after Reset")
assert.False(t, b.comment, "comment after Reset")
assert.Equal(t, "=", b.State(), "State() after Reset")
}
}
func sp(a, sep string) func() ([]rune, error) {
s := strings.Split(a, sep)
return func() ([]rune, error) {
if len(s) > 0 {
z := s[0]
s = s[1:]
return []rune(z), nil
}
return nil, io.EOF
}
}

172
pkg/sqlcmd/commands.go Normal file
Просмотреть файл

@ -0,0 +1,172 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
package sqlcmd
import (
"fmt"
"os"
"regexp"
"strings"
"syscall"
)
// Command defines a sqlcmd action which can be intermixed with the SQL batch
// Commands for sqlcmd are defined at https://docs.microsoft.com/sql/tools/sqlcmd-utility#sqlcmd-commands
type Command struct {
// regex must include at least one group if it has parameters
// Will be matched using FindStringSubmatch
regex *regexp.Regexp
// The function that implements the command. Third parameter is the line number
action func(*Sqlcmd, []string, uint) error
// Name of the command
name string
}
// Commands is the set of Command implementations
var Commands = map[string]*Command{
"QUIT": {
regex: regexp.MustCompile(`(?im)^[\t ]*?:?QUIT(?:[ \t]+(.*$)|$)`),
action: quitCommand,
name: "QUIT",
},
"GO": {
regex: regexp.MustCompile(batchTerminatorRegex("GO")),
action: goCommand,
name: "GO",
},
"OUT": {
regex: regexp.MustCompile(`(?im)^[ \t]*:OUT(?:[ \t]+(.*$)|$)`),
action: outCommand,
name: "OUT",
},
"ERROR": {
regex: regexp.MustCompile(`(?im)^[ \t]*:ERROR(?:[ \t]+(.*$)|$)`),
action: errorCommand,
name: "ERROR",
},
}
func matchCommand(line string) (*Command, []string) {
for _, cmd := range Commands {
matchedCommand := cmd.regex.FindStringSubmatch(line)
if matchedCommand != nil {
return cmd, matchedCommand[1:]
}
}
return nil, nil
}
func batchTerminatorRegex(terminator string) string {
return fmt.Sprintf(`(?im)^[\t ]*?%s(?:[ ]+(.*$)|$)`, regexp.QuoteMeta(terminator))
}
// SetBatchTerminator attempts to set the batch terminator to the given value
// Returns an error if the new value is not usable in the regex
func SetBatchTerminator(terminator string) error {
cmd := Commands["GO"]
regex, err := regexp.Compile(batchTerminatorRegex(terminator))
if err != nil {
return err
}
cmd.regex = regex
return nil
}
// quitCommand immediately exits the program without running any more batches
func quitCommand(s *Sqlcmd, args []string, line uint) error {
if args != nil && strings.TrimSpace(args[0]) != "" {
return InvalidCommandError("QUIT", line)
}
return ErrExitRequested
}
// goCommand runs the current batch the number of times specified
func goCommand(s *Sqlcmd, args []string, line uint) error {
// default to 1 execution
n := 1
var err error
if len(args) > 0 {
cnt := strings.TrimSpace(args[0])
if cnt != "" {
_, err = fmt.Sscanf(cnt, "%d", &n)
}
}
if err != nil || n < 1 {
return InvalidCommandError("GO", line)
}
// This loop will likely be refactored to a helper when we implement -Q and :EXIT(query)
for i := 0; i < n; i++ {
query := s.Query
if query == "" {
query = s.batch.String()
}
s.Format.BeginBatch(query, s.vars, s.GetOutput(), s.GetError())
rows, qe := s.db.Query(query)
if qe != nil {
s.Format.AddError(qe)
}
results := true
for qe == nil && results {
cols, err := rows.ColumnTypes()
if err != nil {
s.Format.AddError(err)
} else {
s.Format.BeginResultSet(cols)
active := rows.Next()
for active {
s.Format.AddRow(rows)
active = rows.Next()
}
if err = rows.Err(); err != nil {
s.Format.AddError(err)
}
s.Format.EndResultSet()
}
results = rows.NextResultSet()
if err = rows.Err(); err != nil {
s.Format.AddError(err)
}
}
s.Format.EndBatch()
}
s.Query = ""
s.batch.Reset(nil)
return nil
}
// outCommand changes the output writer to use a file
func outCommand(s *Sqlcmd, args []string, line uint) error {
switch {
case strings.EqualFold(args[0], "stdout"):
s.SetOutput(nil)
case strings.EqualFold(args[0], "stderr"):
s.SetOutput(os.NewFile(uintptr(syscall.Stderr), "/dev/stderr"))
default:
o, err := os.OpenFile(args[0], os.O_TRUNC|os.O_CREATE|os.O_WRONLY, 0o644)
if err != nil {
return InvalidFileError(err, args[0])
}
s.SetOutput(o)
}
return nil
}
// errorCommand changes the error writer to use a file
func errorCommand(s *Sqlcmd, args []string, line uint) error {
switch {
case strings.EqualFold(args[0], "stderr"):
s.SetError(nil)
case strings.EqualFold(args[0], "stdout"):
s.SetError(os.NewFile(uintptr(syscall.Stderr), "/dev/stdout"))
default:
o, err := os.OpenFile(args[0], os.O_TRUNC|os.O_CREATE|os.O_WRONLY, 0o644)
if err != nil {
return InvalidFileError(err, args[0])
}
s.SetError(o)
}
return nil
}

Просмотреть файл

@ -0,0 +1,59 @@
package sqlcmd
import (
"strings"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestQuitCommand(t *testing.T) {
s := &Sqlcmd{}
err := quitCommand(s, nil, 1)
require.ErrorIs(t, err, ErrExitRequested)
err = quitCommand(s, []string{"extra parameters"}, 2)
require.Error(t, err, "Quit should error out with extra parameters")
assert.NotErrorIs(t, err, ErrExitRequested, "Error with extra arguments")
}
func TestCommandParsing(t *testing.T) {
type commandTest struct {
line string
cmd string
args []string
}
commands := []commandTest{
{"quite", "", nil},
{"quit", "QUIT", []string{""}},
{":QUIT\n", "QUIT", []string{""}},
{" QUIT \n", "QUIT", []string{""}},
{"quit extra\n", "QUIT", []string{"extra"}},
{`:Out c:\folder\file`, "OUT", []string{`c:\folder\file`}},
{` :Error c:\folder\file`, "ERROR", []string{`c:\folder\file`}},
}
for _, test := range commands {
cmd, args := matchCommand(test.line)
if test.cmd != "" {
if assert.NotNil(t, cmd, "No command found for `%s`", test.line) {
assert.Equal(t, test.cmd, cmd.name, "Incorrect command for `%s`", test.line)
assert.Equal(t, test.args, args, "Incorrect arguments for `%s`", test.line)
}
} else {
assert.Nil(t, cmd, "Unexpected match for %s", test.line)
}
}
}
func TestCustomBatchSeparator(t *testing.T) {
err := SetBatchTerminator("me!")
if assert.NoError(t, err, "SetBatchTerminator should succeed") {
cmd, args := matchCommand(" me! 5 \n")
if assert.NotNil(t, cmd, "matchCommand didn't find GO for custom batch separator") {
assert.Equal(t, "GO", cmd.name, "command name")
assert.Equal(t, "5", strings.TrimSpace(args[0]), "go argument")
}
}
}

69
pkg/sqlcmd/errors.go Normal file
Просмотреть файл

@ -0,0 +1,69 @@
package sqlcmd
import (
"errors"
"fmt"
)
// ErrorPrefix is the prefix for all sqlcmd-generated errors
const ErrorPrefix = "Sqlcmd: Error: "
// WarningPrefix is the prefix for all sqlcmd-generated warnings
const WarningPrefix = "Sqlcmd: Warning: "
// ArgumentError is related to command line switch validation not handled by kong
type ArgumentError struct {
Parameter string
Rule string
}
func (e *ArgumentError) Error() string {
return ErrorPrefix + e.Rule
}
// InvalidServerName indicates the SQLCMDSERVER variable has an incorrect format
var InvalidServerName = ArgumentError{
Parameter: "server",
Rule: "server must be of the form [tcp]:server[[/instance]|[,port]]",
}
// VariableError is an error about scripting variables
type VariableError struct {
Variable string
MessageFormat string
}
func (e *VariableError) Error() string {
return ErrorPrefix + fmt.Sprintf(e.MessageFormat, e.Variable)
}
// ReadOnlyVariable indicates the user tried to set a value to a read-only variable
func ReadOnlyVariable(variable string) *VariableError {
return &VariableError{
Variable: variable,
MessageFormat: "The scripting variable: '%s' is read-only",
}
}
// CommandError indicates syntax errors for specific sqlcmd commands
type CommandError struct {
Command string
LineNumber uint
}
func (e *CommandError) Error() string {
return ErrorPrefix + fmt.Sprintf("Syntax error at line %d near command '%s'.", e.LineNumber, e.Command)
}
// InvalidCommandError creates a SQLCmdCommandError
func InvalidCommandError(command string, lineNumber uint) *CommandError {
return &CommandError{
Command: command,
LineNumber: lineNumber,
}
}
// InvalidFileError indicates a file could not be opened
func InvalidFileError(err error, path string) error {
return errors.New(ErrorPrefix + " Error occurred while opening or operating on file " + path + " (Reason: " + err.Error() + ").")
}

562
pkg/sqlcmd/format.go Normal file
Просмотреть файл

@ -0,0 +1,562 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
package sqlcmd
import (
"database/sql"
"fmt"
"io"
"strings"
"time"
mssql "github.com/denisenkom/go-mssqldb"
)
const (
defaultMaxDisplayWidth = 1024 * 1024
maxPadWidth = 8000
)
// Formatter defines methods to process query output
type Formatter interface {
// BeginBatch is called before the query runs
BeginBatch(query string, vars *Variables, out io.Writer, err io.Writer)
// EndBatch is the last function called during batch execution and signals the end of the batch
EndBatch()
// BeginResultSet is called when a new result set is encountered
BeginResultSet([]*sql.ColumnType)
// EndResultSet is called after all rows in a result set have been processed
EndResultSet()
// AddRow is called for each row in a result set
AddRow(*sql.Rows)
// AddMessage is called for every information message returned by the server during the batch
AddMessage(string)
// AddError is called for each error encountered during batch execution
AddError(err error)
}
// ControlCharacterBehavior specifies the text handling required for control characters in the output
type ControlCharacterBehavior int
const (
// ControlIgnore preserves control characters in the output
ControlIgnore ControlCharacterBehavior = iota
// ControlReplace replaces control characters with spaces, 1 space per character
ControlReplace
// ControlRemove removes control characters from the output
ControlRemove
// ControlReplaceConsecutive replaces multiple consecutive control characters with a single space
ControlReplaceConsecutive
)
type columnDetail struct {
displayWidth int64
leftJustify bool
zeroesAfterDecimal bool
col sql.ColumnType
}
// The default formatter based on the native sqlcmd style
type sqlCmdFormatterType struct {
out io.Writer
err io.Writer
vars *Variables
colsep string
removeTrailingSpaces bool
ccb ControlCharacterBehavior
columnDetails []columnDetail
rowcount int
writepos int64
}
// NewSQLCmdDefaultFormatter returns a Formatter that mimics the original ODBC-based sqlcmd formatter
func NewSQLCmdDefaultFormatter(removeTrailingSpaces bool) Formatter {
return &sqlCmdFormatterType{
removeTrailingSpaces: removeTrailingSpaces,
}
}
// Adds the given string to the current line, wrapping it based on the screen width setting
func (f *sqlCmdFormatterType) writeOut(s string) {
w := f.vars.ScreenWidth()
if w == 0 {
f.mustWriteOut(s)
return
}
r := []rune(s)
for i := 0; true; {
if i == len(r) {
f.mustWriteOut(string(r))
return
} else if f.writepos == w {
f.mustWriteOut(string(r[:i]))
f.mustWriteOut(SqlcmdEol)
r = []rune(string(r[i:]))
f.writepos = 0
i = 0
} else {
c := r[i]
if c != '\r' && c != '\n' {
f.writepos++
} else {
f.writepos = 0
}
i++
}
}
}
// Stores the settings to use for processing the current batch
// TODO: add a third io.Writer for messages when we add -r support
func (f *sqlCmdFormatterType) BeginBatch(_ string, vars *Variables, out io.Writer, err io.Writer) {
f.out = out
f.err = err
f.vars = vars
f.colsep = vars.ColumnSeparator()
}
func (f *sqlCmdFormatterType) EndBatch() {
}
// Calculate the widths for each column and print the column names
// Since sql.ColumnType only provides sizes for variable length types we will
// base our numbers for most types on https://docs.microsoft.com/sql/odbc/reference/appendixes/column-size
func (f *sqlCmdFormatterType) BeginResultSet(cols []*sql.ColumnType) {
f.rowcount = 0
f.columnDetails = calcColumnDetails(cols, f.vars.MaxFixedColumnWidth(), f.vars.MaxVarColumnWidth())
if f.vars.RowsBetweenHeaders() > -1 {
f.printColumnHeadings()
}
}
// Writes a blank line to the designated output writer
func (f *sqlCmdFormatterType) EndResultSet() {
f.writeOut(SqlcmdEol)
}
// Writes the current row to the designated output writer
func (f *sqlCmdFormatterType) AddRow(row *sql.Rows) {
f.writepos = 0
values, err := f.scanRow(row)
if err != nil {
f.mustWriteErr(err.Error())
return
}
// values are the full values, look at the displaywidth of each column and truncate accordingly
for i, v := range values {
if i > 0 {
f.writeOut(f.vars.ColumnSeparator())
}
f.printColumnValue(v, i)
}
f.rowcount++
gap := f.vars.RowsBetweenHeaders()
if gap > 0 && (int64(f.rowcount)%gap == 0) {
f.writeOut(SqlcmdEol)
f.printColumnHeadings()
}
f.writeOut(SqlcmdEol)
}
// Writes a non-error message to the designated message writer
func (f *sqlCmdFormatterType) AddMessage(string) {}
// Writes an error to the designated err Writer
func (f *sqlCmdFormatterType) AddError(err error) {
b := new(strings.Builder)
msg := err.Error()
switch e := (err).(type) {
case mssql.Error:
b.WriteString(fmt.Sprintf("Msg %d, Level %d, State %d, Server %s, Line %d%s", e.Number, e.Class, e.State, e.ServerName, e.LineNo, SqlcmdEol))
msg = strings.TrimPrefix(msg, "mssql: ")
}
b.WriteString(msg)
b.WriteString(SqlcmdEol)
f.mustWriteErr(fitToScreen(b, f.vars.ScreenWidth()).String())
}
// Prints column headings based on columnDetail, variables, and command line arguments
func (f *sqlCmdFormatterType) printColumnHeadings() {
names := new(strings.Builder)
sep := new(strings.Builder)
var leftPad, rightPad int64
for i, c := range f.columnDetails {
nameLen := int64(len([]rune(c.col.Name())))
if f.removeTrailingSpaces {
if nameLen == 0 {
// special case for unnamed columns when using -W
// print a single -
rightPad = 1
sep = padRight(sep, 1, "-")
} else {
sep = padRight(sep, nameLen, "-")
}
} else {
length := min64(c.displayWidth, maxPadWidth)
if nameLen < length {
rightPad = length - nameLen
}
sep = padRight(sep, length, "-")
}
names = padRight(names, leftPad, " ")
names.WriteString(c.col.Name()[:min64(nameLen, c.displayWidth)])
names = padRight(names, rightPad, " ")
if i != len(f.columnDetails)-1 {
names.WriteString(f.colsep)
sep.WriteString(f.colsep)
}
}
names.WriteString(SqlcmdEol)
sep.WriteString(SqlcmdEol)
names = fitToScreen(names, f.vars.ScreenWidth())
sep = fitToScreen(sep, f.vars.ScreenWidth())
f.mustWriteOut(names.String())
f.mustWriteOut(sep.String())
}
// Wraps the input string every width characters when width > 0
// When width == 0 returns the input Builder
// When width > 0 returns a new Builder containing the wrapped string
func fitToScreen(s *strings.Builder, width int64) *strings.Builder {
str := s.String()
runes := []rune(str)
if width == 0 || int64(len(runes)) < width {
return s
}
line := new(strings.Builder)
line.Grow(len(str))
var c int64
for i, r := range runes {
if c == width {
// We have printed a line's worth
// if the next character is not part of a carriage return write our Eol
if (SqlcmdEol == "\r\n" && (i == len(runes)-1 || (i < len(runes)-1 && string(runes[i:i+2]) != SqlcmdEol))) || (SqlcmdEol == "\n" && r != '\n') {
line.WriteString(SqlcmdEol)
c = 0
}
}
line.WriteRune(r)
if r == '\n' {
c = 0
// we are assuming \r is a non-printed character
// The likelihood of a \r not being followed by \n is low
} else if r == '\r' && SqlcmdEol == "\r\n" {
c = 0
} else {
c++
}
}
return line
}
// Given the array of driver-provided columnType values and the sqlcmd size limits,
// return an array of columnDetail objects describing the output format for each column
func calcColumnDetails(cols []*sql.ColumnType, fixed int64, variable int64) (columnDetails []columnDetail) {
columnDetails = make([]columnDetail, len(cols))
for i, c := range cols {
length, _ := c.Length()
nameLen := int64(len([]rune(c.Name())))
columnDetails[i].col = *c
columnDetails[i].leftJustify = true
columnDetails[i].zeroesAfterDecimal = false
if length == 0 {
columnDetails[i].displayWidth = defaultMaxDisplayWidth
} else {
columnDetails[i].displayWidth = length
}
switch c.DatabaseTypeName() {
// Types with 0 size from sql.ColumnType
case "BIT":
columnDetails[i].leftJustify = false
columnDetails[i].displayWidth = max64(1, nameLen)
case "TINYINT":
columnDetails[i].leftJustify = false
columnDetails[i].displayWidth = max64(3, nameLen)
case "SMALLINT":
columnDetails[i].leftJustify = false
columnDetails[i].displayWidth = max64(6, nameLen)
case "INT":
columnDetails[i].leftJustify = false
columnDetails[i].displayWidth = max64(11, nameLen)
case "BIGINT":
columnDetails[i].leftJustify = false
columnDetails[i].displayWidth = max64(21, nameLen)
case "REAL":
columnDetails[i].leftJustify = false
columnDetails[i].displayWidth = max64(14, nameLen)
columnDetails[i].zeroesAfterDecimal = true
case "FLOAT":
columnDetails[i].leftJustify = false
columnDetails[i].displayWidth = max64(24, nameLen)
columnDetails[i].zeroesAfterDecimal = true
case "DECIMAL":
columnDetails[i].leftJustify = false
d, _, ok := c.DecimalSize()
// maybe panic on !ok?
if !ok {
d = 24
}
columnDetails[i].displayWidth = max64(d+2, nameLen)
columnDetails[i].zeroesAfterDecimal = true
case "DATE":
columnDetails[i].leftJustify = false
columnDetails[i].displayWidth = max64(16, nameLen)
case "DATETIME":
columnDetails[i].leftJustify = false
columnDetails[i].displayWidth = max64(23, nameLen)
case "SMALLDATETIME":
columnDetails[i].leftJustify = false
columnDetails[i].displayWidth = max64(19, nameLen)
columnDetails[i].zeroesAfterDecimal = true
case "DATETIME2":
columnDetails[i].leftJustify = false
columnDetails[i].displayWidth = max64(38, nameLen)
columnDetails[i].zeroesAfterDecimal = true
case "DATETIMEOFFSET":
columnDetails[i].leftJustify = false
columnDetails[i].displayWidth = max64(45, nameLen)
case "UNIQUEIDENTIFIER":
columnDetails[i].displayWidth = max64(36, nameLen)
// Types that can be fixed or variable
case "VARCHAR":
if length > 8000 {
columnDetails[i].displayWidth = variable
} else {
if fixed > 0 {
length = min64(fixed, length)
}
columnDetails[i].displayWidth = max64(length, nameLen)
}
case "NVARCHAR":
if length > 4000 {
columnDetails[i].displayWidth = variable
} else {
if fixed > 0 {
length = min64(fixed, length)
}
columnDetails[i].displayWidth = max64(length, nameLen)
}
case "VARBINARY":
if length <= 8000 {
if fixed > 0 {
length = min64(fixed, length)
}
columnDetails[i].displayWidth = max64(length, nameLen)
} else {
columnDetails[i].displayWidth = variable
}
// Fixed length types
case "CHAR", "NCHAR", "VARIANT":
if fixed > 0 {
length = min64(fixed, length)
}
columnDetails[i].displayWidth = max64(length, nameLen)
// Variable length types
// TODO: Fix BINARY once we have a driver with fix for https://github.com/denisenkom/go-mssqldb/issues/685
case "XML", "TEXT", "NTEXT", "IMAGE", "BINARY":
columnDetails[i].displayWidth = variable
default:
columnDetails[i].displayWidth = length
}
// When max var length is 0 we don't print column headers and print every value with unlimited width
if variable == 0 {
columnDetails[i].displayWidth = 0
}
}
return columnDetails
}
// scanRow fetches the next row and converts each value to the appropriate string representation
func (f *sqlCmdFormatterType) scanRow(rows *sql.Rows) ([]string, error) {
r := make([]interface{}, len(f.columnDetails))
for i := range r {
r[i] = new(interface{})
}
if err := rows.Scan(r...); err != nil {
return nil, err
}
row := make([]string, len(f.columnDetails))
for n, z := range r {
j := z.(*interface{})
if *j == nil {
row[n] = "NULL"
} else {
switch x := (*j).(type) {
case []byte:
if isBinaryDataType(&f.columnDetails[n].col) {
row[n] = decodeBinary(x)
} else {
row[n] = string(x)
}
case string:
row[n] = x
case time.Time:
// Go lacks any way to get the user's preferred time format or even the system default
row[n] = x.String()
case fmt.Stringer:
row[n] = x.String()
// not sure why go-mssql reports bit as bool
case bool:
if x {
row[n] = "1"
} else {
row[n] = "0"
}
default:
var err error
if row[n], err = fmt.Sprintf("%v", x), nil; err != nil {
return nil, err
}
}
}
}
return row, nil
}
// Prints the final version of a cell based on formatting variables and command line parameters
func (f *sqlCmdFormatterType) printColumnValue(val string, col int) {
c := f.columnDetails[col]
s := new(strings.Builder)
if isNeedingControlCharacterTreatment(&c.col) {
val = applyControlCharacterBehavior(val, f.ccb)
}
if isNeedingHexPrefix(&c.col) {
val = "0x" + val
}
s.WriteString(val)
r := []rune(val)
if !f.removeTrailingSpaces {
if f.vars.MaxVarColumnWidth() != 0 || !isLargeVariableType(&c.col) {
padding := c.displayWidth - min64(c.displayWidth, int64(len(r)))
if padding > 0 {
if c.leftJustify {
s = padRight(s, padding, " ")
} else {
s = padLeft(s, padding, " ")
}
}
}
}
r = []rune(s.String())
if c.displayWidth > 0 && int64(len(r)) > c.displayWidth {
s.Reset()
s.WriteString(string(r[:c.displayWidth]))
}
f.writeOut(s.String())
}
func (f *sqlCmdFormatterType) mustWriteOut(s string) {
_, err := f.out.Write([]byte(s))
if err != nil {
panic(err)
}
}
func (f *sqlCmdFormatterType) mustWriteErr(s string) {
_, err := f.err.Write([]byte(s))
if err != nil {
panic(err)
}
}
func isLargeVariableType(col *sql.ColumnType) bool {
l, _ := col.Length()
switch col.DatabaseTypeName() {
case "VARCHAR", "VARBINARY":
return l > 8000
case "NVARCHAR":
return l > 4000
case "XML", "TEXT", "NTEXT", "IMAGE":
return true
}
return false
}
func isNeedingControlCharacterTreatment(col *sql.ColumnType) bool {
switch col.DatabaseTypeName() {
case "CHAR", "VARCHAR", "TEXT", "NTEXT", "NCHAR", "NVARCHAR", "XML":
return true
}
return false
}
func isBinaryDataType(col *sql.ColumnType) bool {
switch col.DatabaseTypeName() {
case "BINARY", "VARBINARY":
return true
}
return false
}
func isNeedingHexPrefix(col *sql.ColumnType) bool {
return isBinaryDataType(col) // || col.DatabaseTypeName() == "UDT"
}
func isControlChar(r rune) bool {
c := int(r)
return c == 0x7f || (c >= 0 && c <= 0x1f)
}
func applyControlCharacterBehavior(val string, ccb ControlCharacterBehavior) string {
if ccb == ControlIgnore {
return val
}
b := new(strings.Builder)
r := []rune(val)
if ccb == ControlReplace {
for _, l := range r {
if isControlChar(l) {
b.WriteRune(' ')
} else {
b.WriteRune(l)
}
}
} else {
for i := 0; i < len(r); {
if !isControlChar(r[i]) {
b.WriteRune(r[i])
i++
} else {
for ; i < len(r) && isControlChar(r[i]); i++ {
}
if ccb == ControlReplaceConsecutive {
b.WriteRune(' ')
}
}
}
}
return b.String()
}
// Per https://docs.microsoft.com/sql/odbc/reference/appendixes/sql-to-c-binary
var hexDigits = []rune{'A', 'B', 'C', 'D', 'E', 'F'}
func decodeBinary(b []byte) string {
s := new(strings.Builder)
s.Grow(len(b) * 2)
for _, ch := range b {
b1 := ch >> 4
b2 := ch & 0x0f
if b1 >= 10 {
s.WriteRune(hexDigits[b1-10])
} else {
s.WriteRune(rune('0' + b1))
}
if b2 >= 10 {
s.WriteRune(hexDigits[b2-10])
} else {
s.WriteRune(rune('0' + b2))
}
}
return s.String()
}

Просмотреть файл

@ -0,0 +1,7 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
package sqlcmd
// SqlcmdEol is the end-of-line marker for sqlcmd output
const SqlcmdEol = "\n"

Просмотреть файл

@ -0,0 +1,7 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
package sqlcmd
// SqlcmdEol is the end-of-line marker for sqlcmd output
const SqlcmdEol = "\n"

136
pkg/sqlcmd/format_test.go Normal file
Просмотреть файл

@ -0,0 +1,136 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
package sqlcmd
import (
"strings"
"testing"
"github.com/stretchr/testify/assert"
)
func TestFitToScreen(t *testing.T) {
type fitTest struct {
width int64
raw string
fit string
}
tests := []fitTest{
{0, "this is a string", "this is a string"},
{9, "12345678", "12345678"},
{9, "123456789", "123456789"},
{9, "123456789A", "123456789" + SqlcmdEol + "A"},
{9, "123456789" + SqlcmdEol, "123456789" + SqlcmdEol},
{9, "12345678" + SqlcmdEol + "9A", "12345678" + SqlcmdEol + "9A"},
{9, "123456789\rA", "123456789" + SqlcmdEol + "\rA"},
}
for _, test := range tests {
line := new(strings.Builder)
line.WriteString(test.raw)
t.Log(test.raw)
f := fitToScreen(line, test.width).String()
assert.Equal(t, test.fit, f, "Mismatched fit for raw string: '%s'", test.raw)
}
}
func TestCalcColumnDetails(t *testing.T) {
type colTest struct {
fixed int64
variable int64
query string
details []columnDetail
}
tests := []colTest{
{8, 8,
"select 100 as '123456789ABC', getdate() as '987654321', 'string' as col1",
[]columnDetail{
{leftJustify: false, displayWidth: 12},
{leftJustify: false, displayWidth: 23},
{leftJustify: true, displayWidth: 6},
},
},
}
db, err := ConnectDb()
if assert.NoError(t, err, "ConnectDB failed") {
defer db.Close()
for _, test := range tests {
rows, err := db.Query(test.query)
if assert.NoError(t, err, "Query failed: %s", test.query) {
defer rows.Close()
cols, err := rows.ColumnTypes()
if assert.NoError(t, err, "ColumnTypes failed:%s", test.query) {
actual := calcColumnDetails(cols, test.fixed, test.variable)
for i, a := range actual {
if test.details[i].displayWidth != a.displayWidth ||
test.details[i].leftJustify != a.leftJustify ||
test.details[i].zeroesAfterDecimal != a.zeroesAfterDecimal {
assert.Failf(t, "", "Incorrect test details for column [%s] in query '%s':%+v", cols[i].Name(), test.query, a)
}
}
}
}
}
}
}
func TestControlCharacterBehavior(t *testing.T) {
type ccbTest struct {
raw string
replaced string
removed string
consecutivereplaced string
}
tests := []ccbTest{
{"no control", "no control", "no control", "no control"},
{string(rune(1)) + "tabs\t\treturns\r\n\r\n", " tabs returns ", "tabsreturns", " tabs returns "},
}
for _, test := range tests {
s := applyControlCharacterBehavior(test.raw, ControlReplace)
assert.Equalf(t, test.replaced, s, "Incorrect Replaced for '%s'", test.raw)
s = applyControlCharacterBehavior(test.raw, ControlRemove)
assert.Equalf(t, test.removed, s, "Incorrect Remove for '%s'", test.raw)
s = applyControlCharacterBehavior(test.raw, ControlReplaceConsecutive)
assert.Equalf(t, test.consecutivereplaced, s, "Incorrect ReplaceConsecutive for '%s'", test.raw)
}
}
func TestDecodeBinary(t *testing.T) {
type decodeTest struct {
b []byte
s string
}
tests := []decodeTest{
{[]byte("123456ABCDEF"), "313233343536414243444546"},
{[]byte{0x12, 0x34, 0x56}, "123456"},
}
for _, test := range tests {
a := decodeBinary(test.b)
assert.Equalf(t, test.s, a, "Incorrect decoded binary string for %v", test.b)
}
}
func BenchmarkDecodeBinary(b *testing.B) {
b.ReportAllocs()
bytes := make([]byte, 10000)
for i := 0; i < 10000; i++ {
bytes[i] = byte(i % 0xff)
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
s := decodeBinary(bytes)
if len(s) != 20000 {
b.Fatalf("Incorrect length of returned string. Should be 20k, was %d", len(s))
}
}
}

Просмотреть файл

@ -0,0 +1,7 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
package sqlcmd
// SqlcmdEol is the end-of-line marker for sqlcmd output
const SqlcmdEol = "\r\n"

147
pkg/sqlcmd/parse.go Normal file
Просмотреть файл

@ -0,0 +1,147 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
package sqlcmd
import "unicode"
// grab grabs i from r, or returns 0 if i >= end.
func grab(r []rune, i, end int) rune {
if i < end {
return r[i]
}
return 0
}
// findNonSpace finds first non space rune in r, returning end if not found.
func findNonSpace(r []rune, i, end int) (int, bool) {
for ; i < end; i++ {
if !isSpaceOrControl(r[i]) {
return i, true
}
}
return i, false
}
/*
// findSpace finds first space rune in r, returning end if not found.
func findSpace(r []rune, i, end int) (int, bool) {
for ; i < end; i++ {
if IsSpaceOrControl(r[i]) {
return i, true
}
}
return i, false
}
// findRune finds the next rune c in r, returning end if not found.
func findRune(r []rune, i, end int, c rune) (int, bool) {
for ; i < end; i++ {
if r[i] == c {
return i, true
}
}
return i, false
}
*/
// isEmptyLine returns true when r is empty or composed of only whitespace.
func isEmptyLine(r []rune, i, end int) bool {
_, ok := findNonSpace(r, i, end)
return !ok
}
// readString seeks to the end of a string returning the position and whether
// or not the string's end was found.
//
// If the string's terminator was not found, then the result will be the passed
// end.
func readString(r []rune, i, end int, quote rune) (int, bool) {
var prev, c, next rune
for ; i < end; i++ {
c, next = r[i], grab(r, i+1, end)
switch {
case quote == '\'' && c == '\\':
i++
prev = 0
continue
case quote == '\'' && c == '\'' && next == '\'':
i++
continue
case quote == '\'' && c == '\'' && prev != '\'',
quote == '"' && c == '"':
return i, true
}
prev = c
}
return end, false
}
// readMultilineComment finds the end of a multiline comment (ie, '*/').
func readMultilineComment(r []rune, i, end int) (int, bool) {
i++
for ; i < end; i++ {
if r[i-1] == '*' && r[i] == '/' {
return i, true
}
}
return end, false
}
// Read to the next control character and try to find
// a command in the string. Command regexes constrain matches
// to the beginning of the string, and all commands consume
// an entire line.
func readCommand(r []rune, i, end int) (*Command, []string, int) {
for ; i < end; i++ {
next := grab(r, i, end)
if next == 0 || unicode.IsControl(next) {
break
}
}
cmd, args := matchCommand(string(r[:i]))
return cmd, args, i
}
func max64(a, b int64) int64 {
if a > b {
return a
}
return b
}
// min returns the minimum of a, b.
func min(a, b int) int {
if a < b {
return a
}
return b
}
func min64(a, b int64) int64 {
if a < b {
return a
}
return b
}
// isSpaceOrControl is a special test for either a space or a control (ie, \b)
// characters.
func isSpaceOrControl(r rune) bool {
return unicode.IsSpace(r) || unicode.IsControl(r)
}
/*
// runesLastIndex returns the last index in r of needle, or -1 if not found.
func runesLastIndex(r []rune, needle rune) int {
i := len(r) - 1
for ; i >= 0; i-- {
if r[i] == needle {
return i
}
}
return i
}
*/

61
pkg/sqlcmd/parse_test.go Normal file
Просмотреть файл

@ -0,0 +1,61 @@
package sqlcmd
import (
"strings"
"testing"
"github.com/stretchr/testify/assert"
)
func TestReadString(t *testing.T) {
tests := []struct {
// input string
s string
// index to start inside s
i int
// expected return string
exp string
// expected return bool
ok bool
}{
{`'`, 0, ``, false},
{` '`, 1, ``, false},
{`'str' `, 0, `'str'`, true},
{` 'str' `, 1, `'str'`, true},
{`"str"`, 0, `"str"`, true},
{`'str''str'`, 0, `'str''str'`, true},
{` 'str''str' `, 1, `'str''str'`, true},
{` "str''str" `, 1, `"str''str"`, true},
// escaped \" aren't allowed in strings, so the second " would be next
// double quoted string
{`"str\""`, 0, `"str\"`, true},
{` "str\"" `, 1, `"str\"`, true},
{`''''`, 0, `''''`, true},
{` '''' `, 1, `''''`, true},
{`''''''`, 0, `''''''`, true},
{` '''''' `, 1, `''''''`, true},
{`'''`, 0, ``, false},
{` ''' `, 1, ``, false},
{`'''''`, 0, ``, false},
{` ''''' `, 1, ``, false},
{`"st'r"`, 0, `"st'r"`, true},
{` "st'r" `, 1, `"st'r"`, true},
{`"st''r"`, 0, `"st''r"`, true},
{` "st''r" `, 1, `"st''r"`, true},
}
for _, test := range tests {
r := []rune(test.s)
c, end := rune(strings.TrimSpace(test.s)[0]), len(r)
if c != '\'' && c != '"' {
t.Fatalf("test %+v incorrect!", test)
}
pos, ok := readString(r, test.i+1, end, c)
assert.Equal(t, test.ok, ok, "test %+v ok", test)
if !ok {
continue
}
assert.Equal(t, c, r[pos], "test %+v last character")
v := string(r[test.i : pos+1])
assert.Equal(t, test.exp, v, "test %+v returned string", test)
}
}

287
pkg/sqlcmd/sqlcmd.go Normal file
Просмотреть файл

@ -0,0 +1,287 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
package sqlcmd
import (
"database/sql"
"errors"
"fmt"
"io"
"net/url"
"os"
"os/signal"
osuser "os/user"
"syscall"
mssql "github.com/denisenkom/go-mssqldb"
"github.com/xo/usql/rline"
)
var (
// ErrExitRequested tells the hosting application to exit immediately
ErrExitRequested = errors.New("exit")
// ErrNeedPassword indicates the user should provide a password to enable the connection
ErrNeedPassword = errors.New("need password")
// ErrCtrlC indicates execution was ended by ctrl-c or ctrl-break
ErrCtrlC = errors.New(WarningPrefix + "The last operation was terminated because the user pressed CTRL+C")
)
// ConnectSettings are the settings for connections that can't be
// inferred from scripting variables
type ConnectSettings struct {
UseTrustedConnection bool
TrustServerCertificate bool
}
// Sqlcmd is the core processor for text lines.
//
// It accumulates non-command lines in a buffer and and sends command lines to the appropriate command runner.
// When the batch delimiter is encountered it sends the current batch to the active connection and prints
// the results to the output writer
type Sqlcmd struct {
lineIo rline.IO
workingDirectory string
db *sql.DB
out io.WriteCloser
err io.WriteCloser
batch *Batch
// Exitcode is returned to the operating system when the process exits
Exitcode int
Connect ConnectSettings
vars *Variables
Format Formatter
Query string
}
// New creates a new Sqlcmd instance
func New(l rline.IO, workingDirectory string, vars *Variables) *Sqlcmd {
return &Sqlcmd{
lineIo: l,
workingDirectory: workingDirectory,
batch: NewBatch(l.Next),
vars: vars,
}
}
// Run processes all available batches.
// When once is true it stops after the first query runs.
func (s *Sqlcmd) Run(once bool) error {
setupCloseHandler(s)
stderr, iactive := s.GetError(), s.lineIo.Interactive()
var lastError error
for {
var execute bool
if iactive {
s.lineIo.Prompt(s.Prompt())
}
var cmd *Command
var args []string
var err error
if s.Query != "" {
cmd = Commands["GO"]
args = make([]string, 0)
} else {
cmd, args, err = s.batch.Next()
}
switch {
case err == rline.ErrInterrupt:
// Ignore any error printing the ctrl-c notice since we are exiting
_, _ = s.GetOutput().Write([]byte(ErrCtrlC.Error()))
return nil
case err != nil:
if err == io.EOF {
if s.batch.Length == 0 {
return lastError
}
execute = true
} else {
return err
}
}
if cmd != nil {
err = s.RunCommand(cmd, args)
if err == ErrExitRequested || once {
s.SetOutput(nil)
s.SetError(nil)
break
}
if err != nil {
fmt.Fprintln(stderr, err)
lastError = err
continue
}
}
if execute {
s.Query = s.batch.String()
once = true
s.batch.Reset(nil)
}
}
return lastError
}
// Prompt returns the current user prompt message
func (s *Sqlcmd) Prompt() string {
ch := ">"
if s.batch.quote != 0 || s.batch.comment {
ch = "~"
}
return fmt.Sprint(s.batch.batchline) + ch + " "
}
// RunCommand performs the given Command
func (s *Sqlcmd) RunCommand(cmd *Command, args []string) error {
return cmd.action(s, args, s.batch.linecount)
}
// GetOutput returns the io.Writer to use for non-error output
func (s *Sqlcmd) GetOutput() io.Writer {
if s.out == nil {
return s.lineIo.Stdout()
}
return s.out
}
// SetOutput sets the io.WriteCloser to use for non-error output
func (s *Sqlcmd) SetOutput(o io.WriteCloser) {
if s.out != nil {
s.out.Close()
}
s.out = o
}
// GetError returns the io.Writer to use for errors
func (s *Sqlcmd) GetError() io.Writer {
if s.err == nil {
return s.lineIo.Stderr()
}
return s.err
}
// SetError sets the io.WriteCloser to use for errors
func (s *Sqlcmd) SetError(e io.WriteCloser) {
if s.err != nil {
s.err.Close()
}
s.err = e
}
// ConnectionString returns the go-mssql connection string to use for queries
func (s *Sqlcmd) ConnectionString() (connectionString string, err error) {
serverName, instance, port, err := s.vars.SQLCmdServer()
if serverName == "" {
serverName = "."
}
if err != nil {
return "", err
}
query := url.Values{}
connectionURL := &url.URL{
Scheme: "sqlserver",
Path: instance,
}
useTrustedConnection := s.Connect.UseTrustedConnection || (s.vars.SQLCmdUser() == "" && !s.vars.UseAad())
if !useTrustedConnection {
connectionURL.User = url.UserPassword(s.vars.SQLCmdUser(), s.vars.Password())
}
if port > 0 {
connectionURL.Host = fmt.Sprintf("%s:%d", serverName, port)
} else {
connectionURL.Host = serverName
}
if s.vars.SQLCmdDatabase() != "" {
query.Add("database", s.vars.SQLCmdDatabase())
}
if s.Connect.TrustServerCertificate {
query.Add("trustservercertificate", "true")
}
connectionURL.RawQuery = query.Encode()
return connectionURL.String(), nil
}
// ConnectDb opens a connection to the database with the given modifications to the connection
func (s *Sqlcmd) ConnectDb(server string, user string, password string, nopw bool) error {
if user != "" && password == "" && !nopw {
return ErrNeedPassword
}
connstr, err := s.ConnectionString()
if err != nil {
return err
}
connectionURL, err := url.Parse(connstr)
if err != nil {
return err
}
if server != "" {
serverName, instance, port, err := splitServer(server)
if err != nil {
return err
}
connectionURL.Path = instance
if port > 0 {
connectionURL.Host = fmt.Sprintf("%s:%d", serverName, port)
} else {
connectionURL.Host = serverName
}
}
if password == "" {
password = s.vars.Password()
}
if user != "" {
connectionURL.User = url.UserPassword(user, password)
}
connector, err := mssql.NewConnector(connectionURL.String())
if err != nil {
return err
}
db := sql.OpenDB(connector)
err = db.Ping()
if err != nil {
return err
}
// we got a good connection so we can update the Sqlcmd
if s.db != nil {
s.db.Close()
}
s.db = db
if server != "" {
s.vars.Set(SQLCMDSERVER, server)
}
if user != "" {
s.vars.Set(SQLCMDUSER, user)
s.Connect.UseTrustedConnection = false
if password != "" {
s.vars.Set(SQLCMDPASSWORD, password)
}
} else if s.vars.SQLCmdUser() == "" {
u, e := osuser.Current()
if e != nil {
panic("Unable to get user name")
}
s.Connect.UseTrustedConnection = true
s.vars.Set(SQLCMDUSER, u.Username)
}
if s.batch != nil {
s.batch.batchline = 1
}
return nil
}
func setupCloseHandler(s *Sqlcmd) {
c := make(chan os.Signal, 1)
signal.Notify(c, os.Interrupt, syscall.SIGTERM)
go func() {
<-c
_, _ = s.GetOutput().Write([]byte(ErrCtrlC.Error()))
os.Exit(0)
}()
}

136
pkg/sqlcmd/sqlcmd_test.go Normal file
Просмотреть файл

@ -0,0 +1,136 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
package sqlcmd
import (
"database/sql"
"fmt"
"os"
"os/user"
"testing"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"github.com/xo/usql/rline"
)
func TestConnectionStringFromSqlCmd(t *testing.T) {
type connectionStringTest struct {
settings *ConnectSettings
setup func(*Variables)
connectionString string
}
pwd := uuid.New().String()
commands := []connectionStringTest{
{nil, nil, "sqlserver://."},
{
&ConnectSettings{TrustServerCertificate: true},
func(vars *Variables) {
_ = Setvar(SQLCMDDBNAME, "somedatabase")
},
"sqlserver://.?database=somedatabase&trustservercertificate=true",
},
{
&ConnectSettings{TrustServerCertificate: true},
func(vars *Variables) {
vars.Set(SQLCMDSERVER, `someserver/instance`)
vars.Set(SQLCMDDBNAME, "somedatabase")
vars.Set(SQLCMDUSER, "someuser")
vars.Set(SQLCMDPASSWORD, pwd)
},
fmt.Sprintf("sqlserver://someuser:%s@someserver/instance?database=somedatabase&trustservercertificate=true", pwd),
},
{
&ConnectSettings{TrustServerCertificate: true, UseTrustedConnection: true},
func(vars *Variables) {
vars.Set(SQLCMDSERVER, `tcp:someserver,1045`)
vars.Set(SQLCMDUSER, "someuser")
vars.Set(SQLCMDPASSWORD, pwd)
},
"sqlserver://someserver:1045?trustservercertificate=true",
},
{
nil,
func(vars *Variables) {
vars.Set(SQLCMDSERVER, `tcp:someserver,1045`)
},
"sqlserver://someserver:1045",
},
}
for _, test := range commands {
v := InitializeVariables(false)
if test.setup != nil {
test.setup(v)
}
s := &Sqlcmd{vars: v}
if test.settings != nil {
s.Connect = *test.settings
}
connectionString, err := s.ConnectionString()
if assert.NoError(t, err, "Unexpected error from %+v", s) {
assert.Equal(t, test.connectionString, connectionString, "Wrong connection string from: %+v", *s)
}
}
}
/* The following tests require a working SQL instance and rely on SqlCmd environment variables
to manage the initial connection string. The default connection when no environment variables are
set will be to localhost using Windows auth.
*/
func TestSqlCmdConnectDb(t *testing.T) {
v := InitializeVariables(true)
s := &Sqlcmd{vars: v}
err := s.ConnectDb("", "", "", false)
if assert.NoError(t, err, "ConnectDb should succeed") {
sqlcmduser := os.Getenv(SQLCMDUSER)
if sqlcmduser == "" {
u, _ := user.Current()
sqlcmduser = u.Username
}
assert.Equal(t, sqlcmduser, s.vars.SQLCmdUser(), "SQLCMDUSER variable should match connected user")
}
}
func ConnectDb() (*sql.DB, error) {
v := InitializeVariables(true)
s := &Sqlcmd{vars: v}
err := s.ConnectDb("", "", "", false)
return s.db, err
}
func TestSqlCmdQueryAndExit(t *testing.T) {
v := InitializeVariables(true)
v.Set(SQLCMDMAXVARTYPEWIDTH, "0")
line, err := rline.New(false, "", "")
if !assert.NoError(t, err, "rline.New") {
return
}
s := New(line, "", v)
s.Format = NewSQLCmdDefaultFormatter(true)
s.Query = "select 100"
file, err := os.CreateTemp("", "sqlcmdout")
if !assert.NoError(t, err, "os.CreateTemp") {
return
}
defer file.Close()
defer os.Remove(file.Name())
s.SetOutput(file)
err = s.ConnectDb("", "", "", true)
if !assert.NoError(t, err, "s.ConnectDB") {
return
}
err = s.Run(true)
if assert.NoError(t, err, "s.Run(once = true)") {
s.SetOutput(nil)
bytes, err := os.ReadFile(file.Name())
if assert.NoError(t, err, "os.ReadFile") {
assert.Equal(t, "100"+SqlcmdEol+SqlcmdEol, string(bytes), "Incorrect output from Run")
}
}
}

64
pkg/sqlcmd/util.go Normal file
Просмотреть файл

@ -0,0 +1,64 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
package sqlcmd
import (
"strconv"
"strings"
)
// splitServer extracts connection parameters from a server name input
func splitServer(serverName string) (string, string, uint64, error) {
instance := ""
port := uint64(0)
if strings.HasPrefix(serverName, "tcp:") {
if len(serverName) == 4 {
return "", "", 0, &InvalidServerName
}
serverName = serverName[4:]
}
serverNameParts := strings.Split(serverName, ",")
if len(serverNameParts) > 2 {
return "", "", 0, &InvalidServerName
}
if len(serverNameParts) == 2 {
var err error
port, err = strconv.ParseUint(serverNameParts[1], 10, 16)
if err != nil {
return "", "", 0, &InvalidServerName
}
serverName = serverNameParts[0]
} else {
serverNameParts = strings.Split(serverName, "/")
if len(serverNameParts) > 2 {
return "", "", 0, &InvalidServerName
}
if len(serverNameParts) == 2 {
instance = serverNameParts[1]
serverName = serverNameParts[0]
}
}
return serverName, instance, port, nil
}
// padRight appends c instances of s to builder
func padRight(builder *strings.Builder, c int64, s string) *strings.Builder {
var i int64
for ; i < c; i++ {
builder.WriteString(s)
}
return builder
}
// padLeft prepends c instances of s to builder
func padLeft(builder *strings.Builder, c int64, s string) *strings.Builder {
newBuilder := new(strings.Builder)
newBuilder.Grow(builder.Len())
var i int64
for ; i < c; i++ {
newBuilder.WriteString(s)
}
newBuilder.WriteString(builder.String())
return newBuilder
}

209
pkg/sqlcmd/variables.go Normal file
Просмотреть файл

@ -0,0 +1,209 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
package sqlcmd
import (
"fmt"
"os"
"strings"
)
// Variables provides set and get of sqlcmd scripting variables
type Variables map[string]string
var variables Variables
// Built-in scripting variables
const (
SQLCMDDBNAME = "SQLCMDDBNAME"
SQLCMDINI = "SQLCMDINI"
SQLCMDPACKETSIZE = "SQLCMDPACKETSIZE"
SQLCMDPASSWORD = "SQLCMDPASSWORD"
SQLCMDSERVER = "SQLCMDSERVER"
SQLCMDUSER = "SQLCMDUSER"
SQLCMDWORKSTATION = "SQLCMDWORKSTATION"
SQLCMDLOGINTIMEOUT = "SQLCMDLOGINTIMEOUT"
SQLCMDSTATTIMEOUT = "SQLCMDSTATTIMEOUT"
SQLCMDHEADERS = "SQLCMDHEADERS"
SQLCMDCOLSEP = "SQLCMDCOLSEP"
SQLCMDCOLWIDTH = "SQLCMDCOLWIDTH"
SQLCMDERRORLEVEL = "SQLCMDERRORLEVEL"
SQLCMDMAXVARTYPEWIDTH = "SQLCMDMAXVARTYPEWIDTH"
SQLCMDMAXFIXEDTYPEWIDTH = "SQLCMDMAXFIXEDTYPEWIDTH"
SQLCMDEDITOR = "SQLCMDEDITOR"
SQLCMDUSEAAD = "SQLCMDUSEAAD"
)
// Variables that can only be set at startup
var readOnlyVariables = []string{
SQLCMDDBNAME,
SQLCMDINI,
SQLCMDPACKETSIZE,
SQLCMDPASSWORD,
SQLCMDSERVER,
SQLCMDUSER,
SQLCMDWORKSTATION,
}
func (v Variables) checkReadOnly(key string) error {
currentValue, hasValue := v[key]
if hasValue {
for _, variable := range readOnlyVariables {
if variable == key && currentValue != "" {
return ReadOnlyVariable(key)
}
}
}
return nil
}
// Set sets or adds the value in the map.
func (v Variables) Set(name, value string) {
key := strings.ToUpper(name)
v[key] = value
}
// Unset removes the value from the map
func (v Variables) Unset(name string) {
key := strings.ToUpper(name)
delete(v, key)
}
// All returns a copy of the current variables
func (v Variables) All() map[string]string {
return map[string]string(v)
}
// SQLCmdUser returns the SQLCMDUSER variable value
func (v Variables) SQLCmdUser() string {
return v[SQLCMDUSER]
}
// SQLCmdServer returns the server connection parameters derived from the SQLCMDSERVER variable value
func (v Variables) SQLCmdServer() (serverName string, instance string, port uint64, err error) {
serverName = v[SQLCMDSERVER]
return splitServer(serverName)
}
// SQLCmdDatabase returns the SQLCMDDBNAME variable value
func (v Variables) SQLCmdDatabase() string {
return v[SQLCMDDBNAME]
}
// UseAad returns whether the SQLCMDUSEAAD variable value is set to "true"
func (v Variables) UseAad() bool {
return strings.EqualFold(v[SQLCMDUSEAAD], "true")
}
// Password returns the password used for connections as specified by SQLCMDPASSWORD variable
func (v Variables) Password() string {
return v[SQLCMDPASSWORD]
}
// ColumnSeparator is the value of SQLCMDCOLSEP variable. It can have 0 or 1 characters
func (v Variables) ColumnSeparator() string {
sep := v[SQLCMDCOLSEP]
if len(sep) > 1 {
return sep[:1]
}
return sep
}
// MaxFixedColumnWidth is the value of SQLCMDMAXFIXEDTYPEWIDTH variable.
// When non-zero, it limits the width of columns for types CHAR, NCHAR, NVARCHAR, VARCHAR, VARBINARY, VARIANT
func (v Variables) MaxFixedColumnWidth() int64 {
w := v[SQLCMDMAXFIXEDTYPEWIDTH]
return mustValue(w)
}
// MaxVarColumnWidth is the value of SQLCMDMAXVARTYPEWIDTH variable.
// When non-zero, it limits the width of columns for (max) versions of CHAR, NCHAR, VARBINARY.
// It also limits the width of xml, UDT, text, ntext, and image
func (v Variables) MaxVarColumnWidth() int64 {
w := v[SQLCMDMAXVARTYPEWIDTH]
return mustValue(w)
}
// ScreenWidth is the value of SQLCMDCOLWIDTH variable.
// It tells the formatter how many characters wide to limit all screen output.
func (v Variables) ScreenWidth() int64 {
w := v[SQLCMDCOLWIDTH]
return mustValue(w)
}
// RowsBetweenHeaders is the value of SQLCMDHEADERS variable.
// When MaxVarColumnWidth() is 0, it returns -1
func (v Variables) RowsBetweenHeaders() int64 {
if v.MaxVarColumnWidth() == 0 {
return -1
}
h := mustValue(v[SQLCMDHEADERS])
return h
}
func mustValue(val string) int64 {
var n int64
_, err := fmt.Sscanf(val, "%d", &n)
if err == nil {
return n
}
panic(err)
}
// InitializeVariables initializes variables with default values.
// When fromEnvironment is true, then loads from the runtime environment
func InitializeVariables(fromEnvironment bool) *Variables {
variables = Variables{
SQLCMDCOLSEP: " ",
SQLCMDCOLWIDTH: "0",
SQLCMDDBNAME: "",
SQLCMDEDITOR: "edit.com",
SQLCMDERRORLEVEL: "0",
SQLCMDHEADERS: "0",
SQLCMDINI: "",
SQLCMDLOGINTIMEOUT: "30",
SQLCMDMAXFIXEDTYPEWIDTH: "0",
SQLCMDMAXVARTYPEWIDTH: "256",
SQLCMDPACKETSIZE: "4096",
SQLCMDSERVER: "",
SQLCMDSTATTIMEOUT: "0",
SQLCMDUSER: "",
SQLCMDPASSWORD: "",
SQLCMDUSEAAD: "",
}
hostname, _ := os.Hostname()
variables.Set(SQLCMDWORKSTATION, hostname)
if fromEnvironment {
for v := range variables.All() {
envVar, ok := os.LookupEnv(v)
if ok {
variables.Set(v, envVar)
}
}
}
return &variables
}
// Setvar implements the :Setvar command
// TODO: Add validation functions for the variables.
func Setvar(name, value string) error {
err := ValidIdentifier(name)
if err == nil {
err = variables.checkReadOnly(name)
}
if err != nil {
return err
}
variables.Set(name, value)
return nil
}
// ValidIdentifier determines if a given string can be used as a variable name
func ValidIdentifier(name string) error {
if strings.HasPrefix(name, "$(") || strings.ContainsAny(name, "'\"\t\n\r ") {
return InvalidCommandError(":setvar", 0)
}
return nil
}

Просмотреть файл

@ -0,0 +1,64 @@
package sqlcmd
import (
"os"
"testing"
"github.com/stretchr/testify/assert"
)
func TestBasicVariableOperations(t *testing.T) {
variables = Variables{
"var1": "val1",
}
variables.Set("var2", "val2")
assert.Contains(t, variables, "VAR2", "Set should add a capitalized key")
all := variables.All()
keys := make([]string, 0, len(all))
for k := range all {
keys = append(keys, k)
}
assert.ElementsMatch(t, []string{"var1", "VAR2"}, keys, "All returns every key")
assert.Equal(t, "val2", all["VAR2"], "VAR2 set value")
}
func TestSetvarFailsForReadOnlyVariables(t *testing.T) {
variables = Variables{}
err := Setvar("SQLCMDDBNAME", "somedatabase")
assert.NoError(t, err, "Setvar should succeed when SQLCMDDBNAME is not set")
err = Setvar("SQLCMDDBNAME", "newdatabase")
assert.EqualError(t, err, "Sqlcmd: Error: The scripting variable: 'SQLCMDDBNAME' is read-only")
}
func TestEnvironmentVariablesAsInput(t *testing.T) {
os.Setenv("SQLCMDSERVER", "someserver")
defer os.Unsetenv("SQLCMDSERVER")
os.Setenv("x", "somevalue")
defer os.Unsetenv("x")
vars := InitializeVariables(true).All()
assert.Equal(t, "someserver", vars["SQLCMDSERVER"], "InitializeVariables should read a valid environment variable from the known list")
_, ok := vars["x"]
assert.False(t, ok, "InitializeVariables should skip variables not in the known list")
}
func TestSqlServerSplitsName(t *testing.T) {
vars := Variables{
SQLCMDSERVER: `tcp:someserver/someinstance`,
}
serverName, instance, port, err := vars.SQLCmdServer()
if assert.NoError(t, err, "tcp:server/someinstance") {
assert.Equal(t, "someserver", serverName, "server name for instance")
assert.Equal(t, uint64(0), port, "port for instance")
assert.Equal(t, "someinstance", instance, "instance for instance")
}
vars = Variables{
SQLCMDSERVER: `tcp:someserver,1111`,
}
serverName, instance, port, err = vars.SQLCmdServer()
if assert.NoError(t, err, "tcp:server,1111") {
assert.Equal(t, "someserver", serverName, "server name for port number")
assert.Equal(t, uint64(1111), port, "port for port number")
assert.Equal(t, "", instance, "instance for port number")
}
}

3
testdata/sql.txt поставляемый Normal file
Просмотреть файл

@ -0,0 +1,3 @@
select 1 as col1
go