зеркало из https://github.com/microsoft/go-sqlcmd.git
Initial version of go-sqlcmd (#1)
* commandline and variables * add displayname to build tasks * implement quit command * initial version of batch parsing * move variables * new package for variables * add vscode helpers for debug * fix go and quit processing * implement Out and Error commands * fix custom batch separator * move connectionString to sqlcmd * Add sql connection and print column headers * add row and error processing * remove unused package * fix binary rendering and screen fitting * rewrite decodeBinary for performance * fix test pipeline * remove password command line param * exit on ctrl-c * implement -q and -Q * de-lint and update readme * add lint for PRs * separate pkg and cmd folders * fix pipeline for new folder layout * de-lint and reduce public surface * more de-linting * hopefully last round of de-linting
This commit is contained in:
Родитель
3d4fc056ea
Коммит
76685e94af
|
@ -0,0 +1,17 @@
|
|||
name: golangci-lint
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
pull_request:
|
||||
jobs:
|
||||
golangci-pr:
|
||||
name: lint-pr-changes
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
- name: golangci-lint
|
||||
uses: golangci/golangci-lint-action@v2
|
||||
with:
|
||||
version: v1.42.0
|
||||
only-new-issues: true
|
|
@ -13,3 +13,8 @@
|
|||
|
||||
# Dependency directories (remove the comment below to include it)
|
||||
# vendor/
|
||||
|
||||
coverage.json
|
||||
coverage.txt
|
||||
coverage.xml
|
||||
testresults.xml
|
|
@ -0,0 +1,80 @@
|
|||
pool:
|
||||
vmImage: 'ubuntu-latest'
|
||||
|
||||
steps:
|
||||
- task: GoTool@0
|
||||
inputs:
|
||||
version: '1.16.5'
|
||||
- task: Go@0
|
||||
displayName: 'Go: get dependencies'
|
||||
inputs:
|
||||
command: 'get'
|
||||
arguments: '-d'
|
||||
workingDirectory: '$(Build.SourcesDirectory)/cmd/sqlcmd'
|
||||
|
||||
|
||||
- task: Go@0
|
||||
displayName: 'Go: install gotest.tools/gotestsum'
|
||||
inputs:
|
||||
command: 'custom'
|
||||
customCommand: 'install'
|
||||
arguments: 'gotest.tools/gotestsum@latest'
|
||||
workingDirectory: '$(System.DefaultWorkingDirectory)'
|
||||
|
||||
- task: Go@0
|
||||
displayName: 'Go: install github.com/axw/gocov/gocov'
|
||||
inputs:
|
||||
command: 'custom'
|
||||
customCommand: 'install'
|
||||
arguments: 'github.com/axw/gocov/gocov@latest'
|
||||
workingDirectory: '$(System.DefaultWorkingDirectory)'
|
||||
|
||||
- task: Go@0
|
||||
displayName: 'Go: install github.com/axw/gocov/gocov'
|
||||
inputs:
|
||||
command: 'custom'
|
||||
customCommand: 'install'
|
||||
arguments: 'github.com/AlekSi/gocov-xml@latest'
|
||||
workingDirectory: '$(System.DefaultWorkingDirectory)'
|
||||
|
||||
#Your build pipeline references an undefined variables named SQLPASSWORD.
|
||||
#Create or edit the build pipeline for this YAML file, define the variable on the Variables tab. See https://go.microsoft.com/fwlink/?linkid=865972
|
||||
|
||||
- task: Docker@2
|
||||
displayName: 'Run SQL 2017 docker image'
|
||||
inputs:
|
||||
command: run
|
||||
arguments: '-m 2GB -e ACCEPT_EULA=1 -d --name sql2017 -p:1433:1433 -e SA_PASSWORD=$(SQLPASSWORD) mcr.microsoft.com/mssql/server:2017-latest'
|
||||
|
||||
- script: |
|
||||
~/go/bin/gotestsum --junitfile testresults.xml -- ./... -coverprofile=coverage.txt -covermode count
|
||||
~/go/bin/gocov convert coverage.txt > coverage.json
|
||||
~/go/bin/gocov-xml < coverage.json > coverage.xml
|
||||
mkdir coverage
|
||||
workingDirectory: '$(Build.SourcesDirectory)'
|
||||
displayName: 'run tests'
|
||||
env:
|
||||
SQLPASSWORD: $(SQLPASSWORD)
|
||||
SQLCMDUSER: sa
|
||||
SQLCMDPASSWORD: $(SQLPASSWORD)
|
||||
continueOnError: true
|
||||
- task: PublishTestResults@2
|
||||
displayName: "Publish junit-style results"
|
||||
inputs:
|
||||
testResultsFiles: 'testresults.xml'
|
||||
testResultsFormat: JUnit
|
||||
searchFolder: '$(Build.SourcesDirectory)'
|
||||
testRunTitle: 'SQL 2017 - $(Build.SourceBranchName)'
|
||||
condition: always()
|
||||
continueOnError: true
|
||||
|
||||
- task: PublishCodeCoverageResults@1
|
||||
inputs:
|
||||
codeCoverageTool: Cobertura
|
||||
pathToSources: '$(Build.SourcesDirectory)'
|
||||
summaryFileLocation: $(Build.SourcesDirectory)/**/coverage.xml
|
||||
reportDirectory: $(Build.SourcesDirectory)/**/coverage
|
||||
failIfCoverageEmpty: true
|
||||
condition: always()
|
||||
continueOnError: true
|
||||
|
|
@ -0,0 +1,27 @@
|
|||
{
|
||||
// Use IntelliSense to learn about possible attributes.
|
||||
// Hover to view descriptions of existing attributes.
|
||||
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
|
||||
"version": "0.2.0",
|
||||
"configurations": [
|
||||
{
|
||||
"name": "Attach using delve",
|
||||
"type": "go",
|
||||
"request": "attach",
|
||||
"preLaunchTask": "delve",
|
||||
"mode": "remote",
|
||||
"remotePath": "${workspaceFolder}",
|
||||
"port" : 23456,
|
||||
"host" : "127.0.0.1",
|
||||
"cwd" : "${workspaceFolder}",
|
||||
},
|
||||
{
|
||||
"name" : "Run query and exit",
|
||||
"type" : "go",
|
||||
"request": "launch",
|
||||
"mode" : "auto",
|
||||
"program": "${fileDirname}",
|
||||
"args" : ["-Q", "\"select 100 as Count\""],
|
||||
}
|
||||
]
|
||||
}
|
|
@ -0,0 +1,36 @@
|
|||
{
|
||||
// See https://go.microsoft.com/fwlink/?LinkId=733558
|
||||
// for the documentation about the tasks.json format
|
||||
"version": "2.0.0",
|
||||
"tasks": [
|
||||
{
|
||||
"label": "delve",
|
||||
"type": "shell",
|
||||
"command": "dlv debug --headless --listen=:23456 --api-version=2 \"${workspaceFolder}\"",
|
||||
"isBackground": true,
|
||||
"presentation": {
|
||||
"focus": true,
|
||||
"panel": "dedicated",
|
||||
"clear": false
|
||||
},
|
||||
"group": {
|
||||
"kind": "build",
|
||||
"isDefault": true
|
||||
},
|
||||
"problemMatcher": {
|
||||
"pattern": {
|
||||
"regexp": ""
|
||||
},
|
||||
"background": {
|
||||
"activeOnStart": true,
|
||||
"beginsPattern": {
|
||||
"regexp": ".*"
|
||||
},
|
||||
"endsPattern": {
|
||||
"regexp": ".*server listening.*"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
36
README.md
36
README.md
|
@ -1,14 +1,29 @@
|
|||
# Project
|
||||
# SQL Utilities - Go edition
|
||||
|
||||
> This repo has been populated by an initial template to help get you started. Please
|
||||
> make sure to update the content to build a great experience for community-building.
|
||||
This repo contains command line tools and go packages for working with Microsoft SQL Server, Azure SQL Database, and Azure Synapse.
|
||||
|
||||
As the maintainer of this project, please make a few updates:
|
||||
## Sqlcmd
|
||||
|
||||
- Improving this README.MD file to provide a great experience
|
||||
- Updating SUPPORT.MD with content about this project's support experience
|
||||
- Understanding the security reporting process in SECURITY.MD
|
||||
- Remove this section from the README
|
||||
The `sqlcmd` project aims to be a complete port of the native sqlcmd to the `go` language, utilizing the [go-mssqldb](https://github.com/denisenkom/go-mssqldb) driver. For full documentation of the tool, see https://docs.microsoft.com/sql/tools/sqlcmd-utility
|
||||
|
||||
### Breaking changes
|
||||
|
||||
We will be implementing as many command line switches and behaviors as possible over time. Several switches and behaviors are expected to change in this implementation.
|
||||
|
||||
- `-P` switch will be removed. Passwords for SQL authentication can only be provided through these mechanisms:
|
||||
|
||||
-The `SQLCMDPASSWORD` environment variable
|
||||
-The `:CONNECT` command
|
||||
-When prompted, the user can type the password to complete a connection
|
||||
|
||||
- `-R` switch will be removed. The go runtime does not provide access to user locale information, and it's not readily available through syscall on all supported platforms.
|
||||
- Some behaviors that were kept to maintain compatibility with `OSQL` may be changed, such as alignment of column headers for some data types.
|
||||
|
||||
### Packages
|
||||
|
||||
#### sqlcmd
|
||||
|
||||
#### batch
|
||||
|
||||
## Contributing
|
||||
|
||||
|
@ -26,8 +41,9 @@ contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with any additio
|
|||
|
||||
## Trademarks
|
||||
|
||||
This project may contain trademarks or logos for projects, products, or services. Authorized use of Microsoft
|
||||
trademarks or logos is subject to and must follow
|
||||
This project may contain trademarks or logos for projects, products, or services. Authorized use of Microsoft
|
||||
trademarks or logos is subject to and must follow
|
||||
[Microsoft's Trademark & Brand Guidelines](https://www.microsoft.com/en-us/legal/intellectualproperty/trademarks/usage/general).
|
||||
Use of Microsoft trademarks or logos in modified versions of this project must not cause confusion or imply Microsoft sponsorship.
|
||||
Any use of third-party trademarks or logos are subject to those third-party's policies.
|
||||
|
||||
|
|
|
@ -0,0 +1,129 @@
|
|||
// Copyright (c) Microsoft Corporation.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
"github.com/alecthomas/kong"
|
||||
"github.com/microsoft/go-sqlcmd/pkg/sqlcmd"
|
||||
"github.com/xo/usql/rline"
|
||||
)
|
||||
|
||||
// SQLCmdArguments defines the command line arguments for sqlcmd
|
||||
// The exhaustive list is at https://docs.microsoft.com/sql/tools/sqlcmd-utility?view=sql-server-ver15
|
||||
type SQLCmdArguments struct {
|
||||
// Which batch terminator to use. Default is GO
|
||||
BatchTerminator string `short:"c" default:"GO" arghelp:"Specifies the batch terminator. The default value is GO."`
|
||||
// Whether to trust the server certificate on an encrypted connection
|
||||
TrustServerCertificate bool `short:"C" help:"Implicitly trust the server certificate without validation."`
|
||||
DatabaseName string `short:"d" help:"This option sets the sqlcmd scripting variable SQLCMDDBNAME. This parameter specifies the initial database. The default is your login's default-database property. If the database does not exist, an error message is generated and sqlcmd exits."`
|
||||
UseTrustedConnection bool `short:"E" xor:"uid" help:"Uses a trusted connection instead of using a user name and password to sign in to SQL Server, ignoring any any environment variables that define user name and password."`
|
||||
UserName string `short:"U" xor:"uid" help:"The login name or contained database user name. For contained database users, you must provide the database name option"`
|
||||
// Files from which to read query text
|
||||
InputFile []string `short:"i" xor:"input1, input2" type:"existingFile" help:"Identifies one or more files that contain batches of SQL statements. If one or more files do not exist, sqlcmd will exit. Mutually exclusive with -Q/-q."`
|
||||
OutputFile string `short:"o" type:"path" help:"Identifies the file that receives output from sqlcmd."`
|
||||
// First query to run in interactive mode
|
||||
InitialQuery string `short:"q" xor:"input1" help:"Executes a query when sqlcmd starts, but does not exit sqlcmd when the query has finished running. Multiple-semicolon-delimited queries can be executed."`
|
||||
// Query to run then exit
|
||||
Query string `short:"Q" xor:"input2" help:"Executes a query when sqlcmd starts and then immediately exits sqlcmd. Multiple-semicolon-delimited queries can be executed."`
|
||||
Server string `short:"S" help:"[tcp:]server[\\instance_name][,port]Specifies the instance of SQL Server to which to connect. It sets the sqlcmd scripting variable SQLCMDSERVER."`
|
||||
// Disable syscommands with a warning
|
||||
DisableCmdAndWarn bool `short:"X" xor:"syscmd" help:"Disables commands that might compromise system security. Sqlcmd issues a warning and continues."`
|
||||
}
|
||||
|
||||
// Breaking changes in command line are listed here.
|
||||
// Any switch not listed in breaking changes and not also included in SqlCmdArguments just has not been implemented yet
|
||||
// 1. -P: Passwords have to be provided through SQLCMDPASSWORD environment variable or typed when prompted
|
||||
// 2. -R: Go runtime doesn't expose user locale information and syscall would only enable it on Windows, so we won't try to implement it
|
||||
|
||||
var args SQLCmdArguments
|
||||
|
||||
func main() {
|
||||
kong.Parse(&args)
|
||||
vars := sqlcmd.InitializeVariables(!args.DisableCmdAndWarn)
|
||||
setVars(vars, &args)
|
||||
|
||||
exitCode, err := run(vars)
|
||||
if err != nil {
|
||||
fmt.Println(err.Error())
|
||||
}
|
||||
os.Exit(exitCode)
|
||||
}
|
||||
|
||||
// Initializes scripting variables from command line arguments
|
||||
func setVars(vars *sqlcmd.Variables, args *SQLCmdArguments) {
|
||||
varmap := map[string]func(*SQLCmdArguments) string{
|
||||
sqlcmd.SQLCMDDBNAME: func(a *SQLCmdArguments) string { return a.DatabaseName },
|
||||
sqlcmd.SQLCMDLOGINTIMEOUT: func(a *SQLCmdArguments) string { return "" },
|
||||
sqlcmd.SQLCMDUSEAAD: func(a *SQLCmdArguments) string { return "" },
|
||||
sqlcmd.SQLCMDWORKSTATION: func(a *SQLCmdArguments) string { return "" },
|
||||
sqlcmd.SQLCMDSERVER: func(a *SQLCmdArguments) string { return a.Server },
|
||||
sqlcmd.SQLCMDERRORLEVEL: func(a *SQLCmdArguments) string { return "" },
|
||||
sqlcmd.SQLCMDPACKETSIZE: func(a *SQLCmdArguments) string { return "" },
|
||||
sqlcmd.SQLCMDUSER: func(a *SQLCmdArguments) string { return a.UserName },
|
||||
sqlcmd.SQLCMDSTATTIMEOUT: func(a *SQLCmdArguments) string { return "" },
|
||||
sqlcmd.SQLCMDHEADERS: func(a *SQLCmdArguments) string { return "" },
|
||||
sqlcmd.SQLCMDCOLSEP: func(a *SQLCmdArguments) string { return "" },
|
||||
sqlcmd.SQLCMDCOLWIDTH: func(a *SQLCmdArguments) string { return "" },
|
||||
sqlcmd.SQLCMDMAXVARTYPEWIDTH: func(a *SQLCmdArguments) string { return "" },
|
||||
sqlcmd.SQLCMDMAXFIXEDTYPEWIDTH: func(a *SQLCmdArguments) string { return "" },
|
||||
}
|
||||
for varname, set := range varmap {
|
||||
val := set(args)
|
||||
if val != "" {
|
||||
vars.Set(varname, val)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func run(vars *sqlcmd.Variables) (exitcode int, err error) {
|
||||
wd, err := os.Getwd()
|
||||
if err != nil {
|
||||
return 1, err
|
||||
}
|
||||
if args.BatchTerminator != "GO" {
|
||||
err = sqlcmd.SetBatchTerminator(args.BatchTerminator)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("invalid batch terminator '%s'", args.BatchTerminator)
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
return 1, err
|
||||
}
|
||||
|
||||
iactive := args.InputFile == nil
|
||||
line, err := rline.New(!iactive, "", "")
|
||||
if err != nil {
|
||||
return 1, err
|
||||
}
|
||||
defer line.Close()
|
||||
|
||||
s := sqlcmd.New(line, wd, vars)
|
||||
s.Connect.UseTrustedConnection = args.UseTrustedConnection
|
||||
s.Connect.TrustServerCertificate = args.TrustServerCertificate
|
||||
s.Format = sqlcmd.NewSQLCmdDefaultFormatter(false)
|
||||
if args.OutputFile != "" {
|
||||
err = s.RunCommand(sqlcmd.Commands["OUT"], []string{args.OutputFile})
|
||||
if err != nil {
|
||||
return 1, err
|
||||
}
|
||||
}
|
||||
once := false
|
||||
if args.InitialQuery != "" {
|
||||
s.Query = args.InitialQuery
|
||||
} else if args.Query != "" {
|
||||
once = true
|
||||
s.Query = args.Query
|
||||
}
|
||||
err = s.ConnectDb("", "", "", !iactive)
|
||||
if err != nil {
|
||||
return 1, err
|
||||
}
|
||||
if iactive {
|
||||
err = s.Run(once)
|
||||
}
|
||||
return s.Exitcode, err
|
||||
}
|
|
@ -0,0 +1,91 @@
|
|||
// Copyright (c) Microsoft Corporation.
|
||||
// Licensed under the MIT license.
|
||||
package main
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/alecthomas/kong"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func newKong(t *testing.T, cli interface{}, options ...kong.Option) *kong.Kong {
|
||||
t.Helper()
|
||||
options = append([]kong.Option{
|
||||
kong.Name("test"),
|
||||
kong.Exit(func(int) {
|
||||
t.Helper()
|
||||
t.Fatalf("unexpected exit()")
|
||||
}),
|
||||
}, options...)
|
||||
parser, err := kong.New(cli, options...)
|
||||
require.NoError(t, err)
|
||||
return parser
|
||||
}
|
||||
|
||||
func TestValidCommandLineToArgsConversion(t *testing.T) {
|
||||
type cmdLineTest struct {
|
||||
commandLine []string
|
||||
check func(SQLCmdArguments) bool
|
||||
}
|
||||
|
||||
// These tests only cover compatibility with the native sqlcmd, which only supports the short flags
|
||||
// The long flag names are up for debate.
|
||||
commands := []cmdLineTest{
|
||||
{[]string{}, func(args SQLCmdArguments) bool {
|
||||
return args.Server == "" && !args.UseTrustedConnection && args.UserName == ""
|
||||
}},
|
||||
{[]string{"-c", "MYGO", "-C", "-E", "-i", "file1", "-o", "outfile", "-i", "file2"}, func(args SQLCmdArguments) bool {
|
||||
return args.BatchTerminator == "MYGO" && args.TrustServerCertificate && len(args.InputFile) == 2 && strings.HasSuffix(args.OutputFile, "outfile")
|
||||
}},
|
||||
{[]string{"-U", "someuser", "-d", "somedatabase", "-S", "someserver"}, func(args SQLCmdArguments) bool {
|
||||
return args.BatchTerminator == "GO" && !args.TrustServerCertificate && args.UserName == "someuser" && args.DatabaseName == "somedatabase" && args.Server == "someserver"
|
||||
}},
|
||||
// native sqlcmd allows both -q and -Q but only runs the -Q query and exits. We could make them mutually exclusive if desired.
|
||||
{[]string{"-q", "select 1", "-Q", "select 2"}, func(args SQLCmdArguments) bool {
|
||||
return args.Server == "" && args.InitialQuery == "select 1" && args.Query == "select 2"
|
||||
}},
|
||||
{[]string{"-S", "someserver/someinstance"}, func(args SQLCmdArguments) bool {
|
||||
return args.Server == "someserver/someinstance"
|
||||
}},
|
||||
{[]string{"-S", "tcp:someserver,10245"}, func(args SQLCmdArguments) bool {
|
||||
return args.Server == "tcp:someserver,10245"
|
||||
}},
|
||||
{[]string{"-X"}, func(args SQLCmdArguments) bool {
|
||||
return args.DisableCmdAndWarn
|
||||
}},
|
||||
}
|
||||
|
||||
for _, test := range commands {
|
||||
arguments := &SQLCmdArguments{}
|
||||
parser := newKong(t, arguments)
|
||||
_, err := parser.Parse(test.commandLine)
|
||||
msg := ""
|
||||
if err != nil {
|
||||
msg = err.Error()
|
||||
}
|
||||
if assert.Nil(t, err, "Unable to parse commandLine:%v\n%s", test.commandLine, msg) {
|
||||
assert.True(t, test.check(*arguments), "Unexpected SqlCmdArguments from: %v\n%+v", test.commandLine, *arguments)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestInvalidCommandLine(t *testing.T) {
|
||||
type cmdLineTest struct {
|
||||
commandLine []string
|
||||
errorMessage string
|
||||
}
|
||||
|
||||
commands := []cmdLineTest{
|
||||
{[]string{"-E", "-U", "someuser"}, "--use-trusted-connection and --user-name can't be used together"},
|
||||
}
|
||||
|
||||
for _, test := range commands {
|
||||
arguments := &SQLCmdArguments{}
|
||||
parser := newKong(t, arguments)
|
||||
_, err := parser.Parse(test.commandLine)
|
||||
assert.EqualError(t, err, test.errorMessage, "Command line:%v", test.commandLine)
|
||||
}
|
||||
}
|
|
@ -0,0 +1,11 @@
|
|||
module github.com/microsoft/go-sqlcmd
|
||||
|
||||
go 1.16
|
||||
|
||||
require (
|
||||
github.com/alecthomas/kong v0.2.17
|
||||
github.com/denisenkom/go-mssqldb v0.10.0
|
||||
github.com/google/uuid v1.2.0
|
||||
github.com/stretchr/testify v1.7.0
|
||||
github.com/xo/usql v0.9.1
|
||||
)
|
Разница между файлами не показана из-за своего большого размера
Загрузить разницу
|
@ -0,0 +1,182 @@
|
|||
package sqlcmd
|
||||
|
||||
const minCapIncrease = 512
|
||||
|
||||
// lineend is the slice to use when appending a line.
|
||||
var lineend = []rune{'\n'}
|
||||
|
||||
// Batch provides the query text to run
|
||||
type Batch struct {
|
||||
// read provides the next chunk of runes
|
||||
read func() ([]rune, error)
|
||||
// Buffer is the current batch text
|
||||
Buffer []rune
|
||||
// Length is the length of the statement
|
||||
Length int
|
||||
// raw is the unprocessed runes
|
||||
raw []rune
|
||||
// rawlen is the number of unprocessed runes
|
||||
rawlen int
|
||||
// quote indicates currently processing a quoted string
|
||||
quote rune
|
||||
// comment is the state of multi-line comment processing
|
||||
comment bool
|
||||
// batchline is the 1-based index of the next line.
|
||||
// Used for the prompt in interactive mode
|
||||
batchline int
|
||||
// linecount is the total number of batch lines processed in the session
|
||||
linecount uint
|
||||
}
|
||||
|
||||
// NewBatch creates a Batch which converts runes provided by reader into SQL batches
|
||||
func NewBatch(reader func() ([]rune, error)) *Batch {
|
||||
b := &Batch{
|
||||
read: reader,
|
||||
}
|
||||
b.Reset(nil)
|
||||
return b
|
||||
}
|
||||
|
||||
// String returns the current SQL batch text
|
||||
func (b *Batch) String() string {
|
||||
return string(b.Buffer)
|
||||
}
|
||||
|
||||
// Reset clears the current batch text and replaces it with new runes
|
||||
func (b *Batch) Reset(r []rune) {
|
||||
b.Buffer, b.Length = nil, 0
|
||||
b.quote = 0
|
||||
b.comment = false
|
||||
b.batchline = 1
|
||||
if r != nil {
|
||||
b.raw, b.rawlen = r, len(r)
|
||||
}
|
||||
}
|
||||
|
||||
// Next processes the next chunk of input and sets the Batch state accordingly.
|
||||
// If the input contains a command to run, Next returns the Command and its
|
||||
// parameters.
|
||||
// Upon exit from Next, the caller can use the State method to determine if
|
||||
// it represents a runnable SQL batch text.
|
||||
func (b *Batch) Next() (*Command, []string, error) {
|
||||
var err error
|
||||
var i int
|
||||
if b.rawlen == 0 {
|
||||
b.raw, err = b.read()
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
b.rawlen = len(b.raw)
|
||||
}
|
||||
|
||||
var command *Command
|
||||
var args []string
|
||||
var ok bool
|
||||
var scannedCommand bool
|
||||
parse:
|
||||
for ; i < b.rawlen; i++ {
|
||||
c, next := b.raw[i], grab(b.raw, i+1, b.rawlen)
|
||||
switch {
|
||||
// we're in a quoted string
|
||||
case b.quote != 0:
|
||||
i, ok = readString(b.raw, i, b.rawlen, b.quote)
|
||||
if ok {
|
||||
b.quote = 0
|
||||
}
|
||||
// inside a multiline comment
|
||||
case b.comment:
|
||||
i, ok = readMultilineComment(b.raw, i, b.rawlen)
|
||||
b.comment = !ok
|
||||
// start of a string
|
||||
case c == '\'' || c == '"':
|
||||
b.quote = c
|
||||
// inline sql comment, skip to end of line
|
||||
case c == '-' && next == '-':
|
||||
i = b.rawlen
|
||||
// start a multi-line comment
|
||||
case c == '/' && next == '*':
|
||||
b.comment = true
|
||||
i++
|
||||
// continue processing quoted string or multiline comment
|
||||
case b.quote != 0 || b.comment:
|
||||
// Commands have to be alone on the line
|
||||
case !scannedCommand:
|
||||
var cend int
|
||||
scannedCommand = true
|
||||
command, args, cend = readCommand(b.raw, i, b.rawlen)
|
||||
if command != nil {
|
||||
// remove the command from raw
|
||||
b.raw = append(b.raw[:i], b.raw[cend:]...)
|
||||
break parse
|
||||
}
|
||||
}
|
||||
}
|
||||
i = min(i, b.rawlen)
|
||||
empty := isEmptyLine(b.raw, 0, i)
|
||||
appendLine := b.quote != 0 || b.comment || !empty
|
||||
if !b.comment && command != nil && empty {
|
||||
appendLine = false
|
||||
}
|
||||
if appendLine {
|
||||
// skip leading space when empty
|
||||
st := 0
|
||||
if b.Length == 0 {
|
||||
st, _ = findNonSpace(b.raw, 0, i)
|
||||
}
|
||||
// log.Printf(">> appending: `%s`", string(r[st:i]))
|
||||
b.append(b.raw[st:i], lineend)
|
||||
b.batchline++
|
||||
}
|
||||
b.raw = b.raw[i:]
|
||||
b.rawlen = len(b.raw)
|
||||
b.linecount++
|
||||
return command, args, nil
|
||||
}
|
||||
|
||||
// append appends r to b.Buffer separated by sep when b.Buffer is not already empty.
|
||||
//
|
||||
// Dynamically grows b.Buf as necessary to accommodate r and the separator.
|
||||
// Specifically, when b.Buf is not empty, b.Buf will grow by increments of
|
||||
// MinCapIncrease.
|
||||
//
|
||||
// After a call to append, b.Len will be len(b.Buf)+len(sep)+len(r). Call Reset
|
||||
// to reset the Buf.
|
||||
func (b *Batch) append(r, sep []rune) {
|
||||
rlen := len(r)
|
||||
// initial
|
||||
if b.Buffer == nil {
|
||||
b.Buffer, b.Length = r, rlen
|
||||
return
|
||||
}
|
||||
blen, seplen := b.Length, len(sep)
|
||||
tlen := blen + rlen + seplen
|
||||
// grow
|
||||
if bcap := cap(b.Buffer); tlen > bcap {
|
||||
n := tlen + 2*rlen
|
||||
n += minCapIncrease - (n % minCapIncrease)
|
||||
z := make([]rune, blen, n)
|
||||
copy(z, b.Buffer)
|
||||
b.Buffer = z
|
||||
}
|
||||
b.Buffer = b.Buffer[:tlen]
|
||||
copy(b.Buffer[blen:], sep)
|
||||
copy(b.Buffer[blen+seplen:], r)
|
||||
b.Length = tlen
|
||||
}
|
||||
|
||||
// State returns a string representing the state of statement parsing.
|
||||
// * Is in the middle of a multi-line comment
|
||||
// - Has a non-empty batch ready to run
|
||||
// = Is empty
|
||||
// ' " Is in the middle of a multi-line quoted string
|
||||
func (b *Batch) State() string {
|
||||
switch {
|
||||
case b.quote != 0:
|
||||
return string(b.quote)
|
||||
case b.comment:
|
||||
return "*"
|
||||
case b.Length != 0:
|
||||
return "-"
|
||||
}
|
||||
return "="
|
||||
}
|
|
@ -0,0 +1,73 @@
|
|||
package sqlcmd
|
||||
|
||||
import (
|
||||
"io"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestBatchNextReset(t *testing.T) {
|
||||
tests := []struct {
|
||||
s string
|
||||
stmts []string
|
||||
cmds []string
|
||||
state string
|
||||
}{
|
||||
{"", nil, nil, "="},
|
||||
{"select 1", []string{"select 1"}, nil, "-"},
|
||||
{"select 1\nquit", []string{"select 1"}, []string{"QUIT"}, "="},
|
||||
{"select 1\nquite", []string{"select 1\nquite"}, nil, "-"},
|
||||
{"select 1\nquit\nselect 2", []string{"select 1", "select 2"}, []string{"QUIT"}, "-"},
|
||||
{"select '1\n", []string{"select '1\n"}, nil, "'"},
|
||||
{"select 1 /* comment\nGO", []string{"select 1 /* comment\nGO"}, nil, "*"},
|
||||
{"select '1\n00' \n/* comm\nent*/\nGO 4", []string{"select '1\n00' \n/* comm\nent*/"}, []string{"GO"}, "="},
|
||||
}
|
||||
for _, test := range tests {
|
||||
b := NewBatch(sp(test.s, "\n"))
|
||||
var stmts, cmds []string
|
||||
loop:
|
||||
for {
|
||||
cmd, _, err := b.Next()
|
||||
switch {
|
||||
case err == io.EOF:
|
||||
// if we get EOF before a command we will try to run
|
||||
// whatever is in the buffer
|
||||
if s := b.String(); s != "" {
|
||||
stmts = append(stmts, s)
|
||||
}
|
||||
break loop
|
||||
case err != nil:
|
||||
t.Fatalf("test %s did not expect error, got: %v", test.s, err)
|
||||
}
|
||||
// resetting the buffer for every command purely for test purposes
|
||||
if cmd != nil {
|
||||
stmts = append(stmts, b.String())
|
||||
cmds = append(cmds, cmd.name)
|
||||
b.Reset(nil)
|
||||
}
|
||||
}
|
||||
assert.Equal(t, test.stmts, stmts, "Statements for %s", test.s)
|
||||
assert.Equal(t, test.state, b.State(), "State for %s", test.s)
|
||||
assert.Equal(t, test.cmds, cmds, "Commands for %s", test.s)
|
||||
b.Reset(nil)
|
||||
assert.Zero(t, b.Length, "Length after Reset")
|
||||
assert.Zero(t, len(b.Buffer), "len(Buffer) after Reset")
|
||||
assert.Zero(t, b.quote, "quote after Reset")
|
||||
assert.False(t, b.comment, "comment after Reset")
|
||||
assert.Equal(t, "=", b.State(), "State() after Reset")
|
||||
}
|
||||
}
|
||||
|
||||
func sp(a, sep string) func() ([]rune, error) {
|
||||
s := strings.Split(a, sep)
|
||||
return func() ([]rune, error) {
|
||||
if len(s) > 0 {
|
||||
z := s[0]
|
||||
s = s[1:]
|
||||
return []rune(z), nil
|
||||
}
|
||||
return nil, io.EOF
|
||||
}
|
||||
}
|
|
@ -0,0 +1,172 @@
|
|||
// Copyright (c) Microsoft Corporation.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
package sqlcmd
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"regexp"
|
||||
"strings"
|
||||
"syscall"
|
||||
)
|
||||
|
||||
// Command defines a sqlcmd action which can be intermixed with the SQL batch
|
||||
// Commands for sqlcmd are defined at https://docs.microsoft.com/sql/tools/sqlcmd-utility#sqlcmd-commands
|
||||
type Command struct {
|
||||
// regex must include at least one group if it has parameters
|
||||
// Will be matched using FindStringSubmatch
|
||||
regex *regexp.Regexp
|
||||
// The function that implements the command. Third parameter is the line number
|
||||
action func(*Sqlcmd, []string, uint) error
|
||||
// Name of the command
|
||||
name string
|
||||
}
|
||||
|
||||
// Commands is the set of Command implementations
|
||||
var Commands = map[string]*Command{
|
||||
"QUIT": {
|
||||
regex: regexp.MustCompile(`(?im)^[\t ]*?:?QUIT(?:[ \t]+(.*$)|$)`),
|
||||
action: quitCommand,
|
||||
name: "QUIT",
|
||||
},
|
||||
"GO": {
|
||||
regex: regexp.MustCompile(batchTerminatorRegex("GO")),
|
||||
action: goCommand,
|
||||
name: "GO",
|
||||
},
|
||||
"OUT": {
|
||||
regex: regexp.MustCompile(`(?im)^[ \t]*:OUT(?:[ \t]+(.*$)|$)`),
|
||||
action: outCommand,
|
||||
name: "OUT",
|
||||
},
|
||||
"ERROR": {
|
||||
regex: regexp.MustCompile(`(?im)^[ \t]*:ERROR(?:[ \t]+(.*$)|$)`),
|
||||
action: errorCommand,
|
||||
name: "ERROR",
|
||||
},
|
||||
}
|
||||
|
||||
func matchCommand(line string) (*Command, []string) {
|
||||
for _, cmd := range Commands {
|
||||
matchedCommand := cmd.regex.FindStringSubmatch(line)
|
||||
if matchedCommand != nil {
|
||||
return cmd, matchedCommand[1:]
|
||||
}
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func batchTerminatorRegex(terminator string) string {
|
||||
return fmt.Sprintf(`(?im)^[\t ]*?%s(?:[ ]+(.*$)|$)`, regexp.QuoteMeta(terminator))
|
||||
}
|
||||
|
||||
// SetBatchTerminator attempts to set the batch terminator to the given value
|
||||
// Returns an error if the new value is not usable in the regex
|
||||
func SetBatchTerminator(terminator string) error {
|
||||
cmd := Commands["GO"]
|
||||
regex, err := regexp.Compile(batchTerminatorRegex(terminator))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
cmd.regex = regex
|
||||
return nil
|
||||
}
|
||||
|
||||
// quitCommand immediately exits the program without running any more batches
|
||||
func quitCommand(s *Sqlcmd, args []string, line uint) error {
|
||||
if args != nil && strings.TrimSpace(args[0]) != "" {
|
||||
return InvalidCommandError("QUIT", line)
|
||||
}
|
||||
return ErrExitRequested
|
||||
}
|
||||
|
||||
// goCommand runs the current batch the number of times specified
|
||||
func goCommand(s *Sqlcmd, args []string, line uint) error {
|
||||
// default to 1 execution
|
||||
n := 1
|
||||
var err error
|
||||
if len(args) > 0 {
|
||||
cnt := strings.TrimSpace(args[0])
|
||||
if cnt != "" {
|
||||
_, err = fmt.Sscanf(cnt, "%d", &n)
|
||||
}
|
||||
}
|
||||
if err != nil || n < 1 {
|
||||
return InvalidCommandError("GO", line)
|
||||
}
|
||||
|
||||
// This loop will likely be refactored to a helper when we implement -Q and :EXIT(query)
|
||||
for i := 0; i < n; i++ {
|
||||
query := s.Query
|
||||
if query == "" {
|
||||
query = s.batch.String()
|
||||
}
|
||||
s.Format.BeginBatch(query, s.vars, s.GetOutput(), s.GetError())
|
||||
rows, qe := s.db.Query(query)
|
||||
if qe != nil {
|
||||
s.Format.AddError(qe)
|
||||
}
|
||||
|
||||
results := true
|
||||
for qe == nil && results {
|
||||
cols, err := rows.ColumnTypes()
|
||||
if err != nil {
|
||||
s.Format.AddError(err)
|
||||
} else {
|
||||
s.Format.BeginResultSet(cols)
|
||||
active := rows.Next()
|
||||
for active {
|
||||
s.Format.AddRow(rows)
|
||||
active = rows.Next()
|
||||
}
|
||||
if err = rows.Err(); err != nil {
|
||||
s.Format.AddError(err)
|
||||
}
|
||||
s.Format.EndResultSet()
|
||||
}
|
||||
results = rows.NextResultSet()
|
||||
if err = rows.Err(); err != nil {
|
||||
s.Format.AddError(err)
|
||||
}
|
||||
}
|
||||
s.Format.EndBatch()
|
||||
}
|
||||
s.Query = ""
|
||||
s.batch.Reset(nil)
|
||||
return nil
|
||||
}
|
||||
|
||||
// outCommand changes the output writer to use a file
|
||||
func outCommand(s *Sqlcmd, args []string, line uint) error {
|
||||
switch {
|
||||
case strings.EqualFold(args[0], "stdout"):
|
||||
s.SetOutput(nil)
|
||||
case strings.EqualFold(args[0], "stderr"):
|
||||
s.SetOutput(os.NewFile(uintptr(syscall.Stderr), "/dev/stderr"))
|
||||
default:
|
||||
o, err := os.OpenFile(args[0], os.O_TRUNC|os.O_CREATE|os.O_WRONLY, 0o644)
|
||||
if err != nil {
|
||||
return InvalidFileError(err, args[0])
|
||||
}
|
||||
s.SetOutput(o)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// errorCommand changes the error writer to use a file
|
||||
func errorCommand(s *Sqlcmd, args []string, line uint) error {
|
||||
switch {
|
||||
case strings.EqualFold(args[0], "stderr"):
|
||||
s.SetError(nil)
|
||||
case strings.EqualFold(args[0], "stdout"):
|
||||
s.SetError(os.NewFile(uintptr(syscall.Stderr), "/dev/stdout"))
|
||||
default:
|
||||
o, err := os.OpenFile(args[0], os.O_TRUNC|os.O_CREATE|os.O_WRONLY, 0o644)
|
||||
if err != nil {
|
||||
return InvalidFileError(err, args[0])
|
||||
}
|
||||
s.SetError(o)
|
||||
}
|
||||
return nil
|
||||
}
|
|
@ -0,0 +1,59 @@
|
|||
package sqlcmd
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestQuitCommand(t *testing.T) {
|
||||
s := &Sqlcmd{}
|
||||
err := quitCommand(s, nil, 1)
|
||||
require.ErrorIs(t, err, ErrExitRequested)
|
||||
err = quitCommand(s, []string{"extra parameters"}, 2)
|
||||
require.Error(t, err, "Quit should error out with extra parameters")
|
||||
assert.NotErrorIs(t, err, ErrExitRequested, "Error with extra arguments")
|
||||
}
|
||||
|
||||
func TestCommandParsing(t *testing.T) {
|
||||
type commandTest struct {
|
||||
line string
|
||||
cmd string
|
||||
args []string
|
||||
}
|
||||
|
||||
commands := []commandTest{
|
||||
{"quite", "", nil},
|
||||
{"quit", "QUIT", []string{""}},
|
||||
{":QUIT\n", "QUIT", []string{""}},
|
||||
{" QUIT \n", "QUIT", []string{""}},
|
||||
{"quit extra\n", "QUIT", []string{"extra"}},
|
||||
{`:Out c:\folder\file`, "OUT", []string{`c:\folder\file`}},
|
||||
{` :Error c:\folder\file`, "ERROR", []string{`c:\folder\file`}},
|
||||
}
|
||||
|
||||
for _, test := range commands {
|
||||
cmd, args := matchCommand(test.line)
|
||||
if test.cmd != "" {
|
||||
if assert.NotNil(t, cmd, "No command found for `%s`", test.line) {
|
||||
assert.Equal(t, test.cmd, cmd.name, "Incorrect command for `%s`", test.line)
|
||||
assert.Equal(t, test.args, args, "Incorrect arguments for `%s`", test.line)
|
||||
}
|
||||
} else {
|
||||
assert.Nil(t, cmd, "Unexpected match for %s", test.line)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestCustomBatchSeparator(t *testing.T) {
|
||||
err := SetBatchTerminator("me!")
|
||||
if assert.NoError(t, err, "SetBatchTerminator should succeed") {
|
||||
cmd, args := matchCommand(" me! 5 \n")
|
||||
if assert.NotNil(t, cmd, "matchCommand didn't find GO for custom batch separator") {
|
||||
assert.Equal(t, "GO", cmd.name, "command name")
|
||||
assert.Equal(t, "5", strings.TrimSpace(args[0]), "go argument")
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,69 @@
|
|||
package sqlcmd
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// ErrorPrefix is the prefix for all sqlcmd-generated errors
|
||||
const ErrorPrefix = "Sqlcmd: Error: "
|
||||
|
||||
// WarningPrefix is the prefix for all sqlcmd-generated warnings
|
||||
const WarningPrefix = "Sqlcmd: Warning: "
|
||||
|
||||
// ArgumentError is related to command line switch validation not handled by kong
|
||||
type ArgumentError struct {
|
||||
Parameter string
|
||||
Rule string
|
||||
}
|
||||
|
||||
func (e *ArgumentError) Error() string {
|
||||
return ErrorPrefix + e.Rule
|
||||
}
|
||||
|
||||
// InvalidServerName indicates the SQLCMDSERVER variable has an incorrect format
|
||||
var InvalidServerName = ArgumentError{
|
||||
Parameter: "server",
|
||||
Rule: "server must be of the form [tcp]:server[[/instance]|[,port]]",
|
||||
}
|
||||
|
||||
// VariableError is an error about scripting variables
|
||||
type VariableError struct {
|
||||
Variable string
|
||||
MessageFormat string
|
||||
}
|
||||
|
||||
func (e *VariableError) Error() string {
|
||||
return ErrorPrefix + fmt.Sprintf(e.MessageFormat, e.Variable)
|
||||
}
|
||||
|
||||
// ReadOnlyVariable indicates the user tried to set a value to a read-only variable
|
||||
func ReadOnlyVariable(variable string) *VariableError {
|
||||
return &VariableError{
|
||||
Variable: variable,
|
||||
MessageFormat: "The scripting variable: '%s' is read-only",
|
||||
}
|
||||
}
|
||||
|
||||
// CommandError indicates syntax errors for specific sqlcmd commands
|
||||
type CommandError struct {
|
||||
Command string
|
||||
LineNumber uint
|
||||
}
|
||||
|
||||
func (e *CommandError) Error() string {
|
||||
return ErrorPrefix + fmt.Sprintf("Syntax error at line %d near command '%s'.", e.LineNumber, e.Command)
|
||||
}
|
||||
|
||||
// InvalidCommandError creates a SQLCmdCommandError
|
||||
func InvalidCommandError(command string, lineNumber uint) *CommandError {
|
||||
return &CommandError{
|
||||
Command: command,
|
||||
LineNumber: lineNumber,
|
||||
}
|
||||
}
|
||||
|
||||
// InvalidFileError indicates a file could not be opened
|
||||
func InvalidFileError(err error, path string) error {
|
||||
return errors.New(ErrorPrefix + " Error occurred while opening or operating on file " + path + " (Reason: " + err.Error() + ").")
|
||||
}
|
|
@ -0,0 +1,562 @@
|
|||
// Copyright (c) Microsoft Corporation.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
package sqlcmd
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"io"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
mssql "github.com/denisenkom/go-mssqldb"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultMaxDisplayWidth = 1024 * 1024
|
||||
maxPadWidth = 8000
|
||||
)
|
||||
|
||||
// Formatter defines methods to process query output
|
||||
type Formatter interface {
|
||||
// BeginBatch is called before the query runs
|
||||
BeginBatch(query string, vars *Variables, out io.Writer, err io.Writer)
|
||||
// EndBatch is the last function called during batch execution and signals the end of the batch
|
||||
EndBatch()
|
||||
// BeginResultSet is called when a new result set is encountered
|
||||
BeginResultSet([]*sql.ColumnType)
|
||||
// EndResultSet is called after all rows in a result set have been processed
|
||||
EndResultSet()
|
||||
// AddRow is called for each row in a result set
|
||||
AddRow(*sql.Rows)
|
||||
// AddMessage is called for every information message returned by the server during the batch
|
||||
AddMessage(string)
|
||||
// AddError is called for each error encountered during batch execution
|
||||
AddError(err error)
|
||||
}
|
||||
|
||||
// ControlCharacterBehavior specifies the text handling required for control characters in the output
|
||||
type ControlCharacterBehavior int
|
||||
|
||||
const (
|
||||
// ControlIgnore preserves control characters in the output
|
||||
ControlIgnore ControlCharacterBehavior = iota
|
||||
// ControlReplace replaces control characters with spaces, 1 space per character
|
||||
ControlReplace
|
||||
// ControlRemove removes control characters from the output
|
||||
ControlRemove
|
||||
// ControlReplaceConsecutive replaces multiple consecutive control characters with a single space
|
||||
ControlReplaceConsecutive
|
||||
)
|
||||
|
||||
type columnDetail struct {
|
||||
displayWidth int64
|
||||
leftJustify bool
|
||||
zeroesAfterDecimal bool
|
||||
col sql.ColumnType
|
||||
}
|
||||
|
||||
// The default formatter based on the native sqlcmd style
|
||||
type sqlCmdFormatterType struct {
|
||||
out io.Writer
|
||||
err io.Writer
|
||||
vars *Variables
|
||||
colsep string
|
||||
removeTrailingSpaces bool
|
||||
ccb ControlCharacterBehavior
|
||||
columnDetails []columnDetail
|
||||
rowcount int
|
||||
writepos int64
|
||||
}
|
||||
|
||||
// NewSQLCmdDefaultFormatter returns a Formatter that mimics the original ODBC-based sqlcmd formatter
|
||||
func NewSQLCmdDefaultFormatter(removeTrailingSpaces bool) Formatter {
|
||||
return &sqlCmdFormatterType{
|
||||
removeTrailingSpaces: removeTrailingSpaces,
|
||||
}
|
||||
}
|
||||
|
||||
// Adds the given string to the current line, wrapping it based on the screen width setting
|
||||
func (f *sqlCmdFormatterType) writeOut(s string) {
|
||||
w := f.vars.ScreenWidth()
|
||||
if w == 0 {
|
||||
f.mustWriteOut(s)
|
||||
return
|
||||
}
|
||||
|
||||
r := []rune(s)
|
||||
for i := 0; true; {
|
||||
if i == len(r) {
|
||||
f.mustWriteOut(string(r))
|
||||
return
|
||||
} else if f.writepos == w {
|
||||
f.mustWriteOut(string(r[:i]))
|
||||
f.mustWriteOut(SqlcmdEol)
|
||||
r = []rune(string(r[i:]))
|
||||
f.writepos = 0
|
||||
i = 0
|
||||
} else {
|
||||
c := r[i]
|
||||
if c != '\r' && c != '\n' {
|
||||
f.writepos++
|
||||
} else {
|
||||
f.writepos = 0
|
||||
}
|
||||
i++
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Stores the settings to use for processing the current batch
|
||||
// TODO: add a third io.Writer for messages when we add -r support
|
||||
func (f *sqlCmdFormatterType) BeginBatch(_ string, vars *Variables, out io.Writer, err io.Writer) {
|
||||
f.out = out
|
||||
f.err = err
|
||||
f.vars = vars
|
||||
f.colsep = vars.ColumnSeparator()
|
||||
}
|
||||
|
||||
func (f *sqlCmdFormatterType) EndBatch() {
|
||||
}
|
||||
|
||||
// Calculate the widths for each column and print the column names
|
||||
// Since sql.ColumnType only provides sizes for variable length types we will
|
||||
// base our numbers for most types on https://docs.microsoft.com/sql/odbc/reference/appendixes/column-size
|
||||
func (f *sqlCmdFormatterType) BeginResultSet(cols []*sql.ColumnType) {
|
||||
f.rowcount = 0
|
||||
f.columnDetails = calcColumnDetails(cols, f.vars.MaxFixedColumnWidth(), f.vars.MaxVarColumnWidth())
|
||||
if f.vars.RowsBetweenHeaders() > -1 {
|
||||
f.printColumnHeadings()
|
||||
}
|
||||
}
|
||||
|
||||
// Writes a blank line to the designated output writer
|
||||
func (f *sqlCmdFormatterType) EndResultSet() {
|
||||
f.writeOut(SqlcmdEol)
|
||||
}
|
||||
|
||||
// Writes the current row to the designated output writer
|
||||
func (f *sqlCmdFormatterType) AddRow(row *sql.Rows) {
|
||||
|
||||
f.writepos = 0
|
||||
values, err := f.scanRow(row)
|
||||
if err != nil {
|
||||
f.mustWriteErr(err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// values are the full values, look at the displaywidth of each column and truncate accordingly
|
||||
for i, v := range values {
|
||||
if i > 0 {
|
||||
f.writeOut(f.vars.ColumnSeparator())
|
||||
}
|
||||
f.printColumnValue(v, i)
|
||||
}
|
||||
f.rowcount++
|
||||
gap := f.vars.RowsBetweenHeaders()
|
||||
if gap > 0 && (int64(f.rowcount)%gap == 0) {
|
||||
f.writeOut(SqlcmdEol)
|
||||
f.printColumnHeadings()
|
||||
}
|
||||
f.writeOut(SqlcmdEol)
|
||||
}
|
||||
|
||||
// Writes a non-error message to the designated message writer
|
||||
func (f *sqlCmdFormatterType) AddMessage(string) {}
|
||||
|
||||
// Writes an error to the designated err Writer
|
||||
func (f *sqlCmdFormatterType) AddError(err error) {
|
||||
b := new(strings.Builder)
|
||||
msg := err.Error()
|
||||
switch e := (err).(type) {
|
||||
case mssql.Error:
|
||||
b.WriteString(fmt.Sprintf("Msg %d, Level %d, State %d, Server %s, Line %d%s", e.Number, e.Class, e.State, e.ServerName, e.LineNo, SqlcmdEol))
|
||||
msg = strings.TrimPrefix(msg, "mssql: ")
|
||||
}
|
||||
b.WriteString(msg)
|
||||
b.WriteString(SqlcmdEol)
|
||||
f.mustWriteErr(fitToScreen(b, f.vars.ScreenWidth()).String())
|
||||
}
|
||||
|
||||
// Prints column headings based on columnDetail, variables, and command line arguments
|
||||
func (f *sqlCmdFormatterType) printColumnHeadings() {
|
||||
names := new(strings.Builder)
|
||||
sep := new(strings.Builder)
|
||||
|
||||
var leftPad, rightPad int64
|
||||
for i, c := range f.columnDetails {
|
||||
nameLen := int64(len([]rune(c.col.Name())))
|
||||
if f.removeTrailingSpaces {
|
||||
if nameLen == 0 {
|
||||
// special case for unnamed columns when using -W
|
||||
// print a single -
|
||||
rightPad = 1
|
||||
sep = padRight(sep, 1, "-")
|
||||
} else {
|
||||
sep = padRight(sep, nameLen, "-")
|
||||
}
|
||||
} else {
|
||||
length := min64(c.displayWidth, maxPadWidth)
|
||||
if nameLen < length {
|
||||
rightPad = length - nameLen
|
||||
}
|
||||
sep = padRight(sep, length, "-")
|
||||
}
|
||||
names = padRight(names, leftPad, " ")
|
||||
names.WriteString(c.col.Name()[:min64(nameLen, c.displayWidth)])
|
||||
names = padRight(names, rightPad, " ")
|
||||
if i != len(f.columnDetails)-1 {
|
||||
names.WriteString(f.colsep)
|
||||
sep.WriteString(f.colsep)
|
||||
}
|
||||
}
|
||||
names.WriteString(SqlcmdEol)
|
||||
sep.WriteString(SqlcmdEol)
|
||||
names = fitToScreen(names, f.vars.ScreenWidth())
|
||||
sep = fitToScreen(sep, f.vars.ScreenWidth())
|
||||
f.mustWriteOut(names.String())
|
||||
f.mustWriteOut(sep.String())
|
||||
}
|
||||
|
||||
// Wraps the input string every width characters when width > 0
|
||||
// When width == 0 returns the input Builder
|
||||
// When width > 0 returns a new Builder containing the wrapped string
|
||||
func fitToScreen(s *strings.Builder, width int64) *strings.Builder {
|
||||
str := s.String()
|
||||
runes := []rune(str)
|
||||
if width == 0 || int64(len(runes)) < width {
|
||||
return s
|
||||
}
|
||||
|
||||
line := new(strings.Builder)
|
||||
line.Grow(len(str))
|
||||
var c int64
|
||||
for i, r := range runes {
|
||||
if c == width {
|
||||
// We have printed a line's worth
|
||||
// if the next character is not part of a carriage return write our Eol
|
||||
if (SqlcmdEol == "\r\n" && (i == len(runes)-1 || (i < len(runes)-1 && string(runes[i:i+2]) != SqlcmdEol))) || (SqlcmdEol == "\n" && r != '\n') {
|
||||
line.WriteString(SqlcmdEol)
|
||||
c = 0
|
||||
}
|
||||
}
|
||||
line.WriteRune(r)
|
||||
if r == '\n' {
|
||||
c = 0
|
||||
// we are assuming \r is a non-printed character
|
||||
// The likelihood of a \r not being followed by \n is low
|
||||
} else if r == '\r' && SqlcmdEol == "\r\n" {
|
||||
c = 0
|
||||
} else {
|
||||
c++
|
||||
}
|
||||
}
|
||||
return line
|
||||
}
|
||||
|
||||
// Given the array of driver-provided columnType values and the sqlcmd size limits,
|
||||
// return an array of columnDetail objects describing the output format for each column
|
||||
func calcColumnDetails(cols []*sql.ColumnType, fixed int64, variable int64) (columnDetails []columnDetail) {
|
||||
columnDetails = make([]columnDetail, len(cols))
|
||||
for i, c := range cols {
|
||||
length, _ := c.Length()
|
||||
nameLen := int64(len([]rune(c.Name())))
|
||||
columnDetails[i].col = *c
|
||||
columnDetails[i].leftJustify = true
|
||||
columnDetails[i].zeroesAfterDecimal = false
|
||||
if length == 0 {
|
||||
columnDetails[i].displayWidth = defaultMaxDisplayWidth
|
||||
} else {
|
||||
columnDetails[i].displayWidth = length
|
||||
}
|
||||
switch c.DatabaseTypeName() {
|
||||
// Types with 0 size from sql.ColumnType
|
||||
case "BIT":
|
||||
columnDetails[i].leftJustify = false
|
||||
columnDetails[i].displayWidth = max64(1, nameLen)
|
||||
case "TINYINT":
|
||||
columnDetails[i].leftJustify = false
|
||||
columnDetails[i].displayWidth = max64(3, nameLen)
|
||||
case "SMALLINT":
|
||||
columnDetails[i].leftJustify = false
|
||||
columnDetails[i].displayWidth = max64(6, nameLen)
|
||||
case "INT":
|
||||
columnDetails[i].leftJustify = false
|
||||
columnDetails[i].displayWidth = max64(11, nameLen)
|
||||
case "BIGINT":
|
||||
columnDetails[i].leftJustify = false
|
||||
columnDetails[i].displayWidth = max64(21, nameLen)
|
||||
case "REAL":
|
||||
columnDetails[i].leftJustify = false
|
||||
columnDetails[i].displayWidth = max64(14, nameLen)
|
||||
columnDetails[i].zeroesAfterDecimal = true
|
||||
case "FLOAT":
|
||||
columnDetails[i].leftJustify = false
|
||||
columnDetails[i].displayWidth = max64(24, nameLen)
|
||||
columnDetails[i].zeroesAfterDecimal = true
|
||||
case "DECIMAL":
|
||||
columnDetails[i].leftJustify = false
|
||||
d, _, ok := c.DecimalSize()
|
||||
// maybe panic on !ok?
|
||||
if !ok {
|
||||
d = 24
|
||||
}
|
||||
columnDetails[i].displayWidth = max64(d+2, nameLen)
|
||||
columnDetails[i].zeroesAfterDecimal = true
|
||||
case "DATE":
|
||||
columnDetails[i].leftJustify = false
|
||||
columnDetails[i].displayWidth = max64(16, nameLen)
|
||||
case "DATETIME":
|
||||
columnDetails[i].leftJustify = false
|
||||
columnDetails[i].displayWidth = max64(23, nameLen)
|
||||
case "SMALLDATETIME":
|
||||
columnDetails[i].leftJustify = false
|
||||
columnDetails[i].displayWidth = max64(19, nameLen)
|
||||
columnDetails[i].zeroesAfterDecimal = true
|
||||
case "DATETIME2":
|
||||
columnDetails[i].leftJustify = false
|
||||
columnDetails[i].displayWidth = max64(38, nameLen)
|
||||
columnDetails[i].zeroesAfterDecimal = true
|
||||
case "DATETIMEOFFSET":
|
||||
columnDetails[i].leftJustify = false
|
||||
columnDetails[i].displayWidth = max64(45, nameLen)
|
||||
case "UNIQUEIDENTIFIER":
|
||||
columnDetails[i].displayWidth = max64(36, nameLen)
|
||||
// Types that can be fixed or variable
|
||||
case "VARCHAR":
|
||||
if length > 8000 {
|
||||
columnDetails[i].displayWidth = variable
|
||||
} else {
|
||||
if fixed > 0 {
|
||||
length = min64(fixed, length)
|
||||
}
|
||||
columnDetails[i].displayWidth = max64(length, nameLen)
|
||||
}
|
||||
case "NVARCHAR":
|
||||
if length > 4000 {
|
||||
columnDetails[i].displayWidth = variable
|
||||
} else {
|
||||
if fixed > 0 {
|
||||
length = min64(fixed, length)
|
||||
}
|
||||
columnDetails[i].displayWidth = max64(length, nameLen)
|
||||
}
|
||||
case "VARBINARY":
|
||||
if length <= 8000 {
|
||||
if fixed > 0 {
|
||||
length = min64(fixed, length)
|
||||
}
|
||||
columnDetails[i].displayWidth = max64(length, nameLen)
|
||||
} else {
|
||||
columnDetails[i].displayWidth = variable
|
||||
}
|
||||
// Fixed length types
|
||||
case "CHAR", "NCHAR", "VARIANT":
|
||||
if fixed > 0 {
|
||||
length = min64(fixed, length)
|
||||
}
|
||||
columnDetails[i].displayWidth = max64(length, nameLen)
|
||||
// Variable length types
|
||||
// TODO: Fix BINARY once we have a driver with fix for https://github.com/denisenkom/go-mssqldb/issues/685
|
||||
case "XML", "TEXT", "NTEXT", "IMAGE", "BINARY":
|
||||
columnDetails[i].displayWidth = variable
|
||||
default:
|
||||
columnDetails[i].displayWidth = length
|
||||
}
|
||||
// When max var length is 0 we don't print column headers and print every value with unlimited width
|
||||
if variable == 0 {
|
||||
columnDetails[i].displayWidth = 0
|
||||
}
|
||||
}
|
||||
return columnDetails
|
||||
}
|
||||
|
||||
// scanRow fetches the next row and converts each value to the appropriate string representation
|
||||
func (f *sqlCmdFormatterType) scanRow(rows *sql.Rows) ([]string, error) {
|
||||
r := make([]interface{}, len(f.columnDetails))
|
||||
for i := range r {
|
||||
r[i] = new(interface{})
|
||||
}
|
||||
if err := rows.Scan(r...); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
row := make([]string, len(f.columnDetails))
|
||||
for n, z := range r {
|
||||
j := z.(*interface{})
|
||||
if *j == nil {
|
||||
row[n] = "NULL"
|
||||
} else {
|
||||
switch x := (*j).(type) {
|
||||
case []byte:
|
||||
if isBinaryDataType(&f.columnDetails[n].col) {
|
||||
row[n] = decodeBinary(x)
|
||||
} else {
|
||||
row[n] = string(x)
|
||||
}
|
||||
case string:
|
||||
row[n] = x
|
||||
case time.Time:
|
||||
// Go lacks any way to get the user's preferred time format or even the system default
|
||||
row[n] = x.String()
|
||||
case fmt.Stringer:
|
||||
row[n] = x.String()
|
||||
// not sure why go-mssql reports bit as bool
|
||||
case bool:
|
||||
if x {
|
||||
row[n] = "1"
|
||||
} else {
|
||||
row[n] = "0"
|
||||
}
|
||||
default:
|
||||
var err error
|
||||
if row[n], err = fmt.Sprintf("%v", x), nil; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return row, nil
|
||||
}
|
||||
|
||||
// Prints the final version of a cell based on formatting variables and command line parameters
|
||||
func (f *sqlCmdFormatterType) printColumnValue(val string, col int) {
|
||||
c := f.columnDetails[col]
|
||||
s := new(strings.Builder)
|
||||
if isNeedingControlCharacterTreatment(&c.col) {
|
||||
val = applyControlCharacterBehavior(val, f.ccb)
|
||||
}
|
||||
|
||||
if isNeedingHexPrefix(&c.col) {
|
||||
val = "0x" + val
|
||||
}
|
||||
|
||||
s.WriteString(val)
|
||||
r := []rune(val)
|
||||
if !f.removeTrailingSpaces {
|
||||
if f.vars.MaxVarColumnWidth() != 0 || !isLargeVariableType(&c.col) {
|
||||
padding := c.displayWidth - min64(c.displayWidth, int64(len(r)))
|
||||
if padding > 0 {
|
||||
if c.leftJustify {
|
||||
s = padRight(s, padding, " ")
|
||||
} else {
|
||||
s = padLeft(s, padding, " ")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
r = []rune(s.String())
|
||||
if c.displayWidth > 0 && int64(len(r)) > c.displayWidth {
|
||||
s.Reset()
|
||||
s.WriteString(string(r[:c.displayWidth]))
|
||||
}
|
||||
f.writeOut(s.String())
|
||||
}
|
||||
|
||||
func (f *sqlCmdFormatterType) mustWriteOut(s string) {
|
||||
_, err := f.out.Write([]byte(s))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
func (f *sqlCmdFormatterType) mustWriteErr(s string) {
|
||||
_, err := f.err.Write([]byte(s))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
func isLargeVariableType(col *sql.ColumnType) bool {
|
||||
l, _ := col.Length()
|
||||
switch col.DatabaseTypeName() {
|
||||
|
||||
case "VARCHAR", "VARBINARY":
|
||||
return l > 8000
|
||||
case "NVARCHAR":
|
||||
return l > 4000
|
||||
case "XML", "TEXT", "NTEXT", "IMAGE":
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func isNeedingControlCharacterTreatment(col *sql.ColumnType) bool {
|
||||
switch col.DatabaseTypeName() {
|
||||
case "CHAR", "VARCHAR", "TEXT", "NTEXT", "NCHAR", "NVARCHAR", "XML":
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
func isBinaryDataType(col *sql.ColumnType) bool {
|
||||
switch col.DatabaseTypeName() {
|
||||
case "BINARY", "VARBINARY":
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func isNeedingHexPrefix(col *sql.ColumnType) bool {
|
||||
return isBinaryDataType(col) // || col.DatabaseTypeName() == "UDT"
|
||||
}
|
||||
|
||||
func isControlChar(r rune) bool {
|
||||
c := int(r)
|
||||
return c == 0x7f || (c >= 0 && c <= 0x1f)
|
||||
}
|
||||
|
||||
func applyControlCharacterBehavior(val string, ccb ControlCharacterBehavior) string {
|
||||
if ccb == ControlIgnore {
|
||||
return val
|
||||
}
|
||||
b := new(strings.Builder)
|
||||
r := []rune(val)
|
||||
if ccb == ControlReplace {
|
||||
for _, l := range r {
|
||||
if isControlChar(l) {
|
||||
b.WriteRune(' ')
|
||||
} else {
|
||||
b.WriteRune(l)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for i := 0; i < len(r); {
|
||||
if !isControlChar(r[i]) {
|
||||
b.WriteRune(r[i])
|
||||
i++
|
||||
} else {
|
||||
for ; i < len(r) && isControlChar(r[i]); i++ {
|
||||
}
|
||||
if ccb == ControlReplaceConsecutive {
|
||||
b.WriteRune(' ')
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return b.String()
|
||||
}
|
||||
|
||||
// Per https://docs.microsoft.com/sql/odbc/reference/appendixes/sql-to-c-binary
|
||||
var hexDigits = []rune{'A', 'B', 'C', 'D', 'E', 'F'}
|
||||
|
||||
func decodeBinary(b []byte) string {
|
||||
|
||||
s := new(strings.Builder)
|
||||
s.Grow(len(b) * 2)
|
||||
for _, ch := range b {
|
||||
b1 := ch >> 4
|
||||
b2 := ch & 0x0f
|
||||
if b1 >= 10 {
|
||||
s.WriteRune(hexDigits[b1-10])
|
||||
} else {
|
||||
s.WriteRune(rune('0' + b1))
|
||||
}
|
||||
if b2 >= 10 {
|
||||
s.WriteRune(hexDigits[b2-10])
|
||||
} else {
|
||||
s.WriteRune(rune('0' + b2))
|
||||
}
|
||||
}
|
||||
return s.String()
|
||||
}
|
|
@ -0,0 +1,7 @@
|
|||
// Copyright (c) Microsoft Corporation.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
package sqlcmd
|
||||
|
||||
// SqlcmdEol is the end-of-line marker for sqlcmd output
|
||||
const SqlcmdEol = "\n"
|
|
@ -0,0 +1,7 @@
|
|||
// Copyright (c) Microsoft Corporation.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
package sqlcmd
|
||||
|
||||
// SqlcmdEol is the end-of-line marker for sqlcmd output
|
||||
const SqlcmdEol = "\n"
|
|
@ -0,0 +1,136 @@
|
|||
// Copyright (c) Microsoft Corporation.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
package sqlcmd
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestFitToScreen(t *testing.T) {
|
||||
type fitTest struct {
|
||||
width int64
|
||||
raw string
|
||||
fit string
|
||||
}
|
||||
|
||||
tests := []fitTest{
|
||||
{0, "this is a string", "this is a string"},
|
||||
{9, "12345678", "12345678"},
|
||||
{9, "123456789", "123456789"},
|
||||
{9, "123456789A", "123456789" + SqlcmdEol + "A"},
|
||||
{9, "123456789" + SqlcmdEol, "123456789" + SqlcmdEol},
|
||||
{9, "12345678" + SqlcmdEol + "9A", "12345678" + SqlcmdEol + "9A"},
|
||||
{9, "123456789\rA", "123456789" + SqlcmdEol + "\rA"},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
|
||||
line := new(strings.Builder)
|
||||
line.WriteString(test.raw)
|
||||
t.Log(test.raw)
|
||||
f := fitToScreen(line, test.width).String()
|
||||
assert.Equal(t, test.fit, f, "Mismatched fit for raw string: '%s'", test.raw)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCalcColumnDetails(t *testing.T) {
|
||||
type colTest struct {
|
||||
fixed int64
|
||||
variable int64
|
||||
query string
|
||||
details []columnDetail
|
||||
}
|
||||
|
||||
tests := []colTest{
|
||||
{8, 8,
|
||||
"select 100 as '123456789ABC', getdate() as '987654321', 'string' as col1",
|
||||
[]columnDetail{
|
||||
{leftJustify: false, displayWidth: 12},
|
||||
{leftJustify: false, displayWidth: 23},
|
||||
{leftJustify: true, displayWidth: 6},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
db, err := ConnectDb()
|
||||
if assert.NoError(t, err, "ConnectDB failed") {
|
||||
defer db.Close()
|
||||
for _, test := range tests {
|
||||
rows, err := db.Query(test.query)
|
||||
if assert.NoError(t, err, "Query failed: %s", test.query) {
|
||||
defer rows.Close()
|
||||
cols, err := rows.ColumnTypes()
|
||||
if assert.NoError(t, err, "ColumnTypes failed:%s", test.query) {
|
||||
actual := calcColumnDetails(cols, test.fixed, test.variable)
|
||||
for i, a := range actual {
|
||||
if test.details[i].displayWidth != a.displayWidth ||
|
||||
test.details[i].leftJustify != a.leftJustify ||
|
||||
test.details[i].zeroesAfterDecimal != a.zeroesAfterDecimal {
|
||||
assert.Failf(t, "", "Incorrect test details for column [%s] in query '%s':%+v", cols[i].Name(), test.query, a)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestControlCharacterBehavior(t *testing.T) {
|
||||
type ccbTest struct {
|
||||
raw string
|
||||
replaced string
|
||||
removed string
|
||||
consecutivereplaced string
|
||||
}
|
||||
|
||||
tests := []ccbTest{
|
||||
{"no control", "no control", "no control", "no control"},
|
||||
{string(rune(1)) + "tabs\t\treturns\r\n\r\n", " tabs returns ", "tabsreturns", " tabs returns "},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
s := applyControlCharacterBehavior(test.raw, ControlReplace)
|
||||
assert.Equalf(t, test.replaced, s, "Incorrect Replaced for '%s'", test.raw)
|
||||
s = applyControlCharacterBehavior(test.raw, ControlRemove)
|
||||
assert.Equalf(t, test.removed, s, "Incorrect Remove for '%s'", test.raw)
|
||||
s = applyControlCharacterBehavior(test.raw, ControlReplaceConsecutive)
|
||||
assert.Equalf(t, test.consecutivereplaced, s, "Incorrect ReplaceConsecutive for '%s'", test.raw)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecodeBinary(t *testing.T) {
|
||||
type decodeTest struct {
|
||||
b []byte
|
||||
s string
|
||||
}
|
||||
|
||||
tests := []decodeTest{
|
||||
{[]byte("123456ABCDEF"), "313233343536414243444546"},
|
||||
{[]byte{0x12, 0x34, 0x56}, "123456"},
|
||||
}
|
||||
for _, test := range tests {
|
||||
a := decodeBinary(test.b)
|
||||
assert.Equalf(t, test.s, a, "Incorrect decoded binary string for %v", test.b)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkDecodeBinary(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
bytes := make([]byte, 10000)
|
||||
for i := 0; i < 10000; i++ {
|
||||
bytes[i] = byte(i % 0xff)
|
||||
}
|
||||
b.ResetTimer()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
s := decodeBinary(bytes)
|
||||
if len(s) != 20000 {
|
||||
b.Fatalf("Incorrect length of returned string. Should be 20k, was %d", len(s))
|
||||
}
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,7 @@
|
|||
// Copyright (c) Microsoft Corporation.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
package sqlcmd
|
||||
|
||||
// SqlcmdEol is the end-of-line marker for sqlcmd output
|
||||
const SqlcmdEol = "\r\n"
|
|
@ -0,0 +1,147 @@
|
|||
// Copyright (c) Microsoft Corporation.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
package sqlcmd
|
||||
|
||||
import "unicode"
|
||||
|
||||
// grab grabs i from r, or returns 0 if i >= end.
|
||||
func grab(r []rune, i, end int) rune {
|
||||
if i < end {
|
||||
return r[i]
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// findNonSpace finds first non space rune in r, returning end if not found.
|
||||
func findNonSpace(r []rune, i, end int) (int, bool) {
|
||||
for ; i < end; i++ {
|
||||
if !isSpaceOrControl(r[i]) {
|
||||
return i, true
|
||||
}
|
||||
}
|
||||
return i, false
|
||||
}
|
||||
|
||||
/*
|
||||
// findSpace finds first space rune in r, returning end if not found.
|
||||
func findSpace(r []rune, i, end int) (int, bool) {
|
||||
for ; i < end; i++ {
|
||||
if IsSpaceOrControl(r[i]) {
|
||||
return i, true
|
||||
}
|
||||
}
|
||||
return i, false
|
||||
}
|
||||
|
||||
|
||||
// findRune finds the next rune c in r, returning end if not found.
|
||||
func findRune(r []rune, i, end int, c rune) (int, bool) {
|
||||
for ; i < end; i++ {
|
||||
if r[i] == c {
|
||||
return i, true
|
||||
}
|
||||
}
|
||||
return i, false
|
||||
}
|
||||
|
||||
*/
|
||||
|
||||
// isEmptyLine returns true when r is empty or composed of only whitespace.
|
||||
func isEmptyLine(r []rune, i, end int) bool {
|
||||
_, ok := findNonSpace(r, i, end)
|
||||
return !ok
|
||||
}
|
||||
|
||||
// readString seeks to the end of a string returning the position and whether
|
||||
// or not the string's end was found.
|
||||
//
|
||||
// If the string's terminator was not found, then the result will be the passed
|
||||
// end.
|
||||
func readString(r []rune, i, end int, quote rune) (int, bool) {
|
||||
var prev, c, next rune
|
||||
for ; i < end; i++ {
|
||||
c, next = r[i], grab(r, i+1, end)
|
||||
switch {
|
||||
case quote == '\'' && c == '\\':
|
||||
i++
|
||||
prev = 0
|
||||
continue
|
||||
case quote == '\'' && c == '\'' && next == '\'':
|
||||
i++
|
||||
continue
|
||||
case quote == '\'' && c == '\'' && prev != '\'',
|
||||
quote == '"' && c == '"':
|
||||
return i, true
|
||||
}
|
||||
prev = c
|
||||
}
|
||||
return end, false
|
||||
}
|
||||
|
||||
// readMultilineComment finds the end of a multiline comment (ie, '*/').
|
||||
func readMultilineComment(r []rune, i, end int) (int, bool) {
|
||||
i++
|
||||
for ; i < end; i++ {
|
||||
if r[i-1] == '*' && r[i] == '/' {
|
||||
return i, true
|
||||
}
|
||||
}
|
||||
return end, false
|
||||
}
|
||||
|
||||
// Read to the next control character and try to find
|
||||
// a command in the string. Command regexes constrain matches
|
||||
// to the beginning of the string, and all commands consume
|
||||
// an entire line.
|
||||
func readCommand(r []rune, i, end int) (*Command, []string, int) {
|
||||
for ; i < end; i++ {
|
||||
next := grab(r, i, end)
|
||||
if next == 0 || unicode.IsControl(next) {
|
||||
break
|
||||
}
|
||||
}
|
||||
cmd, args := matchCommand(string(r[:i]))
|
||||
return cmd, args, i
|
||||
}
|
||||
|
||||
func max64(a, b int64) int64 {
|
||||
if a > b {
|
||||
return a
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
// min returns the minimum of a, b.
|
||||
func min(a, b int) int {
|
||||
if a < b {
|
||||
return a
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
func min64(a, b int64) int64 {
|
||||
if a < b {
|
||||
return a
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
// isSpaceOrControl is a special test for either a space or a control (ie, \b)
|
||||
// characters.
|
||||
func isSpaceOrControl(r rune) bool {
|
||||
return unicode.IsSpace(r) || unicode.IsControl(r)
|
||||
}
|
||||
|
||||
/*
|
||||
// runesLastIndex returns the last index in r of needle, or -1 if not found.
|
||||
func runesLastIndex(r []rune, needle rune) int {
|
||||
i := len(r) - 1
|
||||
for ; i >= 0; i-- {
|
||||
if r[i] == needle {
|
||||
return i
|
||||
}
|
||||
}
|
||||
return i
|
||||
}
|
||||
*/
|
|
@ -0,0 +1,61 @@
|
|||
package sqlcmd
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestReadString(t *testing.T) {
|
||||
tests := []struct {
|
||||
// input string
|
||||
s string
|
||||
// index to start inside s
|
||||
i int
|
||||
// expected return string
|
||||
exp string
|
||||
// expected return bool
|
||||
ok bool
|
||||
}{
|
||||
{`'`, 0, ``, false},
|
||||
{` '`, 1, ``, false},
|
||||
{`'str' `, 0, `'str'`, true},
|
||||
{` 'str' `, 1, `'str'`, true},
|
||||
{`"str"`, 0, `"str"`, true},
|
||||
{`'str''str'`, 0, `'str''str'`, true},
|
||||
{` 'str''str' `, 1, `'str''str'`, true},
|
||||
{` "str''str" `, 1, `"str''str"`, true},
|
||||
// escaped \" aren't allowed in strings, so the second " would be next
|
||||
// double quoted string
|
||||
{`"str\""`, 0, `"str\"`, true},
|
||||
{` "str\"" `, 1, `"str\"`, true},
|
||||
{`''''`, 0, `''''`, true},
|
||||
{` '''' `, 1, `''''`, true},
|
||||
{`''''''`, 0, `''''''`, true},
|
||||
{` '''''' `, 1, `''''''`, true},
|
||||
{`'''`, 0, ``, false},
|
||||
{` ''' `, 1, ``, false},
|
||||
{`'''''`, 0, ``, false},
|
||||
{` ''''' `, 1, ``, false},
|
||||
{`"st'r"`, 0, `"st'r"`, true},
|
||||
{` "st'r" `, 1, `"st'r"`, true},
|
||||
{`"st''r"`, 0, `"st''r"`, true},
|
||||
{` "st''r" `, 1, `"st''r"`, true},
|
||||
}
|
||||
for _, test := range tests {
|
||||
r := []rune(test.s)
|
||||
c, end := rune(strings.TrimSpace(test.s)[0]), len(r)
|
||||
if c != '\'' && c != '"' {
|
||||
t.Fatalf("test %+v incorrect!", test)
|
||||
}
|
||||
pos, ok := readString(r, test.i+1, end, c)
|
||||
assert.Equal(t, test.ok, ok, "test %+v ok", test)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
assert.Equal(t, c, r[pos], "test %+v last character")
|
||||
v := string(r[test.i : pos+1])
|
||||
assert.Equal(t, test.exp, v, "test %+v returned string", test)
|
||||
}
|
||||
}
|
|
@ -0,0 +1,287 @@
|
|||
// Copyright (c) Microsoft Corporation.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
package sqlcmd
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/url"
|
||||
"os"
|
||||
"os/signal"
|
||||
osuser "os/user"
|
||||
"syscall"
|
||||
|
||||
mssql "github.com/denisenkom/go-mssqldb"
|
||||
"github.com/xo/usql/rline"
|
||||
)
|
||||
|
||||
var (
|
||||
// ErrExitRequested tells the hosting application to exit immediately
|
||||
ErrExitRequested = errors.New("exit")
|
||||
// ErrNeedPassword indicates the user should provide a password to enable the connection
|
||||
ErrNeedPassword = errors.New("need password")
|
||||
// ErrCtrlC indicates execution was ended by ctrl-c or ctrl-break
|
||||
ErrCtrlC = errors.New(WarningPrefix + "The last operation was terminated because the user pressed CTRL+C")
|
||||
)
|
||||
|
||||
// ConnectSettings are the settings for connections that can't be
|
||||
// inferred from scripting variables
|
||||
type ConnectSettings struct {
|
||||
UseTrustedConnection bool
|
||||
TrustServerCertificate bool
|
||||
}
|
||||
|
||||
// Sqlcmd is the core processor for text lines.
|
||||
//
|
||||
// It accumulates non-command lines in a buffer and and sends command lines to the appropriate command runner.
|
||||
// When the batch delimiter is encountered it sends the current batch to the active connection and prints
|
||||
// the results to the output writer
|
||||
type Sqlcmd struct {
|
||||
lineIo rline.IO
|
||||
workingDirectory string
|
||||
db *sql.DB
|
||||
out io.WriteCloser
|
||||
err io.WriteCloser
|
||||
batch *Batch
|
||||
// Exitcode is returned to the operating system when the process exits
|
||||
Exitcode int
|
||||
Connect ConnectSettings
|
||||
vars *Variables
|
||||
Format Formatter
|
||||
Query string
|
||||
}
|
||||
|
||||
// New creates a new Sqlcmd instance
|
||||
func New(l rline.IO, workingDirectory string, vars *Variables) *Sqlcmd {
|
||||
return &Sqlcmd{
|
||||
lineIo: l,
|
||||
workingDirectory: workingDirectory,
|
||||
batch: NewBatch(l.Next),
|
||||
vars: vars,
|
||||
}
|
||||
}
|
||||
|
||||
// Run processes all available batches.
|
||||
// When once is true it stops after the first query runs.
|
||||
func (s *Sqlcmd) Run(once bool) error {
|
||||
setupCloseHandler(s)
|
||||
stderr, iactive := s.GetError(), s.lineIo.Interactive()
|
||||
var lastError error
|
||||
for {
|
||||
var execute bool
|
||||
if iactive {
|
||||
s.lineIo.Prompt(s.Prompt())
|
||||
}
|
||||
var cmd *Command
|
||||
var args []string
|
||||
var err error
|
||||
if s.Query != "" {
|
||||
cmd = Commands["GO"]
|
||||
args = make([]string, 0)
|
||||
} else {
|
||||
cmd, args, err = s.batch.Next()
|
||||
}
|
||||
switch {
|
||||
case err == rline.ErrInterrupt:
|
||||
// Ignore any error printing the ctrl-c notice since we are exiting
|
||||
_, _ = s.GetOutput().Write([]byte(ErrCtrlC.Error()))
|
||||
return nil
|
||||
case err != nil:
|
||||
if err == io.EOF {
|
||||
if s.batch.Length == 0 {
|
||||
return lastError
|
||||
}
|
||||
execute = true
|
||||
} else {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if cmd != nil {
|
||||
err = s.RunCommand(cmd, args)
|
||||
if err == ErrExitRequested || once {
|
||||
s.SetOutput(nil)
|
||||
s.SetError(nil)
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
fmt.Fprintln(stderr, err)
|
||||
lastError = err
|
||||
continue
|
||||
}
|
||||
}
|
||||
if execute {
|
||||
s.Query = s.batch.String()
|
||||
once = true
|
||||
s.batch.Reset(nil)
|
||||
}
|
||||
}
|
||||
return lastError
|
||||
}
|
||||
|
||||
// Prompt returns the current user prompt message
|
||||
func (s *Sqlcmd) Prompt() string {
|
||||
ch := ">"
|
||||
if s.batch.quote != 0 || s.batch.comment {
|
||||
ch = "~"
|
||||
}
|
||||
return fmt.Sprint(s.batch.batchline) + ch + " "
|
||||
}
|
||||
|
||||
// RunCommand performs the given Command
|
||||
func (s *Sqlcmd) RunCommand(cmd *Command, args []string) error {
|
||||
return cmd.action(s, args, s.batch.linecount)
|
||||
}
|
||||
|
||||
// GetOutput returns the io.Writer to use for non-error output
|
||||
func (s *Sqlcmd) GetOutput() io.Writer {
|
||||
if s.out == nil {
|
||||
return s.lineIo.Stdout()
|
||||
}
|
||||
return s.out
|
||||
}
|
||||
|
||||
// SetOutput sets the io.WriteCloser to use for non-error output
|
||||
func (s *Sqlcmd) SetOutput(o io.WriteCloser) {
|
||||
if s.out != nil {
|
||||
s.out.Close()
|
||||
}
|
||||
s.out = o
|
||||
}
|
||||
|
||||
// GetError returns the io.Writer to use for errors
|
||||
func (s *Sqlcmd) GetError() io.Writer {
|
||||
if s.err == nil {
|
||||
return s.lineIo.Stderr()
|
||||
}
|
||||
return s.err
|
||||
}
|
||||
|
||||
// SetError sets the io.WriteCloser to use for errors
|
||||
func (s *Sqlcmd) SetError(e io.WriteCloser) {
|
||||
if s.err != nil {
|
||||
s.err.Close()
|
||||
}
|
||||
s.err = e
|
||||
}
|
||||
|
||||
// ConnectionString returns the go-mssql connection string to use for queries
|
||||
func (s *Sqlcmd) ConnectionString() (connectionString string, err error) {
|
||||
serverName, instance, port, err := s.vars.SQLCmdServer()
|
||||
if serverName == "" {
|
||||
serverName = "."
|
||||
}
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
query := url.Values{}
|
||||
connectionURL := &url.URL{
|
||||
Scheme: "sqlserver",
|
||||
Path: instance,
|
||||
}
|
||||
useTrustedConnection := s.Connect.UseTrustedConnection || (s.vars.SQLCmdUser() == "" && !s.vars.UseAad())
|
||||
if !useTrustedConnection {
|
||||
connectionURL.User = url.UserPassword(s.vars.SQLCmdUser(), s.vars.Password())
|
||||
}
|
||||
if port > 0 {
|
||||
connectionURL.Host = fmt.Sprintf("%s:%d", serverName, port)
|
||||
} else {
|
||||
connectionURL.Host = serverName
|
||||
}
|
||||
if s.vars.SQLCmdDatabase() != "" {
|
||||
query.Add("database", s.vars.SQLCmdDatabase())
|
||||
}
|
||||
|
||||
if s.Connect.TrustServerCertificate {
|
||||
query.Add("trustservercertificate", "true")
|
||||
}
|
||||
connectionURL.RawQuery = query.Encode()
|
||||
return connectionURL.String(), nil
|
||||
}
|
||||
|
||||
// ConnectDb opens a connection to the database with the given modifications to the connection
|
||||
func (s *Sqlcmd) ConnectDb(server string, user string, password string, nopw bool) error {
|
||||
if user != "" && password == "" && !nopw {
|
||||
return ErrNeedPassword
|
||||
}
|
||||
|
||||
connstr, err := s.ConnectionString()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
connectionURL, err := url.Parse(connstr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if server != "" {
|
||||
serverName, instance, port, err := splitServer(server)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
connectionURL.Path = instance
|
||||
if port > 0 {
|
||||
connectionURL.Host = fmt.Sprintf("%s:%d", serverName, port)
|
||||
} else {
|
||||
connectionURL.Host = serverName
|
||||
}
|
||||
}
|
||||
|
||||
if password == "" {
|
||||
password = s.vars.Password()
|
||||
}
|
||||
|
||||
if user != "" {
|
||||
connectionURL.User = url.UserPassword(user, password)
|
||||
}
|
||||
|
||||
connector, err := mssql.NewConnector(connectionURL.String())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
db := sql.OpenDB(connector)
|
||||
err = db.Ping()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// we got a good connection so we can update the Sqlcmd
|
||||
if s.db != nil {
|
||||
s.db.Close()
|
||||
}
|
||||
s.db = db
|
||||
if server != "" {
|
||||
s.vars.Set(SQLCMDSERVER, server)
|
||||
}
|
||||
if user != "" {
|
||||
s.vars.Set(SQLCMDUSER, user)
|
||||
s.Connect.UseTrustedConnection = false
|
||||
if password != "" {
|
||||
s.vars.Set(SQLCMDPASSWORD, password)
|
||||
}
|
||||
} else if s.vars.SQLCmdUser() == "" {
|
||||
u, e := osuser.Current()
|
||||
if e != nil {
|
||||
panic("Unable to get user name")
|
||||
}
|
||||
s.Connect.UseTrustedConnection = true
|
||||
s.vars.Set(SQLCMDUSER, u.Username)
|
||||
}
|
||||
|
||||
if s.batch != nil {
|
||||
s.batch.batchline = 1
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func setupCloseHandler(s *Sqlcmd) {
|
||||
c := make(chan os.Signal, 1)
|
||||
signal.Notify(c, os.Interrupt, syscall.SIGTERM)
|
||||
go func() {
|
||||
<-c
|
||||
_, _ = s.GetOutput().Write([]byte(ErrCtrlC.Error()))
|
||||
os.Exit(0)
|
||||
}()
|
||||
}
|
|
@ -0,0 +1,136 @@
|
|||
// Copyright (c) Microsoft Corporation.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
package sqlcmd
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/user"
|
||||
"testing"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/xo/usql/rline"
|
||||
)
|
||||
|
||||
func TestConnectionStringFromSqlCmd(t *testing.T) {
|
||||
type connectionStringTest struct {
|
||||
settings *ConnectSettings
|
||||
setup func(*Variables)
|
||||
connectionString string
|
||||
}
|
||||
|
||||
pwd := uuid.New().String()
|
||||
|
||||
commands := []connectionStringTest{
|
||||
|
||||
{nil, nil, "sqlserver://."},
|
||||
{
|
||||
&ConnectSettings{TrustServerCertificate: true},
|
||||
func(vars *Variables) {
|
||||
_ = Setvar(SQLCMDDBNAME, "somedatabase")
|
||||
},
|
||||
"sqlserver://.?database=somedatabase&trustservercertificate=true",
|
||||
},
|
||||
{
|
||||
&ConnectSettings{TrustServerCertificate: true},
|
||||
func(vars *Variables) {
|
||||
vars.Set(SQLCMDSERVER, `someserver/instance`)
|
||||
vars.Set(SQLCMDDBNAME, "somedatabase")
|
||||
vars.Set(SQLCMDUSER, "someuser")
|
||||
vars.Set(SQLCMDPASSWORD, pwd)
|
||||
},
|
||||
fmt.Sprintf("sqlserver://someuser:%s@someserver/instance?database=somedatabase&trustservercertificate=true", pwd),
|
||||
},
|
||||
{
|
||||
&ConnectSettings{TrustServerCertificate: true, UseTrustedConnection: true},
|
||||
func(vars *Variables) {
|
||||
vars.Set(SQLCMDSERVER, `tcp:someserver,1045`)
|
||||
vars.Set(SQLCMDUSER, "someuser")
|
||||
vars.Set(SQLCMDPASSWORD, pwd)
|
||||
},
|
||||
"sqlserver://someserver:1045?trustservercertificate=true",
|
||||
},
|
||||
{
|
||||
nil,
|
||||
func(vars *Variables) {
|
||||
vars.Set(SQLCMDSERVER, `tcp:someserver,1045`)
|
||||
},
|
||||
"sqlserver://someserver:1045",
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range commands {
|
||||
v := InitializeVariables(false)
|
||||
if test.setup != nil {
|
||||
test.setup(v)
|
||||
}
|
||||
s := &Sqlcmd{vars: v}
|
||||
if test.settings != nil {
|
||||
s.Connect = *test.settings
|
||||
}
|
||||
connectionString, err := s.ConnectionString()
|
||||
if assert.NoError(t, err, "Unexpected error from %+v", s) {
|
||||
assert.Equal(t, test.connectionString, connectionString, "Wrong connection string from: %+v", *s)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/* The following tests require a working SQL instance and rely on SqlCmd environment variables
|
||||
to manage the initial connection string. The default connection when no environment variables are
|
||||
set will be to localhost using Windows auth.
|
||||
|
||||
*/
|
||||
func TestSqlCmdConnectDb(t *testing.T) {
|
||||
v := InitializeVariables(true)
|
||||
s := &Sqlcmd{vars: v}
|
||||
err := s.ConnectDb("", "", "", false)
|
||||
if assert.NoError(t, err, "ConnectDb should succeed") {
|
||||
sqlcmduser := os.Getenv(SQLCMDUSER)
|
||||
if sqlcmduser == "" {
|
||||
u, _ := user.Current()
|
||||
sqlcmduser = u.Username
|
||||
}
|
||||
assert.Equal(t, sqlcmduser, s.vars.SQLCmdUser(), "SQLCMDUSER variable should match connected user")
|
||||
}
|
||||
}
|
||||
|
||||
func ConnectDb() (*sql.DB, error) {
|
||||
v := InitializeVariables(true)
|
||||
s := &Sqlcmd{vars: v}
|
||||
err := s.ConnectDb("", "", "", false)
|
||||
return s.db, err
|
||||
}
|
||||
|
||||
func TestSqlCmdQueryAndExit(t *testing.T) {
|
||||
v := InitializeVariables(true)
|
||||
v.Set(SQLCMDMAXVARTYPEWIDTH, "0")
|
||||
line, err := rline.New(false, "", "")
|
||||
if !assert.NoError(t, err, "rline.New") {
|
||||
return
|
||||
}
|
||||
s := New(line, "", v)
|
||||
s.Format = NewSQLCmdDefaultFormatter(true)
|
||||
s.Query = "select 100"
|
||||
file, err := os.CreateTemp("", "sqlcmdout")
|
||||
if !assert.NoError(t, err, "os.CreateTemp") {
|
||||
return
|
||||
}
|
||||
defer file.Close()
|
||||
defer os.Remove(file.Name())
|
||||
s.SetOutput(file)
|
||||
err = s.ConnectDb("", "", "", true)
|
||||
if !assert.NoError(t, err, "s.ConnectDB") {
|
||||
return
|
||||
}
|
||||
err = s.Run(true)
|
||||
if assert.NoError(t, err, "s.Run(once = true)") {
|
||||
s.SetOutput(nil)
|
||||
bytes, err := os.ReadFile(file.Name())
|
||||
if assert.NoError(t, err, "os.ReadFile") {
|
||||
assert.Equal(t, "100"+SqlcmdEol+SqlcmdEol, string(bytes), "Incorrect output from Run")
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,64 @@
|
|||
// Copyright (c) Microsoft Corporation.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
package sqlcmd
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// splitServer extracts connection parameters from a server name input
|
||||
func splitServer(serverName string) (string, string, uint64, error) {
|
||||
instance := ""
|
||||
port := uint64(0)
|
||||
if strings.HasPrefix(serverName, "tcp:") {
|
||||
if len(serverName) == 4 {
|
||||
return "", "", 0, &InvalidServerName
|
||||
}
|
||||
serverName = serverName[4:]
|
||||
}
|
||||
serverNameParts := strings.Split(serverName, ",")
|
||||
if len(serverNameParts) > 2 {
|
||||
return "", "", 0, &InvalidServerName
|
||||
}
|
||||
if len(serverNameParts) == 2 {
|
||||
var err error
|
||||
port, err = strconv.ParseUint(serverNameParts[1], 10, 16)
|
||||
if err != nil {
|
||||
return "", "", 0, &InvalidServerName
|
||||
}
|
||||
serverName = serverNameParts[0]
|
||||
} else {
|
||||
serverNameParts = strings.Split(serverName, "/")
|
||||
if len(serverNameParts) > 2 {
|
||||
return "", "", 0, &InvalidServerName
|
||||
}
|
||||
if len(serverNameParts) == 2 {
|
||||
instance = serverNameParts[1]
|
||||
serverName = serverNameParts[0]
|
||||
}
|
||||
}
|
||||
return serverName, instance, port, nil
|
||||
}
|
||||
|
||||
// padRight appends c instances of s to builder
|
||||
func padRight(builder *strings.Builder, c int64, s string) *strings.Builder {
|
||||
var i int64
|
||||
for ; i < c; i++ {
|
||||
builder.WriteString(s)
|
||||
}
|
||||
return builder
|
||||
}
|
||||
|
||||
// padLeft prepends c instances of s to builder
|
||||
func padLeft(builder *strings.Builder, c int64, s string) *strings.Builder {
|
||||
newBuilder := new(strings.Builder)
|
||||
newBuilder.Grow(builder.Len())
|
||||
var i int64
|
||||
for ; i < c; i++ {
|
||||
newBuilder.WriteString(s)
|
||||
}
|
||||
newBuilder.WriteString(builder.String())
|
||||
return newBuilder
|
||||
}
|
|
@ -0,0 +1,209 @@
|
|||
// Copyright (c) Microsoft Corporation.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
package sqlcmd
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Variables provides set and get of sqlcmd scripting variables
|
||||
type Variables map[string]string
|
||||
|
||||
var variables Variables
|
||||
|
||||
// Built-in scripting variables
|
||||
const (
|
||||
SQLCMDDBNAME = "SQLCMDDBNAME"
|
||||
SQLCMDINI = "SQLCMDINI"
|
||||
SQLCMDPACKETSIZE = "SQLCMDPACKETSIZE"
|
||||
SQLCMDPASSWORD = "SQLCMDPASSWORD"
|
||||
SQLCMDSERVER = "SQLCMDSERVER"
|
||||
SQLCMDUSER = "SQLCMDUSER"
|
||||
SQLCMDWORKSTATION = "SQLCMDWORKSTATION"
|
||||
SQLCMDLOGINTIMEOUT = "SQLCMDLOGINTIMEOUT"
|
||||
SQLCMDSTATTIMEOUT = "SQLCMDSTATTIMEOUT"
|
||||
SQLCMDHEADERS = "SQLCMDHEADERS"
|
||||
SQLCMDCOLSEP = "SQLCMDCOLSEP"
|
||||
SQLCMDCOLWIDTH = "SQLCMDCOLWIDTH"
|
||||
SQLCMDERRORLEVEL = "SQLCMDERRORLEVEL"
|
||||
SQLCMDMAXVARTYPEWIDTH = "SQLCMDMAXVARTYPEWIDTH"
|
||||
SQLCMDMAXFIXEDTYPEWIDTH = "SQLCMDMAXFIXEDTYPEWIDTH"
|
||||
SQLCMDEDITOR = "SQLCMDEDITOR"
|
||||
SQLCMDUSEAAD = "SQLCMDUSEAAD"
|
||||
)
|
||||
|
||||
// Variables that can only be set at startup
|
||||
var readOnlyVariables = []string{
|
||||
SQLCMDDBNAME,
|
||||
SQLCMDINI,
|
||||
SQLCMDPACKETSIZE,
|
||||
SQLCMDPASSWORD,
|
||||
SQLCMDSERVER,
|
||||
SQLCMDUSER,
|
||||
SQLCMDWORKSTATION,
|
||||
}
|
||||
|
||||
func (v Variables) checkReadOnly(key string) error {
|
||||
currentValue, hasValue := v[key]
|
||||
if hasValue {
|
||||
for _, variable := range readOnlyVariables {
|
||||
if variable == key && currentValue != "" {
|
||||
return ReadOnlyVariable(key)
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Set sets or adds the value in the map.
|
||||
func (v Variables) Set(name, value string) {
|
||||
key := strings.ToUpper(name)
|
||||
v[key] = value
|
||||
}
|
||||
|
||||
// Unset removes the value from the map
|
||||
func (v Variables) Unset(name string) {
|
||||
key := strings.ToUpper(name)
|
||||
delete(v, key)
|
||||
}
|
||||
|
||||
// All returns a copy of the current variables
|
||||
func (v Variables) All() map[string]string {
|
||||
return map[string]string(v)
|
||||
}
|
||||
|
||||
// SQLCmdUser returns the SQLCMDUSER variable value
|
||||
func (v Variables) SQLCmdUser() string {
|
||||
return v[SQLCMDUSER]
|
||||
}
|
||||
|
||||
// SQLCmdServer returns the server connection parameters derived from the SQLCMDSERVER variable value
|
||||
func (v Variables) SQLCmdServer() (serverName string, instance string, port uint64, err error) {
|
||||
serverName = v[SQLCMDSERVER]
|
||||
return splitServer(serverName)
|
||||
}
|
||||
|
||||
// SQLCmdDatabase returns the SQLCMDDBNAME variable value
|
||||
func (v Variables) SQLCmdDatabase() string {
|
||||
return v[SQLCMDDBNAME]
|
||||
}
|
||||
|
||||
// UseAad returns whether the SQLCMDUSEAAD variable value is set to "true"
|
||||
func (v Variables) UseAad() bool {
|
||||
return strings.EqualFold(v[SQLCMDUSEAAD], "true")
|
||||
}
|
||||
|
||||
// Password returns the password used for connections as specified by SQLCMDPASSWORD variable
|
||||
func (v Variables) Password() string {
|
||||
return v[SQLCMDPASSWORD]
|
||||
}
|
||||
|
||||
// ColumnSeparator is the value of SQLCMDCOLSEP variable. It can have 0 or 1 characters
|
||||
func (v Variables) ColumnSeparator() string {
|
||||
sep := v[SQLCMDCOLSEP]
|
||||
if len(sep) > 1 {
|
||||
return sep[:1]
|
||||
}
|
||||
return sep
|
||||
}
|
||||
|
||||
// MaxFixedColumnWidth is the value of SQLCMDMAXFIXEDTYPEWIDTH variable.
|
||||
// When non-zero, it limits the width of columns for types CHAR, NCHAR, NVARCHAR, VARCHAR, VARBINARY, VARIANT
|
||||
func (v Variables) MaxFixedColumnWidth() int64 {
|
||||
w := v[SQLCMDMAXFIXEDTYPEWIDTH]
|
||||
return mustValue(w)
|
||||
}
|
||||
|
||||
// MaxVarColumnWidth is the value of SQLCMDMAXVARTYPEWIDTH variable.
|
||||
// When non-zero, it limits the width of columns for (max) versions of CHAR, NCHAR, VARBINARY.
|
||||
// It also limits the width of xml, UDT, text, ntext, and image
|
||||
func (v Variables) MaxVarColumnWidth() int64 {
|
||||
w := v[SQLCMDMAXVARTYPEWIDTH]
|
||||
return mustValue(w)
|
||||
}
|
||||
|
||||
// ScreenWidth is the value of SQLCMDCOLWIDTH variable.
|
||||
// It tells the formatter how many characters wide to limit all screen output.
|
||||
func (v Variables) ScreenWidth() int64 {
|
||||
w := v[SQLCMDCOLWIDTH]
|
||||
return mustValue(w)
|
||||
}
|
||||
|
||||
// RowsBetweenHeaders is the value of SQLCMDHEADERS variable.
|
||||
// When MaxVarColumnWidth() is 0, it returns -1
|
||||
func (v Variables) RowsBetweenHeaders() int64 {
|
||||
if v.MaxVarColumnWidth() == 0 {
|
||||
return -1
|
||||
}
|
||||
h := mustValue(v[SQLCMDHEADERS])
|
||||
return h
|
||||
}
|
||||
|
||||
func mustValue(val string) int64 {
|
||||
var n int64
|
||||
_, err := fmt.Sscanf(val, "%d", &n)
|
||||
if err == nil {
|
||||
return n
|
||||
}
|
||||
panic(err)
|
||||
}
|
||||
|
||||
// InitializeVariables initializes variables with default values.
|
||||
// When fromEnvironment is true, then loads from the runtime environment
|
||||
func InitializeVariables(fromEnvironment bool) *Variables {
|
||||
variables = Variables{
|
||||
SQLCMDCOLSEP: " ",
|
||||
SQLCMDCOLWIDTH: "0",
|
||||
SQLCMDDBNAME: "",
|
||||
SQLCMDEDITOR: "edit.com",
|
||||
SQLCMDERRORLEVEL: "0",
|
||||
SQLCMDHEADERS: "0",
|
||||
SQLCMDINI: "",
|
||||
SQLCMDLOGINTIMEOUT: "30",
|
||||
SQLCMDMAXFIXEDTYPEWIDTH: "0",
|
||||
SQLCMDMAXVARTYPEWIDTH: "256",
|
||||
SQLCMDPACKETSIZE: "4096",
|
||||
SQLCMDSERVER: "",
|
||||
SQLCMDSTATTIMEOUT: "0",
|
||||
SQLCMDUSER: "",
|
||||
SQLCMDPASSWORD: "",
|
||||
SQLCMDUSEAAD: "",
|
||||
}
|
||||
hostname, _ := os.Hostname()
|
||||
variables.Set(SQLCMDWORKSTATION, hostname)
|
||||
|
||||
if fromEnvironment {
|
||||
for v := range variables.All() {
|
||||
envVar, ok := os.LookupEnv(v)
|
||||
if ok {
|
||||
variables.Set(v, envVar)
|
||||
}
|
||||
}
|
||||
}
|
||||
return &variables
|
||||
}
|
||||
|
||||
// Setvar implements the :Setvar command
|
||||
// TODO: Add validation functions for the variables.
|
||||
func Setvar(name, value string) error {
|
||||
err := ValidIdentifier(name)
|
||||
if err == nil {
|
||||
err = variables.checkReadOnly(name)
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
variables.Set(name, value)
|
||||
return nil
|
||||
}
|
||||
|
||||
// ValidIdentifier determines if a given string can be used as a variable name
|
||||
func ValidIdentifier(name string) error {
|
||||
if strings.HasPrefix(name, "$(") || strings.ContainsAny(name, "'\"\t\n\r ") {
|
||||
return InvalidCommandError(":setvar", 0)
|
||||
}
|
||||
return nil
|
||||
}
|
|
@ -0,0 +1,64 @@
|
|||
package sqlcmd
|
||||
|
||||
import (
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestBasicVariableOperations(t *testing.T) {
|
||||
variables = Variables{
|
||||
"var1": "val1",
|
||||
}
|
||||
variables.Set("var2", "val2")
|
||||
assert.Contains(t, variables, "VAR2", "Set should add a capitalized key")
|
||||
all := variables.All()
|
||||
keys := make([]string, 0, len(all))
|
||||
for k := range all {
|
||||
keys = append(keys, k)
|
||||
}
|
||||
assert.ElementsMatch(t, []string{"var1", "VAR2"}, keys, "All returns every key")
|
||||
assert.Equal(t, "val2", all["VAR2"], "VAR2 set value")
|
||||
|
||||
}
|
||||
|
||||
func TestSetvarFailsForReadOnlyVariables(t *testing.T) {
|
||||
variables = Variables{}
|
||||
err := Setvar("SQLCMDDBNAME", "somedatabase")
|
||||
assert.NoError(t, err, "Setvar should succeed when SQLCMDDBNAME is not set")
|
||||
err = Setvar("SQLCMDDBNAME", "newdatabase")
|
||||
assert.EqualError(t, err, "Sqlcmd: Error: The scripting variable: 'SQLCMDDBNAME' is read-only")
|
||||
}
|
||||
|
||||
func TestEnvironmentVariablesAsInput(t *testing.T) {
|
||||
os.Setenv("SQLCMDSERVER", "someserver")
|
||||
defer os.Unsetenv("SQLCMDSERVER")
|
||||
os.Setenv("x", "somevalue")
|
||||
defer os.Unsetenv("x")
|
||||
vars := InitializeVariables(true).All()
|
||||
assert.Equal(t, "someserver", vars["SQLCMDSERVER"], "InitializeVariables should read a valid environment variable from the known list")
|
||||
_, ok := vars["x"]
|
||||
assert.False(t, ok, "InitializeVariables should skip variables not in the known list")
|
||||
}
|
||||
|
||||
func TestSqlServerSplitsName(t *testing.T) {
|
||||
vars := Variables{
|
||||
SQLCMDSERVER: `tcp:someserver/someinstance`,
|
||||
}
|
||||
serverName, instance, port, err := vars.SQLCmdServer()
|
||||
if assert.NoError(t, err, "tcp:server/someinstance") {
|
||||
assert.Equal(t, "someserver", serverName, "server name for instance")
|
||||
assert.Equal(t, uint64(0), port, "port for instance")
|
||||
assert.Equal(t, "someinstance", instance, "instance for instance")
|
||||
}
|
||||
vars = Variables{
|
||||
SQLCMDSERVER: `tcp:someserver,1111`,
|
||||
}
|
||||
serverName, instance, port, err = vars.SQLCmdServer()
|
||||
if assert.NoError(t, err, "tcp:server,1111") {
|
||||
assert.Equal(t, "someserver", serverName, "server name for port number")
|
||||
assert.Equal(t, uint64(1111), port, "port for port number")
|
||||
assert.Equal(t, "", instance, "instance for port number")
|
||||
}
|
||||
}
|
|
@ -0,0 +1,3 @@
|
|||
select 1 as col1
|
||||
go
|
||||
|
Загрузка…
Ссылка в новой задаче