github.com/koding/cache 2014-09-12 2016-12-22
This commit is contained in:
Родитель
41865ec83a
Коммит
dea955a107
3
go.mod
3
go.mod
|
@ -17,7 +17,7 @@ require (
|
||||||
github.com/ianschenck/envflag v0.0.0-20140720210342-9111d830d133
|
github.com/ianschenck/envflag v0.0.0-20140720210342-9111d830d133
|
||||||
github.com/joho/godotenv v1.3.0
|
github.com/joho/godotenv v1.3.0
|
||||||
github.com/json-iterator/go v1.1.10 // indirect
|
github.com/json-iterator/go v1.1.10 // indirect
|
||||||
github.com/koding/cache v0.0.0-20140912085602-487fc0ca06f9
|
github.com/koding/cache v0.0.0-20161222233015-e8a81b0b3f20
|
||||||
github.com/lib/pq v0.0.0-20161103024354-d8eeeb8bae88
|
github.com/lib/pq v0.0.0-20161103024354-d8eeeb8bae88
|
||||||
github.com/mattn/go-sqlite3 v0.0.0-20150427235825-542ae647f860
|
github.com/mattn/go-sqlite3 v0.0.0-20150427235825-542ae647f860
|
||||||
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
|
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
|
||||||
|
@ -35,5 +35,6 @@ require (
|
||||||
google.golang.org/protobuf v1.25.0 // indirect
|
google.golang.org/protobuf v1.25.0 // indirect
|
||||||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect
|
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect
|
||||||
gopkg.in/gorp.v1 v1.7.1 // indirect
|
gopkg.in/gorp.v1 v1.7.1 // indirect
|
||||||
|
gopkg.in/mgo.v2 v2.0.0-20190816093944-a6b53ec6cb22 // indirect
|
||||||
gopkg.in/yaml.v2 v2.4.0 // indirect
|
gopkg.in/yaml.v2 v2.4.0 // indirect
|
||||||
)
|
)
|
||||||
|
|
6
go.sum
6
go.sum
|
@ -61,8 +61,8 @@ github.com/joho/godotenv v1.3.0/go.mod h1:7hK45KPybAkOC6peb+G5yklZfMxEjkZhHbwpqx
|
||||||
github.com/json-iterator/go v1.1.9/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4=
|
github.com/json-iterator/go v1.1.9/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4=
|
||||||
github.com/json-iterator/go v1.1.10 h1:Kz6Cvnvv2wGdaG/V8yMvfkmNiXq9Ya2KUv4rouJJr68=
|
github.com/json-iterator/go v1.1.10 h1:Kz6Cvnvv2wGdaG/V8yMvfkmNiXq9Ya2KUv4rouJJr68=
|
||||||
github.com/json-iterator/go v1.1.10/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4=
|
github.com/json-iterator/go v1.1.10/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4=
|
||||||
github.com/koding/cache v0.0.0-20140912085602-487fc0ca06f9 h1:wAnE6YoDViB3XyrvdSL9M+9EuVnWeAe9LQT1SgZm+pU=
|
github.com/koding/cache v0.0.0-20161222233015-e8a81b0b3f20 h1:R7RAW1p8wjhlHKFhS4X7h8EePqADev/PltCmW9qlJoM=
|
||||||
github.com/koding/cache v0.0.0-20140912085602-487fc0ca06f9/go.mod h1:sh5SGGmQVGUkWDnxevz0I2FJ4TeC18hRPRjKVBMb2kA=
|
github.com/koding/cache v0.0.0-20161222233015-e8a81b0b3f20/go.mod h1:sh5SGGmQVGUkWDnxevz0I2FJ4TeC18hRPRjKVBMb2kA=
|
||||||
github.com/kr/pretty v0.2.1 h1:Fmg33tUaq4/8ym9TJN1x7sLJnHVwhP33CNkpYV/7rwI=
|
github.com/kr/pretty v0.2.1 h1:Fmg33tUaq4/8ym9TJN1x7sLJnHVwhP33CNkpYV/7rwI=
|
||||||
github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI=
|
github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI=
|
||||||
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
|
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
|
||||||
|
@ -165,6 +165,8 @@ gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntN
|
||||||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
|
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
|
||||||
gopkg.in/gorp.v1 v1.7.1 h1:GBB9KrWRATQZh95HJyVGUZrWwOPswitEYEyqlK8JbAA=
|
gopkg.in/gorp.v1 v1.7.1 h1:GBB9KrWRATQZh95HJyVGUZrWwOPswitEYEyqlK8JbAA=
|
||||||
gopkg.in/gorp.v1 v1.7.1/go.mod h1:Wo3h+DBQZIxATwftsglhdD/62zRFPhGhTiu5jUJmCaw=
|
gopkg.in/gorp.v1 v1.7.1/go.mod h1:Wo3h+DBQZIxATwftsglhdD/62zRFPhGhTiu5jUJmCaw=
|
||||||
|
gopkg.in/mgo.v2 v2.0.0-20190816093944-a6b53ec6cb22 h1:VpOs+IwYnYBaFnrNAeB8UUWtL3vEUnzSCL1nVjPhqrw=
|
||||||
|
gopkg.in/mgo.v2 v2.0.0-20190816093944-a6b53ec6cb22/go.mod h1:yeKp02qBN3iKW1OzL3MGk2IdtZzaj7SFntXj72NppTA=
|
||||||
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
|
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
|
||||||
gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
|
gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
|
||||||
gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY=
|
gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY=
|
||||||
|
|
|
@ -1,3 +1,18 @@
|
||||||
language: go
|
language: go
|
||||||
go: 1.3
|
|
||||||
script: go test ./...
|
sudo: false
|
||||||
|
|
||||||
|
services:
|
||||||
|
- mongodb
|
||||||
|
|
||||||
|
install:
|
||||||
|
- go get -t -v ./...
|
||||||
|
|
||||||
|
go:
|
||||||
|
- 1.4.3
|
||||||
|
- 1.5.4
|
||||||
|
- 1.6.2
|
||||||
|
|
||||||
|
script:
|
||||||
|
- export GOMAXPROCS=$(nproc) # go1.4
|
||||||
|
- go test -race ./...
|
||||||
|
|
|
@ -0,0 +1,21 @@
|
||||||
|
The MIT License (MIT)
|
||||||
|
|
||||||
|
Copyright (c) 2016 Koding, Inc.
|
||||||
|
|
||||||
|
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||||
|
of this software and associated documentation files (the "Software"), to deal
|
||||||
|
in the Software without restriction, including without limitation the rights
|
||||||
|
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||||
|
copies of the Software, and to permit persons to whom the Software is
|
||||||
|
furnished to do so, subject to the following conditions:
|
||||||
|
|
||||||
|
The above copyright notice and this permission notice shall be included in all
|
||||||
|
copies or substantial portions of the Software.
|
||||||
|
|
||||||
|
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
|
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
|
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||||
|
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
|
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||||
|
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||||
|
SOFTWARE.
|
|
@ -30,3 +30,16 @@ err := cache.Set("test_key", "test_data")
|
||||||
// get item
|
// get item
|
||||||
data, err := cache.Get("test_key")
|
data, err := cache.Get("test_key")
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
||||||
|
## Supported caching algorithms:
|
||||||
|
|
||||||
|
- MemoryNoTS : provides a non-thread safe in-memory caching system
|
||||||
|
- Memory : provides a thread safe in-memory caching system, built on top of MemoryNoTS cache
|
||||||
|
- LRUNoTS : provides a non-thread safe, fixed size in-memory caching system, built on top of MemoryNoTS cache
|
||||||
|
- LRU : provides a thread safe, fixed size in-memory caching system, built on top of LRUNoTS cache
|
||||||
|
- MemoryTTL : provides a thread safe, expiring in-memory caching system, built on top of MemoryNoTS cache
|
||||||
|
- ShardedNoTS : provides a non-thread safe sharded cache system, built on top of a cache interface
|
||||||
|
- ShardedTTL : provides a thread safe, expiring in-memory sharded cache system, built on top of ShardedNoTS over MemoryNoTS
|
||||||
|
- LFUNoTS : provides a non-thread safe, fixed size in-memory caching system, built on top of MemoryNoTS cache
|
||||||
|
- LFU : provides a thread safe, fixed size in-memory caching system, built on top of LFUNoTS cache
|
||||||
|
|
|
@ -1,9 +1,14 @@
|
||||||
// Package cache provides basic caching mechanisms for Go(lang) projects.
|
// Package cache provides basic caching mechanisms for Go(lang) projects.
|
||||||
//
|
//
|
||||||
// Currently supported caching algorithms:
|
// Currently supported caching algorithms:
|
||||||
// MemoryNoTS: provides a non-thread safe in-memory caching system
|
// MemoryNoTS : provides a non-thread safe in-memory caching system
|
||||||
// Memory : provides a thread safe in-memory caching system, built on top of MemoryNoTS cache
|
// Memory : provides a thread safe in-memory caching system, built on top of MemoryNoTS cache
|
||||||
// LRUNoTS : provides a non-thread safe, fixed size in-memory caching system, built on top of MemoryNoTS cache
|
// LRUNoTS : provides a non-thread safe, fixed size in-memory caching system, built on top of MemoryNoTS cache
|
||||||
// LRU : provides a thread safe, fixed size in-memory caching system, built on top of LRUNoTS cache
|
// LRU : provides a thread safe, fixed size in-memory caching system, built on top of LRUNoTS cache
|
||||||
// MemoryTTL : provides a thread safe, expiring in-memory caching system, built on top of MemoryNoTS cache
|
// MemoryTTL : provides a thread safe, expiring in-memory caching system, built on top of MemoryNoTS cache
|
||||||
|
// ShardedNoTS : provides a non-thread safe sharded cache system, built on top of a cache interface
|
||||||
|
// ShardedTTL : provides a thread safe, expiring in-memory sharded cache system, built on top of ShardedNoTS over MemoryNoTS
|
||||||
|
// LFUNoTS : provides a non-thread safe, fixed size in-memory caching system, built on top of MemoryNoTS cache
|
||||||
|
// LFU : provides a thread safe, fixed size in-memory caching system, built on top of LFUNoTS cache
|
||||||
|
//
|
||||||
package cache
|
package cache
|
||||||
|
|
|
@ -0,0 +1,49 @@
|
||||||
|
package cache
|
||||||
|
|
||||||
|
import "sync"
|
||||||
|
|
||||||
|
// LFU holds the Least frequently used cache values
|
||||||
|
type LFU struct {
|
||||||
|
// Mutex is used for handling the concurrent
|
||||||
|
// read/write requests for cache
|
||||||
|
sync.Mutex
|
||||||
|
|
||||||
|
// cache holds the all cache values
|
||||||
|
cache Cache
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewLFU creates a thread-safe LFU cache
|
||||||
|
func NewLFU(size int) Cache {
|
||||||
|
return &LRU{
|
||||||
|
cache: NewLFUNoTS(size),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get returns the value of a given key if it exists, every get item will be
|
||||||
|
// increased for every usage
|
||||||
|
func (l *LFU) Get(key string) (interface{}, error) {
|
||||||
|
l.Lock()
|
||||||
|
defer l.Unlock()
|
||||||
|
|
||||||
|
return l.cache.Get(key)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set sets or overrides the given key with the given value, every set item will
|
||||||
|
// be increased as usage.
|
||||||
|
// when the cache is full, least frequently used items will be evicted from
|
||||||
|
// linked list
|
||||||
|
func (l *LFU) Set(key string, val interface{}) error {
|
||||||
|
l.Lock()
|
||||||
|
defer l.Unlock()
|
||||||
|
|
||||||
|
return l.cache.Set(key, val)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Delete deletes the given key-value pair from cache, this function doesnt
|
||||||
|
// return an error if item is not in the cache
|
||||||
|
func (l *LFU) Delete(key string) error {
|
||||||
|
l.Lock()
|
||||||
|
defer l.Unlock()
|
||||||
|
|
||||||
|
return l.cache.Delete(key)
|
||||||
|
}
|
|
@ -0,0 +1,226 @@
|
||||||
|
package cache
|
||||||
|
|
||||||
|
import "container/list"
|
||||||
|
|
||||||
|
// LFUNoTS holds the cache struct
|
||||||
|
type LFUNoTS struct {
|
||||||
|
// list holds all items in a linked list
|
||||||
|
frequencyList *list.List
|
||||||
|
|
||||||
|
// holds the all cache values
|
||||||
|
cache Cache
|
||||||
|
|
||||||
|
// size holds the limit of the LFU cache
|
||||||
|
size int
|
||||||
|
|
||||||
|
// currentSize holds the current item size in the list
|
||||||
|
// after each adding of item, currentSize will be increased
|
||||||
|
currentSize int
|
||||||
|
}
|
||||||
|
|
||||||
|
type cacheItem struct {
|
||||||
|
// key of cache value
|
||||||
|
k string
|
||||||
|
|
||||||
|
// value of cache value
|
||||||
|
v interface{}
|
||||||
|
|
||||||
|
// holds the frequency elements
|
||||||
|
// it holds the element's usage as count
|
||||||
|
// if cacheItems is used 4 times (with set or get operations)
|
||||||
|
// the freqElement's frequency counter will be 4
|
||||||
|
// it holds entry struct inside Value of list.Element
|
||||||
|
freqElement *list.Element
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewLFUNoTS creates a new LFU cache struct for further cache operations. Size
|
||||||
|
// is used for limiting the upper bound of the cache
|
||||||
|
func NewLFUNoTS(size int) Cache {
|
||||||
|
if size < 1 {
|
||||||
|
panic("invalid cache size")
|
||||||
|
}
|
||||||
|
|
||||||
|
return &LFUNoTS{
|
||||||
|
frequencyList: list.New(),
|
||||||
|
cache: NewMemoryNoTS(),
|
||||||
|
size: size,
|
||||||
|
currentSize: 0,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get gets value of cache item
|
||||||
|
// then increments the usage of the item
|
||||||
|
func (l *LFUNoTS) Get(key string) (interface{}, error) {
|
||||||
|
res, err := l.cache.Get(key)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
ci := res.(*cacheItem)
|
||||||
|
|
||||||
|
// increase usage of cache item
|
||||||
|
l.incr(ci)
|
||||||
|
return ci.v, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set sets a new key-value pair
|
||||||
|
// Set increments the key usage count too
|
||||||
|
//
|
||||||
|
// eg:
|
||||||
|
// cache.Set("test_key","2")
|
||||||
|
// cache.Set("test_key","1")
|
||||||
|
// if you try to set a value into same key
|
||||||
|
// its usage count will be increased
|
||||||
|
// and usage count of "test_key" will be 2 in this example
|
||||||
|
func (l *LFUNoTS) Set(key string, value interface{}) error {
|
||||||
|
return l.set(key, value)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Delete deletes the key and its dependencies
|
||||||
|
func (l *LFUNoTS) Delete(key string) error {
|
||||||
|
res, err := l.cache.Get(key)
|
||||||
|
if err != nil && err != ErrNotFound {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// we dont need to delete if already doesn't exist
|
||||||
|
if err == ErrNotFound {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
ci := res.(*cacheItem)
|
||||||
|
|
||||||
|
l.remove(ci, ci.freqElement)
|
||||||
|
l.currentSize--
|
||||||
|
return l.cache.Delete(key)
|
||||||
|
}
|
||||||
|
|
||||||
|
// set sets a new key-value pair
|
||||||
|
func (l *LFUNoTS) set(key string, value interface{}) error {
|
||||||
|
res, err := l.cache.Get(key)
|
||||||
|
if err != nil && err != ErrNotFound {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err == ErrNotFound {
|
||||||
|
//create new cache item
|
||||||
|
ci := newCacheItem(key, value)
|
||||||
|
|
||||||
|
// if cache size si reached to max size
|
||||||
|
// then first remove lfu item from the list
|
||||||
|
if l.currentSize >= l.size {
|
||||||
|
// then evict some data from head of linked list.
|
||||||
|
l.evict(l.frequencyList.Front())
|
||||||
|
}
|
||||||
|
|
||||||
|
l.cache.Set(key, ci)
|
||||||
|
l.incr(ci)
|
||||||
|
|
||||||
|
} else {
|
||||||
|
//update existing one
|
||||||
|
val := res.(*cacheItem)
|
||||||
|
val.v = value
|
||||||
|
l.cache.Set(key, val)
|
||||||
|
l.incr(res.(*cacheItem))
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// entry holds the frequency node informations
|
||||||
|
type entry struct {
|
||||||
|
// freqCount holds the frequency number
|
||||||
|
freqCount int
|
||||||
|
|
||||||
|
// itemCount holds the items how many exist in list
|
||||||
|
listEntry map[*cacheItem]struct{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// incr increments the usage of cache items
|
||||||
|
// incrementing will be used in 'Get' & 'Set' functions
|
||||||
|
// whenever these functions are used, usage count of any key
|
||||||
|
// will be increased
|
||||||
|
func (l *LFUNoTS) incr(ci *cacheItem) {
|
||||||
|
var nextValue int
|
||||||
|
var nextPosition *list.Element
|
||||||
|
// update existing one
|
||||||
|
if ci.freqElement != nil {
|
||||||
|
nextValue = ci.freqElement.Value.(*entry).freqCount + 1
|
||||||
|
// replace the position of frequency element
|
||||||
|
nextPosition = ci.freqElement.Next()
|
||||||
|
} else {
|
||||||
|
// create new frequency element for cache item
|
||||||
|
// ci.freqElement is nil so next value of freq will be 1
|
||||||
|
nextValue = 1
|
||||||
|
// we created new element and its position will be head of linked list
|
||||||
|
nextPosition = l.frequencyList.Front()
|
||||||
|
l.currentSize++
|
||||||
|
}
|
||||||
|
|
||||||
|
// we need to check position first, otherwise it will panic if we try to fetch value of entry
|
||||||
|
if nextPosition == nil || nextPosition.Value.(*entry).freqCount != nextValue {
|
||||||
|
// create new entry node for linked list
|
||||||
|
entry := newEntry(nextValue)
|
||||||
|
if ci.freqElement == nil {
|
||||||
|
nextPosition = l.frequencyList.PushFront(entry)
|
||||||
|
} else {
|
||||||
|
nextPosition = l.frequencyList.InsertAfter(entry, ci.freqElement)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
nextPosition.Value.(*entry).listEntry[ci] = struct{}{}
|
||||||
|
ci.freqElement = nextPosition
|
||||||
|
|
||||||
|
// we have moved the cache item to the next position,
|
||||||
|
// then we need to remove old position of the cacheItem from the list
|
||||||
|
// then we deleted previous position of cacheItem
|
||||||
|
if ci.freqElement.Prev() != nil {
|
||||||
|
l.remove(ci, ci.freqElement.Prev())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// remove removes the cache item from the cache list
|
||||||
|
// after deleting key from the list, if its linked list has no any item no longer
|
||||||
|
// then that linked list elemnet will be removed from the list too
|
||||||
|
func (l *LFUNoTS) remove(ci *cacheItem, position *list.Element) {
|
||||||
|
entry := position.Value.(*entry).listEntry
|
||||||
|
delete(entry, ci)
|
||||||
|
if len(entry) == 0 {
|
||||||
|
l.frequencyList.Remove(position)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// evict deletes the element from list with given linked list element
|
||||||
|
func (l *LFUNoTS) evict(e *list.Element) error {
|
||||||
|
// ne need to return err if list element is already nil
|
||||||
|
if e == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// remove the first item of the linked list
|
||||||
|
for entry := range e.Value.(*entry).listEntry {
|
||||||
|
l.cache.Delete(entry.k)
|
||||||
|
l.remove(entry, e)
|
||||||
|
l.currentSize--
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// newEntry creates a new entry with frequency count
|
||||||
|
func newEntry(freqCount int) *entry {
|
||||||
|
return &entry{
|
||||||
|
freqCount: freqCount,
|
||||||
|
listEntry: make(map[*cacheItem]struct{}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// newCacheItem creates a new cache item with key and value
|
||||||
|
func newCacheItem(key string, value interface{}) *cacheItem {
|
||||||
|
return &cacheItem{
|
||||||
|
k: key,
|
||||||
|
v: value,
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
|
@ -13,6 +13,12 @@ func NewMemoryNoTS() *MemoryNoTS {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// NewMemNoTSCache is a helper method to return a Cache interface, so callers
|
||||||
|
// don't have to typecast
|
||||||
|
func NewMemNoTSCache() Cache {
|
||||||
|
return NewMemoryNoTS()
|
||||||
|
}
|
||||||
|
|
||||||
// Get returns a value of a given key if it exists
|
// Get returns a value of a given key if it exists
|
||||||
// and valid for the time being
|
// and valid for the time being
|
||||||
func (r *MemoryNoTS) Get(key string) (interface{}, error) {
|
func (r *MemoryNoTS) Get(key string) (interface{}, error) {
|
||||||
|
|
|
@ -12,7 +12,7 @@ var zeroTTL = time.Duration(0)
|
||||||
type MemoryTTL struct {
|
type MemoryTTL struct {
|
||||||
// Mutex is used for handling the concurrent
|
// Mutex is used for handling the concurrent
|
||||||
// read/write requests for cache
|
// read/write requests for cache
|
||||||
sync.Mutex
|
sync.RWMutex
|
||||||
|
|
||||||
// cache holds the cache data
|
// cache holds the cache data
|
||||||
cache *MemoryNoTS
|
cache *MemoryNoTS
|
||||||
|
@ -23,8 +23,11 @@ type MemoryTTL struct {
|
||||||
// ttl is a duration for a cache key to expire
|
// ttl is a duration for a cache key to expire
|
||||||
ttl time.Duration
|
ttl time.Duration
|
||||||
|
|
||||||
// gcInterval is a duration for garbage collection
|
// gcTicker controls gc intervals
|
||||||
gcInterval time.Duration
|
gcTicker *time.Ticker
|
||||||
|
|
||||||
|
// done controls sweeping goroutine lifetime
|
||||||
|
done chan struct{}
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewMemoryWithTTL creates an inmemory cache system
|
// NewMemoryWithTTL creates an inmemory cache system
|
||||||
|
@ -41,29 +44,71 @@ func NewMemoryWithTTL(ttl time.Duration) *MemoryTTL {
|
||||||
|
|
||||||
// StartGC starts the garbage collection process in a go routine
|
// StartGC starts the garbage collection process in a go routine
|
||||||
func (r *MemoryTTL) StartGC(gcInterval time.Duration) {
|
func (r *MemoryTTL) StartGC(gcInterval time.Duration) {
|
||||||
r.gcInterval = gcInterval
|
if gcInterval <= 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
ticker := time.NewTicker(gcInterval)
|
||||||
|
done := make(chan struct{})
|
||||||
|
|
||||||
|
r.Lock()
|
||||||
|
r.gcTicker = ticker
|
||||||
|
r.done = done
|
||||||
|
r.Unlock()
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
for _ = range time.Tick(gcInterval) {
|
for {
|
||||||
for key := range r.cache.items {
|
select {
|
||||||
if !r.isValid(key) {
|
case <-ticker.C:
|
||||||
r.Delete(key)
|
now := time.Now()
|
||||||
|
|
||||||
|
r.Lock()
|
||||||
|
for key := range r.cache.items {
|
||||||
|
if !r.isValidTime(key, now) {
|
||||||
|
r.delete(key)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
r.Unlock()
|
||||||
|
case <-done:
|
||||||
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// StopGC stops sweeping goroutine.
|
||||||
|
func (r *MemoryTTL) StopGC() {
|
||||||
|
if r.gcTicker != nil {
|
||||||
|
r.Lock()
|
||||||
|
r.gcTicker.Stop()
|
||||||
|
r.gcTicker = nil
|
||||||
|
close(r.done)
|
||||||
|
r.done = nil
|
||||||
|
r.Unlock()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Get returns a value of a given key if it exists
|
// Get returns a value of a given key if it exists
|
||||||
// and valid for the time being
|
// and valid for the time being
|
||||||
func (r *MemoryTTL) Get(key string) (interface{}, error) {
|
func (r *MemoryTTL) Get(key string) (interface{}, error) {
|
||||||
r.Lock()
|
r.RLock()
|
||||||
defer r.Unlock()
|
|
||||||
|
|
||||||
if !r.isValid(key) {
|
for !r.isValid(key) {
|
||||||
r.delete(key)
|
r.RUnlock()
|
||||||
return nil, ErrNotFound
|
// Need write lock to delete key, so need to unlock, relock and recheck
|
||||||
|
r.Lock()
|
||||||
|
if !r.isValid(key) {
|
||||||
|
r.delete(key)
|
||||||
|
r.Unlock()
|
||||||
|
return nil, ErrNotFound
|
||||||
|
}
|
||||||
|
r.Unlock()
|
||||||
|
// Could become invalid again in this window
|
||||||
|
r.RLock()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
defer r.RUnlock()
|
||||||
|
|
||||||
value, err := r.cache.Get(key)
|
value, err := r.cache.Get(key)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -98,6 +143,10 @@ func (r *MemoryTTL) delete(key string) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *MemoryTTL) isValid(key string) bool {
|
func (r *MemoryTTL) isValid(key string) bool {
|
||||||
|
return r.isValidTime(key, time.Now())
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *MemoryTTL) isValidTime(key string, t time.Time) bool {
|
||||||
setAt, ok := r.setAts[key]
|
setAt, ok := r.setAts[key]
|
||||||
if !ok {
|
if !ok {
|
||||||
return false
|
return false
|
||||||
|
@ -107,5 +156,5 @@ func (r *MemoryTTL) isValid(key string) bool {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
return setAt.Add(r.ttl).After(time.Now())
|
return setAt.Add(r.ttl).After(t)
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,227 @@
|
||||||
|
package cache
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
mgo "gopkg.in/mgo.v2"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
defaultExpireDuration = time.Minute
|
||||||
|
defaultCollectionName = "jCache"
|
||||||
|
defaultGCInterval = time.Minute
|
||||||
|
indexExpireAt = "expireAt"
|
||||||
|
)
|
||||||
|
|
||||||
|
// MongoCache holds the cache values that will be stored in mongoDB
|
||||||
|
type MongoCache struct {
|
||||||
|
// mongeSession specifies the mongoDB connection
|
||||||
|
mongeSession *mgo.Session
|
||||||
|
|
||||||
|
// CollectionName speficies the optional collection name for mongoDB
|
||||||
|
// if CollectionName is not set, then default value will be set
|
||||||
|
CollectionName string
|
||||||
|
|
||||||
|
// ttl is a duration for a cache key to expire
|
||||||
|
TTL time.Duration
|
||||||
|
|
||||||
|
// GCInterval specifies the time duration for garbage collector time interval
|
||||||
|
GCInterval time.Duration
|
||||||
|
|
||||||
|
// GCStart starts the garbage collector and deletes the
|
||||||
|
// expired keys from mongo with given time interval
|
||||||
|
GCStart bool
|
||||||
|
|
||||||
|
// gcTicker controls gc intervals
|
||||||
|
gcTicker *time.Ticker
|
||||||
|
|
||||||
|
// done controls sweeping goroutine lifetime
|
||||||
|
done chan struct{}
|
||||||
|
|
||||||
|
// Mutex is used for handling the concurrent
|
||||||
|
// read/write requests for cache
|
||||||
|
sync.RWMutex
|
||||||
|
}
|
||||||
|
|
||||||
|
// Option sets the options specified.
|
||||||
|
type Option func(*MongoCache)
|
||||||
|
|
||||||
|
// NewMongoCacheWithTTL creates a caching layer backed by mongo. TTL's are
|
||||||
|
// managed either by a background cleaner or document is removed on the Get
|
||||||
|
// operation. Mongo TTL indexes are not utilized since there can be multiple
|
||||||
|
// systems using the same collection with different TTL values.
|
||||||
|
//
|
||||||
|
// The responsibility of stopping the GC process belongs to the user.
|
||||||
|
//
|
||||||
|
// Session is not closed while stopping the GC.
|
||||||
|
//
|
||||||
|
// This self-referential function satisfy you to avoid passing
|
||||||
|
// nil value to the function as parameter
|
||||||
|
// e.g (usage) :
|
||||||
|
// configure with defaults, just call;
|
||||||
|
// NewMongoCacheWithTTL(session)
|
||||||
|
//
|
||||||
|
// configure ttl duration with;
|
||||||
|
// NewMongoCacheWithTTL(session, func(m *MongoCache) {
|
||||||
|
// m.TTL = 2 * time.Minute
|
||||||
|
// })
|
||||||
|
// or
|
||||||
|
// NewMongoCacheWithTTL(session, SetTTL(time.Minute * 2))
|
||||||
|
//
|
||||||
|
// configure collection name with;
|
||||||
|
// NewMongoCacheWithTTL(session, func(m *MongoCache) {
|
||||||
|
// m.CollectionName = "MongoCacheCollectionName"
|
||||||
|
// })
|
||||||
|
func NewMongoCacheWithTTL(session *mgo.Session, configs ...Option) *MongoCache {
|
||||||
|
if session == nil {
|
||||||
|
panic("session must be set")
|
||||||
|
}
|
||||||
|
|
||||||
|
mc := &MongoCache{
|
||||||
|
mongeSession: session,
|
||||||
|
TTL: defaultExpireDuration,
|
||||||
|
CollectionName: defaultCollectionName,
|
||||||
|
GCInterval: defaultGCInterval,
|
||||||
|
GCStart: false,
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, configFunc := range configs {
|
||||||
|
configFunc(mc)
|
||||||
|
}
|
||||||
|
|
||||||
|
if mc.GCStart {
|
||||||
|
mc.StartGC(mc.GCInterval)
|
||||||
|
}
|
||||||
|
|
||||||
|
return mc
|
||||||
|
}
|
||||||
|
|
||||||
|
// MustEnsureIndexExpireAt ensures the expireAt index
|
||||||
|
// usage:
|
||||||
|
// NewMongoCacheWithTTL(mongoSession, MustEnsureIndexExpireAt())
|
||||||
|
func MustEnsureIndexExpireAt() Option {
|
||||||
|
return func(m *MongoCache) {
|
||||||
|
if err := m.EnsureIndex(); err != nil {
|
||||||
|
panic(fmt.Sprintf("index must ensure %q", err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// StartGC enables the garbage collector in MongoCache struct
|
||||||
|
// usage:
|
||||||
|
// NewMongoCacheWithTTL(mongoSession, StartGC())
|
||||||
|
func StartGC() Option {
|
||||||
|
return func(m *MongoCache) {
|
||||||
|
m.GCStart = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetTTL sets the ttl duration in MongoCache as option
|
||||||
|
// usage:
|
||||||
|
// NewMongoCacheWithTTL(mongoSession, SetTTL(time*Minute))
|
||||||
|
func SetTTL(duration time.Duration) Option {
|
||||||
|
return func(m *MongoCache) {
|
||||||
|
m.TTL = duration
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetGCInterval sets the garbage collector interval in MongoCache struct as option
|
||||||
|
// usage:
|
||||||
|
// NewMongoCacheWithTTL(mongoSession, SetGCInterval(time*Minute))
|
||||||
|
func SetGCInterval(duration time.Duration) Option {
|
||||||
|
return func(m *MongoCache) {
|
||||||
|
m.GCInterval = duration
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetCollectionName sets the collection name for mongoDB in MongoCache struct as option
|
||||||
|
// usage:
|
||||||
|
// NewMongoCacheWithTTL(mongoSession, SetCollectionName("mongoCollName"))
|
||||||
|
func SetCollectionName(collName string) Option {
|
||||||
|
return func(m *MongoCache) {
|
||||||
|
m.CollectionName = collName
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get returns a value of a given key if it exists
|
||||||
|
func (m *MongoCache) Get(key string) (interface{}, error) {
|
||||||
|
data, err := m.get(key)
|
||||||
|
if err == mgo.ErrNotFound {
|
||||||
|
return nil, ErrNotFound
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return data.Value, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set will persist a value to the cache or override existing one with the new
|
||||||
|
// one
|
||||||
|
func (m *MongoCache) Set(key string, value interface{}) error {
|
||||||
|
return m.set(key, m.TTL, value)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetEx will persist a value to the cache or override existing one with the new
|
||||||
|
// one with ttl duration
|
||||||
|
func (m *MongoCache) SetEx(key string, duration time.Duration, value interface{}) error {
|
||||||
|
return m.set(key, duration, value)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Delete deletes a given key if exists
|
||||||
|
func (m *MongoCache) Delete(key string) error {
|
||||||
|
return m.delete(key)
|
||||||
|
}
|
||||||
|
|
||||||
|
// EnsureIndex ensures the index with expireAt key
|
||||||
|
func (m *MongoCache) EnsureIndex() error {
|
||||||
|
query := func(c *mgo.Collection) error {
|
||||||
|
return c.EnsureIndexKey(indexExpireAt)
|
||||||
|
}
|
||||||
|
|
||||||
|
return m.run(m.CollectionName, query)
|
||||||
|
}
|
||||||
|
|
||||||
|
// StartGC starts the garbage collector with given time interval The
|
||||||
|
// expired data will be checked & deleted with given interval time
|
||||||
|
func (m *MongoCache) StartGC(gcInterval time.Duration) {
|
||||||
|
if gcInterval <= 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
ticker := time.NewTicker(gcInterval)
|
||||||
|
done := make(chan struct{})
|
||||||
|
|
||||||
|
m.Lock()
|
||||||
|
m.gcTicker = ticker
|
||||||
|
m.done = done
|
||||||
|
m.Unlock()
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ticker.C:
|
||||||
|
m.Lock()
|
||||||
|
m.deleteExpiredKeys()
|
||||||
|
m.Unlock()
|
||||||
|
case <-done:
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
// StopGC stops sweeping goroutine.
|
||||||
|
func (m *MongoCache) StopGC() {
|
||||||
|
if m.gcTicker != nil {
|
||||||
|
m.Lock()
|
||||||
|
m.gcTicker.Stop()
|
||||||
|
m.gcTicker = nil
|
||||||
|
close(m.done)
|
||||||
|
m.done = nil
|
||||||
|
m.Unlock()
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,81 @@
|
||||||
|
package cache
|
||||||
|
|
||||||
|
import (
|
||||||
|
"time"
|
||||||
|
|
||||||
|
mgo "gopkg.in/mgo.v2"
|
||||||
|
"gopkg.in/mgo.v2/bson"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Document holds the key-value pair for mongo cache
|
||||||
|
type Document struct {
|
||||||
|
Key string `bson:"_id" json:"_id"`
|
||||||
|
Value interface{} `bson:"value" json:"value"`
|
||||||
|
ExpireAt time.Time `bson:"expireAt" json:"expireAt"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// getKey fetches the key with its key
|
||||||
|
func (m *MongoCache) get(key string) (*Document, error) {
|
||||||
|
keyValue := new(Document)
|
||||||
|
|
||||||
|
query := func(c *mgo.Collection) error {
|
||||||
|
return c.Find(bson.M{
|
||||||
|
"_id": key,
|
||||||
|
"expireAt": bson.M{
|
||||||
|
"$gt": time.Now().UTC(),
|
||||||
|
}}).One(&keyValue)
|
||||||
|
}
|
||||||
|
|
||||||
|
err := m.run(m.CollectionName, query)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return keyValue, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MongoCache) set(key string, duration time.Duration, value interface{}) error {
|
||||||
|
update := bson.M{
|
||||||
|
"_id": key,
|
||||||
|
"value": value,
|
||||||
|
"expireAt": time.Now().Add(duration),
|
||||||
|
}
|
||||||
|
|
||||||
|
query := func(c *mgo.Collection) error {
|
||||||
|
_, err := c.UpsertId(key, update)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return m.run(m.CollectionName, query)
|
||||||
|
}
|
||||||
|
|
||||||
|
// deleteKey removes the key-value from mongoDB
|
||||||
|
func (m *MongoCache) delete(key string) error {
|
||||||
|
query := func(c *mgo.Collection) error {
|
||||||
|
err := c.RemoveId(key)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return m.run(m.CollectionName, query)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MongoCache) deleteExpiredKeys() error {
|
||||||
|
var selector = bson.M{"expireAt": bson.M{
|
||||||
|
"$lte": time.Now().UTC(),
|
||||||
|
}}
|
||||||
|
|
||||||
|
query := func(c *mgo.Collection) error {
|
||||||
|
_, err := c.RemoveAll(selector)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return m.run(m.CollectionName, query)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MongoCache) run(collection string, s func(*mgo.Collection) error) error {
|
||||||
|
session := m.mongeSession.Copy()
|
||||||
|
defer session.Close()
|
||||||
|
|
||||||
|
c := session.DB("").C(collection)
|
||||||
|
return s(c)
|
||||||
|
}
|
|
@ -0,0 +1,18 @@
|
||||||
|
package cache
|
||||||
|
|
||||||
|
// ShardedCache is the contract for all of the sharded cache backends that are supported by
|
||||||
|
// this package
|
||||||
|
type ShardedCache interface {
|
||||||
|
// Get returns single item from the backend if the requested item is not
|
||||||
|
// found, returns NotFound err
|
||||||
|
Get(shardID, key string) (interface{}, error)
|
||||||
|
|
||||||
|
// Set sets a single item to the backend
|
||||||
|
Set(shardID, key string, value interface{}) error
|
||||||
|
|
||||||
|
// Delete deletes single item from backend
|
||||||
|
Delete(shardID, key string) error
|
||||||
|
|
||||||
|
// Deletes all items in that shard
|
||||||
|
DeleteShard(shardID string) error
|
||||||
|
}
|
|
@ -0,0 +1,67 @@
|
||||||
|
package cache
|
||||||
|
|
||||||
|
// ShardedNoTS ; the concept behind this storage is that each cache entry is
|
||||||
|
// associated with a tenantID and this enables fast purging for just that
|
||||||
|
// tenantID
|
||||||
|
type ShardedNoTS struct {
|
||||||
|
cache map[string]Cache
|
||||||
|
itemCount map[string]int
|
||||||
|
constructor func() Cache
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewShardedNoTS inits ShardedNoTS struct
|
||||||
|
func NewShardedNoTS(c func() Cache) *ShardedNoTS {
|
||||||
|
return &ShardedNoTS{
|
||||||
|
constructor: c,
|
||||||
|
cache: make(map[string]Cache),
|
||||||
|
itemCount: make(map[string]int),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get returns a value of a given key if it exists
|
||||||
|
// and valid for the time being
|
||||||
|
func (l *ShardedNoTS) Get(tenantID, key string) (interface{}, error) {
|
||||||
|
cache, ok := l.cache[tenantID]
|
||||||
|
if !ok {
|
||||||
|
return nil, ErrNotFound
|
||||||
|
}
|
||||||
|
|
||||||
|
return cache.Get(key)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set will persist a value to the cache or override existing one with the new
|
||||||
|
// one
|
||||||
|
func (l *ShardedNoTS) Set(tenantID, key string, val interface{}) error {
|
||||||
|
_, ok := l.cache[tenantID]
|
||||||
|
if !ok {
|
||||||
|
l.cache[tenantID] = l.constructor()
|
||||||
|
l.itemCount[tenantID] = 0
|
||||||
|
}
|
||||||
|
|
||||||
|
l.itemCount[tenantID]++
|
||||||
|
return l.cache[tenantID].Set(key, val)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Delete deletes a given key
|
||||||
|
func (l *ShardedNoTS) Delete(tenantID, key string) error {
|
||||||
|
_, ok := l.cache[tenantID]
|
||||||
|
if !ok {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
l.itemCount[tenantID]--
|
||||||
|
|
||||||
|
if l.itemCount[tenantID] == 0 {
|
||||||
|
return l.DeleteShard(tenantID)
|
||||||
|
}
|
||||||
|
|
||||||
|
return l.cache[tenantID].Delete(key)
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteShard deletes the keys inside from maps of cache & itemCount
|
||||||
|
func (l *ShardedNoTS) DeleteShard(tenantID string) error {
|
||||||
|
delete(l.cache, tenantID)
|
||||||
|
delete(l.itemCount, tenantID)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
|
@ -0,0 +1,148 @@
|
||||||
|
package cache
|
||||||
|
|
||||||
|
import (
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ShardedTTL holds the required variables to compose an in memory sharded cache system
|
||||||
|
// which also provides expiring key mechanism
|
||||||
|
type ShardedTTL struct {
|
||||||
|
// Mutex is used for handling the concurrent
|
||||||
|
// read/write requests for cache
|
||||||
|
sync.Mutex
|
||||||
|
|
||||||
|
// cache holds the cache data
|
||||||
|
cache ShardedCache
|
||||||
|
|
||||||
|
// setAts holds the time that related item's set at, indexed by tenantID
|
||||||
|
setAts map[string]map[string]time.Time
|
||||||
|
|
||||||
|
// ttl is a duration for a cache key to expire
|
||||||
|
ttl time.Duration
|
||||||
|
|
||||||
|
// gcInterval is a duration for garbage collection
|
||||||
|
gcInterval time.Duration
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewShardedCacheWithTTL creates a sharded cache system with TTL based on specified Cache constructor
|
||||||
|
// Which everytime will return the true values about a cache hit
|
||||||
|
// and never will leak memory
|
||||||
|
// ttl is used for expiration of a key from cache
|
||||||
|
func NewShardedCacheWithTTL(ttl time.Duration, f func() Cache) *ShardedTTL {
|
||||||
|
return &ShardedTTL{
|
||||||
|
cache: NewShardedNoTS(f),
|
||||||
|
setAts: map[string]map[string]time.Time{},
|
||||||
|
ttl: ttl,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewShardedWithTTL creates an in-memory sharded cache system
|
||||||
|
// ttl is used for expiration of a key from cache
|
||||||
|
func NewShardedWithTTL(ttl time.Duration) *ShardedTTL {
|
||||||
|
return NewShardedCacheWithTTL(ttl, NewMemNoTSCache)
|
||||||
|
}
|
||||||
|
|
||||||
|
// StartGC starts the garbage collection process in a go routine
|
||||||
|
func (r *ShardedTTL) StartGC(gcInterval time.Duration) {
|
||||||
|
r.gcInterval = gcInterval
|
||||||
|
go func() {
|
||||||
|
for _ = range time.Tick(gcInterval) {
|
||||||
|
r.Lock()
|
||||||
|
for tenantID := range r.setAts {
|
||||||
|
for key := range r.setAts[tenantID] {
|
||||||
|
if !r.isValid(tenantID, key) {
|
||||||
|
r.delete(tenantID, key)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
r.Unlock()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get returns a value of a given key if it exists
|
||||||
|
// and valid for the time being
|
||||||
|
func (r *ShardedTTL) Get(tenantID, key string) (interface{}, error) {
|
||||||
|
r.Lock()
|
||||||
|
defer r.Unlock()
|
||||||
|
|
||||||
|
if !r.isValid(tenantID, key) {
|
||||||
|
r.delete(tenantID, key)
|
||||||
|
return nil, ErrNotFound
|
||||||
|
}
|
||||||
|
|
||||||
|
value, err := r.cache.Get(tenantID, key)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return value, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set will persist a value to the cache or
|
||||||
|
// override existing one with the new one
|
||||||
|
func (r *ShardedTTL) Set(tenantID, key string, value interface{}) error {
|
||||||
|
r.Lock()
|
||||||
|
defer r.Unlock()
|
||||||
|
|
||||||
|
r.cache.Set(tenantID, key, value)
|
||||||
|
_, ok := r.setAts[tenantID]
|
||||||
|
if !ok {
|
||||||
|
r.setAts[tenantID] = make(map[string]time.Time)
|
||||||
|
}
|
||||||
|
r.setAts[tenantID][key] = time.Now()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Delete deletes a given key if exists
|
||||||
|
func (r *ShardedTTL) Delete(tenantID, key string) error {
|
||||||
|
r.Lock()
|
||||||
|
defer r.Unlock()
|
||||||
|
|
||||||
|
r.delete(tenantID, key)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *ShardedTTL) delete(tenantID, key string) {
|
||||||
|
_, ok := r.setAts[tenantID]
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
r.cache.Delete(tenantID, key)
|
||||||
|
delete(r.setAts[tenantID], key)
|
||||||
|
if len(r.setAts[tenantID]) == 0 {
|
||||||
|
delete(r.setAts, tenantID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *ShardedTTL) isValid(tenantID, key string) bool {
|
||||||
|
|
||||||
|
_, ok := r.setAts[tenantID]
|
||||||
|
if !ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
setAt, ok := r.setAts[tenantID][key]
|
||||||
|
if !ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if r.ttl == zeroTTL {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
return setAt.Add(r.ttl).After(time.Now())
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteShard deletes with given tenantID without key
|
||||||
|
func (r *ShardedTTL) DeleteShard(tenantID string) error {
|
||||||
|
r.Lock()
|
||||||
|
defer r.Unlock()
|
||||||
|
|
||||||
|
_, ok := r.setAts[tenantID]
|
||||||
|
if ok {
|
||||||
|
for key := range r.setAts[tenantID] {
|
||||||
|
r.delete(tenantID, key)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
|
@ -0,0 +1,45 @@
|
||||||
|
language: go
|
||||||
|
|
||||||
|
go_import_path: gopkg.in/mgo.v2
|
||||||
|
|
||||||
|
addons:
|
||||||
|
apt:
|
||||||
|
packages:
|
||||||
|
|
||||||
|
env:
|
||||||
|
global:
|
||||||
|
- BUCKET=https://niemeyer.s3.amazonaws.com
|
||||||
|
matrix:
|
||||||
|
- GO=1.4.1 MONGODB=x86_64-2.2.7
|
||||||
|
- GO=1.4.1 MONGODB=x86_64-2.4.14
|
||||||
|
- GO=1.4.1 MONGODB=x86_64-2.6.11
|
||||||
|
- GO=1.4.1 MONGODB=x86_64-3.0.9
|
||||||
|
- GO=1.4.1 MONGODB=x86_64-3.2.3-nojournal
|
||||||
|
- GO=1.5.3 MONGODB=x86_64-3.0.9
|
||||||
|
- GO=1.6 MONGODB=x86_64-3.0.9
|
||||||
|
|
||||||
|
install:
|
||||||
|
- eval "$(gimme $GO)"
|
||||||
|
|
||||||
|
- wget $BUCKET/mongodb-linux-$MONGODB.tgz
|
||||||
|
- tar xzvf mongodb-linux-$MONGODB.tgz
|
||||||
|
- export PATH=$PWD/mongodb-linux-$MONGODB/bin:$PATH
|
||||||
|
|
||||||
|
- wget $BUCKET/daemontools.tar.gz
|
||||||
|
- tar xzvf daemontools.tar.gz
|
||||||
|
- export PATH=$PWD/daemontools:$PATH
|
||||||
|
|
||||||
|
- go get gopkg.in/check.v1
|
||||||
|
- go get gopkg.in/yaml.v2
|
||||||
|
- go get gopkg.in/tomb.v2
|
||||||
|
|
||||||
|
before_script:
|
||||||
|
- export NOIPV6=1
|
||||||
|
- make startdb
|
||||||
|
|
||||||
|
script:
|
||||||
|
- (cd bson && go test -check.v)
|
||||||
|
- go test -check.v -fast
|
||||||
|
- (cd txn && go test -check.v)
|
||||||
|
|
||||||
|
# vim:sw=4:ts=4:et
|
|
@ -0,0 +1,25 @@
|
||||||
|
mgo - MongoDB driver for Go
|
||||||
|
|
||||||
|
Copyright (c) 2010-2013 - Gustavo Niemeyer <gustavo@niemeyer.net>
|
||||||
|
|
||||||
|
All rights reserved.
|
||||||
|
|
||||||
|
Redistribution and use in source and binary forms, with or without
|
||||||
|
modification, are permitted provided that the following conditions are met:
|
||||||
|
|
||||||
|
1. Redistributions of source code must retain the above copyright notice, this
|
||||||
|
list of conditions and the following disclaimer.
|
||||||
|
2. Redistributions in binary form must reproduce the above copyright notice,
|
||||||
|
this list of conditions and the following disclaimer in the documentation
|
||||||
|
and/or other materials provided with the distribution.
|
||||||
|
|
||||||
|
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
|
||||||
|
ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
|
||||||
|
WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||||
|
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
|
||||||
|
ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
|
||||||
|
(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
|
||||||
|
LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
|
||||||
|
ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||||
|
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||||
|
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|
@ -0,0 +1,5 @@
|
||||||
|
startdb:
|
||||||
|
@harness/setup.sh start
|
||||||
|
|
||||||
|
stopdb:
|
||||||
|
@harness/setup.sh stop
|
|
@ -0,0 +1,4 @@
|
||||||
|
The MongoDB driver for Go
|
||||||
|
-------------------------
|
||||||
|
|
||||||
|
Please go to [http://labix.org/mgo](http://labix.org/mgo) for all project details.
|
|
@ -0,0 +1,467 @@
|
||||||
|
// mgo - MongoDB driver for Go
|
||||||
|
//
|
||||||
|
// Copyright (c) 2010-2012 - Gustavo Niemeyer <gustavo@niemeyer.net>
|
||||||
|
//
|
||||||
|
// All rights reserved.
|
||||||
|
//
|
||||||
|
// Redistribution and use in source and binary forms, with or without
|
||||||
|
// modification, are permitted provided that the following conditions are met:
|
||||||
|
//
|
||||||
|
// 1. Redistributions of source code must retain the above copyright notice, this
|
||||||
|
// list of conditions and the following disclaimer.
|
||||||
|
// 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||||
|
// this list of conditions and the following disclaimer in the documentation
|
||||||
|
// and/or other materials provided with the distribution.
|
||||||
|
//
|
||||||
|
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
|
||||||
|
// ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
|
||||||
|
// WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||||
|
// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
|
||||||
|
// ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
|
||||||
|
// (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
|
||||||
|
// LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
|
||||||
|
// ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||||
|
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||||
|
// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||||
|
|
||||||
|
package mgo
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/md5"
|
||||||
|
"crypto/sha1"
|
||||||
|
"encoding/hex"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"gopkg.in/mgo.v2/bson"
|
||||||
|
"gopkg.in/mgo.v2/internal/scram"
|
||||||
|
)
|
||||||
|
|
||||||
|
type authCmd struct {
|
||||||
|
Authenticate int
|
||||||
|
|
||||||
|
Nonce string
|
||||||
|
User string
|
||||||
|
Key string
|
||||||
|
}
|
||||||
|
|
||||||
|
type startSaslCmd struct {
|
||||||
|
StartSASL int `bson:"startSasl"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type authResult struct {
|
||||||
|
ErrMsg string
|
||||||
|
Ok bool
|
||||||
|
}
|
||||||
|
|
||||||
|
type getNonceCmd struct {
|
||||||
|
GetNonce int
|
||||||
|
}
|
||||||
|
|
||||||
|
type getNonceResult struct {
|
||||||
|
Nonce string
|
||||||
|
Err string "$err"
|
||||||
|
Code int
|
||||||
|
}
|
||||||
|
|
||||||
|
type logoutCmd struct {
|
||||||
|
Logout int
|
||||||
|
}
|
||||||
|
|
||||||
|
type saslCmd struct {
|
||||||
|
Start int `bson:"saslStart,omitempty"`
|
||||||
|
Continue int `bson:"saslContinue,omitempty"`
|
||||||
|
ConversationId int `bson:"conversationId,omitempty"`
|
||||||
|
Mechanism string `bson:"mechanism,omitempty"`
|
||||||
|
Payload []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
type saslResult struct {
|
||||||
|
Ok bool `bson:"ok"`
|
||||||
|
NotOk bool `bson:"code"` // Server <= 2.3.2 returns ok=1 & code>0 on errors (WTF?)
|
||||||
|
Done bool
|
||||||
|
|
||||||
|
ConversationId int `bson:"conversationId"`
|
||||||
|
Payload []byte
|
||||||
|
ErrMsg string
|
||||||
|
}
|
||||||
|
|
||||||
|
type saslStepper interface {
|
||||||
|
Step(serverData []byte) (clientData []byte, done bool, err error)
|
||||||
|
Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (socket *mongoSocket) getNonce() (nonce string, err error) {
|
||||||
|
socket.Lock()
|
||||||
|
for socket.cachedNonce == "" && socket.dead == nil {
|
||||||
|
debugf("Socket %p to %s: waiting for nonce", socket, socket.addr)
|
||||||
|
socket.gotNonce.Wait()
|
||||||
|
}
|
||||||
|
if socket.cachedNonce == "mongos" {
|
||||||
|
socket.Unlock()
|
||||||
|
return "", errors.New("Can't authenticate with mongos; see http://j.mp/mongos-auth")
|
||||||
|
}
|
||||||
|
debugf("Socket %p to %s: got nonce", socket, socket.addr)
|
||||||
|
nonce, err = socket.cachedNonce, socket.dead
|
||||||
|
socket.cachedNonce = ""
|
||||||
|
socket.Unlock()
|
||||||
|
if err != nil {
|
||||||
|
nonce = ""
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (socket *mongoSocket) resetNonce() {
|
||||||
|
debugf("Socket %p to %s: requesting a new nonce", socket, socket.addr)
|
||||||
|
op := &queryOp{}
|
||||||
|
op.query = &getNonceCmd{GetNonce: 1}
|
||||||
|
op.collection = "admin.$cmd"
|
||||||
|
op.limit = -1
|
||||||
|
op.replyFunc = func(err error, reply *replyOp, docNum int, docData []byte) {
|
||||||
|
if err != nil {
|
||||||
|
socket.kill(errors.New("getNonce: "+err.Error()), true)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
result := &getNonceResult{}
|
||||||
|
err = bson.Unmarshal(docData, &result)
|
||||||
|
if err != nil {
|
||||||
|
socket.kill(errors.New("Failed to unmarshal nonce: "+err.Error()), true)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
debugf("Socket %p to %s: nonce unmarshalled: %#v", socket, socket.addr, result)
|
||||||
|
if result.Code == 13390 {
|
||||||
|
// mongos doesn't yet support auth (see http://j.mp/mongos-auth)
|
||||||
|
result.Nonce = "mongos"
|
||||||
|
} else if result.Nonce == "" {
|
||||||
|
var msg string
|
||||||
|
if result.Err != "" {
|
||||||
|
msg = fmt.Sprintf("Got an empty nonce: %s (%d)", result.Err, result.Code)
|
||||||
|
} else {
|
||||||
|
msg = "Got an empty nonce"
|
||||||
|
}
|
||||||
|
socket.kill(errors.New(msg), true)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
socket.Lock()
|
||||||
|
if socket.cachedNonce != "" {
|
||||||
|
socket.Unlock()
|
||||||
|
panic("resetNonce: nonce already cached")
|
||||||
|
}
|
||||||
|
socket.cachedNonce = result.Nonce
|
||||||
|
socket.gotNonce.Signal()
|
||||||
|
socket.Unlock()
|
||||||
|
}
|
||||||
|
err := socket.Query(op)
|
||||||
|
if err != nil {
|
||||||
|
socket.kill(errors.New("resetNonce: "+err.Error()), true)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (socket *mongoSocket) Login(cred Credential) error {
|
||||||
|
socket.Lock()
|
||||||
|
if cred.Mechanism == "" && socket.serverInfo.MaxWireVersion >= 3 {
|
||||||
|
cred.Mechanism = "SCRAM-SHA-1"
|
||||||
|
}
|
||||||
|
for _, sockCred := range socket.creds {
|
||||||
|
if sockCred == cred {
|
||||||
|
debugf("Socket %p to %s: login: db=%q user=%q (already logged in)", socket, socket.addr, cred.Source, cred.Username)
|
||||||
|
socket.Unlock()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if socket.dropLogout(cred) {
|
||||||
|
debugf("Socket %p to %s: login: db=%q user=%q (cached)", socket, socket.addr, cred.Source, cred.Username)
|
||||||
|
socket.creds = append(socket.creds, cred)
|
||||||
|
socket.Unlock()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
socket.Unlock()
|
||||||
|
|
||||||
|
debugf("Socket %p to %s: login: db=%q user=%q", socket, socket.addr, cred.Source, cred.Username)
|
||||||
|
|
||||||
|
var err error
|
||||||
|
switch cred.Mechanism {
|
||||||
|
case "", "MONGODB-CR", "MONGO-CR": // Name changed to MONGODB-CR in SERVER-8501.
|
||||||
|
err = socket.loginClassic(cred)
|
||||||
|
case "PLAIN":
|
||||||
|
err = socket.loginPlain(cred)
|
||||||
|
case "MONGODB-X509":
|
||||||
|
err = socket.loginX509(cred)
|
||||||
|
default:
|
||||||
|
// Try SASL for everything else, if it is available.
|
||||||
|
err = socket.loginSASL(cred)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
debugf("Socket %p to %s: login error: %s", socket, socket.addr, err)
|
||||||
|
} else {
|
||||||
|
debugf("Socket %p to %s: login successful", socket, socket.addr)
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (socket *mongoSocket) loginClassic(cred Credential) error {
|
||||||
|
// Note that this only works properly because this function is
|
||||||
|
// synchronous, which means the nonce won't get reset while we're
|
||||||
|
// using it and any other login requests will block waiting for a
|
||||||
|
// new nonce provided in the defer call below.
|
||||||
|
nonce, err := socket.getNonce()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer socket.resetNonce()
|
||||||
|
|
||||||
|
psum := md5.New()
|
||||||
|
psum.Write([]byte(cred.Username + ":mongo:" + cred.Password))
|
||||||
|
|
||||||
|
ksum := md5.New()
|
||||||
|
ksum.Write([]byte(nonce + cred.Username))
|
||||||
|
ksum.Write([]byte(hex.EncodeToString(psum.Sum(nil))))
|
||||||
|
|
||||||
|
key := hex.EncodeToString(ksum.Sum(nil))
|
||||||
|
|
||||||
|
cmd := authCmd{Authenticate: 1, User: cred.Username, Nonce: nonce, Key: key}
|
||||||
|
res := authResult{}
|
||||||
|
return socket.loginRun(cred.Source, &cmd, &res, func() error {
|
||||||
|
if !res.Ok {
|
||||||
|
return errors.New(res.ErrMsg)
|
||||||
|
}
|
||||||
|
socket.Lock()
|
||||||
|
socket.dropAuth(cred.Source)
|
||||||
|
socket.creds = append(socket.creds, cred)
|
||||||
|
socket.Unlock()
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
type authX509Cmd struct {
|
||||||
|
Authenticate int
|
||||||
|
User string
|
||||||
|
Mechanism string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (socket *mongoSocket) loginX509(cred Credential) error {
|
||||||
|
cmd := authX509Cmd{Authenticate: 1, User: cred.Username, Mechanism: "MONGODB-X509"}
|
||||||
|
res := authResult{}
|
||||||
|
return socket.loginRun(cred.Source, &cmd, &res, func() error {
|
||||||
|
if !res.Ok {
|
||||||
|
return errors.New(res.ErrMsg)
|
||||||
|
}
|
||||||
|
socket.Lock()
|
||||||
|
socket.dropAuth(cred.Source)
|
||||||
|
socket.creds = append(socket.creds, cred)
|
||||||
|
socket.Unlock()
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (socket *mongoSocket) loginPlain(cred Credential) error {
|
||||||
|
cmd := saslCmd{Start: 1, Mechanism: "PLAIN", Payload: []byte("\x00" + cred.Username + "\x00" + cred.Password)}
|
||||||
|
res := authResult{}
|
||||||
|
return socket.loginRun(cred.Source, &cmd, &res, func() error {
|
||||||
|
if !res.Ok {
|
||||||
|
return errors.New(res.ErrMsg)
|
||||||
|
}
|
||||||
|
socket.Lock()
|
||||||
|
socket.dropAuth(cred.Source)
|
||||||
|
socket.creds = append(socket.creds, cred)
|
||||||
|
socket.Unlock()
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (socket *mongoSocket) loginSASL(cred Credential) error {
|
||||||
|
var sasl saslStepper
|
||||||
|
var err error
|
||||||
|
if cred.Mechanism == "SCRAM-SHA-1" {
|
||||||
|
// SCRAM is handled without external libraries.
|
||||||
|
sasl = saslNewScram(cred)
|
||||||
|
} else if len(cred.ServiceHost) > 0 {
|
||||||
|
sasl, err = saslNew(cred, cred.ServiceHost)
|
||||||
|
} else {
|
||||||
|
sasl, err = saslNew(cred, socket.Server().Addr)
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer sasl.Close()
|
||||||
|
|
||||||
|
// The goal of this logic is to carry a locked socket until the
|
||||||
|
// local SASL step confirms the auth is valid; the socket needs to be
|
||||||
|
// locked so that concurrent action doesn't leave the socket in an
|
||||||
|
// auth state that doesn't reflect the operations that took place.
|
||||||
|
// As a simple case, imagine inverting login=>logout to logout=>login.
|
||||||
|
//
|
||||||
|
// The logic below works because the lock func isn't called concurrently.
|
||||||
|
locked := false
|
||||||
|
lock := func(b bool) {
|
||||||
|
if locked != b {
|
||||||
|
locked = b
|
||||||
|
if b {
|
||||||
|
socket.Lock()
|
||||||
|
} else {
|
||||||
|
socket.Unlock()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
lock(true)
|
||||||
|
defer lock(false)
|
||||||
|
|
||||||
|
start := 1
|
||||||
|
cmd := saslCmd{}
|
||||||
|
res := saslResult{}
|
||||||
|
for {
|
||||||
|
payload, done, err := sasl.Step(res.Payload)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if done && res.Done {
|
||||||
|
socket.dropAuth(cred.Source)
|
||||||
|
socket.creds = append(socket.creds, cred)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
lock(false)
|
||||||
|
|
||||||
|
cmd = saslCmd{
|
||||||
|
Start: start,
|
||||||
|
Continue: 1 - start,
|
||||||
|
ConversationId: res.ConversationId,
|
||||||
|
Mechanism: cred.Mechanism,
|
||||||
|
Payload: payload,
|
||||||
|
}
|
||||||
|
start = 0
|
||||||
|
err = socket.loginRun(cred.Source, &cmd, &res, func() error {
|
||||||
|
// See the comment on lock for why this is necessary.
|
||||||
|
lock(true)
|
||||||
|
if !res.Ok || res.NotOk {
|
||||||
|
return fmt.Errorf("server returned error on SASL authentication step: %s", res.ErrMsg)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if done && res.Done {
|
||||||
|
socket.dropAuth(cred.Source)
|
||||||
|
socket.creds = append(socket.creds, cred)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func saslNewScram(cred Credential) *saslScram {
|
||||||
|
credsum := md5.New()
|
||||||
|
credsum.Write([]byte(cred.Username + ":mongo:" + cred.Password))
|
||||||
|
client := scram.NewClient(sha1.New, cred.Username, hex.EncodeToString(credsum.Sum(nil)))
|
||||||
|
return &saslScram{cred: cred, client: client}
|
||||||
|
}
|
||||||
|
|
||||||
|
type saslScram struct {
|
||||||
|
cred Credential
|
||||||
|
client *scram.Client
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *saslScram) Close() {}
|
||||||
|
|
||||||
|
func (s *saslScram) Step(serverData []byte) (clientData []byte, done bool, err error) {
|
||||||
|
more := s.client.Step(serverData)
|
||||||
|
return s.client.Out(), !more, s.client.Err()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (socket *mongoSocket) loginRun(db string, query, result interface{}, f func() error) error {
|
||||||
|
var mutex sync.Mutex
|
||||||
|
var replyErr error
|
||||||
|
mutex.Lock()
|
||||||
|
|
||||||
|
op := queryOp{}
|
||||||
|
op.query = query
|
||||||
|
op.collection = db + ".$cmd"
|
||||||
|
op.limit = -1
|
||||||
|
op.replyFunc = func(err error, reply *replyOp, docNum int, docData []byte) {
|
||||||
|
defer mutex.Unlock()
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
replyErr = err
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
err = bson.Unmarshal(docData, result)
|
||||||
|
if err != nil {
|
||||||
|
replyErr = err
|
||||||
|
} else {
|
||||||
|
// Must handle this within the read loop for the socket, so
|
||||||
|
// that concurrent login requests are properly ordered.
|
||||||
|
replyErr = f()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
err := socket.Query(&op)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
mutex.Lock() // Wait.
|
||||||
|
return replyErr
|
||||||
|
}
|
||||||
|
|
||||||
|
func (socket *mongoSocket) Logout(db string) {
|
||||||
|
socket.Lock()
|
||||||
|
cred, found := socket.dropAuth(db)
|
||||||
|
if found {
|
||||||
|
debugf("Socket %p to %s: logout: db=%q (flagged)", socket, socket.addr, db)
|
||||||
|
socket.logout = append(socket.logout, cred)
|
||||||
|
}
|
||||||
|
socket.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (socket *mongoSocket) LogoutAll() {
|
||||||
|
socket.Lock()
|
||||||
|
if l := len(socket.creds); l > 0 {
|
||||||
|
debugf("Socket %p to %s: logout all (flagged %d)", socket, socket.addr, l)
|
||||||
|
socket.logout = append(socket.logout, socket.creds...)
|
||||||
|
socket.creds = socket.creds[0:0]
|
||||||
|
}
|
||||||
|
socket.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (socket *mongoSocket) flushLogout() (ops []interface{}) {
|
||||||
|
socket.Lock()
|
||||||
|
if l := len(socket.logout); l > 0 {
|
||||||
|
debugf("Socket %p to %s: logout all (flushing %d)", socket, socket.addr, l)
|
||||||
|
for i := 0; i != l; i++ {
|
||||||
|
op := queryOp{}
|
||||||
|
op.query = &logoutCmd{1}
|
||||||
|
op.collection = socket.logout[i].Source + ".$cmd"
|
||||||
|
op.limit = -1
|
||||||
|
ops = append(ops, &op)
|
||||||
|
}
|
||||||
|
socket.logout = socket.logout[0:0]
|
||||||
|
}
|
||||||
|
socket.Unlock()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (socket *mongoSocket) dropAuth(db string) (cred Credential, found bool) {
|
||||||
|
for i, sockCred := range socket.creds {
|
||||||
|
if sockCred.Source == db {
|
||||||
|
copy(socket.creds[i:], socket.creds[i+1:])
|
||||||
|
socket.creds = socket.creds[:len(socket.creds)-1]
|
||||||
|
return sockCred, true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return cred, false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (socket *mongoSocket) dropLogout(cred Credential) (found bool) {
|
||||||
|
for i, sockCred := range socket.logout {
|
||||||
|
if sockCred == cred {
|
||||||
|
copy(socket.logout[i:], socket.logout[i+1:])
|
||||||
|
socket.logout = socket.logout[:len(socket.logout)-1]
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
|
@ -0,0 +1,25 @@
|
||||||
|
BSON library for Go
|
||||||
|
|
||||||
|
Copyright (c) 2010-2012 - Gustavo Niemeyer <gustavo@niemeyer.net>
|
||||||
|
|
||||||
|
All rights reserved.
|
||||||
|
|
||||||
|
Redistribution and use in source and binary forms, with or without
|
||||||
|
modification, are permitted provided that the following conditions are met:
|
||||||
|
|
||||||
|
1. Redistributions of source code must retain the above copyright notice, this
|
||||||
|
list of conditions and the following disclaimer.
|
||||||
|
2. Redistributions in binary form must reproduce the above copyright notice,
|
||||||
|
this list of conditions and the following disclaimer in the documentation
|
||||||
|
and/or other materials provided with the distribution.
|
||||||
|
|
||||||
|
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
|
||||||
|
ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
|
||||||
|
WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||||
|
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
|
||||||
|
ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
|
||||||
|
(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
|
||||||
|
LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
|
||||||
|
ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||||
|
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||||
|
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|
@ -0,0 +1,734 @@
|
||||||
|
// BSON library for Go
|
||||||
|
//
|
||||||
|
// Copyright (c) 2010-2012 - Gustavo Niemeyer <gustavo@niemeyer.net>
|
||||||
|
//
|
||||||
|
// All rights reserved.
|
||||||
|
//
|
||||||
|
// Redistribution and use in source and binary forms, with or without
|
||||||
|
// modification, are permitted provided that the following conditions are met:
|
||||||
|
//
|
||||||
|
// 1. Redistributions of source code must retain the above copyright notice, this
|
||||||
|
// list of conditions and the following disclaimer.
|
||||||
|
// 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||||
|
// this list of conditions and the following disclaimer in the documentation
|
||||||
|
// and/or other materials provided with the distribution.
|
||||||
|
//
|
||||||
|
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
|
||||||
|
// ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
|
||||||
|
// WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||||
|
// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
|
||||||
|
// ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
|
||||||
|
// (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
|
||||||
|
// LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
|
||||||
|
// ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||||
|
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||||
|
// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||||
|
|
||||||
|
// Package bson is an implementation of the BSON specification for Go:
|
||||||
|
//
|
||||||
|
// http://bsonspec.org
|
||||||
|
//
|
||||||
|
// It was created as part of the mgo MongoDB driver for Go, but is standalone
|
||||||
|
// and may be used on its own without the driver.
|
||||||
|
package bson
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"crypto/md5"
|
||||||
|
"encoding/binary"
|
||||||
|
"encoding/hex"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"reflect"
|
||||||
|
"runtime"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// --------------------------------------------------------------------------
|
||||||
|
// The public API.
|
||||||
|
|
||||||
|
// A value implementing the bson.Getter interface will have its GetBSON
|
||||||
|
// method called when the given value has to be marshalled, and the result
|
||||||
|
// of this method will be marshaled in place of the actual object.
|
||||||
|
//
|
||||||
|
// If GetBSON returns return a non-nil error, the marshalling procedure
|
||||||
|
// will stop and error out with the provided value.
|
||||||
|
type Getter interface {
|
||||||
|
GetBSON() (interface{}, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// A value implementing the bson.Setter interface will receive the BSON
|
||||||
|
// value via the SetBSON method during unmarshaling, and the object
|
||||||
|
// itself will not be changed as usual.
|
||||||
|
//
|
||||||
|
// If setting the value works, the method should return nil or alternatively
|
||||||
|
// bson.SetZero to set the respective field to its zero value (nil for
|
||||||
|
// pointer types). If SetBSON returns a value of type bson.TypeError, the
|
||||||
|
// BSON value will be omitted from a map or slice being decoded and the
|
||||||
|
// unmarshalling will continue. If it returns any other non-nil error, the
|
||||||
|
// unmarshalling procedure will stop and error out with the provided value.
|
||||||
|
//
|
||||||
|
// This interface is generally useful in pointer receivers, since the method
|
||||||
|
// will want to change the receiver. A type field that implements the Setter
|
||||||
|
// interface doesn't have to be a pointer, though.
|
||||||
|
//
|
||||||
|
// Unlike the usual behavior, unmarshalling onto a value that implements a
|
||||||
|
// Setter interface will NOT reset the value to its zero state. This allows
|
||||||
|
// the value to decide by itself how to be unmarshalled.
|
||||||
|
//
|
||||||
|
// For example:
|
||||||
|
//
|
||||||
|
// type MyString string
|
||||||
|
//
|
||||||
|
// func (s *MyString) SetBSON(raw bson.Raw) error {
|
||||||
|
// return raw.Unmarshal(s)
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
type Setter interface {
|
||||||
|
SetBSON(raw Raw) error
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetZero may be returned from a SetBSON method to have the value set to
|
||||||
|
// its respective zero value. When used in pointer values, this will set the
|
||||||
|
// field to nil rather than to the pre-allocated value.
|
||||||
|
var SetZero = errors.New("set to zero")
|
||||||
|
|
||||||
|
// M is a convenient alias for a map[string]interface{} map, useful for
|
||||||
|
// dealing with BSON in a native way. For instance:
|
||||||
|
//
|
||||||
|
// bson.M{"a": 1, "b": true}
|
||||||
|
//
|
||||||
|
// There's no special handling for this type in addition to what's done anyway
|
||||||
|
// for an equivalent map type. Elements in the map will be dumped in an
|
||||||
|
// undefined ordered. See also the bson.D type for an ordered alternative.
|
||||||
|
type M map[string]interface{}
|
||||||
|
|
||||||
|
// D represents a BSON document containing ordered elements. For example:
|
||||||
|
//
|
||||||
|
// bson.D{{"a", 1}, {"b", true}}
|
||||||
|
//
|
||||||
|
// In some situations, such as when creating indexes for MongoDB, the order in
|
||||||
|
// which the elements are defined is important. If the order is not important,
|
||||||
|
// using a map is generally more comfortable. See bson.M and bson.RawD.
|
||||||
|
type D []DocElem
|
||||||
|
|
||||||
|
// DocElem is an element of the bson.D document representation.
|
||||||
|
type DocElem struct {
|
||||||
|
Name string
|
||||||
|
Value interface{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Map returns a map out of the ordered element name/value pairs in d.
|
||||||
|
func (d D) Map() (m M) {
|
||||||
|
m = make(M, len(d))
|
||||||
|
for _, item := range d {
|
||||||
|
m[item.Name] = item.Value
|
||||||
|
}
|
||||||
|
return m
|
||||||
|
}
|
||||||
|
|
||||||
|
// The Raw type represents raw unprocessed BSON documents and elements.
|
||||||
|
// Kind is the kind of element as defined per the BSON specification, and
|
||||||
|
// Data is the raw unprocessed data for the respective element.
|
||||||
|
// Using this type it is possible to unmarshal or marshal values partially.
|
||||||
|
//
|
||||||
|
// Relevant documentation:
|
||||||
|
//
|
||||||
|
// http://bsonspec.org/#/specification
|
||||||
|
//
|
||||||
|
type Raw struct {
|
||||||
|
Kind byte
|
||||||
|
Data []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
// RawD represents a BSON document containing raw unprocessed elements.
|
||||||
|
// This low-level representation may be useful when lazily processing
|
||||||
|
// documents of uncertain content, or when manipulating the raw content
|
||||||
|
// documents in general.
|
||||||
|
type RawD []RawDocElem
|
||||||
|
|
||||||
|
// See the RawD type.
|
||||||
|
type RawDocElem struct {
|
||||||
|
Name string
|
||||||
|
Value Raw
|
||||||
|
}
|
||||||
|
|
||||||
|
// ObjectId is a unique ID identifying a BSON value. It must be exactly 12 bytes
|
||||||
|
// long. MongoDB objects by default have such a property set in their "_id"
|
||||||
|
// property.
|
||||||
|
//
|
||||||
|
// http://www.mongodb.org/display/DOCS/Object+IDs
|
||||||
|
type ObjectId string
|
||||||
|
|
||||||
|
// ObjectIdHex returns an ObjectId from the provided hex representation.
|
||||||
|
// Calling this function with an invalid hex representation will
|
||||||
|
// cause a runtime panic. See the IsObjectIdHex function.
|
||||||
|
func ObjectIdHex(s string) ObjectId {
|
||||||
|
d, err := hex.DecodeString(s)
|
||||||
|
if err != nil || len(d) != 12 {
|
||||||
|
panic(fmt.Sprintf("invalid input to ObjectIdHex: %q", s))
|
||||||
|
}
|
||||||
|
return ObjectId(d)
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsObjectIdHex returns whether s is a valid hex representation of
|
||||||
|
// an ObjectId. See the ObjectIdHex function.
|
||||||
|
func IsObjectIdHex(s string) bool {
|
||||||
|
if len(s) != 24 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
_, err := hex.DecodeString(s)
|
||||||
|
return err == nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// objectIdCounter is atomically incremented when generating a new ObjectId
|
||||||
|
// using NewObjectId() function. It's used as a counter part of an id.
|
||||||
|
var objectIdCounter uint32 = readRandomUint32()
|
||||||
|
|
||||||
|
// readRandomUint32 returns a random objectIdCounter.
|
||||||
|
func readRandomUint32() uint32 {
|
||||||
|
// We've found systems hanging in this function due to lack of entropy.
|
||||||
|
// The randomness of these bytes is just preventing nearby clashes, so
|
||||||
|
// just look at the time.
|
||||||
|
return uint32(time.Now().UnixNano())
|
||||||
|
}
|
||||||
|
|
||||||
|
// machineId stores machine id generated once and used in subsequent calls
|
||||||
|
// to NewObjectId function.
|
||||||
|
var machineId = readMachineId()
|
||||||
|
var processId = os.Getpid()
|
||||||
|
|
||||||
|
// readMachineId generates and returns a machine id.
|
||||||
|
// If this function fails to get the hostname it will cause a runtime error.
|
||||||
|
func readMachineId() []byte {
|
||||||
|
var sum [3]byte
|
||||||
|
id := sum[:]
|
||||||
|
hostname, err1 := os.Hostname()
|
||||||
|
if err1 != nil {
|
||||||
|
n := uint32(time.Now().UnixNano())
|
||||||
|
sum[0] = byte(n >> 0)
|
||||||
|
sum[1] = byte(n >> 8)
|
||||||
|
sum[2] = byte(n >> 16)
|
||||||
|
return id
|
||||||
|
}
|
||||||
|
hw := md5.New()
|
||||||
|
hw.Write([]byte(hostname))
|
||||||
|
copy(id, hw.Sum(nil))
|
||||||
|
return id
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewObjectId returns a new unique ObjectId.
|
||||||
|
func NewObjectId() ObjectId {
|
||||||
|
var b [12]byte
|
||||||
|
// Timestamp, 4 bytes, big endian
|
||||||
|
binary.BigEndian.PutUint32(b[:], uint32(time.Now().Unix()))
|
||||||
|
// Machine, first 3 bytes of md5(hostname)
|
||||||
|
b[4] = machineId[0]
|
||||||
|
b[5] = machineId[1]
|
||||||
|
b[6] = machineId[2]
|
||||||
|
// Pid, 2 bytes, specs don't specify endianness, but we use big endian.
|
||||||
|
b[7] = byte(processId >> 8)
|
||||||
|
b[8] = byte(processId)
|
||||||
|
// Increment, 3 bytes, big endian
|
||||||
|
i := atomic.AddUint32(&objectIdCounter, 1)
|
||||||
|
b[9] = byte(i >> 16)
|
||||||
|
b[10] = byte(i >> 8)
|
||||||
|
b[11] = byte(i)
|
||||||
|
return ObjectId(b[:])
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewObjectIdWithTime returns a dummy ObjectId with the timestamp part filled
|
||||||
|
// with the provided number of seconds from epoch UTC, and all other parts
|
||||||
|
// filled with zeroes. It's not safe to insert a document with an id generated
|
||||||
|
// by this method, it is useful only for queries to find documents with ids
|
||||||
|
// generated before or after the specified timestamp.
|
||||||
|
func NewObjectIdWithTime(t time.Time) ObjectId {
|
||||||
|
var b [12]byte
|
||||||
|
binary.BigEndian.PutUint32(b[:4], uint32(t.Unix()))
|
||||||
|
return ObjectId(string(b[:]))
|
||||||
|
}
|
||||||
|
|
||||||
|
// String returns a hex string representation of the id.
|
||||||
|
// Example: ObjectIdHex("4d88e15b60f486e428412dc9").
|
||||||
|
func (id ObjectId) String() string {
|
||||||
|
return fmt.Sprintf(`ObjectIdHex("%x")`, string(id))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Hex returns a hex representation of the ObjectId.
|
||||||
|
func (id ObjectId) Hex() string {
|
||||||
|
return hex.EncodeToString([]byte(id))
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarshalJSON turns a bson.ObjectId into a json.Marshaller.
|
||||||
|
func (id ObjectId) MarshalJSON() ([]byte, error) {
|
||||||
|
return []byte(fmt.Sprintf(`"%x"`, string(id))), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var nullBytes = []byte("null")
|
||||||
|
|
||||||
|
// UnmarshalJSON turns *bson.ObjectId into a json.Unmarshaller.
|
||||||
|
func (id *ObjectId) UnmarshalJSON(data []byte) error {
|
||||||
|
if len(data) > 0 && (data[0] == '{' || data[0] == 'O') {
|
||||||
|
var v struct {
|
||||||
|
Id json.RawMessage `json:"$oid"`
|
||||||
|
Func struct {
|
||||||
|
Id json.RawMessage
|
||||||
|
} `json:"$oidFunc"`
|
||||||
|
}
|
||||||
|
err := jdec(data, &v)
|
||||||
|
if err == nil {
|
||||||
|
if len(v.Id) > 0 {
|
||||||
|
data = []byte(v.Id)
|
||||||
|
} else {
|
||||||
|
data = []byte(v.Func.Id)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(data) == 2 && data[0] == '"' && data[1] == '"' || bytes.Equal(data, nullBytes) {
|
||||||
|
*id = ""
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if len(data) != 26 || data[0] != '"' || data[25] != '"' {
|
||||||
|
return errors.New(fmt.Sprintf("invalid ObjectId in JSON: %s", string(data)))
|
||||||
|
}
|
||||||
|
var buf [12]byte
|
||||||
|
_, err := hex.Decode(buf[:], data[1:25])
|
||||||
|
if err != nil {
|
||||||
|
return errors.New(fmt.Sprintf("invalid ObjectId in JSON: %s (%s)", string(data), err))
|
||||||
|
}
|
||||||
|
*id = ObjectId(string(buf[:]))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarshalText turns bson.ObjectId into an encoding.TextMarshaler.
|
||||||
|
func (id ObjectId) MarshalText() ([]byte, error) {
|
||||||
|
return []byte(fmt.Sprintf("%x", string(id))), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// UnmarshalText turns *bson.ObjectId into an encoding.TextUnmarshaler.
|
||||||
|
func (id *ObjectId) UnmarshalText(data []byte) error {
|
||||||
|
if len(data) == 1 && data[0] == ' ' || len(data) == 0 {
|
||||||
|
*id = ""
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if len(data) != 24 {
|
||||||
|
return fmt.Errorf("invalid ObjectId: %s", data)
|
||||||
|
}
|
||||||
|
var buf [12]byte
|
||||||
|
_, err := hex.Decode(buf[:], data[:])
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("invalid ObjectId: %s (%s)", data, err)
|
||||||
|
}
|
||||||
|
*id = ObjectId(string(buf[:]))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Valid returns true if id is valid. A valid id must contain exactly 12 bytes.
|
||||||
|
func (id ObjectId) Valid() bool {
|
||||||
|
return len(id) == 12
|
||||||
|
}
|
||||||
|
|
||||||
|
// byteSlice returns byte slice of id from start to end.
|
||||||
|
// Calling this function with an invalid id will cause a runtime panic.
|
||||||
|
func (id ObjectId) byteSlice(start, end int) []byte {
|
||||||
|
if len(id) != 12 {
|
||||||
|
panic(fmt.Sprintf("invalid ObjectId: %q", string(id)))
|
||||||
|
}
|
||||||
|
return []byte(string(id)[start:end])
|
||||||
|
}
|
||||||
|
|
||||||
|
// Time returns the timestamp part of the id.
|
||||||
|
// It's a runtime error to call this method with an invalid id.
|
||||||
|
func (id ObjectId) Time() time.Time {
|
||||||
|
// First 4 bytes of ObjectId is 32-bit big-endian seconds from epoch.
|
||||||
|
secs := int64(binary.BigEndian.Uint32(id.byteSlice(0, 4)))
|
||||||
|
return time.Unix(secs, 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Machine returns the 3-byte machine id part of the id.
|
||||||
|
// It's a runtime error to call this method with an invalid id.
|
||||||
|
func (id ObjectId) Machine() []byte {
|
||||||
|
return id.byteSlice(4, 7)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Pid returns the process id part of the id.
|
||||||
|
// It's a runtime error to call this method with an invalid id.
|
||||||
|
func (id ObjectId) Pid() uint16 {
|
||||||
|
return binary.BigEndian.Uint16(id.byteSlice(7, 9))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Counter returns the incrementing value part of the id.
|
||||||
|
// It's a runtime error to call this method with an invalid id.
|
||||||
|
func (id ObjectId) Counter() int32 {
|
||||||
|
b := id.byteSlice(9, 12)
|
||||||
|
// Counter is stored as big-endian 3-byte value
|
||||||
|
return int32(uint32(b[0])<<16 | uint32(b[1])<<8 | uint32(b[2]))
|
||||||
|
}
|
||||||
|
|
||||||
|
// The Symbol type is similar to a string and is used in languages with a
|
||||||
|
// distinct symbol type.
|
||||||
|
type Symbol string
|
||||||
|
|
||||||
|
// Now returns the current time with millisecond precision. MongoDB stores
|
||||||
|
// timestamps with the same precision, so a Time returned from this method
|
||||||
|
// will not change after a roundtrip to the database. That's the only reason
|
||||||
|
// why this function exists. Using the time.Now function also works fine
|
||||||
|
// otherwise.
|
||||||
|
func Now() time.Time {
|
||||||
|
return time.Unix(0, time.Now().UnixNano()/1e6*1e6)
|
||||||
|
}
|
||||||
|
|
||||||
|
// MongoTimestamp is a special internal type used by MongoDB that for some
|
||||||
|
// strange reason has its own datatype defined in BSON.
|
||||||
|
type MongoTimestamp int64
|
||||||
|
|
||||||
|
type orderKey int64
|
||||||
|
|
||||||
|
// MaxKey is a special value that compares higher than all other possible BSON
|
||||||
|
// values in a MongoDB database.
|
||||||
|
var MaxKey = orderKey(1<<63 - 1)
|
||||||
|
|
||||||
|
// MinKey is a special value that compares lower than all other possible BSON
|
||||||
|
// values in a MongoDB database.
|
||||||
|
var MinKey = orderKey(-1 << 63)
|
||||||
|
|
||||||
|
type undefined struct{}
|
||||||
|
|
||||||
|
// Undefined represents the undefined BSON value.
|
||||||
|
var Undefined undefined
|
||||||
|
|
||||||
|
// Binary is a representation for non-standard binary values. Any kind should
|
||||||
|
// work, but the following are known as of this writing:
|
||||||
|
//
|
||||||
|
// 0x00 - Generic. This is decoded as []byte(data), not Binary{0x00, data}.
|
||||||
|
// 0x01 - Function (!?)
|
||||||
|
// 0x02 - Obsolete generic.
|
||||||
|
// 0x03 - UUID
|
||||||
|
// 0x05 - MD5
|
||||||
|
// 0x80 - User defined.
|
||||||
|
//
|
||||||
|
type Binary struct {
|
||||||
|
Kind byte
|
||||||
|
Data []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
// RegEx represents a regular expression. The Options field may contain
|
||||||
|
// individual characters defining the way in which the pattern should be
|
||||||
|
// applied, and must be sorted. Valid options as of this writing are 'i' for
|
||||||
|
// case insensitive matching, 'm' for multi-line matching, 'x' for verbose
|
||||||
|
// mode, 'l' to make \w, \W, and similar be locale-dependent, 's' for dot-all
|
||||||
|
// mode (a '.' matches everything), and 'u' to make \w, \W, and similar match
|
||||||
|
// unicode. The value of the Options parameter is not verified before being
|
||||||
|
// marshaled into the BSON format.
|
||||||
|
type RegEx struct {
|
||||||
|
Pattern string
|
||||||
|
Options string
|
||||||
|
}
|
||||||
|
|
||||||
|
// JavaScript is a type that holds JavaScript code. If Scope is non-nil, it
|
||||||
|
// will be marshaled as a mapping from identifiers to values that may be
|
||||||
|
// used when evaluating the provided Code.
|
||||||
|
type JavaScript struct {
|
||||||
|
Code string
|
||||||
|
Scope interface{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// DBPointer refers to a document id in a namespace.
|
||||||
|
//
|
||||||
|
// This type is deprecated in the BSON specification and should not be used
|
||||||
|
// except for backwards compatibility with ancient applications.
|
||||||
|
type DBPointer struct {
|
||||||
|
Namespace string
|
||||||
|
Id ObjectId
|
||||||
|
}
|
||||||
|
|
||||||
|
const initialBufferSize = 64
|
||||||
|
|
||||||
|
func handleErr(err *error) {
|
||||||
|
if r := recover(); r != nil {
|
||||||
|
if _, ok := r.(runtime.Error); ok {
|
||||||
|
panic(r)
|
||||||
|
} else if _, ok := r.(externalPanic); ok {
|
||||||
|
panic(r)
|
||||||
|
} else if s, ok := r.(string); ok {
|
||||||
|
*err = errors.New(s)
|
||||||
|
} else if e, ok := r.(error); ok {
|
||||||
|
*err = e
|
||||||
|
} else {
|
||||||
|
panic(r)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Marshal serializes the in value, which may be a map or a struct value.
|
||||||
|
// In the case of struct values, only exported fields will be serialized,
|
||||||
|
// and the order of serialized fields will match that of the struct itself.
|
||||||
|
// The lowercased field name is used as the key for each exported field,
|
||||||
|
// but this behavior may be changed using the respective field tag.
|
||||||
|
// The tag may also contain flags to tweak the marshalling behavior for
|
||||||
|
// the field. The tag formats accepted are:
|
||||||
|
//
|
||||||
|
// "[<key>][,<flag1>[,<flag2>]]"
|
||||||
|
//
|
||||||
|
// `(...) bson:"[<key>][,<flag1>[,<flag2>]]" (...)`
|
||||||
|
//
|
||||||
|
// The following flags are currently supported:
|
||||||
|
//
|
||||||
|
// omitempty Only include the field if it's not set to the zero
|
||||||
|
// value for the type or to empty slices or maps.
|
||||||
|
//
|
||||||
|
// minsize Marshal an int64 value as an int32, if that's feasible
|
||||||
|
// while preserving the numeric value.
|
||||||
|
//
|
||||||
|
// inline Inline the field, which must be a struct or a map,
|
||||||
|
// causing all of its fields or keys to be processed as if
|
||||||
|
// they were part of the outer struct. For maps, keys must
|
||||||
|
// not conflict with the bson keys of other struct fields.
|
||||||
|
//
|
||||||
|
// Some examples:
|
||||||
|
//
|
||||||
|
// type T struct {
|
||||||
|
// A bool
|
||||||
|
// B int "myb"
|
||||||
|
// C string "myc,omitempty"
|
||||||
|
// D string `bson:",omitempty" json:"jsonkey"`
|
||||||
|
// E int64 ",minsize"
|
||||||
|
// F int64 "myf,omitempty,minsize"
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
func Marshal(in interface{}) (out []byte, err error) {
|
||||||
|
defer handleErr(&err)
|
||||||
|
e := &encoder{make([]byte, 0, initialBufferSize)}
|
||||||
|
e.addDoc(reflect.ValueOf(in))
|
||||||
|
return e.out, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Unmarshal deserializes data from in into the out value. The out value
|
||||||
|
// must be a map, a pointer to a struct, or a pointer to a bson.D value.
|
||||||
|
// In the case of struct values, only exported fields will be deserialized.
|
||||||
|
// The lowercased field name is used as the key for each exported field,
|
||||||
|
// but this behavior may be changed using the respective field tag.
|
||||||
|
// The tag may also contain flags to tweak the marshalling behavior for
|
||||||
|
// the field. The tag formats accepted are:
|
||||||
|
//
|
||||||
|
// "[<key>][,<flag1>[,<flag2>]]"
|
||||||
|
//
|
||||||
|
// `(...) bson:"[<key>][,<flag1>[,<flag2>]]" (...)`
|
||||||
|
//
|
||||||
|
// The following flags are currently supported during unmarshal (see the
|
||||||
|
// Marshal method for other flags):
|
||||||
|
//
|
||||||
|
// inline Inline the field, which must be a struct or a map.
|
||||||
|
// Inlined structs are handled as if its fields were part
|
||||||
|
// of the outer struct. An inlined map causes keys that do
|
||||||
|
// not match any other struct field to be inserted in the
|
||||||
|
// map rather than being discarded as usual.
|
||||||
|
//
|
||||||
|
// The target field or element types of out may not necessarily match
|
||||||
|
// the BSON values of the provided data. The following conversions are
|
||||||
|
// made automatically:
|
||||||
|
//
|
||||||
|
// - Numeric types are converted if at least the integer part of the
|
||||||
|
// value would be preserved correctly
|
||||||
|
// - Bools are converted to numeric types as 1 or 0
|
||||||
|
// - Numeric types are converted to bools as true if not 0 or false otherwise
|
||||||
|
// - Binary and string BSON data is converted to a string, array or byte slice
|
||||||
|
//
|
||||||
|
// If the value would not fit the type and cannot be converted, it's
|
||||||
|
// silently skipped.
|
||||||
|
//
|
||||||
|
// Pointer values are initialized when necessary.
|
||||||
|
func Unmarshal(in []byte, out interface{}) (err error) {
|
||||||
|
if raw, ok := out.(*Raw); ok {
|
||||||
|
raw.Kind = 3
|
||||||
|
raw.Data = in
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
defer handleErr(&err)
|
||||||
|
v := reflect.ValueOf(out)
|
||||||
|
switch v.Kind() {
|
||||||
|
case reflect.Ptr:
|
||||||
|
fallthrough
|
||||||
|
case reflect.Map:
|
||||||
|
d := newDecoder(in)
|
||||||
|
d.readDocTo(v)
|
||||||
|
case reflect.Struct:
|
||||||
|
return errors.New("Unmarshal can't deal with struct values. Use a pointer.")
|
||||||
|
default:
|
||||||
|
return errors.New("Unmarshal needs a map or a pointer to a struct.")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Unmarshal deserializes raw into the out value. If the out value type
|
||||||
|
// is not compatible with raw, a *bson.TypeError is returned.
|
||||||
|
//
|
||||||
|
// See the Unmarshal function documentation for more details on the
|
||||||
|
// unmarshalling process.
|
||||||
|
func (raw Raw) Unmarshal(out interface{}) (err error) {
|
||||||
|
defer handleErr(&err)
|
||||||
|
v := reflect.ValueOf(out)
|
||||||
|
switch v.Kind() {
|
||||||
|
case reflect.Ptr:
|
||||||
|
v = v.Elem()
|
||||||
|
fallthrough
|
||||||
|
case reflect.Map:
|
||||||
|
d := newDecoder(raw.Data)
|
||||||
|
good := d.readElemTo(v, raw.Kind)
|
||||||
|
if !good {
|
||||||
|
return &TypeError{v.Type(), raw.Kind}
|
||||||
|
}
|
||||||
|
case reflect.Struct:
|
||||||
|
return errors.New("Raw Unmarshal can't deal with struct values. Use a pointer.")
|
||||||
|
default:
|
||||||
|
return errors.New("Raw Unmarshal needs a map or a valid pointer.")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type TypeError struct {
|
||||||
|
Type reflect.Type
|
||||||
|
Kind byte
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *TypeError) Error() string {
|
||||||
|
return fmt.Sprintf("BSON kind 0x%02x isn't compatible with type %s", e.Kind, e.Type.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
// --------------------------------------------------------------------------
|
||||||
|
// Maintain a mapping of keys to structure field indexes
|
||||||
|
|
||||||
|
type structInfo struct {
|
||||||
|
FieldsMap map[string]fieldInfo
|
||||||
|
FieldsList []fieldInfo
|
||||||
|
InlineMap int
|
||||||
|
Zero reflect.Value
|
||||||
|
}
|
||||||
|
|
||||||
|
type fieldInfo struct {
|
||||||
|
Key string
|
||||||
|
Num int
|
||||||
|
OmitEmpty bool
|
||||||
|
MinSize bool
|
||||||
|
Inline []int
|
||||||
|
}
|
||||||
|
|
||||||
|
var structMap = make(map[reflect.Type]*structInfo)
|
||||||
|
var structMapMutex sync.RWMutex
|
||||||
|
|
||||||
|
type externalPanic string
|
||||||
|
|
||||||
|
func (e externalPanic) String() string {
|
||||||
|
return string(e)
|
||||||
|
}
|
||||||
|
|
||||||
|
func getStructInfo(st reflect.Type) (*structInfo, error) {
|
||||||
|
structMapMutex.RLock()
|
||||||
|
sinfo, found := structMap[st]
|
||||||
|
structMapMutex.RUnlock()
|
||||||
|
if found {
|
||||||
|
return sinfo, nil
|
||||||
|
}
|
||||||
|
n := st.NumField()
|
||||||
|
fieldsMap := make(map[string]fieldInfo)
|
||||||
|
fieldsList := make([]fieldInfo, 0, n)
|
||||||
|
inlineMap := -1
|
||||||
|
for i := 0; i != n; i++ {
|
||||||
|
field := st.Field(i)
|
||||||
|
if field.PkgPath != "" && !field.Anonymous {
|
||||||
|
continue // Private field
|
||||||
|
}
|
||||||
|
|
||||||
|
info := fieldInfo{Num: i}
|
||||||
|
|
||||||
|
tag := field.Tag.Get("bson")
|
||||||
|
if tag == "" && strings.Index(string(field.Tag), ":") < 0 {
|
||||||
|
tag = string(field.Tag)
|
||||||
|
}
|
||||||
|
if tag == "-" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
inline := false
|
||||||
|
fields := strings.Split(tag, ",")
|
||||||
|
if len(fields) > 1 {
|
||||||
|
for _, flag := range fields[1:] {
|
||||||
|
switch flag {
|
||||||
|
case "omitempty":
|
||||||
|
info.OmitEmpty = true
|
||||||
|
case "minsize":
|
||||||
|
info.MinSize = true
|
||||||
|
case "inline":
|
||||||
|
inline = true
|
||||||
|
default:
|
||||||
|
msg := fmt.Sprintf("Unsupported flag %q in tag %q of type %s", flag, tag, st)
|
||||||
|
panic(externalPanic(msg))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
tag = fields[0]
|
||||||
|
}
|
||||||
|
|
||||||
|
if inline {
|
||||||
|
switch field.Type.Kind() {
|
||||||
|
case reflect.Map:
|
||||||
|
if inlineMap >= 0 {
|
||||||
|
return nil, errors.New("Multiple ,inline maps in struct " + st.String())
|
||||||
|
}
|
||||||
|
if field.Type.Key() != reflect.TypeOf("") {
|
||||||
|
return nil, errors.New("Option ,inline needs a map with string keys in struct " + st.String())
|
||||||
|
}
|
||||||
|
inlineMap = info.Num
|
||||||
|
case reflect.Struct:
|
||||||
|
sinfo, err := getStructInfo(field.Type)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
for _, finfo := range sinfo.FieldsList {
|
||||||
|
if _, found := fieldsMap[finfo.Key]; found {
|
||||||
|
msg := "Duplicated key '" + finfo.Key + "' in struct " + st.String()
|
||||||
|
return nil, errors.New(msg)
|
||||||
|
}
|
||||||
|
if finfo.Inline == nil {
|
||||||
|
finfo.Inline = []int{i, finfo.Num}
|
||||||
|
} else {
|
||||||
|
finfo.Inline = append([]int{i}, finfo.Inline...)
|
||||||
|
}
|
||||||
|
fieldsMap[finfo.Key] = finfo
|
||||||
|
fieldsList = append(fieldsList, finfo)
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
panic("Option ,inline needs a struct value or map field")
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if tag != "" {
|
||||||
|
info.Key = tag
|
||||||
|
} else {
|
||||||
|
info.Key = strings.ToLower(field.Name)
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, found = fieldsMap[info.Key]; found {
|
||||||
|
msg := "Duplicated key '" + info.Key + "' in struct " + st.String()
|
||||||
|
return nil, errors.New(msg)
|
||||||
|
}
|
||||||
|
|
||||||
|
fieldsList = append(fieldsList, info)
|
||||||
|
fieldsMap[info.Key] = info
|
||||||
|
}
|
||||||
|
sinfo = &structInfo{
|
||||||
|
fieldsMap,
|
||||||
|
fieldsList,
|
||||||
|
inlineMap,
|
||||||
|
reflect.New(st).Elem(),
|
||||||
|
}
|
||||||
|
structMapMutex.Lock()
|
||||||
|
structMap[st] = sinfo
|
||||||
|
structMapMutex.Unlock()
|
||||||
|
return sinfo, nil
|
||||||
|
}
|
|
@ -0,0 +1,310 @@
|
||||||
|
// BSON library for Go
|
||||||
|
//
|
||||||
|
// Copyright (c) 2010-2012 - Gustavo Niemeyer <gustavo@niemeyer.net>
|
||||||
|
//
|
||||||
|
// All rights reserved.
|
||||||
|
//
|
||||||
|
// Redistribution and use in source and binary forms, with or without
|
||||||
|
// modification, are permitted provided that the following conditions are met:
|
||||||
|
//
|
||||||
|
// 1. Redistributions of source code must retain the above copyright notice, this
|
||||||
|
// list of conditions and the following disclaimer.
|
||||||
|
// 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||||
|
// this list of conditions and the following disclaimer in the documentation
|
||||||
|
// and/or other materials provided with the distribution.
|
||||||
|
//
|
||||||
|
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
|
||||||
|
// ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
|
||||||
|
// WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||||
|
// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
|
||||||
|
// ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
|
||||||
|
// (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
|
||||||
|
// LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
|
||||||
|
// ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||||
|
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||||
|
// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||||
|
|
||||||
|
package bson
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Decimal128 holds decimal128 BSON values.
|
||||||
|
type Decimal128 struct {
|
||||||
|
h, l uint64
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d Decimal128) String() string {
|
||||||
|
var pos int // positive sign
|
||||||
|
var e int // exponent
|
||||||
|
var h, l uint64 // significand high/low
|
||||||
|
|
||||||
|
if d.h>>63&1 == 0 {
|
||||||
|
pos = 1
|
||||||
|
}
|
||||||
|
|
||||||
|
switch d.h >> 58 & (1<<5 - 1) {
|
||||||
|
case 0x1F:
|
||||||
|
return "NaN"
|
||||||
|
case 0x1E:
|
||||||
|
return "-Inf"[pos:]
|
||||||
|
}
|
||||||
|
|
||||||
|
l = d.l
|
||||||
|
if d.h>>61&3 == 3 {
|
||||||
|
// Bits: 1*sign 2*ignored 14*exponent 111*significand.
|
||||||
|
// Implicit 0b100 prefix in significand.
|
||||||
|
e = int(d.h>>47&(1<<14-1)) - 6176
|
||||||
|
//h = 4<<47 | d.h&(1<<47-1)
|
||||||
|
// Spec says all of these values are out of range.
|
||||||
|
h, l = 0, 0
|
||||||
|
} else {
|
||||||
|
// Bits: 1*sign 14*exponent 113*significand
|
||||||
|
e = int(d.h>>49&(1<<14-1)) - 6176
|
||||||
|
h = d.h & (1<<49 - 1)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Would be handled by the logic below, but that's trivial and common.
|
||||||
|
if h == 0 && l == 0 && e == 0 {
|
||||||
|
return "-0"[pos:]
|
||||||
|
}
|
||||||
|
|
||||||
|
var repr [48]byte // Loop 5 times over 9 digits plus dot, negative sign, and leading zero.
|
||||||
|
var last = len(repr)
|
||||||
|
var i = len(repr)
|
||||||
|
var dot = len(repr) + e
|
||||||
|
var rem uint32
|
||||||
|
Loop:
|
||||||
|
for d9 := 0; d9 < 5; d9++ {
|
||||||
|
h, l, rem = divmod(h, l, 1e9)
|
||||||
|
for d1 := 0; d1 < 9; d1++ {
|
||||||
|
// Handle "-0.0", "0.00123400", "-1.00E-6", "1.050E+3", etc.
|
||||||
|
if i < len(repr) && (dot == i || l == 0 && h == 0 && rem > 0 && rem < 10 && (dot < i-6 || e > 0)) {
|
||||||
|
e += len(repr) - i
|
||||||
|
i--
|
||||||
|
repr[i] = '.'
|
||||||
|
last = i - 1
|
||||||
|
dot = len(repr) // Unmark.
|
||||||
|
}
|
||||||
|
c := '0' + byte(rem%10)
|
||||||
|
rem /= 10
|
||||||
|
i--
|
||||||
|
repr[i] = c
|
||||||
|
// Handle "0E+3", "1E+3", etc.
|
||||||
|
if l == 0 && h == 0 && rem == 0 && i == len(repr)-1 && (dot < i-5 || e > 0) {
|
||||||
|
last = i
|
||||||
|
break Loop
|
||||||
|
}
|
||||||
|
if c != '0' {
|
||||||
|
last = i
|
||||||
|
}
|
||||||
|
// Break early. Works without it, but why.
|
||||||
|
if dot > i && l == 0 && h == 0 && rem == 0 {
|
||||||
|
break Loop
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
repr[last-1] = '-'
|
||||||
|
last--
|
||||||
|
|
||||||
|
if e > 0 {
|
||||||
|
return string(repr[last+pos:]) + "E+" + strconv.Itoa(e)
|
||||||
|
}
|
||||||
|
if e < 0 {
|
||||||
|
return string(repr[last+pos:]) + "E" + strconv.Itoa(e)
|
||||||
|
}
|
||||||
|
return string(repr[last+pos:])
|
||||||
|
}
|
||||||
|
|
||||||
|
func divmod(h, l uint64, div uint32) (qh, ql uint64, rem uint32) {
|
||||||
|
div64 := uint64(div)
|
||||||
|
a := h >> 32
|
||||||
|
aq := a / div64
|
||||||
|
ar := a % div64
|
||||||
|
b := ar<<32 + h&(1<<32-1)
|
||||||
|
bq := b / div64
|
||||||
|
br := b % div64
|
||||||
|
c := br<<32 + l>>32
|
||||||
|
cq := c / div64
|
||||||
|
cr := c % div64
|
||||||
|
d := cr<<32 + l&(1<<32-1)
|
||||||
|
dq := d / div64
|
||||||
|
dr := d % div64
|
||||||
|
return (aq<<32 | bq), (cq<<32 | dq), uint32(dr)
|
||||||
|
}
|
||||||
|
|
||||||
|
var dNaN = Decimal128{0x1F << 58, 0}
|
||||||
|
var dPosInf = Decimal128{0x1E << 58, 0}
|
||||||
|
var dNegInf = Decimal128{0x3E << 58, 0}
|
||||||
|
|
||||||
|
func dErr(s string) (Decimal128, error) {
|
||||||
|
return dNaN, fmt.Errorf("cannot parse %q as a decimal128", s)
|
||||||
|
}
|
||||||
|
|
||||||
|
func ParseDecimal128(s string) (Decimal128, error) {
|
||||||
|
orig := s
|
||||||
|
if s == "" {
|
||||||
|
return dErr(orig)
|
||||||
|
}
|
||||||
|
neg := s[0] == '-'
|
||||||
|
if neg || s[0] == '+' {
|
||||||
|
s = s[1:]
|
||||||
|
}
|
||||||
|
|
||||||
|
if (len(s) == 3 || len(s) == 8) && (s[0] == 'N' || s[0] == 'n' || s[0] == 'I' || s[0] == 'i') {
|
||||||
|
if s == "NaN" || s == "nan" || strings.EqualFold(s, "nan") {
|
||||||
|
return dNaN, nil
|
||||||
|
}
|
||||||
|
if s == "Inf" || s == "inf" || strings.EqualFold(s, "inf") || strings.EqualFold(s, "infinity") {
|
||||||
|
if neg {
|
||||||
|
return dNegInf, nil
|
||||||
|
}
|
||||||
|
return dPosInf, nil
|
||||||
|
}
|
||||||
|
return dErr(orig)
|
||||||
|
}
|
||||||
|
|
||||||
|
var h, l uint64
|
||||||
|
var e int
|
||||||
|
|
||||||
|
var add, ovr uint32
|
||||||
|
var mul uint32 = 1
|
||||||
|
var dot = -1
|
||||||
|
var digits = 0
|
||||||
|
var i = 0
|
||||||
|
for i < len(s) {
|
||||||
|
c := s[i]
|
||||||
|
if mul == 1e9 {
|
||||||
|
h, l, ovr = muladd(h, l, mul, add)
|
||||||
|
mul, add = 1, 0
|
||||||
|
if ovr > 0 || h&((1<<15-1)<<49) > 0 {
|
||||||
|
return dErr(orig)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if c >= '0' && c <= '9' {
|
||||||
|
i++
|
||||||
|
if c > '0' || digits > 0 {
|
||||||
|
digits++
|
||||||
|
}
|
||||||
|
if digits > 34 {
|
||||||
|
if c == '0' {
|
||||||
|
// Exact rounding.
|
||||||
|
e++
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
return dErr(orig)
|
||||||
|
}
|
||||||
|
mul *= 10
|
||||||
|
add *= 10
|
||||||
|
add += uint32(c - '0')
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if c == '.' {
|
||||||
|
i++
|
||||||
|
if dot >= 0 || i == 1 && len(s) == 1 {
|
||||||
|
return dErr(orig)
|
||||||
|
}
|
||||||
|
if i == len(s) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if s[i] < '0' || s[i] > '9' || e > 0 {
|
||||||
|
return dErr(orig)
|
||||||
|
}
|
||||||
|
dot = i
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if i == 0 {
|
||||||
|
return dErr(orig)
|
||||||
|
}
|
||||||
|
if mul > 1 {
|
||||||
|
h, l, ovr = muladd(h, l, mul, add)
|
||||||
|
if ovr > 0 || h&((1<<15-1)<<49) > 0 {
|
||||||
|
return dErr(orig)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if dot >= 0 {
|
||||||
|
e += dot - i
|
||||||
|
}
|
||||||
|
if i+1 < len(s) && (s[i] == 'E' || s[i] == 'e') {
|
||||||
|
i++
|
||||||
|
eneg := s[i] == '-'
|
||||||
|
if eneg || s[i] == '+' {
|
||||||
|
i++
|
||||||
|
if i == len(s) {
|
||||||
|
return dErr(orig)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
n := 0
|
||||||
|
for i < len(s) && n < 1e4 {
|
||||||
|
c := s[i]
|
||||||
|
i++
|
||||||
|
if c < '0' || c > '9' {
|
||||||
|
return dErr(orig)
|
||||||
|
}
|
||||||
|
n *= 10
|
||||||
|
n += int(c - '0')
|
||||||
|
}
|
||||||
|
if eneg {
|
||||||
|
n = -n
|
||||||
|
}
|
||||||
|
e += n
|
||||||
|
for e < -6176 {
|
||||||
|
// Subnormal.
|
||||||
|
var div uint32 = 1
|
||||||
|
for div < 1e9 && e < -6176 {
|
||||||
|
div *= 10
|
||||||
|
e++
|
||||||
|
}
|
||||||
|
var rem uint32
|
||||||
|
h, l, rem = divmod(h, l, div)
|
||||||
|
if rem > 0 {
|
||||||
|
return dErr(orig)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for e > 6111 {
|
||||||
|
// Clamped.
|
||||||
|
var mul uint32 = 1
|
||||||
|
for mul < 1e9 && e > 6111 {
|
||||||
|
mul *= 10
|
||||||
|
e--
|
||||||
|
}
|
||||||
|
h, l, ovr = muladd(h, l, mul, 0)
|
||||||
|
if ovr > 0 || h&((1<<15-1)<<49) > 0 {
|
||||||
|
return dErr(orig)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if e < -6176 || e > 6111 {
|
||||||
|
return dErr(orig)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if i < len(s) {
|
||||||
|
return dErr(orig)
|
||||||
|
}
|
||||||
|
|
||||||
|
h |= uint64(e+6176) & uint64(1<<14-1) << 49
|
||||||
|
if neg {
|
||||||
|
h |= 1 << 63
|
||||||
|
}
|
||||||
|
return Decimal128{h, l}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func muladd(h, l uint64, mul uint32, add uint32) (resh, resl uint64, overflow uint32) {
|
||||||
|
mul64 := uint64(mul)
|
||||||
|
a := mul64 * (l & (1<<32 - 1))
|
||||||
|
b := a>>32 + mul64*(l>>32)
|
||||||
|
c := b>>32 + mul64*(h&(1<<32-1))
|
||||||
|
d := c>>32 + mul64*(h>>32)
|
||||||
|
|
||||||
|
a = a&(1<<32-1) + uint64(add)
|
||||||
|
b = b&(1<<32-1) + a>>32
|
||||||
|
c = c&(1<<32-1) + b>>32
|
||||||
|
d = d&(1<<32-1) + c>>32
|
||||||
|
|
||||||
|
return (d<<32 | c&(1<<32-1)), (b<<32 | a&(1<<32-1)), uint32(d >> 32)
|
||||||
|
}
|
|
@ -0,0 +1,849 @@
|
||||||
|
// BSON library for Go
|
||||||
|
//
|
||||||
|
// Copyright (c) 2010-2012 - Gustavo Niemeyer <gustavo@niemeyer.net>
|
||||||
|
//
|
||||||
|
// All rights reserved.
|
||||||
|
//
|
||||||
|
// Redistribution and use in source and binary forms, with or without
|
||||||
|
// modification, are permitted provided that the following conditions are met:
|
||||||
|
//
|
||||||
|
// 1. Redistributions of source code must retain the above copyright notice, this
|
||||||
|
// list of conditions and the following disclaimer.
|
||||||
|
// 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||||
|
// this list of conditions and the following disclaimer in the documentation
|
||||||
|
// and/or other materials provided with the distribution.
|
||||||
|
//
|
||||||
|
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
|
||||||
|
// ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
|
||||||
|
// WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||||
|
// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
|
||||||
|
// ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
|
||||||
|
// (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
|
||||||
|
// LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
|
||||||
|
// ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||||
|
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||||
|
// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||||
|
// gobson - BSON library for Go.
|
||||||
|
|
||||||
|
package bson
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"math"
|
||||||
|
"net/url"
|
||||||
|
"reflect"
|
||||||
|
"strconv"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
type decoder struct {
|
||||||
|
in []byte
|
||||||
|
i int
|
||||||
|
docType reflect.Type
|
||||||
|
}
|
||||||
|
|
||||||
|
var typeM = reflect.TypeOf(M{})
|
||||||
|
|
||||||
|
func newDecoder(in []byte) *decoder {
|
||||||
|
return &decoder{in, 0, typeM}
|
||||||
|
}
|
||||||
|
|
||||||
|
// --------------------------------------------------------------------------
|
||||||
|
// Some helper functions.
|
||||||
|
|
||||||
|
func corrupted() {
|
||||||
|
panic("Document is corrupted")
|
||||||
|
}
|
||||||
|
|
||||||
|
func settableValueOf(i interface{}) reflect.Value {
|
||||||
|
v := reflect.ValueOf(i)
|
||||||
|
sv := reflect.New(v.Type()).Elem()
|
||||||
|
sv.Set(v)
|
||||||
|
return sv
|
||||||
|
}
|
||||||
|
|
||||||
|
// --------------------------------------------------------------------------
|
||||||
|
// Unmarshaling of documents.
|
||||||
|
|
||||||
|
const (
|
||||||
|
setterUnknown = iota
|
||||||
|
setterNone
|
||||||
|
setterType
|
||||||
|
setterAddr
|
||||||
|
)
|
||||||
|
|
||||||
|
var setterStyles map[reflect.Type]int
|
||||||
|
var setterIface reflect.Type
|
||||||
|
var setterMutex sync.RWMutex
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
var iface Setter
|
||||||
|
setterIface = reflect.TypeOf(&iface).Elem()
|
||||||
|
setterStyles = make(map[reflect.Type]int)
|
||||||
|
}
|
||||||
|
|
||||||
|
func setterStyle(outt reflect.Type) int {
|
||||||
|
setterMutex.RLock()
|
||||||
|
style := setterStyles[outt]
|
||||||
|
setterMutex.RUnlock()
|
||||||
|
if style == setterUnknown {
|
||||||
|
setterMutex.Lock()
|
||||||
|
defer setterMutex.Unlock()
|
||||||
|
if outt.Implements(setterIface) {
|
||||||
|
setterStyles[outt] = setterType
|
||||||
|
} else if reflect.PtrTo(outt).Implements(setterIface) {
|
||||||
|
setterStyles[outt] = setterAddr
|
||||||
|
} else {
|
||||||
|
setterStyles[outt] = setterNone
|
||||||
|
}
|
||||||
|
style = setterStyles[outt]
|
||||||
|
}
|
||||||
|
return style
|
||||||
|
}
|
||||||
|
|
||||||
|
func getSetter(outt reflect.Type, out reflect.Value) Setter {
|
||||||
|
style := setterStyle(outt)
|
||||||
|
if style == setterNone {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if style == setterAddr {
|
||||||
|
if !out.CanAddr() {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
out = out.Addr()
|
||||||
|
} else if outt.Kind() == reflect.Ptr && out.IsNil() {
|
||||||
|
out.Set(reflect.New(outt.Elem()))
|
||||||
|
}
|
||||||
|
return out.Interface().(Setter)
|
||||||
|
}
|
||||||
|
|
||||||
|
func clearMap(m reflect.Value) {
|
||||||
|
var none reflect.Value
|
||||||
|
for _, k := range m.MapKeys() {
|
||||||
|
m.SetMapIndex(k, none)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *decoder) readDocTo(out reflect.Value) {
|
||||||
|
var elemType reflect.Type
|
||||||
|
outt := out.Type()
|
||||||
|
outk := outt.Kind()
|
||||||
|
|
||||||
|
for {
|
||||||
|
if outk == reflect.Ptr && out.IsNil() {
|
||||||
|
out.Set(reflect.New(outt.Elem()))
|
||||||
|
}
|
||||||
|
if setter := getSetter(outt, out); setter != nil {
|
||||||
|
var raw Raw
|
||||||
|
d.readDocTo(reflect.ValueOf(&raw))
|
||||||
|
err := setter.SetBSON(raw)
|
||||||
|
if _, ok := err.(*TypeError); err != nil && !ok {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if outk == reflect.Ptr {
|
||||||
|
out = out.Elem()
|
||||||
|
outt = out.Type()
|
||||||
|
outk = out.Kind()
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
var fieldsMap map[string]fieldInfo
|
||||||
|
var inlineMap reflect.Value
|
||||||
|
start := d.i
|
||||||
|
|
||||||
|
origout := out
|
||||||
|
if outk == reflect.Interface {
|
||||||
|
if d.docType.Kind() == reflect.Map {
|
||||||
|
mv := reflect.MakeMap(d.docType)
|
||||||
|
out.Set(mv)
|
||||||
|
out = mv
|
||||||
|
} else {
|
||||||
|
dv := reflect.New(d.docType).Elem()
|
||||||
|
out.Set(dv)
|
||||||
|
out = dv
|
||||||
|
}
|
||||||
|
outt = out.Type()
|
||||||
|
outk = outt.Kind()
|
||||||
|
}
|
||||||
|
|
||||||
|
docType := d.docType
|
||||||
|
keyType := typeString
|
||||||
|
convertKey := false
|
||||||
|
switch outk {
|
||||||
|
case reflect.Map:
|
||||||
|
keyType = outt.Key()
|
||||||
|
if keyType.Kind() != reflect.String {
|
||||||
|
panic("BSON map must have string keys. Got: " + outt.String())
|
||||||
|
}
|
||||||
|
if keyType != typeString {
|
||||||
|
convertKey = true
|
||||||
|
}
|
||||||
|
elemType = outt.Elem()
|
||||||
|
if elemType == typeIface {
|
||||||
|
d.docType = outt
|
||||||
|
}
|
||||||
|
if out.IsNil() {
|
||||||
|
out.Set(reflect.MakeMap(out.Type()))
|
||||||
|
} else if out.Len() > 0 {
|
||||||
|
clearMap(out)
|
||||||
|
}
|
||||||
|
case reflect.Struct:
|
||||||
|
if outt != typeRaw {
|
||||||
|
sinfo, err := getStructInfo(out.Type())
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
fieldsMap = sinfo.FieldsMap
|
||||||
|
out.Set(sinfo.Zero)
|
||||||
|
if sinfo.InlineMap != -1 {
|
||||||
|
inlineMap = out.Field(sinfo.InlineMap)
|
||||||
|
if !inlineMap.IsNil() && inlineMap.Len() > 0 {
|
||||||
|
clearMap(inlineMap)
|
||||||
|
}
|
||||||
|
elemType = inlineMap.Type().Elem()
|
||||||
|
if elemType == typeIface {
|
||||||
|
d.docType = inlineMap.Type()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case reflect.Slice:
|
||||||
|
switch outt.Elem() {
|
||||||
|
case typeDocElem:
|
||||||
|
origout.Set(d.readDocElems(outt))
|
||||||
|
return
|
||||||
|
case typeRawDocElem:
|
||||||
|
origout.Set(d.readRawDocElems(outt))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
fallthrough
|
||||||
|
default:
|
||||||
|
panic("Unsupported document type for unmarshalling: " + out.Type().String())
|
||||||
|
}
|
||||||
|
|
||||||
|
end := int(d.readInt32())
|
||||||
|
end += d.i - 4
|
||||||
|
if end <= d.i || end > len(d.in) || d.in[end-1] != '\x00' {
|
||||||
|
corrupted()
|
||||||
|
}
|
||||||
|
for d.in[d.i] != '\x00' {
|
||||||
|
kind := d.readByte()
|
||||||
|
name := d.readCStr()
|
||||||
|
if d.i >= end {
|
||||||
|
corrupted()
|
||||||
|
}
|
||||||
|
|
||||||
|
switch outk {
|
||||||
|
case reflect.Map:
|
||||||
|
e := reflect.New(elemType).Elem()
|
||||||
|
if d.readElemTo(e, kind) {
|
||||||
|
k := reflect.ValueOf(name)
|
||||||
|
if convertKey {
|
||||||
|
k = k.Convert(keyType)
|
||||||
|
}
|
||||||
|
out.SetMapIndex(k, e)
|
||||||
|
}
|
||||||
|
case reflect.Struct:
|
||||||
|
if outt == typeRaw {
|
||||||
|
d.dropElem(kind)
|
||||||
|
} else {
|
||||||
|
if info, ok := fieldsMap[name]; ok {
|
||||||
|
if info.Inline == nil {
|
||||||
|
d.readElemTo(out.Field(info.Num), kind)
|
||||||
|
} else {
|
||||||
|
d.readElemTo(out.FieldByIndex(info.Inline), kind)
|
||||||
|
}
|
||||||
|
} else if inlineMap.IsValid() {
|
||||||
|
if inlineMap.IsNil() {
|
||||||
|
inlineMap.Set(reflect.MakeMap(inlineMap.Type()))
|
||||||
|
}
|
||||||
|
e := reflect.New(elemType).Elem()
|
||||||
|
if d.readElemTo(e, kind) {
|
||||||
|
inlineMap.SetMapIndex(reflect.ValueOf(name), e)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
d.dropElem(kind)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case reflect.Slice:
|
||||||
|
}
|
||||||
|
|
||||||
|
if d.i >= end {
|
||||||
|
corrupted()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
d.i++ // '\x00'
|
||||||
|
if d.i != end {
|
||||||
|
corrupted()
|
||||||
|
}
|
||||||
|
d.docType = docType
|
||||||
|
|
||||||
|
if outt == typeRaw {
|
||||||
|
out.Set(reflect.ValueOf(Raw{0x03, d.in[start:d.i]}))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *decoder) readArrayDocTo(out reflect.Value) {
|
||||||
|
end := int(d.readInt32())
|
||||||
|
end += d.i - 4
|
||||||
|
if end <= d.i || end > len(d.in) || d.in[end-1] != '\x00' {
|
||||||
|
corrupted()
|
||||||
|
}
|
||||||
|
i := 0
|
||||||
|
l := out.Len()
|
||||||
|
for d.in[d.i] != '\x00' {
|
||||||
|
if i >= l {
|
||||||
|
panic("Length mismatch on array field")
|
||||||
|
}
|
||||||
|
kind := d.readByte()
|
||||||
|
for d.i < end && d.in[d.i] != '\x00' {
|
||||||
|
d.i++
|
||||||
|
}
|
||||||
|
if d.i >= end {
|
||||||
|
corrupted()
|
||||||
|
}
|
||||||
|
d.i++
|
||||||
|
d.readElemTo(out.Index(i), kind)
|
||||||
|
if d.i >= end {
|
||||||
|
corrupted()
|
||||||
|
}
|
||||||
|
i++
|
||||||
|
}
|
||||||
|
if i != l {
|
||||||
|
panic("Length mismatch on array field")
|
||||||
|
}
|
||||||
|
d.i++ // '\x00'
|
||||||
|
if d.i != end {
|
||||||
|
corrupted()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *decoder) readSliceDoc(t reflect.Type) interface{} {
|
||||||
|
tmp := make([]reflect.Value, 0, 8)
|
||||||
|
elemType := t.Elem()
|
||||||
|
if elemType == typeRawDocElem {
|
||||||
|
d.dropElem(0x04)
|
||||||
|
return reflect.Zero(t).Interface()
|
||||||
|
}
|
||||||
|
|
||||||
|
end := int(d.readInt32())
|
||||||
|
end += d.i - 4
|
||||||
|
if end <= d.i || end > len(d.in) || d.in[end-1] != '\x00' {
|
||||||
|
corrupted()
|
||||||
|
}
|
||||||
|
for d.in[d.i] != '\x00' {
|
||||||
|
kind := d.readByte()
|
||||||
|
for d.i < end && d.in[d.i] != '\x00' {
|
||||||
|
d.i++
|
||||||
|
}
|
||||||
|
if d.i >= end {
|
||||||
|
corrupted()
|
||||||
|
}
|
||||||
|
d.i++
|
||||||
|
e := reflect.New(elemType).Elem()
|
||||||
|
if d.readElemTo(e, kind) {
|
||||||
|
tmp = append(tmp, e)
|
||||||
|
}
|
||||||
|
if d.i >= end {
|
||||||
|
corrupted()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
d.i++ // '\x00'
|
||||||
|
if d.i != end {
|
||||||
|
corrupted()
|
||||||
|
}
|
||||||
|
|
||||||
|
n := len(tmp)
|
||||||
|
slice := reflect.MakeSlice(t, n, n)
|
||||||
|
for i := 0; i != n; i++ {
|
||||||
|
slice.Index(i).Set(tmp[i])
|
||||||
|
}
|
||||||
|
return slice.Interface()
|
||||||
|
}
|
||||||
|
|
||||||
|
var typeSlice = reflect.TypeOf([]interface{}{})
|
||||||
|
var typeIface = typeSlice.Elem()
|
||||||
|
|
||||||
|
func (d *decoder) readDocElems(typ reflect.Type) reflect.Value {
|
||||||
|
docType := d.docType
|
||||||
|
d.docType = typ
|
||||||
|
slice := make([]DocElem, 0, 8)
|
||||||
|
d.readDocWith(func(kind byte, name string) {
|
||||||
|
e := DocElem{Name: name}
|
||||||
|
v := reflect.ValueOf(&e.Value)
|
||||||
|
if d.readElemTo(v.Elem(), kind) {
|
||||||
|
slice = append(slice, e)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
slicev := reflect.New(typ).Elem()
|
||||||
|
slicev.Set(reflect.ValueOf(slice))
|
||||||
|
d.docType = docType
|
||||||
|
return slicev
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *decoder) readRawDocElems(typ reflect.Type) reflect.Value {
|
||||||
|
docType := d.docType
|
||||||
|
d.docType = typ
|
||||||
|
slice := make([]RawDocElem, 0, 8)
|
||||||
|
d.readDocWith(func(kind byte, name string) {
|
||||||
|
e := RawDocElem{Name: name}
|
||||||
|
v := reflect.ValueOf(&e.Value)
|
||||||
|
if d.readElemTo(v.Elem(), kind) {
|
||||||
|
slice = append(slice, e)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
slicev := reflect.New(typ).Elem()
|
||||||
|
slicev.Set(reflect.ValueOf(slice))
|
||||||
|
d.docType = docType
|
||||||
|
return slicev
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *decoder) readDocWith(f func(kind byte, name string)) {
|
||||||
|
end := int(d.readInt32())
|
||||||
|
end += d.i - 4
|
||||||
|
if end <= d.i || end > len(d.in) || d.in[end-1] != '\x00' {
|
||||||
|
corrupted()
|
||||||
|
}
|
||||||
|
for d.in[d.i] != '\x00' {
|
||||||
|
kind := d.readByte()
|
||||||
|
name := d.readCStr()
|
||||||
|
if d.i >= end {
|
||||||
|
corrupted()
|
||||||
|
}
|
||||||
|
f(kind, name)
|
||||||
|
if d.i >= end {
|
||||||
|
corrupted()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
d.i++ // '\x00'
|
||||||
|
if d.i != end {
|
||||||
|
corrupted()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// --------------------------------------------------------------------------
|
||||||
|
// Unmarshaling of individual elements within a document.
|
||||||
|
|
||||||
|
var blackHole = settableValueOf(struct{}{})
|
||||||
|
|
||||||
|
func (d *decoder) dropElem(kind byte) {
|
||||||
|
d.readElemTo(blackHole, kind)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Attempt to decode an element from the document and put it into out.
|
||||||
|
// If the types are not compatible, the returned ok value will be
|
||||||
|
// false and out will be unchanged.
|
||||||
|
func (d *decoder) readElemTo(out reflect.Value, kind byte) (good bool) {
|
||||||
|
|
||||||
|
start := d.i
|
||||||
|
|
||||||
|
if kind == 0x03 {
|
||||||
|
// Delegate unmarshaling of documents.
|
||||||
|
outt := out.Type()
|
||||||
|
outk := out.Kind()
|
||||||
|
switch outk {
|
||||||
|
case reflect.Interface, reflect.Ptr, reflect.Struct, reflect.Map:
|
||||||
|
d.readDocTo(out)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if setterStyle(outt) != setterNone {
|
||||||
|
d.readDocTo(out)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if outk == reflect.Slice {
|
||||||
|
switch outt.Elem() {
|
||||||
|
case typeDocElem:
|
||||||
|
out.Set(d.readDocElems(outt))
|
||||||
|
case typeRawDocElem:
|
||||||
|
out.Set(d.readRawDocElems(outt))
|
||||||
|
default:
|
||||||
|
d.readDocTo(blackHole)
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
d.readDocTo(blackHole)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
var in interface{}
|
||||||
|
|
||||||
|
switch kind {
|
||||||
|
case 0x01: // Float64
|
||||||
|
in = d.readFloat64()
|
||||||
|
case 0x02: // UTF-8 string
|
||||||
|
in = d.readStr()
|
||||||
|
case 0x03: // Document
|
||||||
|
panic("Can't happen. Handled above.")
|
||||||
|
case 0x04: // Array
|
||||||
|
outt := out.Type()
|
||||||
|
if setterStyle(outt) != setterNone {
|
||||||
|
// Skip the value so its data is handed to the setter below.
|
||||||
|
d.dropElem(kind)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
for outt.Kind() == reflect.Ptr {
|
||||||
|
outt = outt.Elem()
|
||||||
|
}
|
||||||
|
switch outt.Kind() {
|
||||||
|
case reflect.Array:
|
||||||
|
d.readArrayDocTo(out)
|
||||||
|
return true
|
||||||
|
case reflect.Slice:
|
||||||
|
in = d.readSliceDoc(outt)
|
||||||
|
default:
|
||||||
|
in = d.readSliceDoc(typeSlice)
|
||||||
|
}
|
||||||
|
case 0x05: // Binary
|
||||||
|
b := d.readBinary()
|
||||||
|
if b.Kind == 0x00 || b.Kind == 0x02 {
|
||||||
|
in = b.Data
|
||||||
|
} else {
|
||||||
|
in = b
|
||||||
|
}
|
||||||
|
case 0x06: // Undefined (obsolete, but still seen in the wild)
|
||||||
|
in = Undefined
|
||||||
|
case 0x07: // ObjectId
|
||||||
|
in = ObjectId(d.readBytes(12))
|
||||||
|
case 0x08: // Bool
|
||||||
|
in = d.readBool()
|
||||||
|
case 0x09: // Timestamp
|
||||||
|
// MongoDB handles timestamps as milliseconds.
|
||||||
|
i := d.readInt64()
|
||||||
|
if i == -62135596800000 {
|
||||||
|
in = time.Time{} // In UTC for convenience.
|
||||||
|
} else {
|
||||||
|
in = time.Unix(i/1e3, i%1e3*1e6)
|
||||||
|
}
|
||||||
|
case 0x0A: // Nil
|
||||||
|
in = nil
|
||||||
|
case 0x0B: // RegEx
|
||||||
|
in = d.readRegEx()
|
||||||
|
case 0x0C:
|
||||||
|
in = DBPointer{Namespace: d.readStr(), Id: ObjectId(d.readBytes(12))}
|
||||||
|
case 0x0D: // JavaScript without scope
|
||||||
|
in = JavaScript{Code: d.readStr()}
|
||||||
|
case 0x0E: // Symbol
|
||||||
|
in = Symbol(d.readStr())
|
||||||
|
case 0x0F: // JavaScript with scope
|
||||||
|
d.i += 4 // Skip length
|
||||||
|
js := JavaScript{d.readStr(), make(M)}
|
||||||
|
d.readDocTo(reflect.ValueOf(js.Scope))
|
||||||
|
in = js
|
||||||
|
case 0x10: // Int32
|
||||||
|
in = int(d.readInt32())
|
||||||
|
case 0x11: // Mongo-specific timestamp
|
||||||
|
in = MongoTimestamp(d.readInt64())
|
||||||
|
case 0x12: // Int64
|
||||||
|
in = d.readInt64()
|
||||||
|
case 0x13: // Decimal128
|
||||||
|
in = Decimal128{
|
||||||
|
l: uint64(d.readInt64()),
|
||||||
|
h: uint64(d.readInt64()),
|
||||||
|
}
|
||||||
|
case 0x7F: // Max key
|
||||||
|
in = MaxKey
|
||||||
|
case 0xFF: // Min key
|
||||||
|
in = MinKey
|
||||||
|
default:
|
||||||
|
panic(fmt.Sprintf("Unknown element kind (0x%02X)", kind))
|
||||||
|
}
|
||||||
|
|
||||||
|
outt := out.Type()
|
||||||
|
|
||||||
|
if outt == typeRaw {
|
||||||
|
out.Set(reflect.ValueOf(Raw{kind, d.in[start:d.i]}))
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
if setter := getSetter(outt, out); setter != nil {
|
||||||
|
err := setter.SetBSON(Raw{kind, d.in[start:d.i]})
|
||||||
|
if err == SetZero {
|
||||||
|
out.Set(reflect.Zero(outt))
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if err == nil {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if _, ok := err.(*TypeError); !ok {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
if in == nil {
|
||||||
|
out.Set(reflect.Zero(outt))
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
outk := outt.Kind()
|
||||||
|
|
||||||
|
// Dereference and initialize pointer if necessary.
|
||||||
|
first := true
|
||||||
|
for outk == reflect.Ptr {
|
||||||
|
if !out.IsNil() {
|
||||||
|
out = out.Elem()
|
||||||
|
} else {
|
||||||
|
elem := reflect.New(outt.Elem())
|
||||||
|
if first {
|
||||||
|
// Only set if value is compatible.
|
||||||
|
first = false
|
||||||
|
defer func(out, elem reflect.Value) {
|
||||||
|
if good {
|
||||||
|
out.Set(elem)
|
||||||
|
}
|
||||||
|
}(out, elem)
|
||||||
|
} else {
|
||||||
|
out.Set(elem)
|
||||||
|
}
|
||||||
|
out = elem
|
||||||
|
}
|
||||||
|
outt = out.Type()
|
||||||
|
outk = outt.Kind()
|
||||||
|
}
|
||||||
|
|
||||||
|
inv := reflect.ValueOf(in)
|
||||||
|
if outt == inv.Type() {
|
||||||
|
out.Set(inv)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
switch outk {
|
||||||
|
case reflect.Interface:
|
||||||
|
out.Set(inv)
|
||||||
|
return true
|
||||||
|
case reflect.String:
|
||||||
|
switch inv.Kind() {
|
||||||
|
case reflect.String:
|
||||||
|
out.SetString(inv.String())
|
||||||
|
return true
|
||||||
|
case reflect.Slice:
|
||||||
|
if b, ok := in.([]byte); ok {
|
||||||
|
out.SetString(string(b))
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
case reflect.Int, reflect.Int64:
|
||||||
|
if outt == typeJSONNumber {
|
||||||
|
out.SetString(strconv.FormatInt(inv.Int(), 10))
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
case reflect.Float64:
|
||||||
|
if outt == typeJSONNumber {
|
||||||
|
out.SetString(strconv.FormatFloat(inv.Float(), 'f', -1, 64))
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case reflect.Slice, reflect.Array:
|
||||||
|
// Remember, array (0x04) slices are built with the correct
|
||||||
|
// element type. If we are here, must be a cross BSON kind
|
||||||
|
// conversion (e.g. 0x05 unmarshalling on string).
|
||||||
|
if outt.Elem().Kind() != reflect.Uint8 {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
switch inv.Kind() {
|
||||||
|
case reflect.String:
|
||||||
|
slice := []byte(inv.String())
|
||||||
|
out.Set(reflect.ValueOf(slice))
|
||||||
|
return true
|
||||||
|
case reflect.Slice:
|
||||||
|
switch outt.Kind() {
|
||||||
|
case reflect.Array:
|
||||||
|
reflect.Copy(out, inv)
|
||||||
|
case reflect.Slice:
|
||||||
|
out.SetBytes(inv.Bytes())
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
||||||
|
switch inv.Kind() {
|
||||||
|
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
||||||
|
out.SetInt(inv.Int())
|
||||||
|
return true
|
||||||
|
case reflect.Float32, reflect.Float64:
|
||||||
|
out.SetInt(int64(inv.Float()))
|
||||||
|
return true
|
||||||
|
case reflect.Bool:
|
||||||
|
if inv.Bool() {
|
||||||
|
out.SetInt(1)
|
||||||
|
} else {
|
||||||
|
out.SetInt(0)
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
|
||||||
|
panic("can't happen: no uint types in BSON (!?)")
|
||||||
|
}
|
||||||
|
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
|
||||||
|
switch inv.Kind() {
|
||||||
|
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
||||||
|
out.SetUint(uint64(inv.Int()))
|
||||||
|
return true
|
||||||
|
case reflect.Float32, reflect.Float64:
|
||||||
|
out.SetUint(uint64(inv.Float()))
|
||||||
|
return true
|
||||||
|
case reflect.Bool:
|
||||||
|
if inv.Bool() {
|
||||||
|
out.SetUint(1)
|
||||||
|
} else {
|
||||||
|
out.SetUint(0)
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
|
||||||
|
panic("Can't happen. No uint types in BSON.")
|
||||||
|
}
|
||||||
|
case reflect.Float32, reflect.Float64:
|
||||||
|
switch inv.Kind() {
|
||||||
|
case reflect.Float32, reflect.Float64:
|
||||||
|
out.SetFloat(inv.Float())
|
||||||
|
return true
|
||||||
|
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
||||||
|
out.SetFloat(float64(inv.Int()))
|
||||||
|
return true
|
||||||
|
case reflect.Bool:
|
||||||
|
if inv.Bool() {
|
||||||
|
out.SetFloat(1)
|
||||||
|
} else {
|
||||||
|
out.SetFloat(0)
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
|
||||||
|
panic("Can't happen. No uint types in BSON?")
|
||||||
|
}
|
||||||
|
case reflect.Bool:
|
||||||
|
switch inv.Kind() {
|
||||||
|
case reflect.Bool:
|
||||||
|
out.SetBool(inv.Bool())
|
||||||
|
return true
|
||||||
|
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
||||||
|
out.SetBool(inv.Int() != 0)
|
||||||
|
return true
|
||||||
|
case reflect.Float32, reflect.Float64:
|
||||||
|
out.SetBool(inv.Float() != 0)
|
||||||
|
return true
|
||||||
|
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
|
||||||
|
panic("Can't happen. No uint types in BSON?")
|
||||||
|
}
|
||||||
|
case reflect.Struct:
|
||||||
|
if outt == typeURL && inv.Kind() == reflect.String {
|
||||||
|
u, err := url.Parse(inv.String())
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
out.Set(reflect.ValueOf(u).Elem())
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if outt == typeBinary {
|
||||||
|
if b, ok := in.([]byte); ok {
|
||||||
|
out.Set(reflect.ValueOf(Binary{Data: b}))
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// --------------------------------------------------------------------------
|
||||||
|
// Parsers of basic types.
|
||||||
|
|
||||||
|
func (d *decoder) readRegEx() RegEx {
|
||||||
|
re := RegEx{}
|
||||||
|
re.Pattern = d.readCStr()
|
||||||
|
re.Options = d.readCStr()
|
||||||
|
return re
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *decoder) readBinary() Binary {
|
||||||
|
l := d.readInt32()
|
||||||
|
b := Binary{}
|
||||||
|
b.Kind = d.readByte()
|
||||||
|
b.Data = d.readBytes(l)
|
||||||
|
if b.Kind == 0x02 && len(b.Data) >= 4 {
|
||||||
|
// Weird obsolete format with redundant length.
|
||||||
|
b.Data = b.Data[4:]
|
||||||
|
}
|
||||||
|
return b
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *decoder) readStr() string {
|
||||||
|
l := d.readInt32()
|
||||||
|
b := d.readBytes(l - 1)
|
||||||
|
if d.readByte() != '\x00' {
|
||||||
|
corrupted()
|
||||||
|
}
|
||||||
|
return string(b)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *decoder) readCStr() string {
|
||||||
|
start := d.i
|
||||||
|
end := start
|
||||||
|
l := len(d.in)
|
||||||
|
for ; end != l; end++ {
|
||||||
|
if d.in[end] == '\x00' {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
d.i = end + 1
|
||||||
|
if d.i > l {
|
||||||
|
corrupted()
|
||||||
|
}
|
||||||
|
return string(d.in[start:end])
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *decoder) readBool() bool {
|
||||||
|
b := d.readByte()
|
||||||
|
if b == 0 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if b == 1 {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
panic(fmt.Sprintf("encoded boolean must be 1 or 0, found %d", b))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *decoder) readFloat64() float64 {
|
||||||
|
return math.Float64frombits(uint64(d.readInt64()))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *decoder) readInt32() int32 {
|
||||||
|
b := d.readBytes(4)
|
||||||
|
return int32((uint32(b[0]) << 0) |
|
||||||
|
(uint32(b[1]) << 8) |
|
||||||
|
(uint32(b[2]) << 16) |
|
||||||
|
(uint32(b[3]) << 24))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *decoder) readInt64() int64 {
|
||||||
|
b := d.readBytes(8)
|
||||||
|
return int64((uint64(b[0]) << 0) |
|
||||||
|
(uint64(b[1]) << 8) |
|
||||||
|
(uint64(b[2]) << 16) |
|
||||||
|
(uint64(b[3]) << 24) |
|
||||||
|
(uint64(b[4]) << 32) |
|
||||||
|
(uint64(b[5]) << 40) |
|
||||||
|
(uint64(b[6]) << 48) |
|
||||||
|
(uint64(b[7]) << 56))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *decoder) readByte() byte {
|
||||||
|
i := d.i
|
||||||
|
d.i++
|
||||||
|
if d.i > len(d.in) {
|
||||||
|
corrupted()
|
||||||
|
}
|
||||||
|
return d.in[i]
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *decoder) readBytes(length int32) []byte {
|
||||||
|
if length < 0 {
|
||||||
|
corrupted()
|
||||||
|
}
|
||||||
|
start := d.i
|
||||||
|
d.i += int(length)
|
||||||
|
if d.i < start || d.i > len(d.in) {
|
||||||
|
corrupted()
|
||||||
|
}
|
||||||
|
return d.in[start : start+int(length)]
|
||||||
|
}
|
|
@ -0,0 +1,514 @@
|
||||||
|
// BSON library for Go
|
||||||
|
//
|
||||||
|
// Copyright (c) 2010-2012 - Gustavo Niemeyer <gustavo@niemeyer.net>
|
||||||
|
//
|
||||||
|
// All rights reserved.
|
||||||
|
//
|
||||||
|
// Redistribution and use in source and binary forms, with or without
|
||||||
|
// modification, are permitted provided that the following conditions are met:
|
||||||
|
//
|
||||||
|
// 1. Redistributions of source code must retain the above copyright notice, this
|
||||||
|
// list of conditions and the following disclaimer.
|
||||||
|
// 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||||
|
// this list of conditions and the following disclaimer in the documentation
|
||||||
|
// and/or other materials provided with the distribution.
|
||||||
|
//
|
||||||
|
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
|
||||||
|
// ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
|
||||||
|
// WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||||
|
// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
|
||||||
|
// ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
|
||||||
|
// (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
|
||||||
|
// LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
|
||||||
|
// ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||||
|
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||||
|
// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||||
|
// gobson - BSON library for Go.
|
||||||
|
|
||||||
|
package bson
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"math"
|
||||||
|
"net/url"
|
||||||
|
"reflect"
|
||||||
|
"strconv"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// --------------------------------------------------------------------------
|
||||||
|
// Some internal infrastructure.
|
||||||
|
|
||||||
|
var (
|
||||||
|
typeBinary = reflect.TypeOf(Binary{})
|
||||||
|
typeObjectId = reflect.TypeOf(ObjectId(""))
|
||||||
|
typeDBPointer = reflect.TypeOf(DBPointer{"", ObjectId("")})
|
||||||
|
typeSymbol = reflect.TypeOf(Symbol(""))
|
||||||
|
typeMongoTimestamp = reflect.TypeOf(MongoTimestamp(0))
|
||||||
|
typeOrderKey = reflect.TypeOf(MinKey)
|
||||||
|
typeDocElem = reflect.TypeOf(DocElem{})
|
||||||
|
typeRawDocElem = reflect.TypeOf(RawDocElem{})
|
||||||
|
typeRaw = reflect.TypeOf(Raw{})
|
||||||
|
typeURL = reflect.TypeOf(url.URL{})
|
||||||
|
typeTime = reflect.TypeOf(time.Time{})
|
||||||
|
typeString = reflect.TypeOf("")
|
||||||
|
typeJSONNumber = reflect.TypeOf(json.Number(""))
|
||||||
|
)
|
||||||
|
|
||||||
|
const itoaCacheSize = 32
|
||||||
|
|
||||||
|
var itoaCache []string
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
itoaCache = make([]string, itoaCacheSize)
|
||||||
|
for i := 0; i != itoaCacheSize; i++ {
|
||||||
|
itoaCache[i] = strconv.Itoa(i)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func itoa(i int) string {
|
||||||
|
if i < itoaCacheSize {
|
||||||
|
return itoaCache[i]
|
||||||
|
}
|
||||||
|
return strconv.Itoa(i)
|
||||||
|
}
|
||||||
|
|
||||||
|
// --------------------------------------------------------------------------
|
||||||
|
// Marshaling of the document value itself.
|
||||||
|
|
||||||
|
type encoder struct {
|
||||||
|
out []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *encoder) addDoc(v reflect.Value) {
|
||||||
|
for {
|
||||||
|
if vi, ok := v.Interface().(Getter); ok {
|
||||||
|
getv, err := vi.GetBSON()
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
v = reflect.ValueOf(getv)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if v.Kind() == reflect.Ptr {
|
||||||
|
v = v.Elem()
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
if v.Type() == typeRaw {
|
||||||
|
raw := v.Interface().(Raw)
|
||||||
|
if raw.Kind != 0x03 && raw.Kind != 0x00 {
|
||||||
|
panic("Attempted to marshal Raw kind " + strconv.Itoa(int(raw.Kind)) + " as a document")
|
||||||
|
}
|
||||||
|
if len(raw.Data) == 0 {
|
||||||
|
panic("Attempted to marshal empty Raw document")
|
||||||
|
}
|
||||||
|
e.addBytes(raw.Data...)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
start := e.reserveInt32()
|
||||||
|
|
||||||
|
switch v.Kind() {
|
||||||
|
case reflect.Map:
|
||||||
|
e.addMap(v)
|
||||||
|
case reflect.Struct:
|
||||||
|
e.addStruct(v)
|
||||||
|
case reflect.Array, reflect.Slice:
|
||||||
|
e.addSlice(v)
|
||||||
|
default:
|
||||||
|
panic("Can't marshal " + v.Type().String() + " as a BSON document")
|
||||||
|
}
|
||||||
|
|
||||||
|
e.addBytes(0)
|
||||||
|
e.setInt32(start, int32(len(e.out)-start))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *encoder) addMap(v reflect.Value) {
|
||||||
|
for _, k := range v.MapKeys() {
|
||||||
|
e.addElem(k.String(), v.MapIndex(k), false)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *encoder) addStruct(v reflect.Value) {
|
||||||
|
sinfo, err := getStructInfo(v.Type())
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
var value reflect.Value
|
||||||
|
if sinfo.InlineMap >= 0 {
|
||||||
|
m := v.Field(sinfo.InlineMap)
|
||||||
|
if m.Len() > 0 {
|
||||||
|
for _, k := range m.MapKeys() {
|
||||||
|
ks := k.String()
|
||||||
|
if _, found := sinfo.FieldsMap[ks]; found {
|
||||||
|
panic(fmt.Sprintf("Can't have key %q in inlined map; conflicts with struct field", ks))
|
||||||
|
}
|
||||||
|
e.addElem(ks, m.MapIndex(k), false)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for _, info := range sinfo.FieldsList {
|
||||||
|
if info.Inline == nil {
|
||||||
|
value = v.Field(info.Num)
|
||||||
|
} else {
|
||||||
|
value = v.FieldByIndex(info.Inline)
|
||||||
|
}
|
||||||
|
if info.OmitEmpty && isZero(value) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
e.addElem(info.Key, value, info.MinSize)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func isZero(v reflect.Value) bool {
|
||||||
|
switch v.Kind() {
|
||||||
|
case reflect.String:
|
||||||
|
return len(v.String()) == 0
|
||||||
|
case reflect.Ptr, reflect.Interface:
|
||||||
|
return v.IsNil()
|
||||||
|
case reflect.Slice:
|
||||||
|
return v.Len() == 0
|
||||||
|
case reflect.Map:
|
||||||
|
return v.Len() == 0
|
||||||
|
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
||||||
|
return v.Int() == 0
|
||||||
|
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
|
||||||
|
return v.Uint() == 0
|
||||||
|
case reflect.Float32, reflect.Float64:
|
||||||
|
return v.Float() == 0
|
||||||
|
case reflect.Bool:
|
||||||
|
return !v.Bool()
|
||||||
|
case reflect.Struct:
|
||||||
|
vt := v.Type()
|
||||||
|
if vt == typeTime {
|
||||||
|
return v.Interface().(time.Time).IsZero()
|
||||||
|
}
|
||||||
|
for i := 0; i < v.NumField(); i++ {
|
||||||
|
if vt.Field(i).PkgPath != "" && !vt.Field(i).Anonymous {
|
||||||
|
continue // Private field
|
||||||
|
}
|
||||||
|
if !isZero(v.Field(i)) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *encoder) addSlice(v reflect.Value) {
|
||||||
|
vi := v.Interface()
|
||||||
|
if d, ok := vi.(D); ok {
|
||||||
|
for _, elem := range d {
|
||||||
|
e.addElem(elem.Name, reflect.ValueOf(elem.Value), false)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if d, ok := vi.(RawD); ok {
|
||||||
|
for _, elem := range d {
|
||||||
|
e.addElem(elem.Name, reflect.ValueOf(elem.Value), false)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
l := v.Len()
|
||||||
|
et := v.Type().Elem()
|
||||||
|
if et == typeDocElem {
|
||||||
|
for i := 0; i < l; i++ {
|
||||||
|
elem := v.Index(i).Interface().(DocElem)
|
||||||
|
e.addElem(elem.Name, reflect.ValueOf(elem.Value), false)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if et == typeRawDocElem {
|
||||||
|
for i := 0; i < l; i++ {
|
||||||
|
elem := v.Index(i).Interface().(RawDocElem)
|
||||||
|
e.addElem(elem.Name, reflect.ValueOf(elem.Value), false)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
for i := 0; i < l; i++ {
|
||||||
|
e.addElem(itoa(i), v.Index(i), false)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// --------------------------------------------------------------------------
|
||||||
|
// Marshaling of elements in a document.
|
||||||
|
|
||||||
|
func (e *encoder) addElemName(kind byte, name string) {
|
||||||
|
e.addBytes(kind)
|
||||||
|
e.addBytes([]byte(name)...)
|
||||||
|
e.addBytes(0)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *encoder) addElem(name string, v reflect.Value, minSize bool) {
|
||||||
|
|
||||||
|
if !v.IsValid() {
|
||||||
|
e.addElemName(0x0A, name)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if getter, ok := v.Interface().(Getter); ok {
|
||||||
|
getv, err := getter.GetBSON()
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
e.addElem(name, reflect.ValueOf(getv), minSize)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
switch v.Kind() {
|
||||||
|
|
||||||
|
case reflect.Interface:
|
||||||
|
e.addElem(name, v.Elem(), minSize)
|
||||||
|
|
||||||
|
case reflect.Ptr:
|
||||||
|
e.addElem(name, v.Elem(), minSize)
|
||||||
|
|
||||||
|
case reflect.String:
|
||||||
|
s := v.String()
|
||||||
|
switch v.Type() {
|
||||||
|
case typeObjectId:
|
||||||
|
if len(s) != 12 {
|
||||||
|
panic("ObjectIDs must be exactly 12 bytes long (got " +
|
||||||
|
strconv.Itoa(len(s)) + ")")
|
||||||
|
}
|
||||||
|
e.addElemName(0x07, name)
|
||||||
|
e.addBytes([]byte(s)...)
|
||||||
|
case typeSymbol:
|
||||||
|
e.addElemName(0x0E, name)
|
||||||
|
e.addStr(s)
|
||||||
|
case typeJSONNumber:
|
||||||
|
n := v.Interface().(json.Number)
|
||||||
|
if i, err := n.Int64(); err == nil {
|
||||||
|
e.addElemName(0x12, name)
|
||||||
|
e.addInt64(i)
|
||||||
|
} else if f, err := n.Float64(); err == nil {
|
||||||
|
e.addElemName(0x01, name)
|
||||||
|
e.addFloat64(f)
|
||||||
|
} else {
|
||||||
|
panic("failed to convert json.Number to a number: " + s)
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
e.addElemName(0x02, name)
|
||||||
|
e.addStr(s)
|
||||||
|
}
|
||||||
|
|
||||||
|
case reflect.Float32, reflect.Float64:
|
||||||
|
e.addElemName(0x01, name)
|
||||||
|
e.addFloat64(v.Float())
|
||||||
|
|
||||||
|
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
|
||||||
|
u := v.Uint()
|
||||||
|
if int64(u) < 0 {
|
||||||
|
panic("BSON has no uint64 type, and value is too large to fit correctly in an int64")
|
||||||
|
} else if u <= math.MaxInt32 && (minSize || v.Kind() <= reflect.Uint32) {
|
||||||
|
e.addElemName(0x10, name)
|
||||||
|
e.addInt32(int32(u))
|
||||||
|
} else {
|
||||||
|
e.addElemName(0x12, name)
|
||||||
|
e.addInt64(int64(u))
|
||||||
|
}
|
||||||
|
|
||||||
|
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
||||||
|
switch v.Type() {
|
||||||
|
case typeMongoTimestamp:
|
||||||
|
e.addElemName(0x11, name)
|
||||||
|
e.addInt64(v.Int())
|
||||||
|
|
||||||
|
case typeOrderKey:
|
||||||
|
if v.Int() == int64(MaxKey) {
|
||||||
|
e.addElemName(0x7F, name)
|
||||||
|
} else {
|
||||||
|
e.addElemName(0xFF, name)
|
||||||
|
}
|
||||||
|
|
||||||
|
default:
|
||||||
|
i := v.Int()
|
||||||
|
if (minSize || v.Type().Kind() != reflect.Int64) && i >= math.MinInt32 && i <= math.MaxInt32 {
|
||||||
|
// It fits into an int32, encode as such.
|
||||||
|
e.addElemName(0x10, name)
|
||||||
|
e.addInt32(int32(i))
|
||||||
|
} else {
|
||||||
|
e.addElemName(0x12, name)
|
||||||
|
e.addInt64(i)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
case reflect.Bool:
|
||||||
|
e.addElemName(0x08, name)
|
||||||
|
if v.Bool() {
|
||||||
|
e.addBytes(1)
|
||||||
|
} else {
|
||||||
|
e.addBytes(0)
|
||||||
|
}
|
||||||
|
|
||||||
|
case reflect.Map:
|
||||||
|
e.addElemName(0x03, name)
|
||||||
|
e.addDoc(v)
|
||||||
|
|
||||||
|
case reflect.Slice:
|
||||||
|
vt := v.Type()
|
||||||
|
et := vt.Elem()
|
||||||
|
if et.Kind() == reflect.Uint8 {
|
||||||
|
e.addElemName(0x05, name)
|
||||||
|
e.addBinary(0x00, v.Bytes())
|
||||||
|
} else if et == typeDocElem || et == typeRawDocElem {
|
||||||
|
e.addElemName(0x03, name)
|
||||||
|
e.addDoc(v)
|
||||||
|
} else {
|
||||||
|
e.addElemName(0x04, name)
|
||||||
|
e.addDoc(v)
|
||||||
|
}
|
||||||
|
|
||||||
|
case reflect.Array:
|
||||||
|
et := v.Type().Elem()
|
||||||
|
if et.Kind() == reflect.Uint8 {
|
||||||
|
e.addElemName(0x05, name)
|
||||||
|
if v.CanAddr() {
|
||||||
|
e.addBinary(0x00, v.Slice(0, v.Len()).Interface().([]byte))
|
||||||
|
} else {
|
||||||
|
n := v.Len()
|
||||||
|
e.addInt32(int32(n))
|
||||||
|
e.addBytes(0x00)
|
||||||
|
for i := 0; i < n; i++ {
|
||||||
|
el := v.Index(i)
|
||||||
|
e.addBytes(byte(el.Uint()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
e.addElemName(0x04, name)
|
||||||
|
e.addDoc(v)
|
||||||
|
}
|
||||||
|
|
||||||
|
case reflect.Struct:
|
||||||
|
switch s := v.Interface().(type) {
|
||||||
|
|
||||||
|
case Raw:
|
||||||
|
kind := s.Kind
|
||||||
|
if kind == 0x00 {
|
||||||
|
kind = 0x03
|
||||||
|
}
|
||||||
|
if len(s.Data) == 0 && kind != 0x06 && kind != 0x0A && kind != 0xFF && kind != 0x7F {
|
||||||
|
panic("Attempted to marshal empty Raw document")
|
||||||
|
}
|
||||||
|
e.addElemName(kind, name)
|
||||||
|
e.addBytes(s.Data...)
|
||||||
|
|
||||||
|
case Binary:
|
||||||
|
e.addElemName(0x05, name)
|
||||||
|
e.addBinary(s.Kind, s.Data)
|
||||||
|
|
||||||
|
case Decimal128:
|
||||||
|
e.addElemName(0x13, name)
|
||||||
|
e.addInt64(int64(s.l))
|
||||||
|
e.addInt64(int64(s.h))
|
||||||
|
|
||||||
|
case DBPointer:
|
||||||
|
e.addElemName(0x0C, name)
|
||||||
|
e.addStr(s.Namespace)
|
||||||
|
if len(s.Id) != 12 {
|
||||||
|
panic("ObjectIDs must be exactly 12 bytes long (got " +
|
||||||
|
strconv.Itoa(len(s.Id)) + ")")
|
||||||
|
}
|
||||||
|
e.addBytes([]byte(s.Id)...)
|
||||||
|
|
||||||
|
case RegEx:
|
||||||
|
e.addElemName(0x0B, name)
|
||||||
|
e.addCStr(s.Pattern)
|
||||||
|
e.addCStr(s.Options)
|
||||||
|
|
||||||
|
case JavaScript:
|
||||||
|
if s.Scope == nil {
|
||||||
|
e.addElemName(0x0D, name)
|
||||||
|
e.addStr(s.Code)
|
||||||
|
} else {
|
||||||
|
e.addElemName(0x0F, name)
|
||||||
|
start := e.reserveInt32()
|
||||||
|
e.addStr(s.Code)
|
||||||
|
e.addDoc(reflect.ValueOf(s.Scope))
|
||||||
|
e.setInt32(start, int32(len(e.out)-start))
|
||||||
|
}
|
||||||
|
|
||||||
|
case time.Time:
|
||||||
|
// MongoDB handles timestamps as milliseconds.
|
||||||
|
e.addElemName(0x09, name)
|
||||||
|
e.addInt64(s.Unix()*1000 + int64(s.Nanosecond()/1e6))
|
||||||
|
|
||||||
|
case url.URL:
|
||||||
|
e.addElemName(0x02, name)
|
||||||
|
e.addStr(s.String())
|
||||||
|
|
||||||
|
case undefined:
|
||||||
|
e.addElemName(0x06, name)
|
||||||
|
|
||||||
|
default:
|
||||||
|
e.addElemName(0x03, name)
|
||||||
|
e.addDoc(v)
|
||||||
|
}
|
||||||
|
|
||||||
|
default:
|
||||||
|
panic("Can't marshal " + v.Type().String() + " in a BSON document")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// --------------------------------------------------------------------------
|
||||||
|
// Marshaling of base types.
|
||||||
|
|
||||||
|
func (e *encoder) addBinary(subtype byte, v []byte) {
|
||||||
|
if subtype == 0x02 {
|
||||||
|
// Wonder how that brilliant idea came to life. Obsolete, luckily.
|
||||||
|
e.addInt32(int32(len(v) + 4))
|
||||||
|
e.addBytes(subtype)
|
||||||
|
e.addInt32(int32(len(v)))
|
||||||
|
} else {
|
||||||
|
e.addInt32(int32(len(v)))
|
||||||
|
e.addBytes(subtype)
|
||||||
|
}
|
||||||
|
e.addBytes(v...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *encoder) addStr(v string) {
|
||||||
|
e.addInt32(int32(len(v) + 1))
|
||||||
|
e.addCStr(v)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *encoder) addCStr(v string) {
|
||||||
|
e.addBytes([]byte(v)...)
|
||||||
|
e.addBytes(0)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *encoder) reserveInt32() (pos int) {
|
||||||
|
pos = len(e.out)
|
||||||
|
e.addBytes(0, 0, 0, 0)
|
||||||
|
return pos
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *encoder) setInt32(pos int, v int32) {
|
||||||
|
e.out[pos+0] = byte(v)
|
||||||
|
e.out[pos+1] = byte(v >> 8)
|
||||||
|
e.out[pos+2] = byte(v >> 16)
|
||||||
|
e.out[pos+3] = byte(v >> 24)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *encoder) addInt32(v int32) {
|
||||||
|
u := uint32(v)
|
||||||
|
e.addBytes(byte(u), byte(u>>8), byte(u>>16), byte(u>>24))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *encoder) addInt64(v int64) {
|
||||||
|
u := uint64(v)
|
||||||
|
e.addBytes(byte(u), byte(u>>8), byte(u>>16), byte(u>>24),
|
||||||
|
byte(u>>32), byte(u>>40), byte(u>>48), byte(u>>56))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *encoder) addFloat64(v float64) {
|
||||||
|
e.addInt64(int64(math.Float64bits(v)))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *encoder) addBytes(v ...byte) {
|
||||||
|
e.out = append(e.out, v...)
|
||||||
|
}
|
|
@ -0,0 +1,380 @@
|
||||||
|
package bson
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/base64"
|
||||||
|
"fmt"
|
||||||
|
"gopkg.in/mgo.v2/internal/json"
|
||||||
|
"strconv"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// UnmarshalJSON unmarshals a JSON value that may hold non-standard
|
||||||
|
// syntax as defined in BSON's extended JSON specification.
|
||||||
|
func UnmarshalJSON(data []byte, value interface{}) error {
|
||||||
|
d := json.NewDecoder(bytes.NewBuffer(data))
|
||||||
|
d.Extend(&jsonExt)
|
||||||
|
return d.Decode(value)
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarshalJSON marshals a JSON value that may hold non-standard
|
||||||
|
// syntax as defined in BSON's extended JSON specification.
|
||||||
|
func MarshalJSON(value interface{}) ([]byte, error) {
|
||||||
|
var buf bytes.Buffer
|
||||||
|
e := json.NewEncoder(&buf)
|
||||||
|
e.Extend(&jsonExt)
|
||||||
|
err := e.Encode(value)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return buf.Bytes(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// jdec is used internally by the JSON decoding functions
|
||||||
|
// so they may unmarshal functions without getting into endless
|
||||||
|
// recursion due to keyed objects.
|
||||||
|
func jdec(data []byte, value interface{}) error {
|
||||||
|
d := json.NewDecoder(bytes.NewBuffer(data))
|
||||||
|
d.Extend(&funcExt)
|
||||||
|
return d.Decode(value)
|
||||||
|
}
|
||||||
|
|
||||||
|
var jsonExt json.Extension
|
||||||
|
var funcExt json.Extension
|
||||||
|
|
||||||
|
// TODO
|
||||||
|
// - Shell regular expressions ("/regexp/opts")
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
jsonExt.DecodeUnquotedKeys(true)
|
||||||
|
jsonExt.DecodeTrailingCommas(true)
|
||||||
|
|
||||||
|
funcExt.DecodeFunc("BinData", "$binaryFunc", "$type", "$binary")
|
||||||
|
jsonExt.DecodeKeyed("$binary", jdecBinary)
|
||||||
|
jsonExt.DecodeKeyed("$binaryFunc", jdecBinary)
|
||||||
|
jsonExt.EncodeType([]byte(nil), jencBinarySlice)
|
||||||
|
jsonExt.EncodeType(Binary{}, jencBinaryType)
|
||||||
|
|
||||||
|
funcExt.DecodeFunc("ISODate", "$dateFunc", "S")
|
||||||
|
funcExt.DecodeFunc("new Date", "$dateFunc", "S")
|
||||||
|
jsonExt.DecodeKeyed("$date", jdecDate)
|
||||||
|
jsonExt.DecodeKeyed("$dateFunc", jdecDate)
|
||||||
|
jsonExt.EncodeType(time.Time{}, jencDate)
|
||||||
|
|
||||||
|
funcExt.DecodeFunc("Timestamp", "$timestamp", "t", "i")
|
||||||
|
jsonExt.DecodeKeyed("$timestamp", jdecTimestamp)
|
||||||
|
jsonExt.EncodeType(MongoTimestamp(0), jencTimestamp)
|
||||||
|
|
||||||
|
funcExt.DecodeConst("undefined", Undefined)
|
||||||
|
|
||||||
|
jsonExt.DecodeKeyed("$regex", jdecRegEx)
|
||||||
|
jsonExt.EncodeType(RegEx{}, jencRegEx)
|
||||||
|
|
||||||
|
funcExt.DecodeFunc("ObjectId", "$oidFunc", "Id")
|
||||||
|
jsonExt.DecodeKeyed("$oid", jdecObjectId)
|
||||||
|
jsonExt.DecodeKeyed("$oidFunc", jdecObjectId)
|
||||||
|
jsonExt.EncodeType(ObjectId(""), jencObjectId)
|
||||||
|
|
||||||
|
funcExt.DecodeFunc("DBRef", "$dbrefFunc", "$ref", "$id")
|
||||||
|
jsonExt.DecodeKeyed("$dbrefFunc", jdecDBRef)
|
||||||
|
|
||||||
|
funcExt.DecodeFunc("NumberLong", "$numberLongFunc", "N")
|
||||||
|
jsonExt.DecodeKeyed("$numberLong", jdecNumberLong)
|
||||||
|
jsonExt.DecodeKeyed("$numberLongFunc", jdecNumberLong)
|
||||||
|
jsonExt.EncodeType(int64(0), jencNumberLong)
|
||||||
|
jsonExt.EncodeType(int(0), jencInt)
|
||||||
|
|
||||||
|
funcExt.DecodeConst("MinKey", MinKey)
|
||||||
|
funcExt.DecodeConst("MaxKey", MaxKey)
|
||||||
|
jsonExt.DecodeKeyed("$minKey", jdecMinKey)
|
||||||
|
jsonExt.DecodeKeyed("$maxKey", jdecMaxKey)
|
||||||
|
jsonExt.EncodeType(orderKey(0), jencMinMaxKey)
|
||||||
|
|
||||||
|
jsonExt.DecodeKeyed("$undefined", jdecUndefined)
|
||||||
|
jsonExt.EncodeType(Undefined, jencUndefined)
|
||||||
|
|
||||||
|
jsonExt.Extend(&funcExt)
|
||||||
|
}
|
||||||
|
|
||||||
|
func fbytes(format string, args ...interface{}) []byte {
|
||||||
|
var buf bytes.Buffer
|
||||||
|
fmt.Fprintf(&buf, format, args...)
|
||||||
|
return buf.Bytes()
|
||||||
|
}
|
||||||
|
|
||||||
|
func jdecBinary(data []byte) (interface{}, error) {
|
||||||
|
var v struct {
|
||||||
|
Binary []byte `json:"$binary"`
|
||||||
|
Type string `json:"$type"`
|
||||||
|
Func struct {
|
||||||
|
Binary []byte `json:"$binary"`
|
||||||
|
Type int64 `json:"$type"`
|
||||||
|
} `json:"$binaryFunc"`
|
||||||
|
}
|
||||||
|
err := jdec(data, &v)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
var binData []byte
|
||||||
|
var binKind int64
|
||||||
|
if v.Type == "" && v.Binary == nil {
|
||||||
|
binData = v.Func.Binary
|
||||||
|
binKind = v.Func.Type
|
||||||
|
} else if v.Type == "" {
|
||||||
|
return v.Binary, nil
|
||||||
|
} else {
|
||||||
|
binData = v.Binary
|
||||||
|
binKind, err = strconv.ParseInt(v.Type, 0, 64)
|
||||||
|
if err != nil {
|
||||||
|
binKind = -1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if binKind == 0 {
|
||||||
|
return binData, nil
|
||||||
|
}
|
||||||
|
if binKind < 0 || binKind > 255 {
|
||||||
|
return nil, fmt.Errorf("invalid type in binary object: %s", data)
|
||||||
|
}
|
||||||
|
|
||||||
|
return Binary{Kind: byte(binKind), Data: binData}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func jencBinarySlice(v interface{}) ([]byte, error) {
|
||||||
|
in := v.([]byte)
|
||||||
|
out := make([]byte, base64.StdEncoding.EncodedLen(len(in)))
|
||||||
|
base64.StdEncoding.Encode(out, in)
|
||||||
|
return fbytes(`{"$binary":"%s","$type":"0x0"}`, out), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func jencBinaryType(v interface{}) ([]byte, error) {
|
||||||
|
in := v.(Binary)
|
||||||
|
out := make([]byte, base64.StdEncoding.EncodedLen(len(in.Data)))
|
||||||
|
base64.StdEncoding.Encode(out, in.Data)
|
||||||
|
return fbytes(`{"$binary":"%s","$type":"0x%x"}`, out, in.Kind), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
const jdateFormat = "2006-01-02T15:04:05.999Z"
|
||||||
|
|
||||||
|
func jdecDate(data []byte) (interface{}, error) {
|
||||||
|
var v struct {
|
||||||
|
S string `json:"$date"`
|
||||||
|
Func struct {
|
||||||
|
S string
|
||||||
|
} `json:"$dateFunc"`
|
||||||
|
}
|
||||||
|
_ = jdec(data, &v)
|
||||||
|
if v.S == "" {
|
||||||
|
v.S = v.Func.S
|
||||||
|
}
|
||||||
|
if v.S != "" {
|
||||||
|
for _, format := range []string{jdateFormat, "2006-01-02"} {
|
||||||
|
t, err := time.Parse(format, v.S)
|
||||||
|
if err == nil {
|
||||||
|
return t, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("cannot parse date: %q", v.S)
|
||||||
|
}
|
||||||
|
|
||||||
|
var vn struct {
|
||||||
|
Date struct {
|
||||||
|
N int64 `json:"$numberLong,string"`
|
||||||
|
} `json:"$date"`
|
||||||
|
Func struct {
|
||||||
|
S int64
|
||||||
|
} `json:"$dateFunc"`
|
||||||
|
}
|
||||||
|
err := jdec(data, &vn)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("cannot parse date: %q", data)
|
||||||
|
}
|
||||||
|
n := vn.Date.N
|
||||||
|
if n == 0 {
|
||||||
|
n = vn.Func.S
|
||||||
|
}
|
||||||
|
return time.Unix(n/1000, n%1000*1e6).UTC(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func jencDate(v interface{}) ([]byte, error) {
|
||||||
|
t := v.(time.Time)
|
||||||
|
return fbytes(`{"$date":%q}`, t.Format(jdateFormat)), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func jdecTimestamp(data []byte) (interface{}, error) {
|
||||||
|
var v struct {
|
||||||
|
Func struct {
|
||||||
|
T int32 `json:"t"`
|
||||||
|
I int32 `json:"i"`
|
||||||
|
} `json:"$timestamp"`
|
||||||
|
}
|
||||||
|
err := jdec(data, &v)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return MongoTimestamp(uint64(v.Func.T)<<32 | uint64(uint32(v.Func.I))), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func jencTimestamp(v interface{}) ([]byte, error) {
|
||||||
|
ts := uint64(v.(MongoTimestamp))
|
||||||
|
return fbytes(`{"$timestamp":{"t":%d,"i":%d}}`, ts>>32, uint32(ts)), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func jdecRegEx(data []byte) (interface{}, error) {
|
||||||
|
var v struct {
|
||||||
|
Regex string `json:"$regex"`
|
||||||
|
Options string `json:"$options"`
|
||||||
|
}
|
||||||
|
err := jdec(data, &v)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return RegEx{v.Regex, v.Options}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func jencRegEx(v interface{}) ([]byte, error) {
|
||||||
|
re := v.(RegEx)
|
||||||
|
type regex struct {
|
||||||
|
Regex string `json:"$regex"`
|
||||||
|
Options string `json:"$options"`
|
||||||
|
}
|
||||||
|
return json.Marshal(regex{re.Pattern, re.Options})
|
||||||
|
}
|
||||||
|
|
||||||
|
func jdecObjectId(data []byte) (interface{}, error) {
|
||||||
|
var v struct {
|
||||||
|
Id string `json:"$oid"`
|
||||||
|
Func struct {
|
||||||
|
Id string
|
||||||
|
} `json:"$oidFunc"`
|
||||||
|
}
|
||||||
|
err := jdec(data, &v)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if v.Id == "" {
|
||||||
|
v.Id = v.Func.Id
|
||||||
|
}
|
||||||
|
return ObjectIdHex(v.Id), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func jencObjectId(v interface{}) ([]byte, error) {
|
||||||
|
return fbytes(`{"$oid":"%s"}`, v.(ObjectId).Hex()), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func jdecDBRef(data []byte) (interface{}, error) {
|
||||||
|
// TODO Support unmarshaling $ref and $id into the input value.
|
||||||
|
var v struct {
|
||||||
|
Obj map[string]interface{} `json:"$dbrefFunc"`
|
||||||
|
}
|
||||||
|
// TODO Fix this. Must not be required.
|
||||||
|
v.Obj = make(map[string]interface{})
|
||||||
|
err := jdec(data, &v)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return v.Obj, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func jdecNumberLong(data []byte) (interface{}, error) {
|
||||||
|
var v struct {
|
||||||
|
N int64 `json:"$numberLong,string"`
|
||||||
|
Func struct {
|
||||||
|
N int64 `json:",string"`
|
||||||
|
} `json:"$numberLongFunc"`
|
||||||
|
}
|
||||||
|
var vn struct {
|
||||||
|
N int64 `json:"$numberLong"`
|
||||||
|
Func struct {
|
||||||
|
N int64
|
||||||
|
} `json:"$numberLongFunc"`
|
||||||
|
}
|
||||||
|
err := jdec(data, &v)
|
||||||
|
if err != nil {
|
||||||
|
err = jdec(data, &vn)
|
||||||
|
v.N = vn.N
|
||||||
|
v.Func.N = vn.Func.N
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if v.N != 0 {
|
||||||
|
return v.N, nil
|
||||||
|
}
|
||||||
|
return v.Func.N, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func jencNumberLong(v interface{}) ([]byte, error) {
|
||||||
|
n := v.(int64)
|
||||||
|
f := `{"$numberLong":"%d"}`
|
||||||
|
if n <= 1<<53 {
|
||||||
|
f = `{"$numberLong":%d}`
|
||||||
|
}
|
||||||
|
return fbytes(f, n), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func jencInt(v interface{}) ([]byte, error) {
|
||||||
|
n := v.(int)
|
||||||
|
f := `{"$numberLong":"%d"}`
|
||||||
|
if int64(n) <= 1<<53 {
|
||||||
|
f = `%d`
|
||||||
|
}
|
||||||
|
return fbytes(f, n), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func jdecMinKey(data []byte) (interface{}, error) {
|
||||||
|
var v struct {
|
||||||
|
N int64 `json:"$minKey"`
|
||||||
|
}
|
||||||
|
err := jdec(data, &v)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if v.N != 1 {
|
||||||
|
return nil, fmt.Errorf("invalid $minKey object: %s", data)
|
||||||
|
}
|
||||||
|
return MinKey, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func jdecMaxKey(data []byte) (interface{}, error) {
|
||||||
|
var v struct {
|
||||||
|
N int64 `json:"$maxKey"`
|
||||||
|
}
|
||||||
|
err := jdec(data, &v)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if v.N != 1 {
|
||||||
|
return nil, fmt.Errorf("invalid $maxKey object: %s", data)
|
||||||
|
}
|
||||||
|
return MaxKey, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func jencMinMaxKey(v interface{}) ([]byte, error) {
|
||||||
|
switch v.(orderKey) {
|
||||||
|
case MinKey:
|
||||||
|
return []byte(`{"$minKey":1}`), nil
|
||||||
|
case MaxKey:
|
||||||
|
return []byte(`{"$maxKey":1}`), nil
|
||||||
|
}
|
||||||
|
panic(fmt.Sprintf("invalid $minKey/$maxKey value: %d", v))
|
||||||
|
}
|
||||||
|
|
||||||
|
func jdecUndefined(data []byte) (interface{}, error) {
|
||||||
|
var v struct {
|
||||||
|
B bool `json:"$undefined"`
|
||||||
|
}
|
||||||
|
err := jdec(data, &v)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if !v.B {
|
||||||
|
return nil, fmt.Errorf("invalid $undefined object: %s", data)
|
||||||
|
}
|
||||||
|
return Undefined, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func jencUndefined(v interface{}) ([]byte, error) {
|
||||||
|
return []byte(`{"$undefined":true}`), nil
|
||||||
|
}
|
|
@ -0,0 +1,351 @@
|
||||||
|
package mgo
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"sort"
|
||||||
|
|
||||||
|
"gopkg.in/mgo.v2/bson"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Bulk represents an operation that can be prepared with several
|
||||||
|
// orthogonal changes before being delivered to the server.
|
||||||
|
//
|
||||||
|
// MongoDB servers older than version 2.6 do not have proper support for bulk
|
||||||
|
// operations, so the driver attempts to map its API as much as possible into
|
||||||
|
// the functionality that works. In particular, in those releases updates and
|
||||||
|
// removals are sent individually, and inserts are sent in bulk but have
|
||||||
|
// suboptimal error reporting compared to more recent versions of the server.
|
||||||
|
// See the documentation of BulkErrorCase for details on that.
|
||||||
|
//
|
||||||
|
// Relevant documentation:
|
||||||
|
//
|
||||||
|
// http://blog.mongodb.org/post/84922794768/mongodbs-new-bulk-api
|
||||||
|
//
|
||||||
|
type Bulk struct {
|
||||||
|
c *Collection
|
||||||
|
opcount int
|
||||||
|
actions []bulkAction
|
||||||
|
ordered bool
|
||||||
|
}
|
||||||
|
|
||||||
|
type bulkOp int
|
||||||
|
|
||||||
|
const (
|
||||||
|
bulkInsert bulkOp = iota + 1
|
||||||
|
bulkUpdate
|
||||||
|
bulkUpdateAll
|
||||||
|
bulkRemove
|
||||||
|
)
|
||||||
|
|
||||||
|
type bulkAction struct {
|
||||||
|
op bulkOp
|
||||||
|
docs []interface{}
|
||||||
|
idxs []int
|
||||||
|
}
|
||||||
|
|
||||||
|
type bulkUpdateOp []interface{}
|
||||||
|
type bulkDeleteOp []interface{}
|
||||||
|
|
||||||
|
// BulkResult holds the results for a bulk operation.
|
||||||
|
type BulkResult struct {
|
||||||
|
Matched int
|
||||||
|
Modified int // Available only for MongoDB 2.6+
|
||||||
|
|
||||||
|
// Be conservative while we understand exactly how to report these
|
||||||
|
// results in a useful and convenient way, and also how to emulate
|
||||||
|
// them with prior servers.
|
||||||
|
private bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// BulkError holds an error returned from running a Bulk operation.
|
||||||
|
// Individual errors may be obtained and inspected via the Cases method.
|
||||||
|
type BulkError struct {
|
||||||
|
ecases []BulkErrorCase
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *BulkError) Error() string {
|
||||||
|
if len(e.ecases) == 0 {
|
||||||
|
return "invalid BulkError instance: no errors"
|
||||||
|
}
|
||||||
|
if len(e.ecases) == 1 {
|
||||||
|
return e.ecases[0].Err.Error()
|
||||||
|
}
|
||||||
|
msgs := make([]string, 0, len(e.ecases))
|
||||||
|
seen := make(map[string]bool)
|
||||||
|
for _, ecase := range e.ecases {
|
||||||
|
msg := ecase.Err.Error()
|
||||||
|
if !seen[msg] {
|
||||||
|
seen[msg] = true
|
||||||
|
msgs = append(msgs, msg)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(msgs) == 1 {
|
||||||
|
return msgs[0]
|
||||||
|
}
|
||||||
|
var buf bytes.Buffer
|
||||||
|
buf.WriteString("multiple errors in bulk operation:\n")
|
||||||
|
for _, msg := range msgs {
|
||||||
|
buf.WriteString(" - ")
|
||||||
|
buf.WriteString(msg)
|
||||||
|
buf.WriteByte('\n')
|
||||||
|
}
|
||||||
|
return buf.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
type bulkErrorCases []BulkErrorCase
|
||||||
|
|
||||||
|
func (slice bulkErrorCases) Len() int { return len(slice) }
|
||||||
|
func (slice bulkErrorCases) Less(i, j int) bool { return slice[i].Index < slice[j].Index }
|
||||||
|
func (slice bulkErrorCases) Swap(i, j int) { slice[i], slice[j] = slice[j], slice[i] }
|
||||||
|
|
||||||
|
// BulkErrorCase holds an individual error found while attempting a single change
|
||||||
|
// within a bulk operation, and the position in which it was enqueued.
|
||||||
|
//
|
||||||
|
// MongoDB servers older than version 2.6 do not have proper support for bulk
|
||||||
|
// operations, so the driver attempts to map its API as much as possible into
|
||||||
|
// the functionality that works. In particular, only the last error is reported
|
||||||
|
// for bulk inserts and without any positional information, so the Index
|
||||||
|
// field is set to -1 in these cases.
|
||||||
|
type BulkErrorCase struct {
|
||||||
|
Index int // Position of operation that failed, or -1 if unknown.
|
||||||
|
Err error
|
||||||
|
}
|
||||||
|
|
||||||
|
// Cases returns all individual errors found while attempting the requested changes.
|
||||||
|
//
|
||||||
|
// See the documentation of BulkErrorCase for limitations in older MongoDB releases.
|
||||||
|
func (e *BulkError) Cases() []BulkErrorCase {
|
||||||
|
return e.ecases
|
||||||
|
}
|
||||||
|
|
||||||
|
// Bulk returns a value to prepare the execution of a bulk operation.
|
||||||
|
func (c *Collection) Bulk() *Bulk {
|
||||||
|
return &Bulk{c: c, ordered: true}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Unordered puts the bulk operation in unordered mode.
|
||||||
|
//
|
||||||
|
// In unordered mode the indvidual operations may be sent
|
||||||
|
// out of order, which means latter operations may proceed
|
||||||
|
// even if prior ones have failed.
|
||||||
|
func (b *Bulk) Unordered() {
|
||||||
|
b.ordered = false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *Bulk) action(op bulkOp, opcount int) *bulkAction {
|
||||||
|
var action *bulkAction
|
||||||
|
if len(b.actions) > 0 && b.actions[len(b.actions)-1].op == op {
|
||||||
|
action = &b.actions[len(b.actions)-1]
|
||||||
|
} else if !b.ordered {
|
||||||
|
for i := range b.actions {
|
||||||
|
if b.actions[i].op == op {
|
||||||
|
action = &b.actions[i]
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if action == nil {
|
||||||
|
b.actions = append(b.actions, bulkAction{op: op})
|
||||||
|
action = &b.actions[len(b.actions)-1]
|
||||||
|
}
|
||||||
|
for i := 0; i < opcount; i++ {
|
||||||
|
action.idxs = append(action.idxs, b.opcount)
|
||||||
|
b.opcount++
|
||||||
|
}
|
||||||
|
return action
|
||||||
|
}
|
||||||
|
|
||||||
|
// Insert queues up the provided documents for insertion.
|
||||||
|
func (b *Bulk) Insert(docs ...interface{}) {
|
||||||
|
action := b.action(bulkInsert, len(docs))
|
||||||
|
action.docs = append(action.docs, docs...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Remove queues up the provided selectors for removing matching documents.
|
||||||
|
// Each selector will remove only a single matching document.
|
||||||
|
func (b *Bulk) Remove(selectors ...interface{}) {
|
||||||
|
action := b.action(bulkRemove, len(selectors))
|
||||||
|
for _, selector := range selectors {
|
||||||
|
if selector == nil {
|
||||||
|
selector = bson.D{}
|
||||||
|
}
|
||||||
|
action.docs = append(action.docs, &deleteOp{
|
||||||
|
Collection: b.c.FullName,
|
||||||
|
Selector: selector,
|
||||||
|
Flags: 1,
|
||||||
|
Limit: 1,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// RemoveAll queues up the provided selectors for removing all matching documents.
|
||||||
|
// Each selector will remove all matching documents.
|
||||||
|
func (b *Bulk) RemoveAll(selectors ...interface{}) {
|
||||||
|
action := b.action(bulkRemove, len(selectors))
|
||||||
|
for _, selector := range selectors {
|
||||||
|
if selector == nil {
|
||||||
|
selector = bson.D{}
|
||||||
|
}
|
||||||
|
action.docs = append(action.docs, &deleteOp{
|
||||||
|
Collection: b.c.FullName,
|
||||||
|
Selector: selector,
|
||||||
|
Flags: 0,
|
||||||
|
Limit: 0,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update queues up the provided pairs of updating instructions.
|
||||||
|
// The first element of each pair selects which documents must be
|
||||||
|
// updated, and the second element defines how to update it.
|
||||||
|
// Each pair matches exactly one document for updating at most.
|
||||||
|
func (b *Bulk) Update(pairs ...interface{}) {
|
||||||
|
if len(pairs)%2 != 0 {
|
||||||
|
panic("Bulk.Update requires an even number of parameters")
|
||||||
|
}
|
||||||
|
action := b.action(bulkUpdate, len(pairs)/2)
|
||||||
|
for i := 0; i < len(pairs); i += 2 {
|
||||||
|
selector := pairs[i]
|
||||||
|
if selector == nil {
|
||||||
|
selector = bson.D{}
|
||||||
|
}
|
||||||
|
action.docs = append(action.docs, &updateOp{
|
||||||
|
Collection: b.c.FullName,
|
||||||
|
Selector: selector,
|
||||||
|
Update: pairs[i+1],
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateAll queues up the provided pairs of updating instructions.
|
||||||
|
// The first element of each pair selects which documents must be
|
||||||
|
// updated, and the second element defines how to update it.
|
||||||
|
// Each pair updates all documents matching the selector.
|
||||||
|
func (b *Bulk) UpdateAll(pairs ...interface{}) {
|
||||||
|
if len(pairs)%2 != 0 {
|
||||||
|
panic("Bulk.UpdateAll requires an even number of parameters")
|
||||||
|
}
|
||||||
|
action := b.action(bulkUpdate, len(pairs)/2)
|
||||||
|
for i := 0; i < len(pairs); i += 2 {
|
||||||
|
selector := pairs[i]
|
||||||
|
if selector == nil {
|
||||||
|
selector = bson.D{}
|
||||||
|
}
|
||||||
|
action.docs = append(action.docs, &updateOp{
|
||||||
|
Collection: b.c.FullName,
|
||||||
|
Selector: selector,
|
||||||
|
Update: pairs[i+1],
|
||||||
|
Flags: 2,
|
||||||
|
Multi: true,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Upsert queues up the provided pairs of upserting instructions.
|
||||||
|
// The first element of each pair selects which documents must be
|
||||||
|
// updated, and the second element defines how to update it.
|
||||||
|
// Each pair matches exactly one document for updating at most.
|
||||||
|
func (b *Bulk) Upsert(pairs ...interface{}) {
|
||||||
|
if len(pairs)%2 != 0 {
|
||||||
|
panic("Bulk.Update requires an even number of parameters")
|
||||||
|
}
|
||||||
|
action := b.action(bulkUpdate, len(pairs)/2)
|
||||||
|
for i := 0; i < len(pairs); i += 2 {
|
||||||
|
selector := pairs[i]
|
||||||
|
if selector == nil {
|
||||||
|
selector = bson.D{}
|
||||||
|
}
|
||||||
|
action.docs = append(action.docs, &updateOp{
|
||||||
|
Collection: b.c.FullName,
|
||||||
|
Selector: selector,
|
||||||
|
Update: pairs[i+1],
|
||||||
|
Flags: 1,
|
||||||
|
Upsert: true,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Run runs all the operations queued up.
|
||||||
|
//
|
||||||
|
// If an error is reported on an unordered bulk operation, the error value may
|
||||||
|
// be an aggregation of all issues observed. As an exception to that, Insert
|
||||||
|
// operations running on MongoDB versions prior to 2.6 will report the last
|
||||||
|
// error only due to a limitation in the wire protocol.
|
||||||
|
func (b *Bulk) Run() (*BulkResult, error) {
|
||||||
|
var result BulkResult
|
||||||
|
var berr BulkError
|
||||||
|
var failed bool
|
||||||
|
for i := range b.actions {
|
||||||
|
action := &b.actions[i]
|
||||||
|
var ok bool
|
||||||
|
switch action.op {
|
||||||
|
case bulkInsert:
|
||||||
|
ok = b.runInsert(action, &result, &berr)
|
||||||
|
case bulkUpdate:
|
||||||
|
ok = b.runUpdate(action, &result, &berr)
|
||||||
|
case bulkRemove:
|
||||||
|
ok = b.runRemove(action, &result, &berr)
|
||||||
|
default:
|
||||||
|
panic("unknown bulk operation")
|
||||||
|
}
|
||||||
|
if !ok {
|
||||||
|
failed = true
|
||||||
|
if b.ordered {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if failed {
|
||||||
|
sort.Sort(bulkErrorCases(berr.ecases))
|
||||||
|
return nil, &berr
|
||||||
|
}
|
||||||
|
return &result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *Bulk) runInsert(action *bulkAction, result *BulkResult, berr *BulkError) bool {
|
||||||
|
op := &insertOp{b.c.FullName, action.docs, 0}
|
||||||
|
if !b.ordered {
|
||||||
|
op.flags = 1 // ContinueOnError
|
||||||
|
}
|
||||||
|
lerr, err := b.c.writeOp(op, b.ordered)
|
||||||
|
return b.checkSuccess(action, berr, lerr, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *Bulk) runUpdate(action *bulkAction, result *BulkResult, berr *BulkError) bool {
|
||||||
|
lerr, err := b.c.writeOp(bulkUpdateOp(action.docs), b.ordered)
|
||||||
|
if lerr != nil {
|
||||||
|
result.Matched += lerr.N
|
||||||
|
result.Modified += lerr.modified
|
||||||
|
}
|
||||||
|
return b.checkSuccess(action, berr, lerr, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *Bulk) runRemove(action *bulkAction, result *BulkResult, berr *BulkError) bool {
|
||||||
|
lerr, err := b.c.writeOp(bulkDeleteOp(action.docs), b.ordered)
|
||||||
|
if lerr != nil {
|
||||||
|
result.Matched += lerr.N
|
||||||
|
result.Modified += lerr.modified
|
||||||
|
}
|
||||||
|
return b.checkSuccess(action, berr, lerr, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *Bulk) checkSuccess(action *bulkAction, berr *BulkError, lerr *LastError, err error) bool {
|
||||||
|
if lerr != nil && len(lerr.ecases) > 0 {
|
||||||
|
for i := 0; i < len(lerr.ecases); i++ {
|
||||||
|
// Map back from the local error index into the visible one.
|
||||||
|
ecase := lerr.ecases[i]
|
||||||
|
idx := ecase.Index
|
||||||
|
if idx >= 0 {
|
||||||
|
idx = action.idxs[idx]
|
||||||
|
}
|
||||||
|
berr.ecases = append(berr.ecases, BulkErrorCase{idx, ecase.Err})
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
} else if err != nil {
|
||||||
|
for i := 0; i < len(action.idxs); i++ {
|
||||||
|
berr.ecases = append(berr.ecases, BulkErrorCase{action.idxs[i], err})
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
|
@ -0,0 +1,682 @@
|
||||||
|
// mgo - MongoDB driver for Go
|
||||||
|
//
|
||||||
|
// Copyright (c) 2010-2012 - Gustavo Niemeyer <gustavo@niemeyer.net>
|
||||||
|
//
|
||||||
|
// All rights reserved.
|
||||||
|
//
|
||||||
|
// Redistribution and use in source and binary forms, with or without
|
||||||
|
// modification, are permitted provided that the following conditions are met:
|
||||||
|
//
|
||||||
|
// 1. Redistributions of source code must retain the above copyright notice, this
|
||||||
|
// list of conditions and the following disclaimer.
|
||||||
|
// 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||||
|
// this list of conditions and the following disclaimer in the documentation
|
||||||
|
// and/or other materials provided with the distribution.
|
||||||
|
//
|
||||||
|
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
|
||||||
|
// ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
|
||||||
|
// WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||||
|
// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
|
||||||
|
// ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
|
||||||
|
// (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
|
||||||
|
// LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
|
||||||
|
// ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||||
|
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||||
|
// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||||
|
|
||||||
|
package mgo
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"gopkg.in/mgo.v2/bson"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// Mongo cluster encapsulation.
|
||||||
|
//
|
||||||
|
// A cluster enables the communication with one or more servers participating
|
||||||
|
// in a mongo cluster. This works with individual servers, a replica set,
|
||||||
|
// a replica pair, one or multiple mongos routers, etc.
|
||||||
|
|
||||||
|
type mongoCluster struct {
|
||||||
|
sync.RWMutex
|
||||||
|
serverSynced sync.Cond
|
||||||
|
userSeeds []string
|
||||||
|
dynaSeeds []string
|
||||||
|
servers mongoServers
|
||||||
|
masters mongoServers
|
||||||
|
references int
|
||||||
|
syncing bool
|
||||||
|
direct bool
|
||||||
|
failFast bool
|
||||||
|
syncCount uint
|
||||||
|
setName string
|
||||||
|
cachedIndex map[string]bool
|
||||||
|
sync chan bool
|
||||||
|
dial dialer
|
||||||
|
}
|
||||||
|
|
||||||
|
func newCluster(userSeeds []string, direct, failFast bool, dial dialer, setName string) *mongoCluster {
|
||||||
|
cluster := &mongoCluster{
|
||||||
|
userSeeds: userSeeds,
|
||||||
|
references: 1,
|
||||||
|
direct: direct,
|
||||||
|
failFast: failFast,
|
||||||
|
dial: dial,
|
||||||
|
setName: setName,
|
||||||
|
}
|
||||||
|
cluster.serverSynced.L = cluster.RWMutex.RLocker()
|
||||||
|
cluster.sync = make(chan bool, 1)
|
||||||
|
stats.cluster(+1)
|
||||||
|
go cluster.syncServersLoop()
|
||||||
|
return cluster
|
||||||
|
}
|
||||||
|
|
||||||
|
// Acquire increases the reference count for the cluster.
|
||||||
|
func (cluster *mongoCluster) Acquire() {
|
||||||
|
cluster.Lock()
|
||||||
|
cluster.references++
|
||||||
|
debugf("Cluster %p acquired (refs=%d)", cluster, cluster.references)
|
||||||
|
cluster.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Release decreases the reference count for the cluster. Once
|
||||||
|
// it reaches zero, all servers will be closed.
|
||||||
|
func (cluster *mongoCluster) Release() {
|
||||||
|
cluster.Lock()
|
||||||
|
if cluster.references == 0 {
|
||||||
|
panic("cluster.Release() with references == 0")
|
||||||
|
}
|
||||||
|
cluster.references--
|
||||||
|
debugf("Cluster %p released (refs=%d)", cluster, cluster.references)
|
||||||
|
if cluster.references == 0 {
|
||||||
|
for _, server := range cluster.servers.Slice() {
|
||||||
|
server.Close()
|
||||||
|
}
|
||||||
|
// Wake up the sync loop so it can die.
|
||||||
|
cluster.syncServers()
|
||||||
|
stats.cluster(-1)
|
||||||
|
}
|
||||||
|
cluster.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cluster *mongoCluster) LiveServers() (servers []string) {
|
||||||
|
cluster.RLock()
|
||||||
|
for _, serv := range cluster.servers.Slice() {
|
||||||
|
servers = append(servers, serv.Addr)
|
||||||
|
}
|
||||||
|
cluster.RUnlock()
|
||||||
|
return servers
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cluster *mongoCluster) removeServer(server *mongoServer) {
|
||||||
|
cluster.Lock()
|
||||||
|
cluster.masters.Remove(server)
|
||||||
|
other := cluster.servers.Remove(server)
|
||||||
|
cluster.Unlock()
|
||||||
|
if other != nil {
|
||||||
|
other.Close()
|
||||||
|
log("Removed server ", server.Addr, " from cluster.")
|
||||||
|
}
|
||||||
|
server.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
type isMasterResult struct {
|
||||||
|
IsMaster bool
|
||||||
|
Secondary bool
|
||||||
|
Primary string
|
||||||
|
Hosts []string
|
||||||
|
Passives []string
|
||||||
|
Tags bson.D
|
||||||
|
Msg string
|
||||||
|
SetName string `bson:"setName"`
|
||||||
|
MaxWireVersion int `bson:"maxWireVersion"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cluster *mongoCluster) isMaster(socket *mongoSocket, result *isMasterResult) error {
|
||||||
|
// Monotonic let's it talk to a slave and still hold the socket.
|
||||||
|
session := newSession(Monotonic, cluster, 10*time.Second)
|
||||||
|
session.setSocket(socket)
|
||||||
|
err := session.Run("ismaster", result)
|
||||||
|
session.Close()
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
type possibleTimeout interface {
|
||||||
|
Timeout() bool
|
||||||
|
}
|
||||||
|
|
||||||
|
var syncSocketTimeout = 5 * time.Second
|
||||||
|
|
||||||
|
func (cluster *mongoCluster) syncServer(server *mongoServer) (info *mongoServerInfo, hosts []string, err error) {
|
||||||
|
var syncTimeout time.Duration
|
||||||
|
if raceDetector {
|
||||||
|
// This variable is only ever touched by tests.
|
||||||
|
globalMutex.Lock()
|
||||||
|
syncTimeout = syncSocketTimeout
|
||||||
|
globalMutex.Unlock()
|
||||||
|
} else {
|
||||||
|
syncTimeout = syncSocketTimeout
|
||||||
|
}
|
||||||
|
|
||||||
|
addr := server.Addr
|
||||||
|
log("SYNC Processing ", addr, "...")
|
||||||
|
|
||||||
|
// Retry a few times to avoid knocking a server down for a hiccup.
|
||||||
|
var result isMasterResult
|
||||||
|
var tryerr error
|
||||||
|
for retry := 0; ; retry++ {
|
||||||
|
if retry == 3 || retry == 1 && cluster.failFast {
|
||||||
|
return nil, nil, tryerr
|
||||||
|
}
|
||||||
|
if retry > 0 {
|
||||||
|
// Don't abuse the server needlessly if there's something actually wrong.
|
||||||
|
if err, ok := tryerr.(possibleTimeout); ok && err.Timeout() {
|
||||||
|
// Give a chance for waiters to timeout as well.
|
||||||
|
cluster.serverSynced.Broadcast()
|
||||||
|
}
|
||||||
|
time.Sleep(syncShortDelay)
|
||||||
|
}
|
||||||
|
|
||||||
|
// It's not clear what would be a good timeout here. Is it
|
||||||
|
// better to wait longer or to retry?
|
||||||
|
socket, _, err := server.AcquireSocket(0, syncTimeout)
|
||||||
|
if err != nil {
|
||||||
|
tryerr = err
|
||||||
|
logf("SYNC Failed to get socket to %s: %v", addr, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
err = cluster.isMaster(socket, &result)
|
||||||
|
socket.Release()
|
||||||
|
if err != nil {
|
||||||
|
tryerr = err
|
||||||
|
logf("SYNC Command 'ismaster' to %s failed: %v", addr, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
debugf("SYNC Result of 'ismaster' from %s: %#v", addr, result)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
if cluster.setName != "" && result.SetName != cluster.setName {
|
||||||
|
logf("SYNC Server %s is not a member of replica set %q", addr, cluster.setName)
|
||||||
|
return nil, nil, fmt.Errorf("server %s is not a member of replica set %q", addr, cluster.setName)
|
||||||
|
}
|
||||||
|
|
||||||
|
if result.IsMaster {
|
||||||
|
debugf("SYNC %s is a master.", addr)
|
||||||
|
if !server.info.Master {
|
||||||
|
// Made an incorrect assumption above, so fix stats.
|
||||||
|
stats.conn(-1, false)
|
||||||
|
stats.conn(+1, true)
|
||||||
|
}
|
||||||
|
} else if result.Secondary {
|
||||||
|
debugf("SYNC %s is a slave.", addr)
|
||||||
|
} else if cluster.direct {
|
||||||
|
logf("SYNC %s in unknown state. Pretending it's a slave due to direct connection.", addr)
|
||||||
|
} else {
|
||||||
|
logf("SYNC %s is neither a master nor a slave.", addr)
|
||||||
|
// Let stats track it as whatever was known before.
|
||||||
|
return nil, nil, errors.New(addr + " is not a master nor slave")
|
||||||
|
}
|
||||||
|
|
||||||
|
info = &mongoServerInfo{
|
||||||
|
Master: result.IsMaster,
|
||||||
|
Mongos: result.Msg == "isdbgrid",
|
||||||
|
Tags: result.Tags,
|
||||||
|
SetName: result.SetName,
|
||||||
|
MaxWireVersion: result.MaxWireVersion,
|
||||||
|
}
|
||||||
|
|
||||||
|
hosts = make([]string, 0, 1+len(result.Hosts)+len(result.Passives))
|
||||||
|
if result.Primary != "" {
|
||||||
|
// First in the list to speed up master discovery.
|
||||||
|
hosts = append(hosts, result.Primary)
|
||||||
|
}
|
||||||
|
hosts = append(hosts, result.Hosts...)
|
||||||
|
hosts = append(hosts, result.Passives...)
|
||||||
|
|
||||||
|
debugf("SYNC %s knows about the following peers: %#v", addr, hosts)
|
||||||
|
return info, hosts, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type syncKind bool
|
||||||
|
|
||||||
|
const (
|
||||||
|
completeSync syncKind = true
|
||||||
|
partialSync syncKind = false
|
||||||
|
)
|
||||||
|
|
||||||
|
func (cluster *mongoCluster) addServer(server *mongoServer, info *mongoServerInfo, syncKind syncKind) {
|
||||||
|
cluster.Lock()
|
||||||
|
current := cluster.servers.Search(server.ResolvedAddr)
|
||||||
|
if current == nil {
|
||||||
|
if syncKind == partialSync {
|
||||||
|
cluster.Unlock()
|
||||||
|
server.Close()
|
||||||
|
log("SYNC Discarding unknown server ", server.Addr, " due to partial sync.")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
cluster.servers.Add(server)
|
||||||
|
if info.Master {
|
||||||
|
cluster.masters.Add(server)
|
||||||
|
log("SYNC Adding ", server.Addr, " to cluster as a master.")
|
||||||
|
} else {
|
||||||
|
log("SYNC Adding ", server.Addr, " to cluster as a slave.")
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if server != current {
|
||||||
|
panic("addServer attempting to add duplicated server")
|
||||||
|
}
|
||||||
|
if server.Info().Master != info.Master {
|
||||||
|
if info.Master {
|
||||||
|
log("SYNC Server ", server.Addr, " is now a master.")
|
||||||
|
cluster.masters.Add(server)
|
||||||
|
} else {
|
||||||
|
log("SYNC Server ", server.Addr, " is now a slave.")
|
||||||
|
cluster.masters.Remove(server)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
server.SetInfo(info)
|
||||||
|
debugf("SYNC Broadcasting availability of server %s", server.Addr)
|
||||||
|
cluster.serverSynced.Broadcast()
|
||||||
|
cluster.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cluster *mongoCluster) getKnownAddrs() []string {
|
||||||
|
cluster.RLock()
|
||||||
|
max := len(cluster.userSeeds) + len(cluster.dynaSeeds) + cluster.servers.Len()
|
||||||
|
seen := make(map[string]bool, max)
|
||||||
|
known := make([]string, 0, max)
|
||||||
|
|
||||||
|
add := func(addr string) {
|
||||||
|
if _, found := seen[addr]; !found {
|
||||||
|
seen[addr] = true
|
||||||
|
known = append(known, addr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, addr := range cluster.userSeeds {
|
||||||
|
add(addr)
|
||||||
|
}
|
||||||
|
for _, addr := range cluster.dynaSeeds {
|
||||||
|
add(addr)
|
||||||
|
}
|
||||||
|
for _, serv := range cluster.servers.Slice() {
|
||||||
|
add(serv.Addr)
|
||||||
|
}
|
||||||
|
cluster.RUnlock()
|
||||||
|
|
||||||
|
return known
|
||||||
|
}
|
||||||
|
|
||||||
|
// syncServers injects a value into the cluster.sync channel to force
|
||||||
|
// an iteration of the syncServersLoop function.
|
||||||
|
func (cluster *mongoCluster) syncServers() {
|
||||||
|
select {
|
||||||
|
case cluster.sync <- true:
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// How long to wait for a checkup of the cluster topology if nothing
|
||||||
|
// else kicks a synchronization before that.
|
||||||
|
const syncServersDelay = 30 * time.Second
|
||||||
|
const syncShortDelay = 500 * time.Millisecond
|
||||||
|
|
||||||
|
// syncServersLoop loops while the cluster is alive to keep its idea of
|
||||||
|
// the server topology up-to-date. It must be called just once from
|
||||||
|
// newCluster. The loop iterates once syncServersDelay has passed, or
|
||||||
|
// if somebody injects a value into the cluster.sync channel to force a
|
||||||
|
// synchronization. A loop iteration will contact all servers in
|
||||||
|
// parallel, ask them about known peers and their own role within the
|
||||||
|
// cluster, and then attempt to do the same with all the peers
|
||||||
|
// retrieved.
|
||||||
|
func (cluster *mongoCluster) syncServersLoop() {
|
||||||
|
for {
|
||||||
|
debugf("SYNC Cluster %p is starting a sync loop iteration.", cluster)
|
||||||
|
|
||||||
|
cluster.Lock()
|
||||||
|
if cluster.references == 0 {
|
||||||
|
cluster.Unlock()
|
||||||
|
break
|
||||||
|
}
|
||||||
|
cluster.references++ // Keep alive while syncing.
|
||||||
|
direct := cluster.direct
|
||||||
|
cluster.Unlock()
|
||||||
|
|
||||||
|
cluster.syncServersIteration(direct)
|
||||||
|
|
||||||
|
// We just synchronized, so consume any outstanding requests.
|
||||||
|
select {
|
||||||
|
case <-cluster.sync:
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
|
||||||
|
cluster.Release()
|
||||||
|
|
||||||
|
// Hold off before allowing another sync. No point in
|
||||||
|
// burning CPU looking for down servers.
|
||||||
|
if !cluster.failFast {
|
||||||
|
time.Sleep(syncShortDelay)
|
||||||
|
}
|
||||||
|
|
||||||
|
cluster.Lock()
|
||||||
|
if cluster.references == 0 {
|
||||||
|
cluster.Unlock()
|
||||||
|
break
|
||||||
|
}
|
||||||
|
cluster.syncCount++
|
||||||
|
// Poke all waiters so they have a chance to timeout or
|
||||||
|
// restart syncing if they wish to.
|
||||||
|
cluster.serverSynced.Broadcast()
|
||||||
|
// Check if we have to restart immediately either way.
|
||||||
|
restart := !direct && cluster.masters.Empty() || cluster.servers.Empty()
|
||||||
|
cluster.Unlock()
|
||||||
|
|
||||||
|
if restart {
|
||||||
|
log("SYNC No masters found. Will synchronize again.")
|
||||||
|
time.Sleep(syncShortDelay)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
debugf("SYNC Cluster %p waiting for next requested or scheduled sync.", cluster)
|
||||||
|
|
||||||
|
// Hold off until somebody explicitly requests a synchronization
|
||||||
|
// or it's time to check for a cluster topology change again.
|
||||||
|
select {
|
||||||
|
case <-cluster.sync:
|
||||||
|
case <-time.After(syncServersDelay):
|
||||||
|
}
|
||||||
|
}
|
||||||
|
debugf("SYNC Cluster %p is stopping its sync loop.", cluster)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cluster *mongoCluster) server(addr string, tcpaddr *net.TCPAddr) *mongoServer {
|
||||||
|
cluster.RLock()
|
||||||
|
server := cluster.servers.Search(tcpaddr.String())
|
||||||
|
cluster.RUnlock()
|
||||||
|
if server != nil {
|
||||||
|
return server
|
||||||
|
}
|
||||||
|
return newServer(addr, tcpaddr, cluster.sync, cluster.dial)
|
||||||
|
}
|
||||||
|
|
||||||
|
func resolveAddr(addr string) (*net.TCPAddr, error) {
|
||||||
|
// Simple cases that do not need actual resolution. Works with IPv4 and v6.
|
||||||
|
if host, port, err := net.SplitHostPort(addr); err == nil {
|
||||||
|
if port, _ := strconv.Atoi(port); port > 0 {
|
||||||
|
zone := ""
|
||||||
|
if i := strings.LastIndex(host, "%"); i >= 0 {
|
||||||
|
zone = host[i+1:]
|
||||||
|
host = host[:i]
|
||||||
|
}
|
||||||
|
ip := net.ParseIP(host)
|
||||||
|
if ip != nil {
|
||||||
|
return &net.TCPAddr{IP: ip, Port: port, Zone: zone}, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Attempt to resolve IPv4 and v6 concurrently.
|
||||||
|
addrChan := make(chan *net.TCPAddr, 2)
|
||||||
|
for _, network := range []string{"udp4", "udp6"} {
|
||||||
|
network := network
|
||||||
|
go func() {
|
||||||
|
// The unfortunate UDP dialing hack allows having a timeout on address resolution.
|
||||||
|
conn, err := net.DialTimeout(network, addr, 10*time.Second)
|
||||||
|
if err != nil {
|
||||||
|
addrChan <- nil
|
||||||
|
} else {
|
||||||
|
addrChan <- (*net.TCPAddr)(conn.RemoteAddr().(*net.UDPAddr))
|
||||||
|
conn.Close()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait for the result of IPv4 and v6 resolution. Use IPv4 if available.
|
||||||
|
tcpaddr := <-addrChan
|
||||||
|
if tcpaddr == nil || len(tcpaddr.IP) != 4 {
|
||||||
|
var timeout <-chan time.Time
|
||||||
|
if tcpaddr != nil {
|
||||||
|
// Don't wait too long if an IPv6 address is known.
|
||||||
|
timeout = time.After(50 * time.Millisecond)
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case <-timeout:
|
||||||
|
case tcpaddr2 := <-addrChan:
|
||||||
|
if tcpaddr == nil || tcpaddr2 != nil {
|
||||||
|
// It's an IPv4 address or the only known address. Use it.
|
||||||
|
tcpaddr = tcpaddr2
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if tcpaddr == nil {
|
||||||
|
log("SYNC Failed to resolve server address: ", addr)
|
||||||
|
return nil, errors.New("failed to resolve server address: " + addr)
|
||||||
|
}
|
||||||
|
if tcpaddr.String() != addr {
|
||||||
|
debug("SYNC Address ", addr, " resolved as ", tcpaddr.String())
|
||||||
|
}
|
||||||
|
return tcpaddr, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type pendingAdd struct {
|
||||||
|
server *mongoServer
|
||||||
|
info *mongoServerInfo
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cluster *mongoCluster) syncServersIteration(direct bool) {
|
||||||
|
log("SYNC Starting full topology synchronization...")
|
||||||
|
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
var m sync.Mutex
|
||||||
|
notYetAdded := make(map[string]pendingAdd)
|
||||||
|
addIfFound := make(map[string]bool)
|
||||||
|
seen := make(map[string]bool)
|
||||||
|
syncKind := partialSync
|
||||||
|
|
||||||
|
var spawnSync func(addr string, byMaster bool)
|
||||||
|
spawnSync = func(addr string, byMaster bool) {
|
||||||
|
wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
|
||||||
|
tcpaddr, err := resolveAddr(addr)
|
||||||
|
if err != nil {
|
||||||
|
log("SYNC Failed to start sync of ", addr, ": ", err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
resolvedAddr := tcpaddr.String()
|
||||||
|
|
||||||
|
m.Lock()
|
||||||
|
if byMaster {
|
||||||
|
if pending, ok := notYetAdded[resolvedAddr]; ok {
|
||||||
|
delete(notYetAdded, resolvedAddr)
|
||||||
|
m.Unlock()
|
||||||
|
cluster.addServer(pending.server, pending.info, completeSync)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
addIfFound[resolvedAddr] = true
|
||||||
|
}
|
||||||
|
if seen[resolvedAddr] {
|
||||||
|
m.Unlock()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
seen[resolvedAddr] = true
|
||||||
|
m.Unlock()
|
||||||
|
|
||||||
|
server := cluster.server(addr, tcpaddr)
|
||||||
|
info, hosts, err := cluster.syncServer(server)
|
||||||
|
if err != nil {
|
||||||
|
cluster.removeServer(server)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
m.Lock()
|
||||||
|
add := direct || info.Master || addIfFound[resolvedAddr]
|
||||||
|
if add {
|
||||||
|
syncKind = completeSync
|
||||||
|
} else {
|
||||||
|
notYetAdded[resolvedAddr] = pendingAdd{server, info}
|
||||||
|
}
|
||||||
|
m.Unlock()
|
||||||
|
if add {
|
||||||
|
cluster.addServer(server, info, completeSync)
|
||||||
|
}
|
||||||
|
if !direct {
|
||||||
|
for _, addr := range hosts {
|
||||||
|
spawnSync(addr, info.Master)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
knownAddrs := cluster.getKnownAddrs()
|
||||||
|
for _, addr := range knownAddrs {
|
||||||
|
spawnSync(addr, false)
|
||||||
|
}
|
||||||
|
wg.Wait()
|
||||||
|
|
||||||
|
if syncKind == completeSync {
|
||||||
|
logf("SYNC Synchronization was complete (got data from primary).")
|
||||||
|
for _, pending := range notYetAdded {
|
||||||
|
cluster.removeServer(pending.server)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
logf("SYNC Synchronization was partial (cannot talk to primary).")
|
||||||
|
for _, pending := range notYetAdded {
|
||||||
|
cluster.addServer(pending.server, pending.info, partialSync)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
cluster.Lock()
|
||||||
|
mastersLen := cluster.masters.Len()
|
||||||
|
logf("SYNC Synchronization completed: %d master(s) and %d slave(s) alive.", mastersLen, cluster.servers.Len()-mastersLen)
|
||||||
|
|
||||||
|
// Update dynamic seeds, but only if we have any good servers. Otherwise,
|
||||||
|
// leave them alone for better chances of a successful sync in the future.
|
||||||
|
if syncKind == completeSync {
|
||||||
|
dynaSeeds := make([]string, cluster.servers.Len())
|
||||||
|
for i, server := range cluster.servers.Slice() {
|
||||||
|
dynaSeeds[i] = server.Addr
|
||||||
|
}
|
||||||
|
cluster.dynaSeeds = dynaSeeds
|
||||||
|
debugf("SYNC New dynamic seeds: %#v\n", dynaSeeds)
|
||||||
|
}
|
||||||
|
cluster.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
// AcquireSocket returns a socket to a server in the cluster. If slaveOk is
|
||||||
|
// true, it will attempt to return a socket to a slave server. If it is
|
||||||
|
// false, the socket will necessarily be to a master server.
|
||||||
|
func (cluster *mongoCluster) AcquireSocket(mode Mode, slaveOk bool, syncTimeout time.Duration, socketTimeout time.Duration, serverTags []bson.D, poolLimit int) (s *mongoSocket, err error) {
|
||||||
|
var started time.Time
|
||||||
|
var syncCount uint
|
||||||
|
warnedLimit := false
|
||||||
|
for {
|
||||||
|
cluster.RLock()
|
||||||
|
for {
|
||||||
|
mastersLen := cluster.masters.Len()
|
||||||
|
slavesLen := cluster.servers.Len() - mastersLen
|
||||||
|
debugf("Cluster has %d known masters and %d known slaves.", mastersLen, slavesLen)
|
||||||
|
if mastersLen > 0 && !(slaveOk && mode == Secondary) || slavesLen > 0 && slaveOk {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if mastersLen > 0 && mode == Secondary && cluster.masters.HasMongos() {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if started.IsZero() {
|
||||||
|
// Initialize after fast path above.
|
||||||
|
started = time.Now()
|
||||||
|
syncCount = cluster.syncCount
|
||||||
|
} else if syncTimeout != 0 && started.Before(time.Now().Add(-syncTimeout)) || cluster.failFast && cluster.syncCount != syncCount {
|
||||||
|
cluster.RUnlock()
|
||||||
|
return nil, errors.New("no reachable servers")
|
||||||
|
}
|
||||||
|
log("Waiting for servers to synchronize...")
|
||||||
|
cluster.syncServers()
|
||||||
|
|
||||||
|
// Remember: this will release and reacquire the lock.
|
||||||
|
cluster.serverSynced.Wait()
|
||||||
|
}
|
||||||
|
|
||||||
|
var server *mongoServer
|
||||||
|
if slaveOk {
|
||||||
|
server = cluster.servers.BestFit(mode, serverTags)
|
||||||
|
} else {
|
||||||
|
server = cluster.masters.BestFit(mode, nil)
|
||||||
|
}
|
||||||
|
cluster.RUnlock()
|
||||||
|
|
||||||
|
if server == nil {
|
||||||
|
// Must have failed the requested tags. Sleep to avoid spinning.
|
||||||
|
time.Sleep(1e8)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
s, abended, err := server.AcquireSocket(poolLimit, socketTimeout)
|
||||||
|
if err == errPoolLimit {
|
||||||
|
if !warnedLimit {
|
||||||
|
warnedLimit = true
|
||||||
|
log("WARNING: Per-server connection limit reached.")
|
||||||
|
}
|
||||||
|
time.Sleep(100 * time.Millisecond)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
cluster.removeServer(server)
|
||||||
|
cluster.syncServers()
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if abended && !slaveOk {
|
||||||
|
var result isMasterResult
|
||||||
|
err := cluster.isMaster(s, &result)
|
||||||
|
if err != nil || !result.IsMaster {
|
||||||
|
logf("Cannot confirm server %s as master (%v)", server.Addr, err)
|
||||||
|
s.Release()
|
||||||
|
cluster.syncServers()
|
||||||
|
time.Sleep(100 * time.Millisecond)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return s, nil
|
||||||
|
}
|
||||||
|
panic("unreached")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cluster *mongoCluster) CacheIndex(cacheKey string, exists bool) {
|
||||||
|
cluster.Lock()
|
||||||
|
if cluster.cachedIndex == nil {
|
||||||
|
cluster.cachedIndex = make(map[string]bool)
|
||||||
|
}
|
||||||
|
if exists {
|
||||||
|
cluster.cachedIndex[cacheKey] = true
|
||||||
|
} else {
|
||||||
|
delete(cluster.cachedIndex, cacheKey)
|
||||||
|
}
|
||||||
|
cluster.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cluster *mongoCluster) HasCachedIndex(cacheKey string) (result bool) {
|
||||||
|
cluster.RLock()
|
||||||
|
if cluster.cachedIndex != nil {
|
||||||
|
result = cluster.cachedIndex[cacheKey]
|
||||||
|
}
|
||||||
|
cluster.RUnlock()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cluster *mongoCluster) ResetIndexCache() {
|
||||||
|
cluster.Lock()
|
||||||
|
cluster.cachedIndex = make(map[string]bool)
|
||||||
|
cluster.Unlock()
|
||||||
|
}
|
|
@ -0,0 +1,34 @@
|
||||||
|
// Package mgo offers a rich MongoDB driver for Go.
|
||||||
|
//
|
||||||
|
// #########################################################
|
||||||
|
//
|
||||||
|
// THIS DRIVER IS UNMAINTAINED! See here for details:
|
||||||
|
//
|
||||||
|
// https://github.com/go-mgo/mgo/blob/v2-unstable/README.md
|
||||||
|
//
|
||||||
|
// #########################################################
|
||||||
|
//
|
||||||
|
// Usage of the driver revolves around the concept of sessions. To
|
||||||
|
// get started, obtain a session using the Dial function:
|
||||||
|
//
|
||||||
|
// session, err := mgo.Dial(url)
|
||||||
|
//
|
||||||
|
// This will establish one or more connections with the cluster of
|
||||||
|
// servers defined by the url parameter. From then on, the cluster
|
||||||
|
// may be queried with multiple consistency rules (see SetMode) and
|
||||||
|
// documents retrieved with statements such as:
|
||||||
|
//
|
||||||
|
// c := session.DB(database).C(collection)
|
||||||
|
// err := c.Find(query).One(&result)
|
||||||
|
//
|
||||||
|
// New sessions are typically created by calling session.Copy on the
|
||||||
|
// initial session obtained at dial time. These new sessions will share
|
||||||
|
// the same cluster information and connection pool, and may be easily
|
||||||
|
// handed into other methods and functions for organizing logic.
|
||||||
|
// Every session created must have its Close method called at the end
|
||||||
|
// of its life time, so its resources may be put back in the pool or
|
||||||
|
// collected, depending on the case.
|
||||||
|
//
|
||||||
|
// For more details, see the documentation for the types and methods.
|
||||||
|
//
|
||||||
|
package mgo
|
|
@ -0,0 +1,761 @@
|
||||||
|
// mgo - MongoDB driver for Go
|
||||||
|
//
|
||||||
|
// Copyright (c) 2010-2012 - Gustavo Niemeyer <gustavo@niemeyer.net>
|
||||||
|
//
|
||||||
|
// All rights reserved.
|
||||||
|
//
|
||||||
|
// Redistribution and use in source and binary forms, with or without
|
||||||
|
// modification, are permitted provided that the following conditions are met:
|
||||||
|
//
|
||||||
|
// 1. Redistributions of source code must retain the above copyright notice, this
|
||||||
|
// list of conditions and the following disclaimer.
|
||||||
|
// 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||||
|
// this list of conditions and the following disclaimer in the documentation
|
||||||
|
// and/or other materials provided with the distribution.
|
||||||
|
//
|
||||||
|
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
|
||||||
|
// ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
|
||||||
|
// WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||||
|
// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
|
||||||
|
// ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
|
||||||
|
// (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
|
||||||
|
// LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
|
||||||
|
// ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||||
|
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||||
|
// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||||
|
|
||||||
|
package mgo
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/md5"
|
||||||
|
"encoding/hex"
|
||||||
|
"errors"
|
||||||
|
"hash"
|
||||||
|
"io"
|
||||||
|
"os"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"gopkg.in/mgo.v2/bson"
|
||||||
|
)
|
||||||
|
|
||||||
|
type GridFS struct {
|
||||||
|
Files *Collection
|
||||||
|
Chunks *Collection
|
||||||
|
}
|
||||||
|
|
||||||
|
type gfsFileMode int
|
||||||
|
|
||||||
|
const (
|
||||||
|
gfsClosed gfsFileMode = 0
|
||||||
|
gfsReading gfsFileMode = 1
|
||||||
|
gfsWriting gfsFileMode = 2
|
||||||
|
)
|
||||||
|
|
||||||
|
type GridFile struct {
|
||||||
|
m sync.Mutex
|
||||||
|
c sync.Cond
|
||||||
|
gfs *GridFS
|
||||||
|
mode gfsFileMode
|
||||||
|
err error
|
||||||
|
|
||||||
|
chunk int
|
||||||
|
offset int64
|
||||||
|
|
||||||
|
wpending int
|
||||||
|
wbuf []byte
|
||||||
|
wsum hash.Hash
|
||||||
|
|
||||||
|
rbuf []byte
|
||||||
|
rcache *gfsCachedChunk
|
||||||
|
|
||||||
|
doc gfsFile
|
||||||
|
}
|
||||||
|
|
||||||
|
type gfsFile struct {
|
||||||
|
Id interface{} "_id"
|
||||||
|
ChunkSize int "chunkSize"
|
||||||
|
UploadDate time.Time "uploadDate"
|
||||||
|
Length int64 ",minsize"
|
||||||
|
MD5 string
|
||||||
|
Filename string ",omitempty"
|
||||||
|
ContentType string "contentType,omitempty"
|
||||||
|
Metadata *bson.Raw ",omitempty"
|
||||||
|
}
|
||||||
|
|
||||||
|
type gfsChunk struct {
|
||||||
|
Id interface{} "_id"
|
||||||
|
FilesId interface{} "files_id"
|
||||||
|
N int
|
||||||
|
Data []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
type gfsCachedChunk struct {
|
||||||
|
wait sync.Mutex
|
||||||
|
n int
|
||||||
|
data []byte
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
|
||||||
|
func newGridFS(db *Database, prefix string) *GridFS {
|
||||||
|
return &GridFS{db.C(prefix + ".files"), db.C(prefix + ".chunks")}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (gfs *GridFS) newFile() *GridFile {
|
||||||
|
file := &GridFile{gfs: gfs}
|
||||||
|
file.c.L = &file.m
|
||||||
|
//runtime.SetFinalizer(file, finalizeFile)
|
||||||
|
return file
|
||||||
|
}
|
||||||
|
|
||||||
|
func finalizeFile(file *GridFile) {
|
||||||
|
file.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create creates a new file with the provided name in the GridFS. If the file
|
||||||
|
// name already exists, a new version will be inserted with an up-to-date
|
||||||
|
// uploadDate that will cause it to be atomically visible to the Open and
|
||||||
|
// OpenId methods. If the file name is not important, an empty name may be
|
||||||
|
// provided and the file Id used instead.
|
||||||
|
//
|
||||||
|
// It's important to Close files whether they are being written to
|
||||||
|
// or read from, and to check the err result to ensure the operation
|
||||||
|
// completed successfully.
|
||||||
|
//
|
||||||
|
// A simple example inserting a new file:
|
||||||
|
//
|
||||||
|
// func check(err error) {
|
||||||
|
// if err != nil {
|
||||||
|
// panic(err.String())
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
// file, err := db.GridFS("fs").Create("myfile.txt")
|
||||||
|
// check(err)
|
||||||
|
// n, err := file.Write([]byte("Hello world!"))
|
||||||
|
// check(err)
|
||||||
|
// err = file.Close()
|
||||||
|
// check(err)
|
||||||
|
// fmt.Printf("%d bytes written\n", n)
|
||||||
|
//
|
||||||
|
// The io.Writer interface is implemented by *GridFile and may be used to
|
||||||
|
// help on the file creation. For example:
|
||||||
|
//
|
||||||
|
// file, err := db.GridFS("fs").Create("myfile.txt")
|
||||||
|
// check(err)
|
||||||
|
// messages, err := os.Open("/var/log/messages")
|
||||||
|
// check(err)
|
||||||
|
// defer messages.Close()
|
||||||
|
// err = io.Copy(file, messages)
|
||||||
|
// check(err)
|
||||||
|
// err = file.Close()
|
||||||
|
// check(err)
|
||||||
|
//
|
||||||
|
func (gfs *GridFS) Create(name string) (file *GridFile, err error) {
|
||||||
|
file = gfs.newFile()
|
||||||
|
file.mode = gfsWriting
|
||||||
|
file.wsum = md5.New()
|
||||||
|
file.doc = gfsFile{Id: bson.NewObjectId(), ChunkSize: 255 * 1024, Filename: name}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// OpenId returns the file with the provided id, for reading.
|
||||||
|
// If the file isn't found, err will be set to mgo.ErrNotFound.
|
||||||
|
//
|
||||||
|
// It's important to Close files whether they are being written to
|
||||||
|
// or read from, and to check the err result to ensure the operation
|
||||||
|
// completed successfully.
|
||||||
|
//
|
||||||
|
// The following example will print the first 8192 bytes from the file:
|
||||||
|
//
|
||||||
|
// func check(err error) {
|
||||||
|
// if err != nil {
|
||||||
|
// panic(err.String())
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
// file, err := db.GridFS("fs").OpenId(objid)
|
||||||
|
// check(err)
|
||||||
|
// b := make([]byte, 8192)
|
||||||
|
// n, err := file.Read(b)
|
||||||
|
// check(err)
|
||||||
|
// fmt.Println(string(b))
|
||||||
|
// check(err)
|
||||||
|
// err = file.Close()
|
||||||
|
// check(err)
|
||||||
|
// fmt.Printf("%d bytes read\n", n)
|
||||||
|
//
|
||||||
|
// The io.Reader interface is implemented by *GridFile and may be used to
|
||||||
|
// deal with it. As an example, the following snippet will dump the whole
|
||||||
|
// file into the standard output:
|
||||||
|
//
|
||||||
|
// file, err := db.GridFS("fs").OpenId(objid)
|
||||||
|
// check(err)
|
||||||
|
// err = io.Copy(os.Stdout, file)
|
||||||
|
// check(err)
|
||||||
|
// err = file.Close()
|
||||||
|
// check(err)
|
||||||
|
//
|
||||||
|
func (gfs *GridFS) OpenId(id interface{}) (file *GridFile, err error) {
|
||||||
|
var doc gfsFile
|
||||||
|
err = gfs.Files.Find(bson.M{"_id": id}).One(&doc)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
file = gfs.newFile()
|
||||||
|
file.mode = gfsReading
|
||||||
|
file.doc = doc
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Open returns the most recently uploaded file with the provided
|
||||||
|
// name, for reading. If the file isn't found, err will be set
|
||||||
|
// to mgo.ErrNotFound.
|
||||||
|
//
|
||||||
|
// It's important to Close files whether they are being written to
|
||||||
|
// or read from, and to check the err result to ensure the operation
|
||||||
|
// completed successfully.
|
||||||
|
//
|
||||||
|
// The following example will print the first 8192 bytes from the file:
|
||||||
|
//
|
||||||
|
// file, err := db.GridFS("fs").Open("myfile.txt")
|
||||||
|
// check(err)
|
||||||
|
// b := make([]byte, 8192)
|
||||||
|
// n, err := file.Read(b)
|
||||||
|
// check(err)
|
||||||
|
// fmt.Println(string(b))
|
||||||
|
// check(err)
|
||||||
|
// err = file.Close()
|
||||||
|
// check(err)
|
||||||
|
// fmt.Printf("%d bytes read\n", n)
|
||||||
|
//
|
||||||
|
// The io.Reader interface is implemented by *GridFile and may be used to
|
||||||
|
// deal with it. As an example, the following snippet will dump the whole
|
||||||
|
// file into the standard output:
|
||||||
|
//
|
||||||
|
// file, err := db.GridFS("fs").Open("myfile.txt")
|
||||||
|
// check(err)
|
||||||
|
// err = io.Copy(os.Stdout, file)
|
||||||
|
// check(err)
|
||||||
|
// err = file.Close()
|
||||||
|
// check(err)
|
||||||
|
//
|
||||||
|
func (gfs *GridFS) Open(name string) (file *GridFile, err error) {
|
||||||
|
var doc gfsFile
|
||||||
|
err = gfs.Files.Find(bson.M{"filename": name}).Sort("-uploadDate").One(&doc)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
file = gfs.newFile()
|
||||||
|
file.mode = gfsReading
|
||||||
|
file.doc = doc
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// OpenNext opens the next file from iter for reading, sets *file to it,
|
||||||
|
// and returns true on the success case. If no more documents are available
|
||||||
|
// on iter or an error occurred, *file is set to nil and the result is false.
|
||||||
|
// Errors will be available via iter.Err().
|
||||||
|
//
|
||||||
|
// The iter parameter must be an iterator on the GridFS files collection.
|
||||||
|
// Using the GridFS.Find method is an easy way to obtain such an iterator,
|
||||||
|
// but any iterator on the collection will work.
|
||||||
|
//
|
||||||
|
// If the provided *file is non-nil, OpenNext will close it before attempting
|
||||||
|
// to iterate to the next element. This means that in a loop one only
|
||||||
|
// has to worry about closing files when breaking out of the loop early
|
||||||
|
// (break, return, or panic).
|
||||||
|
//
|
||||||
|
// For example:
|
||||||
|
//
|
||||||
|
// gfs := db.GridFS("fs")
|
||||||
|
// query := gfs.Find(nil).Sort("filename")
|
||||||
|
// iter := query.Iter()
|
||||||
|
// var f *mgo.GridFile
|
||||||
|
// for gfs.OpenNext(iter, &f) {
|
||||||
|
// fmt.Printf("Filename: %s\n", f.Name())
|
||||||
|
// }
|
||||||
|
// if iter.Close() != nil {
|
||||||
|
// panic(iter.Close())
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
func (gfs *GridFS) OpenNext(iter *Iter, file **GridFile) bool {
|
||||||
|
if *file != nil {
|
||||||
|
// Ignoring the error here shouldn't be a big deal
|
||||||
|
// as we're reading the file and the loop iteration
|
||||||
|
// for this file is finished.
|
||||||
|
_ = (*file).Close()
|
||||||
|
}
|
||||||
|
var doc gfsFile
|
||||||
|
if !iter.Next(&doc) {
|
||||||
|
*file = nil
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
f := gfs.newFile()
|
||||||
|
f.mode = gfsReading
|
||||||
|
f.doc = doc
|
||||||
|
*file = f
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// Find runs query on GridFS's files collection and returns
|
||||||
|
// the resulting Query.
|
||||||
|
//
|
||||||
|
// This logic:
|
||||||
|
//
|
||||||
|
// gfs := db.GridFS("fs")
|
||||||
|
// iter := gfs.Find(nil).Iter()
|
||||||
|
//
|
||||||
|
// Is equivalent to:
|
||||||
|
//
|
||||||
|
// files := db.C("fs" + ".files")
|
||||||
|
// iter := files.Find(nil).Iter()
|
||||||
|
//
|
||||||
|
func (gfs *GridFS) Find(query interface{}) *Query {
|
||||||
|
return gfs.Files.Find(query)
|
||||||
|
}
|
||||||
|
|
||||||
|
// RemoveId deletes the file with the provided id from the GridFS.
|
||||||
|
func (gfs *GridFS) RemoveId(id interface{}) error {
|
||||||
|
err := gfs.Files.Remove(bson.M{"_id": id})
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
_, err = gfs.Chunks.RemoveAll(bson.D{{"files_id", id}})
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
type gfsDocId struct {
|
||||||
|
Id interface{} "_id"
|
||||||
|
}
|
||||||
|
|
||||||
|
// Remove deletes all files with the provided name from the GridFS.
|
||||||
|
func (gfs *GridFS) Remove(name string) (err error) {
|
||||||
|
iter := gfs.Files.Find(bson.M{"filename": name}).Select(bson.M{"_id": 1}).Iter()
|
||||||
|
var doc gfsDocId
|
||||||
|
for iter.Next(&doc) {
|
||||||
|
if e := gfs.RemoveId(doc.Id); e != nil {
|
||||||
|
err = e
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if err == nil {
|
||||||
|
err = iter.Close()
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (file *GridFile) assertMode(mode gfsFileMode) {
|
||||||
|
switch file.mode {
|
||||||
|
case mode:
|
||||||
|
return
|
||||||
|
case gfsWriting:
|
||||||
|
panic("GridFile is open for writing")
|
||||||
|
case gfsReading:
|
||||||
|
panic("GridFile is open for reading")
|
||||||
|
case gfsClosed:
|
||||||
|
panic("GridFile is closed")
|
||||||
|
default:
|
||||||
|
panic("internal error: missing GridFile mode")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetChunkSize sets size of saved chunks. Once the file is written to, it
|
||||||
|
// will be split in blocks of that size and each block saved into an
|
||||||
|
// independent chunk document. The default chunk size is 255kb.
|
||||||
|
//
|
||||||
|
// It is a runtime error to call this function once the file has started
|
||||||
|
// being written to.
|
||||||
|
func (file *GridFile) SetChunkSize(bytes int) {
|
||||||
|
file.assertMode(gfsWriting)
|
||||||
|
debugf("GridFile %p: setting chunk size to %d", file, bytes)
|
||||||
|
file.m.Lock()
|
||||||
|
file.doc.ChunkSize = bytes
|
||||||
|
file.m.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Id returns the current file Id.
|
||||||
|
func (file *GridFile) Id() interface{} {
|
||||||
|
return file.doc.Id
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetId changes the current file Id.
|
||||||
|
//
|
||||||
|
// It is a runtime error to call this function once the file has started
|
||||||
|
// being written to, or when the file is not open for writing.
|
||||||
|
func (file *GridFile) SetId(id interface{}) {
|
||||||
|
file.assertMode(gfsWriting)
|
||||||
|
file.m.Lock()
|
||||||
|
file.doc.Id = id
|
||||||
|
file.m.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Name returns the optional file name. An empty string will be returned
|
||||||
|
// in case it is unset.
|
||||||
|
func (file *GridFile) Name() string {
|
||||||
|
return file.doc.Filename
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetName changes the optional file name. An empty string may be used to
|
||||||
|
// unset it.
|
||||||
|
//
|
||||||
|
// It is a runtime error to call this function when the file is not open
|
||||||
|
// for writing.
|
||||||
|
func (file *GridFile) SetName(name string) {
|
||||||
|
file.assertMode(gfsWriting)
|
||||||
|
file.m.Lock()
|
||||||
|
file.doc.Filename = name
|
||||||
|
file.m.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
// ContentType returns the optional file content type. An empty string will be
|
||||||
|
// returned in case it is unset.
|
||||||
|
func (file *GridFile) ContentType() string {
|
||||||
|
return file.doc.ContentType
|
||||||
|
}
|
||||||
|
|
||||||
|
// ContentType changes the optional file content type. An empty string may be
|
||||||
|
// used to unset it.
|
||||||
|
//
|
||||||
|
// It is a runtime error to call this function when the file is not open
|
||||||
|
// for writing.
|
||||||
|
func (file *GridFile) SetContentType(ctype string) {
|
||||||
|
file.assertMode(gfsWriting)
|
||||||
|
file.m.Lock()
|
||||||
|
file.doc.ContentType = ctype
|
||||||
|
file.m.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetMeta unmarshals the optional "metadata" field associated with the
|
||||||
|
// file into the result parameter. The meaning of keys under that field
|
||||||
|
// is user-defined. For example:
|
||||||
|
//
|
||||||
|
// result := struct{ INode int }{}
|
||||||
|
// err = file.GetMeta(&result)
|
||||||
|
// if err != nil {
|
||||||
|
// panic(err.String())
|
||||||
|
// }
|
||||||
|
// fmt.Printf("inode: %d\n", result.INode)
|
||||||
|
//
|
||||||
|
func (file *GridFile) GetMeta(result interface{}) (err error) {
|
||||||
|
file.m.Lock()
|
||||||
|
if file.doc.Metadata != nil {
|
||||||
|
err = bson.Unmarshal(file.doc.Metadata.Data, result)
|
||||||
|
}
|
||||||
|
file.m.Unlock()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetMeta changes the optional "metadata" field associated with the
|
||||||
|
// file. The meaning of keys under that field is user-defined.
|
||||||
|
// For example:
|
||||||
|
//
|
||||||
|
// file.SetMeta(bson.M{"inode": inode})
|
||||||
|
//
|
||||||
|
// It is a runtime error to call this function when the file is not open
|
||||||
|
// for writing.
|
||||||
|
func (file *GridFile) SetMeta(metadata interface{}) {
|
||||||
|
file.assertMode(gfsWriting)
|
||||||
|
data, err := bson.Marshal(metadata)
|
||||||
|
file.m.Lock()
|
||||||
|
if err != nil && file.err == nil {
|
||||||
|
file.err = err
|
||||||
|
} else {
|
||||||
|
file.doc.Metadata = &bson.Raw{Data: data}
|
||||||
|
}
|
||||||
|
file.m.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Size returns the file size in bytes.
|
||||||
|
func (file *GridFile) Size() (bytes int64) {
|
||||||
|
file.m.Lock()
|
||||||
|
bytes = file.doc.Length
|
||||||
|
file.m.Unlock()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// MD5 returns the file MD5 as a hex-encoded string.
|
||||||
|
func (file *GridFile) MD5() (md5 string) {
|
||||||
|
return file.doc.MD5
|
||||||
|
}
|
||||||
|
|
||||||
|
// UploadDate returns the file upload time.
|
||||||
|
func (file *GridFile) UploadDate() time.Time {
|
||||||
|
return file.doc.UploadDate
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetUploadDate changes the file upload time.
|
||||||
|
//
|
||||||
|
// It is a runtime error to call this function when the file is not open
|
||||||
|
// for writing.
|
||||||
|
func (file *GridFile) SetUploadDate(t time.Time) {
|
||||||
|
file.assertMode(gfsWriting)
|
||||||
|
file.m.Lock()
|
||||||
|
file.doc.UploadDate = t
|
||||||
|
file.m.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close flushes any pending changes in case the file is being written
|
||||||
|
// to, waits for any background operations to finish, and closes the file.
|
||||||
|
//
|
||||||
|
// It's important to Close files whether they are being written to
|
||||||
|
// or read from, and to check the err result to ensure the operation
|
||||||
|
// completed successfully.
|
||||||
|
func (file *GridFile) Close() (err error) {
|
||||||
|
file.m.Lock()
|
||||||
|
defer file.m.Unlock()
|
||||||
|
if file.mode == gfsWriting {
|
||||||
|
if len(file.wbuf) > 0 && file.err == nil {
|
||||||
|
file.insertChunk(file.wbuf)
|
||||||
|
file.wbuf = file.wbuf[0:0]
|
||||||
|
}
|
||||||
|
file.completeWrite()
|
||||||
|
} else if file.mode == gfsReading && file.rcache != nil {
|
||||||
|
file.rcache.wait.Lock()
|
||||||
|
file.rcache = nil
|
||||||
|
}
|
||||||
|
file.mode = gfsClosed
|
||||||
|
debugf("GridFile %p: closed", file)
|
||||||
|
return file.err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (file *GridFile) completeWrite() {
|
||||||
|
for file.wpending > 0 {
|
||||||
|
debugf("GridFile %p: waiting for %d pending chunks to complete file write", file, file.wpending)
|
||||||
|
file.c.Wait()
|
||||||
|
}
|
||||||
|
if file.err == nil {
|
||||||
|
hexsum := hex.EncodeToString(file.wsum.Sum(nil))
|
||||||
|
if file.doc.UploadDate.IsZero() {
|
||||||
|
file.doc.UploadDate = bson.Now()
|
||||||
|
}
|
||||||
|
file.doc.MD5 = hexsum
|
||||||
|
file.err = file.gfs.Files.Insert(file.doc)
|
||||||
|
}
|
||||||
|
if file.err != nil {
|
||||||
|
file.gfs.Chunks.RemoveAll(bson.D{{"files_id", file.doc.Id}})
|
||||||
|
}
|
||||||
|
if file.err == nil {
|
||||||
|
index := Index{
|
||||||
|
Key: []string{"files_id", "n"},
|
||||||
|
Unique: true,
|
||||||
|
}
|
||||||
|
file.err = file.gfs.Chunks.EnsureIndex(index)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Abort cancels an in-progress write, preventing the file from being
|
||||||
|
// automically created and ensuring previously written chunks are
|
||||||
|
// removed when the file is closed.
|
||||||
|
//
|
||||||
|
// It is a runtime error to call Abort when the file was not opened
|
||||||
|
// for writing.
|
||||||
|
func (file *GridFile) Abort() {
|
||||||
|
if file.mode != gfsWriting {
|
||||||
|
panic("file.Abort must be called on file opened for writing")
|
||||||
|
}
|
||||||
|
file.err = errors.New("write aborted")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write writes the provided data to the file and returns the
|
||||||
|
// number of bytes written and an error in case something
|
||||||
|
// wrong happened.
|
||||||
|
//
|
||||||
|
// The file will internally cache the data so that all but the last
|
||||||
|
// chunk sent to the database have the size defined by SetChunkSize.
|
||||||
|
// This also means that errors may be deferred until a future call
|
||||||
|
// to Write or Close.
|
||||||
|
//
|
||||||
|
// The parameters and behavior of this function turn the file
|
||||||
|
// into an io.Writer.
|
||||||
|
func (file *GridFile) Write(data []byte) (n int, err error) {
|
||||||
|
file.assertMode(gfsWriting)
|
||||||
|
file.m.Lock()
|
||||||
|
debugf("GridFile %p: writing %d bytes", file, len(data))
|
||||||
|
defer file.m.Unlock()
|
||||||
|
|
||||||
|
if file.err != nil {
|
||||||
|
return 0, file.err
|
||||||
|
}
|
||||||
|
|
||||||
|
n = len(data)
|
||||||
|
file.doc.Length += int64(n)
|
||||||
|
chunkSize := file.doc.ChunkSize
|
||||||
|
|
||||||
|
if len(file.wbuf)+len(data) < chunkSize {
|
||||||
|
file.wbuf = append(file.wbuf, data...)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// First, flush file.wbuf complementing with data.
|
||||||
|
if len(file.wbuf) > 0 {
|
||||||
|
missing := chunkSize - len(file.wbuf)
|
||||||
|
if missing > len(data) {
|
||||||
|
missing = len(data)
|
||||||
|
}
|
||||||
|
file.wbuf = append(file.wbuf, data[:missing]...)
|
||||||
|
data = data[missing:]
|
||||||
|
file.insertChunk(file.wbuf)
|
||||||
|
file.wbuf = file.wbuf[0:0]
|
||||||
|
}
|
||||||
|
|
||||||
|
// Then, flush all chunks from data without copying.
|
||||||
|
for len(data) > chunkSize {
|
||||||
|
size := chunkSize
|
||||||
|
if size > len(data) {
|
||||||
|
size = len(data)
|
||||||
|
}
|
||||||
|
file.insertChunk(data[:size])
|
||||||
|
data = data[size:]
|
||||||
|
}
|
||||||
|
|
||||||
|
// And append the rest for a future call.
|
||||||
|
file.wbuf = append(file.wbuf, data...)
|
||||||
|
|
||||||
|
return n, file.err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (file *GridFile) insertChunk(data []byte) {
|
||||||
|
n := file.chunk
|
||||||
|
file.chunk++
|
||||||
|
debugf("GridFile %p: adding to checksum: %q", file, string(data))
|
||||||
|
file.wsum.Write(data)
|
||||||
|
|
||||||
|
for file.doc.ChunkSize*file.wpending >= 1024*1024 {
|
||||||
|
// Hold on.. we got a MB pending.
|
||||||
|
file.c.Wait()
|
||||||
|
if file.err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
file.wpending++
|
||||||
|
|
||||||
|
debugf("GridFile %p: inserting chunk %d with %d bytes", file, n, len(data))
|
||||||
|
|
||||||
|
// We may not own the memory of data, so rather than
|
||||||
|
// simply copying it, we'll marshal the document ahead of time.
|
||||||
|
data, err := bson.Marshal(gfsChunk{bson.NewObjectId(), file.doc.Id, n, data})
|
||||||
|
if err != nil {
|
||||||
|
file.err = err
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
err := file.gfs.Chunks.Insert(bson.Raw{Data: data})
|
||||||
|
file.m.Lock()
|
||||||
|
file.wpending--
|
||||||
|
if err != nil && file.err == nil {
|
||||||
|
file.err = err
|
||||||
|
}
|
||||||
|
file.c.Broadcast()
|
||||||
|
file.m.Unlock()
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Seek sets the offset for the next Read or Write on file to
|
||||||
|
// offset, interpreted according to whence: 0 means relative to
|
||||||
|
// the origin of the file, 1 means relative to the current offset,
|
||||||
|
// and 2 means relative to the end. It returns the new offset and
|
||||||
|
// an error, if any.
|
||||||
|
func (file *GridFile) Seek(offset int64, whence int) (pos int64, err error) {
|
||||||
|
file.m.Lock()
|
||||||
|
debugf("GridFile %p: seeking for %d (whence=%d)", file, offset, whence)
|
||||||
|
defer file.m.Unlock()
|
||||||
|
switch whence {
|
||||||
|
case os.SEEK_SET:
|
||||||
|
case os.SEEK_CUR:
|
||||||
|
offset += file.offset
|
||||||
|
case os.SEEK_END:
|
||||||
|
offset += file.doc.Length
|
||||||
|
default:
|
||||||
|
panic("unsupported whence value")
|
||||||
|
}
|
||||||
|
if offset > file.doc.Length {
|
||||||
|
return file.offset, errors.New("seek past end of file")
|
||||||
|
}
|
||||||
|
if offset == file.doc.Length {
|
||||||
|
// If we're seeking to the end of the file,
|
||||||
|
// no need to read anything. This enables
|
||||||
|
// a client to find the size of the file using only the
|
||||||
|
// io.ReadSeeker interface with low overhead.
|
||||||
|
file.offset = offset
|
||||||
|
return file.offset, nil
|
||||||
|
}
|
||||||
|
chunk := int(offset / int64(file.doc.ChunkSize))
|
||||||
|
if chunk+1 == file.chunk && offset >= file.offset {
|
||||||
|
file.rbuf = file.rbuf[int(offset-file.offset):]
|
||||||
|
file.offset = offset
|
||||||
|
return file.offset, nil
|
||||||
|
}
|
||||||
|
file.offset = offset
|
||||||
|
file.chunk = chunk
|
||||||
|
file.rbuf = nil
|
||||||
|
file.rbuf, err = file.getChunk()
|
||||||
|
if err == nil {
|
||||||
|
file.rbuf = file.rbuf[int(file.offset-int64(chunk)*int64(file.doc.ChunkSize)):]
|
||||||
|
}
|
||||||
|
return file.offset, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read reads into b the next available data from the file and
|
||||||
|
// returns the number of bytes written and an error in case
|
||||||
|
// something wrong happened. At the end of the file, n will
|
||||||
|
// be zero and err will be set to io.EOF.
|
||||||
|
//
|
||||||
|
// The parameters and behavior of this function turn the file
|
||||||
|
// into an io.Reader.
|
||||||
|
func (file *GridFile) Read(b []byte) (n int, err error) {
|
||||||
|
file.assertMode(gfsReading)
|
||||||
|
file.m.Lock()
|
||||||
|
debugf("GridFile %p: reading at offset %d into buffer of length %d", file, file.offset, len(b))
|
||||||
|
defer file.m.Unlock()
|
||||||
|
if file.offset == file.doc.Length {
|
||||||
|
return 0, io.EOF
|
||||||
|
}
|
||||||
|
for err == nil {
|
||||||
|
i := copy(b, file.rbuf)
|
||||||
|
n += i
|
||||||
|
file.offset += int64(i)
|
||||||
|
file.rbuf = file.rbuf[i:]
|
||||||
|
if i == len(b) || file.offset == file.doc.Length {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
b = b[i:]
|
||||||
|
file.rbuf, err = file.getChunk()
|
||||||
|
}
|
||||||
|
return n, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (file *GridFile) getChunk() (data []byte, err error) {
|
||||||
|
cache := file.rcache
|
||||||
|
file.rcache = nil
|
||||||
|
if cache != nil && cache.n == file.chunk {
|
||||||
|
debugf("GridFile %p: Getting chunk %d from cache", file, file.chunk)
|
||||||
|
cache.wait.Lock()
|
||||||
|
data, err = cache.data, cache.err
|
||||||
|
} else {
|
||||||
|
debugf("GridFile %p: Fetching chunk %d", file, file.chunk)
|
||||||
|
var doc gfsChunk
|
||||||
|
err = file.gfs.Chunks.Find(bson.D{{"files_id", file.doc.Id}, {"n", file.chunk}}).One(&doc)
|
||||||
|
data = doc.Data
|
||||||
|
}
|
||||||
|
file.chunk++
|
||||||
|
if int64(file.chunk)*int64(file.doc.ChunkSize) < file.doc.Length {
|
||||||
|
// Read the next one in background.
|
||||||
|
cache = &gfsCachedChunk{n: file.chunk}
|
||||||
|
cache.wait.Lock()
|
||||||
|
debugf("GridFile %p: Scheduling chunk %d for background caching", file, file.chunk)
|
||||||
|
// Clone the session to avoid having it closed in between.
|
||||||
|
chunks := file.gfs.Chunks
|
||||||
|
session := chunks.Database.Session.Clone()
|
||||||
|
go func(id interface{}, n int) {
|
||||||
|
defer session.Close()
|
||||||
|
chunks = chunks.With(session)
|
||||||
|
var doc gfsChunk
|
||||||
|
cache.err = chunks.Find(bson.D{{"files_id", id}, {"n", n}}).One(&doc)
|
||||||
|
cache.data = doc.Data
|
||||||
|
cache.wait.Unlock()
|
||||||
|
}(file.doc.Id, file.chunk)
|
||||||
|
file.rcache = cache
|
||||||
|
}
|
||||||
|
debugf("Returning err: %#v", err)
|
||||||
|
return
|
||||||
|
}
|
|
@ -0,0 +1,27 @@
|
||||||
|
Copyright (c) 2012 The Go Authors. All rights reserved.
|
||||||
|
|
||||||
|
Redistribution and use in source and binary forms, with or without
|
||||||
|
modification, are permitted provided that the following conditions are
|
||||||
|
met:
|
||||||
|
|
||||||
|
* Redistributions of source code must retain the above copyright
|
||||||
|
notice, this list of conditions and the following disclaimer.
|
||||||
|
* Redistributions in binary form must reproduce the above
|
||||||
|
copyright notice, this list of conditions and the following disclaimer
|
||||||
|
in the documentation and/or other materials provided with the
|
||||||
|
distribution.
|
||||||
|
* Neither the name of Google Inc. nor the names of its
|
||||||
|
contributors may be used to endorse or promote products derived from
|
||||||
|
this software without specific prior written permission.
|
||||||
|
|
||||||
|
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||||
|
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||||
|
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||||
|
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
|
||||||
|
OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
||||||
|
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
|
||||||
|
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
||||||
|
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
|
||||||
|
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||||
|
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||||
|
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
Разница между файлами не показана из-за своего большого размера
Загрузить разницу
Разница между файлами не показана из-за своего большого размера
Загрузить разницу
|
@ -0,0 +1,95 @@
|
||||||
|
package json
|
||||||
|
|
||||||
|
import (
|
||||||
|
"reflect"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Extension holds a set of additional rules to be used when unmarshaling
|
||||||
|
// strict JSON or JSON-like content.
|
||||||
|
type Extension struct {
|
||||||
|
funcs map[string]funcExt
|
||||||
|
consts map[string]interface{}
|
||||||
|
keyed map[string]func([]byte) (interface{}, error)
|
||||||
|
encode map[reflect.Type]func(v interface{}) ([]byte, error)
|
||||||
|
|
||||||
|
unquotedKeys bool
|
||||||
|
trailingCommas bool
|
||||||
|
}
|
||||||
|
|
||||||
|
type funcExt struct {
|
||||||
|
key string
|
||||||
|
args []string
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extend changes the decoder behavior to consider the provided extension.
|
||||||
|
func (dec *Decoder) Extend(ext *Extension) { dec.d.ext = *ext }
|
||||||
|
|
||||||
|
// Extend changes the encoder behavior to consider the provided extension.
|
||||||
|
func (enc *Encoder) Extend(ext *Extension) { enc.ext = *ext }
|
||||||
|
|
||||||
|
// Extend includes in e the extensions defined in ext.
|
||||||
|
func (e *Extension) Extend(ext *Extension) {
|
||||||
|
for name, fext := range ext.funcs {
|
||||||
|
e.DecodeFunc(name, fext.key, fext.args...)
|
||||||
|
}
|
||||||
|
for name, value := range ext.consts {
|
||||||
|
e.DecodeConst(name, value)
|
||||||
|
}
|
||||||
|
for key, decode := range ext.keyed {
|
||||||
|
e.DecodeKeyed(key, decode)
|
||||||
|
}
|
||||||
|
for typ, encode := range ext.encode {
|
||||||
|
if e.encode == nil {
|
||||||
|
e.encode = make(map[reflect.Type]func(v interface{}) ([]byte, error))
|
||||||
|
}
|
||||||
|
e.encode[typ] = encode
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// DecodeFunc defines a function call that may be observed inside JSON content.
|
||||||
|
// A function with the provided name will be unmarshaled as the document
|
||||||
|
// {key: {args[0]: ..., args[N]: ...}}.
|
||||||
|
func (e *Extension) DecodeFunc(name string, key string, args ...string) {
|
||||||
|
if e.funcs == nil {
|
||||||
|
e.funcs = make(map[string]funcExt)
|
||||||
|
}
|
||||||
|
e.funcs[name] = funcExt{key, args}
|
||||||
|
}
|
||||||
|
|
||||||
|
// DecodeConst defines a constant name that may be observed inside JSON content
|
||||||
|
// and will be decoded with the provided value.
|
||||||
|
func (e *Extension) DecodeConst(name string, value interface{}) {
|
||||||
|
if e.consts == nil {
|
||||||
|
e.consts = make(map[string]interface{})
|
||||||
|
}
|
||||||
|
e.consts[name] = value
|
||||||
|
}
|
||||||
|
|
||||||
|
// DecodeKeyed defines a key that when observed as the first element inside a
|
||||||
|
// JSON document triggers the decoding of that document via the provided
|
||||||
|
// decode function.
|
||||||
|
func (e *Extension) DecodeKeyed(key string, decode func(data []byte) (interface{}, error)) {
|
||||||
|
if e.keyed == nil {
|
||||||
|
e.keyed = make(map[string]func([]byte) (interface{}, error))
|
||||||
|
}
|
||||||
|
e.keyed[key] = decode
|
||||||
|
}
|
||||||
|
|
||||||
|
// DecodeUnquotedKeys defines whether to accept map keys that are unquoted strings.
|
||||||
|
func (e *Extension) DecodeUnquotedKeys(accept bool) {
|
||||||
|
e.unquotedKeys = accept
|
||||||
|
}
|
||||||
|
|
||||||
|
// DecodeTrailingCommas defines whether to accept trailing commas in maps and arrays.
|
||||||
|
func (e *Extension) DecodeTrailingCommas(accept bool) {
|
||||||
|
e.trailingCommas = accept
|
||||||
|
}
|
||||||
|
|
||||||
|
// EncodeType registers a function to encode values with the same type of the
|
||||||
|
// provided sample.
|
||||||
|
func (e *Extension) EncodeType(sample interface{}, encode func(v interface{}) ([]byte, error)) {
|
||||||
|
if e.encode == nil {
|
||||||
|
e.encode = make(map[reflect.Type]func(v interface{}) ([]byte, error))
|
||||||
|
}
|
||||||
|
e.encode[reflect.TypeOf(sample)] = encode
|
||||||
|
}
|
|
@ -0,0 +1,143 @@
|
||||||
|
// Copyright 2013 The Go Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package json
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"unicode/utf8"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
caseMask = ^byte(0x20) // Mask to ignore case in ASCII.
|
||||||
|
kelvin = '\u212a'
|
||||||
|
smallLongEss = '\u017f'
|
||||||
|
)
|
||||||
|
|
||||||
|
// foldFunc returns one of four different case folding equivalence
|
||||||
|
// functions, from most general (and slow) to fastest:
|
||||||
|
//
|
||||||
|
// 1) bytes.EqualFold, if the key s contains any non-ASCII UTF-8
|
||||||
|
// 2) equalFoldRight, if s contains special folding ASCII ('k', 'K', 's', 'S')
|
||||||
|
// 3) asciiEqualFold, no special, but includes non-letters (including _)
|
||||||
|
// 4) simpleLetterEqualFold, no specials, no non-letters.
|
||||||
|
//
|
||||||
|
// The letters S and K are special because they map to 3 runes, not just 2:
|
||||||
|
// * S maps to s and to U+017F 'ſ' Latin small letter long s
|
||||||
|
// * k maps to K and to U+212A 'K' Kelvin sign
|
||||||
|
// See https://play.golang.org/p/tTxjOc0OGo
|
||||||
|
//
|
||||||
|
// The returned function is specialized for matching against s and
|
||||||
|
// should only be given s. It's not curried for performance reasons.
|
||||||
|
func foldFunc(s []byte) func(s, t []byte) bool {
|
||||||
|
nonLetter := false
|
||||||
|
special := false // special letter
|
||||||
|
for _, b := range s {
|
||||||
|
if b >= utf8.RuneSelf {
|
||||||
|
return bytes.EqualFold
|
||||||
|
}
|
||||||
|
upper := b & caseMask
|
||||||
|
if upper < 'A' || upper > 'Z' {
|
||||||
|
nonLetter = true
|
||||||
|
} else if upper == 'K' || upper == 'S' {
|
||||||
|
// See above for why these letters are special.
|
||||||
|
special = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if special {
|
||||||
|
return equalFoldRight
|
||||||
|
}
|
||||||
|
if nonLetter {
|
||||||
|
return asciiEqualFold
|
||||||
|
}
|
||||||
|
return simpleLetterEqualFold
|
||||||
|
}
|
||||||
|
|
||||||
|
// equalFoldRight is a specialization of bytes.EqualFold when s is
|
||||||
|
// known to be all ASCII (including punctuation), but contains an 's',
|
||||||
|
// 'S', 'k', or 'K', requiring a Unicode fold on the bytes in t.
|
||||||
|
// See comments on foldFunc.
|
||||||
|
func equalFoldRight(s, t []byte) bool {
|
||||||
|
for _, sb := range s {
|
||||||
|
if len(t) == 0 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
tb := t[0]
|
||||||
|
if tb < utf8.RuneSelf {
|
||||||
|
if sb != tb {
|
||||||
|
sbUpper := sb & caseMask
|
||||||
|
if 'A' <= sbUpper && sbUpper <= 'Z' {
|
||||||
|
if sbUpper != tb&caseMask {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
t = t[1:]
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
// sb is ASCII and t is not. t must be either kelvin
|
||||||
|
// sign or long s; sb must be s, S, k, or K.
|
||||||
|
tr, size := utf8.DecodeRune(t)
|
||||||
|
switch sb {
|
||||||
|
case 's', 'S':
|
||||||
|
if tr != smallLongEss {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
case 'k', 'K':
|
||||||
|
if tr != kelvin {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
t = t[size:]
|
||||||
|
|
||||||
|
}
|
||||||
|
if len(t) > 0 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// asciiEqualFold is a specialization of bytes.EqualFold for use when
|
||||||
|
// s is all ASCII (but may contain non-letters) and contains no
|
||||||
|
// special-folding letters.
|
||||||
|
// See comments on foldFunc.
|
||||||
|
func asciiEqualFold(s, t []byte) bool {
|
||||||
|
if len(s) != len(t) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
for i, sb := range s {
|
||||||
|
tb := t[i]
|
||||||
|
if sb == tb {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if ('a' <= sb && sb <= 'z') || ('A' <= sb && sb <= 'Z') {
|
||||||
|
if sb&caseMask != tb&caseMask {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// simpleLetterEqualFold is a specialization of bytes.EqualFold for
|
||||||
|
// use when s is all ASCII letters (no underscores, etc) and also
|
||||||
|
// doesn't contain 'k', 'K', 's', or 'S'.
|
||||||
|
// See comments on foldFunc.
|
||||||
|
func simpleLetterEqualFold(s, t []byte) bool {
|
||||||
|
if len(s) != len(t) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
for i, b := range s {
|
||||||
|
if b&caseMask != t[i]&caseMask {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
|
@ -0,0 +1,141 @@
|
||||||
|
// Copyright 2010 The Go Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package json
|
||||||
|
|
||||||
|
import "bytes"
|
||||||
|
|
||||||
|
// Compact appends to dst the JSON-encoded src with
|
||||||
|
// insignificant space characters elided.
|
||||||
|
func Compact(dst *bytes.Buffer, src []byte) error {
|
||||||
|
return compact(dst, src, false)
|
||||||
|
}
|
||||||
|
|
||||||
|
func compact(dst *bytes.Buffer, src []byte, escape bool) error {
|
||||||
|
origLen := dst.Len()
|
||||||
|
var scan scanner
|
||||||
|
scan.reset()
|
||||||
|
start := 0
|
||||||
|
for i, c := range src {
|
||||||
|
if escape && (c == '<' || c == '>' || c == '&') {
|
||||||
|
if start < i {
|
||||||
|
dst.Write(src[start:i])
|
||||||
|
}
|
||||||
|
dst.WriteString(`\u00`)
|
||||||
|
dst.WriteByte(hex[c>>4])
|
||||||
|
dst.WriteByte(hex[c&0xF])
|
||||||
|
start = i + 1
|
||||||
|
}
|
||||||
|
// Convert U+2028 and U+2029 (E2 80 A8 and E2 80 A9).
|
||||||
|
if c == 0xE2 && i+2 < len(src) && src[i+1] == 0x80 && src[i+2]&^1 == 0xA8 {
|
||||||
|
if start < i {
|
||||||
|
dst.Write(src[start:i])
|
||||||
|
}
|
||||||
|
dst.WriteString(`\u202`)
|
||||||
|
dst.WriteByte(hex[src[i+2]&0xF])
|
||||||
|
start = i + 3
|
||||||
|
}
|
||||||
|
v := scan.step(&scan, c)
|
||||||
|
if v >= scanSkipSpace {
|
||||||
|
if v == scanError {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if start < i {
|
||||||
|
dst.Write(src[start:i])
|
||||||
|
}
|
||||||
|
start = i + 1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if scan.eof() == scanError {
|
||||||
|
dst.Truncate(origLen)
|
||||||
|
return scan.err
|
||||||
|
}
|
||||||
|
if start < len(src) {
|
||||||
|
dst.Write(src[start:])
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func newline(dst *bytes.Buffer, prefix, indent string, depth int) {
|
||||||
|
dst.WriteByte('\n')
|
||||||
|
dst.WriteString(prefix)
|
||||||
|
for i := 0; i < depth; i++ {
|
||||||
|
dst.WriteString(indent)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Indent appends to dst an indented form of the JSON-encoded src.
|
||||||
|
// Each element in a JSON object or array begins on a new,
|
||||||
|
// indented line beginning with prefix followed by one or more
|
||||||
|
// copies of indent according to the indentation nesting.
|
||||||
|
// The data appended to dst does not begin with the prefix nor
|
||||||
|
// any indentation, to make it easier to embed inside other formatted JSON data.
|
||||||
|
// Although leading space characters (space, tab, carriage return, newline)
|
||||||
|
// at the beginning of src are dropped, trailing space characters
|
||||||
|
// at the end of src are preserved and copied to dst.
|
||||||
|
// For example, if src has no trailing spaces, neither will dst;
|
||||||
|
// if src ends in a trailing newline, so will dst.
|
||||||
|
func Indent(dst *bytes.Buffer, src []byte, prefix, indent string) error {
|
||||||
|
origLen := dst.Len()
|
||||||
|
var scan scanner
|
||||||
|
scan.reset()
|
||||||
|
needIndent := false
|
||||||
|
depth := 0
|
||||||
|
for _, c := range src {
|
||||||
|
scan.bytes++
|
||||||
|
v := scan.step(&scan, c)
|
||||||
|
if v == scanSkipSpace {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if v == scanError {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if needIndent && v != scanEndObject && v != scanEndArray {
|
||||||
|
needIndent = false
|
||||||
|
depth++
|
||||||
|
newline(dst, prefix, indent, depth)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Emit semantically uninteresting bytes
|
||||||
|
// (in particular, punctuation in strings) unmodified.
|
||||||
|
if v == scanContinue {
|
||||||
|
dst.WriteByte(c)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add spacing around real punctuation.
|
||||||
|
switch c {
|
||||||
|
case '{', '[':
|
||||||
|
// delay indent so that empty object and array are formatted as {} and [].
|
||||||
|
needIndent = true
|
||||||
|
dst.WriteByte(c)
|
||||||
|
|
||||||
|
case ',':
|
||||||
|
dst.WriteByte(c)
|
||||||
|
newline(dst, prefix, indent, depth)
|
||||||
|
|
||||||
|
case ':':
|
||||||
|
dst.WriteByte(c)
|
||||||
|
dst.WriteByte(' ')
|
||||||
|
|
||||||
|
case '}', ']':
|
||||||
|
if needIndent {
|
||||||
|
// suppress indent in empty object/array
|
||||||
|
needIndent = false
|
||||||
|
} else {
|
||||||
|
depth--
|
||||||
|
newline(dst, prefix, indent, depth)
|
||||||
|
}
|
||||||
|
dst.WriteByte(c)
|
||||||
|
|
||||||
|
default:
|
||||||
|
dst.WriteByte(c)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if scan.eof() == scanError {
|
||||||
|
dst.Truncate(origLen)
|
||||||
|
return scan.err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
|
@ -0,0 +1,697 @@
|
||||||
|
// Copyright 2010 The Go Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package json
|
||||||
|
|
||||||
|
// JSON value parser state machine.
|
||||||
|
// Just about at the limit of what is reasonable to write by hand.
|
||||||
|
// Some parts are a bit tedious, but overall it nicely factors out the
|
||||||
|
// otherwise common code from the multiple scanning functions
|
||||||
|
// in this package (Compact, Indent, checkValid, nextValue, etc).
|
||||||
|
//
|
||||||
|
// This file starts with two simple examples using the scanner
|
||||||
|
// before diving into the scanner itself.
|
||||||
|
|
||||||
|
import "strconv"
|
||||||
|
|
||||||
|
// checkValid verifies that data is valid JSON-encoded data.
|
||||||
|
// scan is passed in for use by checkValid to avoid an allocation.
|
||||||
|
func checkValid(data []byte, scan *scanner) error {
|
||||||
|
scan.reset()
|
||||||
|
for _, c := range data {
|
||||||
|
scan.bytes++
|
||||||
|
if scan.step(scan, c) == scanError {
|
||||||
|
return scan.err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if scan.eof() == scanError {
|
||||||
|
return scan.err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// nextValue splits data after the next whole JSON value,
|
||||||
|
// returning that value and the bytes that follow it as separate slices.
|
||||||
|
// scan is passed in for use by nextValue to avoid an allocation.
|
||||||
|
func nextValue(data []byte, scan *scanner) (value, rest []byte, err error) {
|
||||||
|
scan.reset()
|
||||||
|
for i, c := range data {
|
||||||
|
v := scan.step(scan, c)
|
||||||
|
if v >= scanEndObject {
|
||||||
|
switch v {
|
||||||
|
// probe the scanner with a space to determine whether we will
|
||||||
|
// get scanEnd on the next character. Otherwise, if the next character
|
||||||
|
// is not a space, scanEndTop allocates a needless error.
|
||||||
|
case scanEndObject, scanEndArray, scanEndParams:
|
||||||
|
if scan.step(scan, ' ') == scanEnd {
|
||||||
|
return data[:i+1], data[i+1:], nil
|
||||||
|
}
|
||||||
|
case scanError:
|
||||||
|
return nil, nil, scan.err
|
||||||
|
case scanEnd:
|
||||||
|
return data[:i], data[i:], nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if scan.eof() == scanError {
|
||||||
|
return nil, nil, scan.err
|
||||||
|
}
|
||||||
|
return data, nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// A SyntaxError is a description of a JSON syntax error.
|
||||||
|
type SyntaxError struct {
|
||||||
|
msg string // description of error
|
||||||
|
Offset int64 // error occurred after reading Offset bytes
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *SyntaxError) Error() string { return e.msg }
|
||||||
|
|
||||||
|
// A scanner is a JSON scanning state machine.
|
||||||
|
// Callers call scan.reset() and then pass bytes in one at a time
|
||||||
|
// by calling scan.step(&scan, c) for each byte.
|
||||||
|
// The return value, referred to as an opcode, tells the
|
||||||
|
// caller about significant parsing events like beginning
|
||||||
|
// and ending literals, objects, and arrays, so that the
|
||||||
|
// caller can follow along if it wishes.
|
||||||
|
// The return value scanEnd indicates that a single top-level
|
||||||
|
// JSON value has been completed, *before* the byte that
|
||||||
|
// just got passed in. (The indication must be delayed in order
|
||||||
|
// to recognize the end of numbers: is 123 a whole value or
|
||||||
|
// the beginning of 12345e+6?).
|
||||||
|
type scanner struct {
|
||||||
|
// The step is a func to be called to execute the next transition.
|
||||||
|
// Also tried using an integer constant and a single func
|
||||||
|
// with a switch, but using the func directly was 10% faster
|
||||||
|
// on a 64-bit Mac Mini, and it's nicer to read.
|
||||||
|
step func(*scanner, byte) int
|
||||||
|
|
||||||
|
// Reached end of top-level value.
|
||||||
|
endTop bool
|
||||||
|
|
||||||
|
// Stack of what we're in the middle of - array values, object keys, object values.
|
||||||
|
parseState []int
|
||||||
|
|
||||||
|
// Error that happened, if any.
|
||||||
|
err error
|
||||||
|
|
||||||
|
// 1-byte redo (see undo method)
|
||||||
|
redo bool
|
||||||
|
redoCode int
|
||||||
|
redoState func(*scanner, byte) int
|
||||||
|
|
||||||
|
// total bytes consumed, updated by decoder.Decode
|
||||||
|
bytes int64
|
||||||
|
}
|
||||||
|
|
||||||
|
// These values are returned by the state transition functions
|
||||||
|
// assigned to scanner.state and the method scanner.eof.
|
||||||
|
// They give details about the current state of the scan that
|
||||||
|
// callers might be interested to know about.
|
||||||
|
// It is okay to ignore the return value of any particular
|
||||||
|
// call to scanner.state: if one call returns scanError,
|
||||||
|
// every subsequent call will return scanError too.
|
||||||
|
const (
|
||||||
|
// Continue.
|
||||||
|
scanContinue = iota // uninteresting byte
|
||||||
|
scanBeginLiteral // end implied by next result != scanContinue
|
||||||
|
scanBeginObject // begin object
|
||||||
|
scanObjectKey // just finished object key (string)
|
||||||
|
scanObjectValue // just finished non-last object value
|
||||||
|
scanEndObject // end object (implies scanObjectValue if possible)
|
||||||
|
scanBeginArray // begin array
|
||||||
|
scanArrayValue // just finished array value
|
||||||
|
scanEndArray // end array (implies scanArrayValue if possible)
|
||||||
|
scanBeginName // begin function call
|
||||||
|
scanParam // begin function argument
|
||||||
|
scanEndParams // end function call
|
||||||
|
scanSkipSpace // space byte; can skip; known to be last "continue" result
|
||||||
|
|
||||||
|
// Stop.
|
||||||
|
scanEnd // top-level value ended *before* this byte; known to be first "stop" result
|
||||||
|
scanError // hit an error, scanner.err.
|
||||||
|
)
|
||||||
|
|
||||||
|
// These values are stored in the parseState stack.
|
||||||
|
// They give the current state of a composite value
|
||||||
|
// being scanned. If the parser is inside a nested value
|
||||||
|
// the parseState describes the nested state, outermost at entry 0.
|
||||||
|
const (
|
||||||
|
parseObjectKey = iota // parsing object key (before colon)
|
||||||
|
parseObjectValue // parsing object value (after colon)
|
||||||
|
parseArrayValue // parsing array value
|
||||||
|
parseName // parsing unquoted name
|
||||||
|
parseParam // parsing function argument value
|
||||||
|
)
|
||||||
|
|
||||||
|
// reset prepares the scanner for use.
|
||||||
|
// It must be called before calling s.step.
|
||||||
|
func (s *scanner) reset() {
|
||||||
|
s.step = stateBeginValue
|
||||||
|
s.parseState = s.parseState[0:0]
|
||||||
|
s.err = nil
|
||||||
|
s.redo = false
|
||||||
|
s.endTop = false
|
||||||
|
}
|
||||||
|
|
||||||
|
// eof tells the scanner that the end of input has been reached.
|
||||||
|
// It returns a scan status just as s.step does.
|
||||||
|
func (s *scanner) eof() int {
|
||||||
|
if s.err != nil {
|
||||||
|
return scanError
|
||||||
|
}
|
||||||
|
if s.endTop {
|
||||||
|
return scanEnd
|
||||||
|
}
|
||||||
|
s.step(s, ' ')
|
||||||
|
if s.endTop {
|
||||||
|
return scanEnd
|
||||||
|
}
|
||||||
|
if s.err == nil {
|
||||||
|
s.err = &SyntaxError{"unexpected end of JSON input", s.bytes}
|
||||||
|
}
|
||||||
|
return scanError
|
||||||
|
}
|
||||||
|
|
||||||
|
// pushParseState pushes a new parse state p onto the parse stack.
|
||||||
|
func (s *scanner) pushParseState(p int) {
|
||||||
|
s.parseState = append(s.parseState, p)
|
||||||
|
}
|
||||||
|
|
||||||
|
// popParseState pops a parse state (already obtained) off the stack
|
||||||
|
// and updates s.step accordingly.
|
||||||
|
func (s *scanner) popParseState() {
|
||||||
|
n := len(s.parseState) - 1
|
||||||
|
s.parseState = s.parseState[0:n]
|
||||||
|
s.redo = false
|
||||||
|
if n == 0 {
|
||||||
|
s.step = stateEndTop
|
||||||
|
s.endTop = true
|
||||||
|
} else {
|
||||||
|
s.step = stateEndValue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func isSpace(c byte) bool {
|
||||||
|
return c == ' ' || c == '\t' || c == '\r' || c == '\n'
|
||||||
|
}
|
||||||
|
|
||||||
|
// stateBeginValueOrEmpty is the state after reading `[`.
|
||||||
|
func stateBeginValueOrEmpty(s *scanner, c byte) int {
|
||||||
|
if c <= ' ' && isSpace(c) {
|
||||||
|
return scanSkipSpace
|
||||||
|
}
|
||||||
|
if c == ']' {
|
||||||
|
return stateEndValue(s, c)
|
||||||
|
}
|
||||||
|
return stateBeginValue(s, c)
|
||||||
|
}
|
||||||
|
|
||||||
|
// stateBeginValue is the state at the beginning of the input.
|
||||||
|
func stateBeginValue(s *scanner, c byte) int {
|
||||||
|
if c <= ' ' && isSpace(c) {
|
||||||
|
return scanSkipSpace
|
||||||
|
}
|
||||||
|
switch c {
|
||||||
|
case '{':
|
||||||
|
s.step = stateBeginStringOrEmpty
|
||||||
|
s.pushParseState(parseObjectKey)
|
||||||
|
return scanBeginObject
|
||||||
|
case '[':
|
||||||
|
s.step = stateBeginValueOrEmpty
|
||||||
|
s.pushParseState(parseArrayValue)
|
||||||
|
return scanBeginArray
|
||||||
|
case '"':
|
||||||
|
s.step = stateInString
|
||||||
|
return scanBeginLiteral
|
||||||
|
case '-':
|
||||||
|
s.step = stateNeg
|
||||||
|
return scanBeginLiteral
|
||||||
|
case '0': // beginning of 0.123
|
||||||
|
s.step = state0
|
||||||
|
return scanBeginLiteral
|
||||||
|
case 'n':
|
||||||
|
s.step = stateNew0
|
||||||
|
return scanBeginName
|
||||||
|
}
|
||||||
|
if '1' <= c && c <= '9' { // beginning of 1234.5
|
||||||
|
s.step = state1
|
||||||
|
return scanBeginLiteral
|
||||||
|
}
|
||||||
|
if isName(c) {
|
||||||
|
s.step = stateName
|
||||||
|
return scanBeginName
|
||||||
|
}
|
||||||
|
return s.error(c, "looking for beginning of value")
|
||||||
|
}
|
||||||
|
|
||||||
|
func isName(c byte) bool {
|
||||||
|
return c == '$' || c == '_' || 'a' <= c && c <= 'z' || 'A' <= c && c <= 'Z' || '0' <= c && c <= '9'
|
||||||
|
}
|
||||||
|
|
||||||
|
// stateBeginStringOrEmpty is the state after reading `{`.
|
||||||
|
func stateBeginStringOrEmpty(s *scanner, c byte) int {
|
||||||
|
if c <= ' ' && isSpace(c) {
|
||||||
|
return scanSkipSpace
|
||||||
|
}
|
||||||
|
if c == '}' {
|
||||||
|
n := len(s.parseState)
|
||||||
|
s.parseState[n-1] = parseObjectValue
|
||||||
|
return stateEndValue(s, c)
|
||||||
|
}
|
||||||
|
return stateBeginString(s, c)
|
||||||
|
}
|
||||||
|
|
||||||
|
// stateBeginString is the state after reading `{"key": value,`.
|
||||||
|
func stateBeginString(s *scanner, c byte) int {
|
||||||
|
if c <= ' ' && isSpace(c) {
|
||||||
|
return scanSkipSpace
|
||||||
|
}
|
||||||
|
if c == '"' {
|
||||||
|
s.step = stateInString
|
||||||
|
return scanBeginLiteral
|
||||||
|
}
|
||||||
|
if isName(c) {
|
||||||
|
s.step = stateName
|
||||||
|
return scanBeginName
|
||||||
|
}
|
||||||
|
return s.error(c, "looking for beginning of object key string")
|
||||||
|
}
|
||||||
|
|
||||||
|
// stateEndValue is the state after completing a value,
|
||||||
|
// such as after reading `{}` or `true` or `["x"`.
|
||||||
|
func stateEndValue(s *scanner, c byte) int {
|
||||||
|
n := len(s.parseState)
|
||||||
|
if n == 0 {
|
||||||
|
// Completed top-level before the current byte.
|
||||||
|
s.step = stateEndTop
|
||||||
|
s.endTop = true
|
||||||
|
return stateEndTop(s, c)
|
||||||
|
}
|
||||||
|
if c <= ' ' && isSpace(c) {
|
||||||
|
s.step = stateEndValue
|
||||||
|
return scanSkipSpace
|
||||||
|
}
|
||||||
|
ps := s.parseState[n-1]
|
||||||
|
switch ps {
|
||||||
|
case parseObjectKey:
|
||||||
|
if c == ':' {
|
||||||
|
s.parseState[n-1] = parseObjectValue
|
||||||
|
s.step = stateBeginValue
|
||||||
|
return scanObjectKey
|
||||||
|
}
|
||||||
|
return s.error(c, "after object key")
|
||||||
|
case parseObjectValue:
|
||||||
|
if c == ',' {
|
||||||
|
s.parseState[n-1] = parseObjectKey
|
||||||
|
s.step = stateBeginStringOrEmpty
|
||||||
|
return scanObjectValue
|
||||||
|
}
|
||||||
|
if c == '}' {
|
||||||
|
s.popParseState()
|
||||||
|
return scanEndObject
|
||||||
|
}
|
||||||
|
return s.error(c, "after object key:value pair")
|
||||||
|
case parseArrayValue:
|
||||||
|
if c == ',' {
|
||||||
|
s.step = stateBeginValueOrEmpty
|
||||||
|
return scanArrayValue
|
||||||
|
}
|
||||||
|
if c == ']' {
|
||||||
|
s.popParseState()
|
||||||
|
return scanEndArray
|
||||||
|
}
|
||||||
|
return s.error(c, "after array element")
|
||||||
|
case parseParam:
|
||||||
|
if c == ',' {
|
||||||
|
s.step = stateBeginValue
|
||||||
|
return scanParam
|
||||||
|
}
|
||||||
|
if c == ')' {
|
||||||
|
s.popParseState()
|
||||||
|
return scanEndParams
|
||||||
|
}
|
||||||
|
return s.error(c, "after array element")
|
||||||
|
}
|
||||||
|
return s.error(c, "")
|
||||||
|
}
|
||||||
|
|
||||||
|
// stateEndTop is the state after finishing the top-level value,
|
||||||
|
// such as after reading `{}` or `[1,2,3]`.
|
||||||
|
// Only space characters should be seen now.
|
||||||
|
func stateEndTop(s *scanner, c byte) int {
|
||||||
|
if c != ' ' && c != '\t' && c != '\r' && c != '\n' {
|
||||||
|
// Complain about non-space byte on next call.
|
||||||
|
s.error(c, "after top-level value")
|
||||||
|
}
|
||||||
|
return scanEnd
|
||||||
|
}
|
||||||
|
|
||||||
|
// stateInString is the state after reading `"`.
|
||||||
|
func stateInString(s *scanner, c byte) int {
|
||||||
|
if c == '"' {
|
||||||
|
s.step = stateEndValue
|
||||||
|
return scanContinue
|
||||||
|
}
|
||||||
|
if c == '\\' {
|
||||||
|
s.step = stateInStringEsc
|
||||||
|
return scanContinue
|
||||||
|
}
|
||||||
|
if c < 0x20 {
|
||||||
|
return s.error(c, "in string literal")
|
||||||
|
}
|
||||||
|
return scanContinue
|
||||||
|
}
|
||||||
|
|
||||||
|
// stateInStringEsc is the state after reading `"\` during a quoted string.
|
||||||
|
func stateInStringEsc(s *scanner, c byte) int {
|
||||||
|
switch c {
|
||||||
|
case 'b', 'f', 'n', 'r', 't', '\\', '/', '"':
|
||||||
|
s.step = stateInString
|
||||||
|
return scanContinue
|
||||||
|
case 'u':
|
||||||
|
s.step = stateInStringEscU
|
||||||
|
return scanContinue
|
||||||
|
}
|
||||||
|
return s.error(c, "in string escape code")
|
||||||
|
}
|
||||||
|
|
||||||
|
// stateInStringEscU is the state after reading `"\u` during a quoted string.
|
||||||
|
func stateInStringEscU(s *scanner, c byte) int {
|
||||||
|
if '0' <= c && c <= '9' || 'a' <= c && c <= 'f' || 'A' <= c && c <= 'F' {
|
||||||
|
s.step = stateInStringEscU1
|
||||||
|
return scanContinue
|
||||||
|
}
|
||||||
|
// numbers
|
||||||
|
return s.error(c, "in \\u hexadecimal character escape")
|
||||||
|
}
|
||||||
|
|
||||||
|
// stateInStringEscU1 is the state after reading `"\u1` during a quoted string.
|
||||||
|
func stateInStringEscU1(s *scanner, c byte) int {
|
||||||
|
if '0' <= c && c <= '9' || 'a' <= c && c <= 'f' || 'A' <= c && c <= 'F' {
|
||||||
|
s.step = stateInStringEscU12
|
||||||
|
return scanContinue
|
||||||
|
}
|
||||||
|
// numbers
|
||||||
|
return s.error(c, "in \\u hexadecimal character escape")
|
||||||
|
}
|
||||||
|
|
||||||
|
// stateInStringEscU12 is the state after reading `"\u12` during a quoted string.
|
||||||
|
func stateInStringEscU12(s *scanner, c byte) int {
|
||||||
|
if '0' <= c && c <= '9' || 'a' <= c && c <= 'f' || 'A' <= c && c <= 'F' {
|
||||||
|
s.step = stateInStringEscU123
|
||||||
|
return scanContinue
|
||||||
|
}
|
||||||
|
// numbers
|
||||||
|
return s.error(c, "in \\u hexadecimal character escape")
|
||||||
|
}
|
||||||
|
|
||||||
|
// stateInStringEscU123 is the state after reading `"\u123` during a quoted string.
|
||||||
|
func stateInStringEscU123(s *scanner, c byte) int {
|
||||||
|
if '0' <= c && c <= '9' || 'a' <= c && c <= 'f' || 'A' <= c && c <= 'F' {
|
||||||
|
s.step = stateInString
|
||||||
|
return scanContinue
|
||||||
|
}
|
||||||
|
// numbers
|
||||||
|
return s.error(c, "in \\u hexadecimal character escape")
|
||||||
|
}
|
||||||
|
|
||||||
|
// stateNeg is the state after reading `-` during a number.
|
||||||
|
func stateNeg(s *scanner, c byte) int {
|
||||||
|
if c == '0' {
|
||||||
|
s.step = state0
|
||||||
|
return scanContinue
|
||||||
|
}
|
||||||
|
if '1' <= c && c <= '9' {
|
||||||
|
s.step = state1
|
||||||
|
return scanContinue
|
||||||
|
}
|
||||||
|
return s.error(c, "in numeric literal")
|
||||||
|
}
|
||||||
|
|
||||||
|
// state1 is the state after reading a non-zero integer during a number,
|
||||||
|
// such as after reading `1` or `100` but not `0`.
|
||||||
|
func state1(s *scanner, c byte) int {
|
||||||
|
if '0' <= c && c <= '9' {
|
||||||
|
s.step = state1
|
||||||
|
return scanContinue
|
||||||
|
}
|
||||||
|
return state0(s, c)
|
||||||
|
}
|
||||||
|
|
||||||
|
// state0 is the state after reading `0` during a number.
|
||||||
|
func state0(s *scanner, c byte) int {
|
||||||
|
if c == '.' {
|
||||||
|
s.step = stateDot
|
||||||
|
return scanContinue
|
||||||
|
}
|
||||||
|
if c == 'e' || c == 'E' {
|
||||||
|
s.step = stateE
|
||||||
|
return scanContinue
|
||||||
|
}
|
||||||
|
return stateEndValue(s, c)
|
||||||
|
}
|
||||||
|
|
||||||
|
// stateDot is the state after reading the integer and decimal point in a number,
|
||||||
|
// such as after reading `1.`.
|
||||||
|
func stateDot(s *scanner, c byte) int {
|
||||||
|
if '0' <= c && c <= '9' {
|
||||||
|
s.step = stateDot0
|
||||||
|
return scanContinue
|
||||||
|
}
|
||||||
|
return s.error(c, "after decimal point in numeric literal")
|
||||||
|
}
|
||||||
|
|
||||||
|
// stateDot0 is the state after reading the integer, decimal point, and subsequent
|
||||||
|
// digits of a number, such as after reading `3.14`.
|
||||||
|
func stateDot0(s *scanner, c byte) int {
|
||||||
|
if '0' <= c && c <= '9' {
|
||||||
|
return scanContinue
|
||||||
|
}
|
||||||
|
if c == 'e' || c == 'E' {
|
||||||
|
s.step = stateE
|
||||||
|
return scanContinue
|
||||||
|
}
|
||||||
|
return stateEndValue(s, c)
|
||||||
|
}
|
||||||
|
|
||||||
|
// stateE is the state after reading the mantissa and e in a number,
|
||||||
|
// such as after reading `314e` or `0.314e`.
|
||||||
|
func stateE(s *scanner, c byte) int {
|
||||||
|
if c == '+' || c == '-' {
|
||||||
|
s.step = stateESign
|
||||||
|
return scanContinue
|
||||||
|
}
|
||||||
|
return stateESign(s, c)
|
||||||
|
}
|
||||||
|
|
||||||
|
// stateESign is the state after reading the mantissa, e, and sign in a number,
|
||||||
|
// such as after reading `314e-` or `0.314e+`.
|
||||||
|
func stateESign(s *scanner, c byte) int {
|
||||||
|
if '0' <= c && c <= '9' {
|
||||||
|
s.step = stateE0
|
||||||
|
return scanContinue
|
||||||
|
}
|
||||||
|
return s.error(c, "in exponent of numeric literal")
|
||||||
|
}
|
||||||
|
|
||||||
|
// stateE0 is the state after reading the mantissa, e, optional sign,
|
||||||
|
// and at least one digit of the exponent in a number,
|
||||||
|
// such as after reading `314e-2` or `0.314e+1` or `3.14e0`.
|
||||||
|
func stateE0(s *scanner, c byte) int {
|
||||||
|
if '0' <= c && c <= '9' {
|
||||||
|
return scanContinue
|
||||||
|
}
|
||||||
|
return stateEndValue(s, c)
|
||||||
|
}
|
||||||
|
|
||||||
|
// stateNew0 is the state after reading `n`.
|
||||||
|
func stateNew0(s *scanner, c byte) int {
|
||||||
|
if c == 'e' {
|
||||||
|
s.step = stateNew1
|
||||||
|
return scanContinue
|
||||||
|
}
|
||||||
|
s.step = stateName
|
||||||
|
return stateName(s, c)
|
||||||
|
}
|
||||||
|
|
||||||
|
// stateNew1 is the state after reading `ne`.
|
||||||
|
func stateNew1(s *scanner, c byte) int {
|
||||||
|
if c == 'w' {
|
||||||
|
s.step = stateNew2
|
||||||
|
return scanContinue
|
||||||
|
}
|
||||||
|
s.step = stateName
|
||||||
|
return stateName(s, c)
|
||||||
|
}
|
||||||
|
|
||||||
|
// stateNew2 is the state after reading `new`.
|
||||||
|
func stateNew2(s *scanner, c byte) int {
|
||||||
|
s.step = stateName
|
||||||
|
if c == ' ' {
|
||||||
|
return scanContinue
|
||||||
|
}
|
||||||
|
return stateName(s, c)
|
||||||
|
}
|
||||||
|
|
||||||
|
// stateName is the state while reading an unquoted function name.
|
||||||
|
func stateName(s *scanner, c byte) int {
|
||||||
|
if isName(c) {
|
||||||
|
return scanContinue
|
||||||
|
}
|
||||||
|
if c == '(' {
|
||||||
|
s.step = stateParamOrEmpty
|
||||||
|
s.pushParseState(parseParam)
|
||||||
|
return scanParam
|
||||||
|
}
|
||||||
|
return stateEndValue(s, c)
|
||||||
|
}
|
||||||
|
|
||||||
|
// stateParamOrEmpty is the state after reading `(`.
|
||||||
|
func stateParamOrEmpty(s *scanner, c byte) int {
|
||||||
|
if c <= ' ' && isSpace(c) {
|
||||||
|
return scanSkipSpace
|
||||||
|
}
|
||||||
|
if c == ')' {
|
||||||
|
return stateEndValue(s, c)
|
||||||
|
}
|
||||||
|
return stateBeginValue(s, c)
|
||||||
|
}
|
||||||
|
|
||||||
|
// stateT is the state after reading `t`.
|
||||||
|
func stateT(s *scanner, c byte) int {
|
||||||
|
if c == 'r' {
|
||||||
|
s.step = stateTr
|
||||||
|
return scanContinue
|
||||||
|
}
|
||||||
|
return s.error(c, "in literal true (expecting 'r')")
|
||||||
|
}
|
||||||
|
|
||||||
|
// stateTr is the state after reading `tr`.
|
||||||
|
func stateTr(s *scanner, c byte) int {
|
||||||
|
if c == 'u' {
|
||||||
|
s.step = stateTru
|
||||||
|
return scanContinue
|
||||||
|
}
|
||||||
|
return s.error(c, "in literal true (expecting 'u')")
|
||||||
|
}
|
||||||
|
|
||||||
|
// stateTru is the state after reading `tru`.
|
||||||
|
func stateTru(s *scanner, c byte) int {
|
||||||
|
if c == 'e' {
|
||||||
|
s.step = stateEndValue
|
||||||
|
return scanContinue
|
||||||
|
}
|
||||||
|
return s.error(c, "in literal true (expecting 'e')")
|
||||||
|
}
|
||||||
|
|
||||||
|
// stateF is the state after reading `f`.
|
||||||
|
func stateF(s *scanner, c byte) int {
|
||||||
|
if c == 'a' {
|
||||||
|
s.step = stateFa
|
||||||
|
return scanContinue
|
||||||
|
}
|
||||||
|
return s.error(c, "in literal false (expecting 'a')")
|
||||||
|
}
|
||||||
|
|
||||||
|
// stateFa is the state after reading `fa`.
|
||||||
|
func stateFa(s *scanner, c byte) int {
|
||||||
|
if c == 'l' {
|
||||||
|
s.step = stateFal
|
||||||
|
return scanContinue
|
||||||
|
}
|
||||||
|
return s.error(c, "in literal false (expecting 'l')")
|
||||||
|
}
|
||||||
|
|
||||||
|
// stateFal is the state after reading `fal`.
|
||||||
|
func stateFal(s *scanner, c byte) int {
|
||||||
|
if c == 's' {
|
||||||
|
s.step = stateFals
|
||||||
|
return scanContinue
|
||||||
|
}
|
||||||
|
return s.error(c, "in literal false (expecting 's')")
|
||||||
|
}
|
||||||
|
|
||||||
|
// stateFals is the state after reading `fals`.
|
||||||
|
func stateFals(s *scanner, c byte) int {
|
||||||
|
if c == 'e' {
|
||||||
|
s.step = stateEndValue
|
||||||
|
return scanContinue
|
||||||
|
}
|
||||||
|
return s.error(c, "in literal false (expecting 'e')")
|
||||||
|
}
|
||||||
|
|
||||||
|
// stateN is the state after reading `n`.
|
||||||
|
func stateN(s *scanner, c byte) int {
|
||||||
|
if c == 'u' {
|
||||||
|
s.step = stateNu
|
||||||
|
return scanContinue
|
||||||
|
}
|
||||||
|
return s.error(c, "in literal null (expecting 'u')")
|
||||||
|
}
|
||||||
|
|
||||||
|
// stateNu is the state after reading `nu`.
|
||||||
|
func stateNu(s *scanner, c byte) int {
|
||||||
|
if c == 'l' {
|
||||||
|
s.step = stateNul
|
||||||
|
return scanContinue
|
||||||
|
}
|
||||||
|
return s.error(c, "in literal null (expecting 'l')")
|
||||||
|
}
|
||||||
|
|
||||||
|
// stateNul is the state after reading `nul`.
|
||||||
|
func stateNul(s *scanner, c byte) int {
|
||||||
|
if c == 'l' {
|
||||||
|
s.step = stateEndValue
|
||||||
|
return scanContinue
|
||||||
|
}
|
||||||
|
return s.error(c, "in literal null (expecting 'l')")
|
||||||
|
}
|
||||||
|
|
||||||
|
// stateError is the state after reaching a syntax error,
|
||||||
|
// such as after reading `[1}` or `5.1.2`.
|
||||||
|
func stateError(s *scanner, c byte) int {
|
||||||
|
return scanError
|
||||||
|
}
|
||||||
|
|
||||||
|
// error records an error and switches to the error state.
|
||||||
|
func (s *scanner) error(c byte, context string) int {
|
||||||
|
s.step = stateError
|
||||||
|
s.err = &SyntaxError{"invalid character " + quoteChar(c) + " " + context, s.bytes}
|
||||||
|
return scanError
|
||||||
|
}
|
||||||
|
|
||||||
|
// quoteChar formats c as a quoted character literal
|
||||||
|
func quoteChar(c byte) string {
|
||||||
|
// special cases - different from quoted strings
|
||||||
|
if c == '\'' {
|
||||||
|
return `'\''`
|
||||||
|
}
|
||||||
|
if c == '"' {
|
||||||
|
return `'"'`
|
||||||
|
}
|
||||||
|
|
||||||
|
// use quoted string with different quotation marks
|
||||||
|
s := strconv.Quote(string(c))
|
||||||
|
return "'" + s[1:len(s)-1] + "'"
|
||||||
|
}
|
||||||
|
|
||||||
|
// undo causes the scanner to return scanCode from the next state transition.
|
||||||
|
// This gives callers a simple 1-byte undo mechanism.
|
||||||
|
func (s *scanner) undo(scanCode int) {
|
||||||
|
if s.redo {
|
||||||
|
panic("json: invalid use of scanner")
|
||||||
|
}
|
||||||
|
s.redoCode = scanCode
|
||||||
|
s.redoState = s.step
|
||||||
|
s.step = stateRedo
|
||||||
|
s.redo = true
|
||||||
|
}
|
||||||
|
|
||||||
|
// stateRedo helps implement the scanner's 1-byte undo.
|
||||||
|
func stateRedo(s *scanner, c byte) int {
|
||||||
|
s.redo = false
|
||||||
|
s.step = s.redoState
|
||||||
|
return s.redoCode
|
||||||
|
}
|
|
@ -0,0 +1,510 @@
|
||||||
|
// Copyright 2010 The Go Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package json
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"errors"
|
||||||
|
"io"
|
||||||
|
)
|
||||||
|
|
||||||
|
// A Decoder reads and decodes JSON values from an input stream.
|
||||||
|
type Decoder struct {
|
||||||
|
r io.Reader
|
||||||
|
buf []byte
|
||||||
|
d decodeState
|
||||||
|
scanp int // start of unread data in buf
|
||||||
|
scan scanner
|
||||||
|
err error
|
||||||
|
|
||||||
|
tokenState int
|
||||||
|
tokenStack []int
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewDecoder returns a new decoder that reads from r.
|
||||||
|
//
|
||||||
|
// The decoder introduces its own buffering and may
|
||||||
|
// read data from r beyond the JSON values requested.
|
||||||
|
func NewDecoder(r io.Reader) *Decoder {
|
||||||
|
return &Decoder{r: r}
|
||||||
|
}
|
||||||
|
|
||||||
|
// UseNumber causes the Decoder to unmarshal a number into an interface{} as a
|
||||||
|
// Number instead of as a float64.
|
||||||
|
func (dec *Decoder) UseNumber() { dec.d.useNumber = true }
|
||||||
|
|
||||||
|
// Decode reads the next JSON-encoded value from its
|
||||||
|
// input and stores it in the value pointed to by v.
|
||||||
|
//
|
||||||
|
// See the documentation for Unmarshal for details about
|
||||||
|
// the conversion of JSON into a Go value.
|
||||||
|
func (dec *Decoder) Decode(v interface{}) error {
|
||||||
|
if dec.err != nil {
|
||||||
|
return dec.err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := dec.tokenPrepareForDecode(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if !dec.tokenValueAllowed() {
|
||||||
|
return &SyntaxError{msg: "not at beginning of value"}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read whole value into buffer.
|
||||||
|
n, err := dec.readValue()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
dec.d.init(dec.buf[dec.scanp : dec.scanp+n])
|
||||||
|
dec.scanp += n
|
||||||
|
|
||||||
|
// Don't save err from unmarshal into dec.err:
|
||||||
|
// the connection is still usable since we read a complete JSON
|
||||||
|
// object from it before the error happened.
|
||||||
|
err = dec.d.unmarshal(v)
|
||||||
|
|
||||||
|
// fixup token streaming state
|
||||||
|
dec.tokenValueEnd()
|
||||||
|
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Buffered returns a reader of the data remaining in the Decoder's
|
||||||
|
// buffer. The reader is valid until the next call to Decode.
|
||||||
|
func (dec *Decoder) Buffered() io.Reader {
|
||||||
|
return bytes.NewReader(dec.buf[dec.scanp:])
|
||||||
|
}
|
||||||
|
|
||||||
|
// readValue reads a JSON value into dec.buf.
|
||||||
|
// It returns the length of the encoding.
|
||||||
|
func (dec *Decoder) readValue() (int, error) {
|
||||||
|
dec.scan.reset()
|
||||||
|
|
||||||
|
scanp := dec.scanp
|
||||||
|
var err error
|
||||||
|
Input:
|
||||||
|
for {
|
||||||
|
// Look in the buffer for a new value.
|
||||||
|
for i, c := range dec.buf[scanp:] {
|
||||||
|
dec.scan.bytes++
|
||||||
|
v := dec.scan.step(&dec.scan, c)
|
||||||
|
if v == scanEnd {
|
||||||
|
scanp += i
|
||||||
|
break Input
|
||||||
|
}
|
||||||
|
// scanEnd is delayed one byte.
|
||||||
|
// We might block trying to get that byte from src,
|
||||||
|
// so instead invent a space byte.
|
||||||
|
if (v == scanEndObject || v == scanEndArray) && dec.scan.step(&dec.scan, ' ') == scanEnd {
|
||||||
|
scanp += i + 1
|
||||||
|
break Input
|
||||||
|
}
|
||||||
|
if v == scanError {
|
||||||
|
dec.err = dec.scan.err
|
||||||
|
return 0, dec.scan.err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
scanp = len(dec.buf)
|
||||||
|
|
||||||
|
// Did the last read have an error?
|
||||||
|
// Delayed until now to allow buffer scan.
|
||||||
|
if err != nil {
|
||||||
|
if err == io.EOF {
|
||||||
|
if dec.scan.step(&dec.scan, ' ') == scanEnd {
|
||||||
|
break Input
|
||||||
|
}
|
||||||
|
if nonSpace(dec.buf) {
|
||||||
|
err = io.ErrUnexpectedEOF
|
||||||
|
}
|
||||||
|
}
|
||||||
|
dec.err = err
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
n := scanp - dec.scanp
|
||||||
|
err = dec.refill()
|
||||||
|
scanp = dec.scanp + n
|
||||||
|
}
|
||||||
|
return scanp - dec.scanp, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (dec *Decoder) refill() error {
|
||||||
|
// Make room to read more into the buffer.
|
||||||
|
// First slide down data already consumed.
|
||||||
|
if dec.scanp > 0 {
|
||||||
|
n := copy(dec.buf, dec.buf[dec.scanp:])
|
||||||
|
dec.buf = dec.buf[:n]
|
||||||
|
dec.scanp = 0
|
||||||
|
}
|
||||||
|
|
||||||
|
// Grow buffer if not large enough.
|
||||||
|
const minRead = 512
|
||||||
|
if cap(dec.buf)-len(dec.buf) < minRead {
|
||||||
|
newBuf := make([]byte, len(dec.buf), 2*cap(dec.buf)+minRead)
|
||||||
|
copy(newBuf, dec.buf)
|
||||||
|
dec.buf = newBuf
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read. Delay error for next iteration (after scan).
|
||||||
|
n, err := dec.r.Read(dec.buf[len(dec.buf):cap(dec.buf)])
|
||||||
|
dec.buf = dec.buf[0 : len(dec.buf)+n]
|
||||||
|
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func nonSpace(b []byte) bool {
|
||||||
|
for _, c := range b {
|
||||||
|
if !isSpace(c) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// An Encoder writes JSON values to an output stream.
|
||||||
|
type Encoder struct {
|
||||||
|
w io.Writer
|
||||||
|
err error
|
||||||
|
escapeHTML bool
|
||||||
|
|
||||||
|
indentBuf *bytes.Buffer
|
||||||
|
indentPrefix string
|
||||||
|
indentValue string
|
||||||
|
|
||||||
|
ext Extension
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewEncoder returns a new encoder that writes to w.
|
||||||
|
func NewEncoder(w io.Writer) *Encoder {
|
||||||
|
return &Encoder{w: w, escapeHTML: true}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Encode writes the JSON encoding of v to the stream,
|
||||||
|
// followed by a newline character.
|
||||||
|
//
|
||||||
|
// See the documentation for Marshal for details about the
|
||||||
|
// conversion of Go values to JSON.
|
||||||
|
func (enc *Encoder) Encode(v interface{}) error {
|
||||||
|
if enc.err != nil {
|
||||||
|
return enc.err
|
||||||
|
}
|
||||||
|
e := newEncodeState()
|
||||||
|
e.ext = enc.ext
|
||||||
|
err := e.marshal(v, encOpts{escapeHTML: enc.escapeHTML})
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Terminate each value with a newline.
|
||||||
|
// This makes the output look a little nicer
|
||||||
|
// when debugging, and some kind of space
|
||||||
|
// is required if the encoded value was a number,
|
||||||
|
// so that the reader knows there aren't more
|
||||||
|
// digits coming.
|
||||||
|
e.WriteByte('\n')
|
||||||
|
|
||||||
|
b := e.Bytes()
|
||||||
|
if enc.indentBuf != nil {
|
||||||
|
enc.indentBuf.Reset()
|
||||||
|
err = Indent(enc.indentBuf, b, enc.indentPrefix, enc.indentValue)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
b = enc.indentBuf.Bytes()
|
||||||
|
}
|
||||||
|
if _, err = enc.w.Write(b); err != nil {
|
||||||
|
enc.err = err
|
||||||
|
}
|
||||||
|
encodeStatePool.Put(e)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Indent sets the encoder to format each encoded value with Indent.
|
||||||
|
func (enc *Encoder) Indent(prefix, indent string) {
|
||||||
|
enc.indentBuf = new(bytes.Buffer)
|
||||||
|
enc.indentPrefix = prefix
|
||||||
|
enc.indentValue = indent
|
||||||
|
}
|
||||||
|
|
||||||
|
// DisableHTMLEscaping causes the encoder not to escape angle brackets
|
||||||
|
// ("<" and ">") or ampersands ("&") in JSON strings.
|
||||||
|
func (enc *Encoder) DisableHTMLEscaping() {
|
||||||
|
enc.escapeHTML = false
|
||||||
|
}
|
||||||
|
|
||||||
|
// RawMessage is a raw encoded JSON value.
|
||||||
|
// It implements Marshaler and Unmarshaler and can
|
||||||
|
// be used to delay JSON decoding or precompute a JSON encoding.
|
||||||
|
type RawMessage []byte
|
||||||
|
|
||||||
|
// MarshalJSON returns *m as the JSON encoding of m.
|
||||||
|
func (m *RawMessage) MarshalJSON() ([]byte, error) {
|
||||||
|
return *m, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// UnmarshalJSON sets *m to a copy of data.
|
||||||
|
func (m *RawMessage) UnmarshalJSON(data []byte) error {
|
||||||
|
if m == nil {
|
||||||
|
return errors.New("json.RawMessage: UnmarshalJSON on nil pointer")
|
||||||
|
}
|
||||||
|
*m = append((*m)[0:0], data...)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ Marshaler = (*RawMessage)(nil)
|
||||||
|
var _ Unmarshaler = (*RawMessage)(nil)
|
||||||
|
|
||||||
|
// A Token holds a value of one of these types:
|
||||||
|
//
|
||||||
|
// Delim, for the four JSON delimiters [ ] { }
|
||||||
|
// bool, for JSON booleans
|
||||||
|
// float64, for JSON numbers
|
||||||
|
// Number, for JSON numbers
|
||||||
|
// string, for JSON string literals
|
||||||
|
// nil, for JSON null
|
||||||
|
//
|
||||||
|
type Token interface{}
|
||||||
|
|
||||||
|
const (
|
||||||
|
tokenTopValue = iota
|
||||||
|
tokenArrayStart
|
||||||
|
tokenArrayValue
|
||||||
|
tokenArrayComma
|
||||||
|
tokenObjectStart
|
||||||
|
tokenObjectKey
|
||||||
|
tokenObjectColon
|
||||||
|
tokenObjectValue
|
||||||
|
tokenObjectComma
|
||||||
|
)
|
||||||
|
|
||||||
|
// advance tokenstate from a separator state to a value state
|
||||||
|
func (dec *Decoder) tokenPrepareForDecode() error {
|
||||||
|
// Note: Not calling peek before switch, to avoid
|
||||||
|
// putting peek into the standard Decode path.
|
||||||
|
// peek is only called when using the Token API.
|
||||||
|
switch dec.tokenState {
|
||||||
|
case tokenArrayComma:
|
||||||
|
c, err := dec.peek()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if c != ',' {
|
||||||
|
return &SyntaxError{"expected comma after array element", 0}
|
||||||
|
}
|
||||||
|
dec.scanp++
|
||||||
|
dec.tokenState = tokenArrayValue
|
||||||
|
case tokenObjectColon:
|
||||||
|
c, err := dec.peek()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if c != ':' {
|
||||||
|
return &SyntaxError{"expected colon after object key", 0}
|
||||||
|
}
|
||||||
|
dec.scanp++
|
||||||
|
dec.tokenState = tokenObjectValue
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (dec *Decoder) tokenValueAllowed() bool {
|
||||||
|
switch dec.tokenState {
|
||||||
|
case tokenTopValue, tokenArrayStart, tokenArrayValue, tokenObjectValue:
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (dec *Decoder) tokenValueEnd() {
|
||||||
|
switch dec.tokenState {
|
||||||
|
case tokenArrayStart, tokenArrayValue:
|
||||||
|
dec.tokenState = tokenArrayComma
|
||||||
|
case tokenObjectValue:
|
||||||
|
dec.tokenState = tokenObjectComma
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// A Delim is a JSON array or object delimiter, one of [ ] { or }.
|
||||||
|
type Delim rune
|
||||||
|
|
||||||
|
func (d Delim) String() string {
|
||||||
|
return string(d)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Token returns the next JSON token in the input stream.
|
||||||
|
// At the end of the input stream, Token returns nil, io.EOF.
|
||||||
|
//
|
||||||
|
// Token guarantees that the delimiters [ ] { } it returns are
|
||||||
|
// properly nested and matched: if Token encounters an unexpected
|
||||||
|
// delimiter in the input, it will return an error.
|
||||||
|
//
|
||||||
|
// The input stream consists of basic JSON values—bool, string,
|
||||||
|
// number, and null—along with delimiters [ ] { } of type Delim
|
||||||
|
// to mark the start and end of arrays and objects.
|
||||||
|
// Commas and colons are elided.
|
||||||
|
func (dec *Decoder) Token() (Token, error) {
|
||||||
|
for {
|
||||||
|
c, err := dec.peek()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
switch c {
|
||||||
|
case '[':
|
||||||
|
if !dec.tokenValueAllowed() {
|
||||||
|
return dec.tokenError(c)
|
||||||
|
}
|
||||||
|
dec.scanp++
|
||||||
|
dec.tokenStack = append(dec.tokenStack, dec.tokenState)
|
||||||
|
dec.tokenState = tokenArrayStart
|
||||||
|
return Delim('['), nil
|
||||||
|
|
||||||
|
case ']':
|
||||||
|
if dec.tokenState != tokenArrayStart && dec.tokenState != tokenArrayComma {
|
||||||
|
return dec.tokenError(c)
|
||||||
|
}
|
||||||
|
dec.scanp++
|
||||||
|
dec.tokenState = dec.tokenStack[len(dec.tokenStack)-1]
|
||||||
|
dec.tokenStack = dec.tokenStack[:len(dec.tokenStack)-1]
|
||||||
|
dec.tokenValueEnd()
|
||||||
|
return Delim(']'), nil
|
||||||
|
|
||||||
|
case '{':
|
||||||
|
if !dec.tokenValueAllowed() {
|
||||||
|
return dec.tokenError(c)
|
||||||
|
}
|
||||||
|
dec.scanp++
|
||||||
|
dec.tokenStack = append(dec.tokenStack, dec.tokenState)
|
||||||
|
dec.tokenState = tokenObjectStart
|
||||||
|
return Delim('{'), nil
|
||||||
|
|
||||||
|
case '}':
|
||||||
|
if dec.tokenState != tokenObjectStart && dec.tokenState != tokenObjectComma {
|
||||||
|
return dec.tokenError(c)
|
||||||
|
}
|
||||||
|
dec.scanp++
|
||||||
|
dec.tokenState = dec.tokenStack[len(dec.tokenStack)-1]
|
||||||
|
dec.tokenStack = dec.tokenStack[:len(dec.tokenStack)-1]
|
||||||
|
dec.tokenValueEnd()
|
||||||
|
return Delim('}'), nil
|
||||||
|
|
||||||
|
case ':':
|
||||||
|
if dec.tokenState != tokenObjectColon {
|
||||||
|
return dec.tokenError(c)
|
||||||
|
}
|
||||||
|
dec.scanp++
|
||||||
|
dec.tokenState = tokenObjectValue
|
||||||
|
continue
|
||||||
|
|
||||||
|
case ',':
|
||||||
|
if dec.tokenState == tokenArrayComma {
|
||||||
|
dec.scanp++
|
||||||
|
dec.tokenState = tokenArrayValue
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if dec.tokenState == tokenObjectComma {
|
||||||
|
dec.scanp++
|
||||||
|
dec.tokenState = tokenObjectKey
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
return dec.tokenError(c)
|
||||||
|
|
||||||
|
case '"':
|
||||||
|
if dec.tokenState == tokenObjectStart || dec.tokenState == tokenObjectKey {
|
||||||
|
var x string
|
||||||
|
old := dec.tokenState
|
||||||
|
dec.tokenState = tokenTopValue
|
||||||
|
err := dec.Decode(&x)
|
||||||
|
dec.tokenState = old
|
||||||
|
if err != nil {
|
||||||
|
clearOffset(err)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
dec.tokenState = tokenObjectColon
|
||||||
|
return x, nil
|
||||||
|
}
|
||||||
|
fallthrough
|
||||||
|
|
||||||
|
default:
|
||||||
|
if !dec.tokenValueAllowed() {
|
||||||
|
return dec.tokenError(c)
|
||||||
|
}
|
||||||
|
var x interface{}
|
||||||
|
if err := dec.Decode(&x); err != nil {
|
||||||
|
clearOffset(err)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return x, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func clearOffset(err error) {
|
||||||
|
if s, ok := err.(*SyntaxError); ok {
|
||||||
|
s.Offset = 0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (dec *Decoder) tokenError(c byte) (Token, error) {
|
||||||
|
var context string
|
||||||
|
switch dec.tokenState {
|
||||||
|
case tokenTopValue:
|
||||||
|
context = " looking for beginning of value"
|
||||||
|
case tokenArrayStart, tokenArrayValue, tokenObjectValue:
|
||||||
|
context = " looking for beginning of value"
|
||||||
|
case tokenArrayComma:
|
||||||
|
context = " after array element"
|
||||||
|
case tokenObjectKey:
|
||||||
|
context = " looking for beginning of object key string"
|
||||||
|
case tokenObjectColon:
|
||||||
|
context = " after object key"
|
||||||
|
case tokenObjectComma:
|
||||||
|
context = " after object key:value pair"
|
||||||
|
}
|
||||||
|
return nil, &SyntaxError{"invalid character " + quoteChar(c) + " " + context, 0}
|
||||||
|
}
|
||||||
|
|
||||||
|
// More reports whether there is another element in the
|
||||||
|
// current array or object being parsed.
|
||||||
|
func (dec *Decoder) More() bool {
|
||||||
|
c, err := dec.peek()
|
||||||
|
return err == nil && c != ']' && c != '}'
|
||||||
|
}
|
||||||
|
|
||||||
|
func (dec *Decoder) peek() (byte, error) {
|
||||||
|
var err error
|
||||||
|
for {
|
||||||
|
for i := dec.scanp; i < len(dec.buf); i++ {
|
||||||
|
c := dec.buf[i]
|
||||||
|
if isSpace(c) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
dec.scanp = i
|
||||||
|
return c, nil
|
||||||
|
}
|
||||||
|
// buffer has been scanned, now report any error
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
err = dec.refill()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
TODO
|
||||||
|
|
||||||
|
// EncodeToken writes the given JSON token to the stream.
|
||||||
|
// It returns an error if the delimiters [ ] { } are not properly used.
|
||||||
|
//
|
||||||
|
// EncodeToken does not call Flush, because usually it is part of
|
||||||
|
// a larger operation such as Encode, and those will call Flush when finished.
|
||||||
|
// Callers that create an Encoder and then invoke EncodeToken directly,
|
||||||
|
// without using Encode, need to call Flush when finished to ensure that
|
||||||
|
// the JSON is written to the underlying writer.
|
||||||
|
func (e *Encoder) EncodeToken(t Token) error {
|
||||||
|
...
|
||||||
|
}
|
||||||
|
|
||||||
|
*/
|
|
@ -0,0 +1,44 @@
|
||||||
|
// Copyright 2011 The Go Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package json
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
// tagOptions is the string following a comma in a struct field's "json"
|
||||||
|
// tag, or the empty string. It does not include the leading comma.
|
||||||
|
type tagOptions string
|
||||||
|
|
||||||
|
// parseTag splits a struct field's json tag into its name and
|
||||||
|
// comma-separated options.
|
||||||
|
func parseTag(tag string) (string, tagOptions) {
|
||||||
|
if idx := strings.Index(tag, ","); idx != -1 {
|
||||||
|
return tag[:idx], tagOptions(tag[idx+1:])
|
||||||
|
}
|
||||||
|
return tag, tagOptions("")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Contains reports whether a comma-separated list of options
|
||||||
|
// contains a particular substr flag. substr must be surrounded by a
|
||||||
|
// string boundary or commas.
|
||||||
|
func (o tagOptions) Contains(optionName string) bool {
|
||||||
|
if len(o) == 0 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
s := string(o)
|
||||||
|
for s != "" {
|
||||||
|
var next string
|
||||||
|
i := strings.Index(s, ",")
|
||||||
|
if i >= 0 {
|
||||||
|
s, next = s[:i], s[i+1:]
|
||||||
|
}
|
||||||
|
if s == optionName {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
s = next
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
|
@ -0,0 +1,77 @@
|
||||||
|
// +build !windows
|
||||||
|
|
||||||
|
#include <stdlib.h>
|
||||||
|
#include <stdio.h>
|
||||||
|
#include <string.h>
|
||||||
|
#include <sasl/sasl.h>
|
||||||
|
|
||||||
|
static int mgo_sasl_simple(void *context, int id, const char **result, unsigned int *len)
|
||||||
|
{
|
||||||
|
if (!result) {
|
||||||
|
return SASL_BADPARAM;
|
||||||
|
}
|
||||||
|
switch (id) {
|
||||||
|
case SASL_CB_USER:
|
||||||
|
*result = (char *)context;
|
||||||
|
break;
|
||||||
|
case SASL_CB_AUTHNAME:
|
||||||
|
*result = (char *)context;
|
||||||
|
break;
|
||||||
|
case SASL_CB_LANGUAGE:
|
||||||
|
*result = NULL;
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
return SASL_BADPARAM;
|
||||||
|
}
|
||||||
|
if (len) {
|
||||||
|
*len = *result ? strlen(*result) : 0;
|
||||||
|
}
|
||||||
|
return SASL_OK;
|
||||||
|
}
|
||||||
|
|
||||||
|
typedef int (*callback)(void);
|
||||||
|
|
||||||
|
static int mgo_sasl_secret(sasl_conn_t *conn, void *context, int id, sasl_secret_t **result)
|
||||||
|
{
|
||||||
|
if (!conn || !result || id != SASL_CB_PASS) {
|
||||||
|
return SASL_BADPARAM;
|
||||||
|
}
|
||||||
|
*result = (sasl_secret_t *)context;
|
||||||
|
return SASL_OK;
|
||||||
|
}
|
||||||
|
|
||||||
|
sasl_callback_t *mgo_sasl_callbacks(const char *username, const char *password)
|
||||||
|
{
|
||||||
|
sasl_callback_t *cb = malloc(4 * sizeof(sasl_callback_t));
|
||||||
|
int n = 0;
|
||||||
|
|
||||||
|
size_t len = strlen(password);
|
||||||
|
sasl_secret_t *secret = (sasl_secret_t*)malloc(sizeof(sasl_secret_t) + len);
|
||||||
|
if (!secret) {
|
||||||
|
free(cb);
|
||||||
|
return NULL;
|
||||||
|
}
|
||||||
|
strcpy((char *)secret->data, password);
|
||||||
|
secret->len = len;
|
||||||
|
|
||||||
|
cb[n].id = SASL_CB_PASS;
|
||||||
|
cb[n].proc = (callback)&mgo_sasl_secret;
|
||||||
|
cb[n].context = secret;
|
||||||
|
n++;
|
||||||
|
|
||||||
|
cb[n].id = SASL_CB_USER;
|
||||||
|
cb[n].proc = (callback)&mgo_sasl_simple;
|
||||||
|
cb[n].context = (char*)username;
|
||||||
|
n++;
|
||||||
|
|
||||||
|
cb[n].id = SASL_CB_AUTHNAME;
|
||||||
|
cb[n].proc = (callback)&mgo_sasl_simple;
|
||||||
|
cb[n].context = (char*)username;
|
||||||
|
n++;
|
||||||
|
|
||||||
|
cb[n].id = SASL_CB_LIST_END;
|
||||||
|
cb[n].proc = NULL;
|
||||||
|
cb[n].context = NULL;
|
||||||
|
|
||||||
|
return cb;
|
||||||
|
}
|
|
@ -0,0 +1,138 @@
|
||||||
|
// Package sasl is an implementation detail of the mgo package.
|
||||||
|
//
|
||||||
|
// This package is not meant to be used by itself.
|
||||||
|
//
|
||||||
|
|
||||||
|
// +build !windows
|
||||||
|
|
||||||
|
package sasl
|
||||||
|
|
||||||
|
// #cgo LDFLAGS: -lsasl2
|
||||||
|
//
|
||||||
|
// struct sasl_conn {};
|
||||||
|
//
|
||||||
|
// #include <stdlib.h>
|
||||||
|
// #include <sasl/sasl.h>
|
||||||
|
//
|
||||||
|
// sasl_callback_t *mgo_sasl_callbacks(const char *username, const char *password);
|
||||||
|
//
|
||||||
|
import "C"
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"unsafe"
|
||||||
|
)
|
||||||
|
|
||||||
|
type saslStepper interface {
|
||||||
|
Step(serverData []byte) (clientData []byte, done bool, err error)
|
||||||
|
Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
type saslSession struct {
|
||||||
|
conn *C.sasl_conn_t
|
||||||
|
step int
|
||||||
|
mech string
|
||||||
|
|
||||||
|
cstrings []*C.char
|
||||||
|
callbacks *C.sasl_callback_t
|
||||||
|
}
|
||||||
|
|
||||||
|
var initError error
|
||||||
|
var initOnce sync.Once
|
||||||
|
|
||||||
|
func initSASL() {
|
||||||
|
rc := C.sasl_client_init(nil)
|
||||||
|
if rc != C.SASL_OK {
|
||||||
|
initError = saslError(rc, nil, "cannot initialize SASL library")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func New(username, password, mechanism, service, host string) (saslStepper, error) {
|
||||||
|
initOnce.Do(initSASL)
|
||||||
|
if initError != nil {
|
||||||
|
return nil, initError
|
||||||
|
}
|
||||||
|
|
||||||
|
ss := &saslSession{mech: mechanism}
|
||||||
|
if service == "" {
|
||||||
|
service = "mongodb"
|
||||||
|
}
|
||||||
|
if i := strings.Index(host, ":"); i >= 0 {
|
||||||
|
host = host[:i]
|
||||||
|
}
|
||||||
|
ss.callbacks = C.mgo_sasl_callbacks(ss.cstr(username), ss.cstr(password))
|
||||||
|
rc := C.sasl_client_new(ss.cstr(service), ss.cstr(host), nil, nil, ss.callbacks, 0, &ss.conn)
|
||||||
|
if rc != C.SASL_OK {
|
||||||
|
ss.Close()
|
||||||
|
return nil, saslError(rc, nil, "cannot create new SASL client")
|
||||||
|
}
|
||||||
|
return ss, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ss *saslSession) cstr(s string) *C.char {
|
||||||
|
cstr := C.CString(s)
|
||||||
|
ss.cstrings = append(ss.cstrings, cstr)
|
||||||
|
return cstr
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ss *saslSession) Close() {
|
||||||
|
for _, cstr := range ss.cstrings {
|
||||||
|
C.free(unsafe.Pointer(cstr))
|
||||||
|
}
|
||||||
|
ss.cstrings = nil
|
||||||
|
|
||||||
|
if ss.callbacks != nil {
|
||||||
|
C.free(unsafe.Pointer(ss.callbacks))
|
||||||
|
}
|
||||||
|
|
||||||
|
// The documentation of SASL dispose makes it clear that this should only
|
||||||
|
// be done when the connection is done, not when the authentication phase
|
||||||
|
// is done, because an encryption layer may have been negotiated.
|
||||||
|
// Even then, we'll do this for now, because it's simpler and prevents
|
||||||
|
// keeping track of this state for every socket. If it breaks, we'll fix it.
|
||||||
|
C.sasl_dispose(&ss.conn)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ss *saslSession) Step(serverData []byte) (clientData []byte, done bool, err error) {
|
||||||
|
ss.step++
|
||||||
|
if ss.step > 10 {
|
||||||
|
return nil, false, fmt.Errorf("too many SASL steps without authentication")
|
||||||
|
}
|
||||||
|
var cclientData *C.char
|
||||||
|
var cclientDataLen C.uint
|
||||||
|
var rc C.int
|
||||||
|
if ss.step == 1 {
|
||||||
|
var mechanism *C.char // ignored - must match cred
|
||||||
|
rc = C.sasl_client_start(ss.conn, ss.cstr(ss.mech), nil, &cclientData, &cclientDataLen, &mechanism)
|
||||||
|
} else {
|
||||||
|
var cserverData *C.char
|
||||||
|
var cserverDataLen C.uint
|
||||||
|
if len(serverData) > 0 {
|
||||||
|
cserverData = (*C.char)(unsafe.Pointer(&serverData[0]))
|
||||||
|
cserverDataLen = C.uint(len(serverData))
|
||||||
|
}
|
||||||
|
rc = C.sasl_client_step(ss.conn, cserverData, cserverDataLen, nil, &cclientData, &cclientDataLen)
|
||||||
|
}
|
||||||
|
if cclientData != nil && cclientDataLen > 0 {
|
||||||
|
clientData = C.GoBytes(unsafe.Pointer(cclientData), C.int(cclientDataLen))
|
||||||
|
}
|
||||||
|
if rc == C.SASL_OK {
|
||||||
|
return clientData, true, nil
|
||||||
|
}
|
||||||
|
if rc == C.SASL_CONTINUE {
|
||||||
|
return clientData, false, nil
|
||||||
|
}
|
||||||
|
return nil, false, saslError(rc, ss.conn, "cannot establish SASL session")
|
||||||
|
}
|
||||||
|
|
||||||
|
func saslError(rc C.int, conn *C.sasl_conn_t, msg string) error {
|
||||||
|
var detail string
|
||||||
|
if conn == nil {
|
||||||
|
detail = C.GoString(C.sasl_errstring(rc, nil, nil))
|
||||||
|
} else {
|
||||||
|
detail = C.GoString(C.sasl_errdetail(conn))
|
||||||
|
}
|
||||||
|
return fmt.Errorf(msg + ": " + detail)
|
||||||
|
}
|
|
@ -0,0 +1,122 @@
|
||||||
|
#include "sasl_windows.h"
|
||||||
|
|
||||||
|
static const LPSTR SSPI_PACKAGE_NAME = "kerberos";
|
||||||
|
|
||||||
|
SECURITY_STATUS SEC_ENTRY sspi_acquire_credentials_handle(CredHandle *cred_handle, char *username, char *password, char *domain)
|
||||||
|
{
|
||||||
|
SEC_WINNT_AUTH_IDENTITY auth_identity;
|
||||||
|
SECURITY_INTEGER ignored;
|
||||||
|
|
||||||
|
auth_identity.Flags = SEC_WINNT_AUTH_IDENTITY_ANSI;
|
||||||
|
auth_identity.User = (LPSTR) username;
|
||||||
|
auth_identity.UserLength = strlen(username);
|
||||||
|
auth_identity.Password = NULL;
|
||||||
|
auth_identity.PasswordLength = 0;
|
||||||
|
if(password){
|
||||||
|
auth_identity.Password = (LPSTR) password;
|
||||||
|
auth_identity.PasswordLength = strlen(password);
|
||||||
|
}
|
||||||
|
auth_identity.Domain = (LPSTR) domain;
|
||||||
|
auth_identity.DomainLength = strlen(domain);
|
||||||
|
return call_sspi_acquire_credentials_handle(NULL, SSPI_PACKAGE_NAME, SECPKG_CRED_OUTBOUND, NULL, &auth_identity, NULL, NULL, cred_handle, &ignored);
|
||||||
|
}
|
||||||
|
|
||||||
|
int sspi_step(CredHandle *cred_handle, int has_context, CtxtHandle *context, PVOID buffer, ULONG buffer_length, PVOID *out_buffer, ULONG *out_buffer_length, char *target)
|
||||||
|
{
|
||||||
|
SecBufferDesc inbuf;
|
||||||
|
SecBuffer in_bufs[1];
|
||||||
|
SecBufferDesc outbuf;
|
||||||
|
SecBuffer out_bufs[1];
|
||||||
|
|
||||||
|
if (has_context > 0) {
|
||||||
|
// If we already have a context, we now have data to send.
|
||||||
|
// Put this data in an inbuf.
|
||||||
|
inbuf.ulVersion = SECBUFFER_VERSION;
|
||||||
|
inbuf.cBuffers = 1;
|
||||||
|
inbuf.pBuffers = in_bufs;
|
||||||
|
in_bufs[0].pvBuffer = buffer;
|
||||||
|
in_bufs[0].cbBuffer = buffer_length;
|
||||||
|
in_bufs[0].BufferType = SECBUFFER_TOKEN;
|
||||||
|
}
|
||||||
|
|
||||||
|
outbuf.ulVersion = SECBUFFER_VERSION;
|
||||||
|
outbuf.cBuffers = 1;
|
||||||
|
outbuf.pBuffers = out_bufs;
|
||||||
|
out_bufs[0].pvBuffer = NULL;
|
||||||
|
out_bufs[0].cbBuffer = 0;
|
||||||
|
out_bufs[0].BufferType = SECBUFFER_TOKEN;
|
||||||
|
|
||||||
|
ULONG context_attr = 0;
|
||||||
|
|
||||||
|
int ret = call_sspi_initialize_security_context(cred_handle,
|
||||||
|
has_context > 0 ? context : NULL,
|
||||||
|
(LPSTR) target,
|
||||||
|
ISC_REQ_ALLOCATE_MEMORY | ISC_REQ_MUTUAL_AUTH,
|
||||||
|
0,
|
||||||
|
SECURITY_NETWORK_DREP,
|
||||||
|
has_context > 0 ? &inbuf : NULL,
|
||||||
|
0,
|
||||||
|
context,
|
||||||
|
&outbuf,
|
||||||
|
&context_attr,
|
||||||
|
NULL);
|
||||||
|
|
||||||
|
*out_buffer = malloc(out_bufs[0].cbBuffer);
|
||||||
|
*out_buffer_length = out_bufs[0].cbBuffer;
|
||||||
|
memcpy(*out_buffer, out_bufs[0].pvBuffer, *out_buffer_length);
|
||||||
|
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
|
||||||
|
int sspi_send_client_authz_id(CtxtHandle *context, PVOID *buffer, ULONG *buffer_length, char *user_plus_realm)
|
||||||
|
{
|
||||||
|
SecPkgContext_Sizes sizes;
|
||||||
|
SECURITY_STATUS status = call_sspi_query_context_attributes(context, SECPKG_ATTR_SIZES, &sizes);
|
||||||
|
|
||||||
|
if (status != SEC_E_OK) {
|
||||||
|
return status;
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t user_plus_realm_length = strlen(user_plus_realm);
|
||||||
|
int msgSize = 4 + user_plus_realm_length;
|
||||||
|
char *msg = malloc((sizes.cbSecurityTrailer + msgSize + sizes.cbBlockSize) * sizeof(char));
|
||||||
|
msg[sizes.cbSecurityTrailer + 0] = 1;
|
||||||
|
msg[sizes.cbSecurityTrailer + 1] = 0;
|
||||||
|
msg[sizes.cbSecurityTrailer + 2] = 0;
|
||||||
|
msg[sizes.cbSecurityTrailer + 3] = 0;
|
||||||
|
memcpy(&msg[sizes.cbSecurityTrailer + 4], user_plus_realm, user_plus_realm_length);
|
||||||
|
|
||||||
|
SecBuffer wrapBufs[3];
|
||||||
|
SecBufferDesc wrapBufDesc;
|
||||||
|
wrapBufDesc.cBuffers = 3;
|
||||||
|
wrapBufDesc.pBuffers = wrapBufs;
|
||||||
|
wrapBufDesc.ulVersion = SECBUFFER_VERSION;
|
||||||
|
|
||||||
|
wrapBufs[0].cbBuffer = sizes.cbSecurityTrailer;
|
||||||
|
wrapBufs[0].BufferType = SECBUFFER_TOKEN;
|
||||||
|
wrapBufs[0].pvBuffer = msg;
|
||||||
|
|
||||||
|
wrapBufs[1].cbBuffer = msgSize;
|
||||||
|
wrapBufs[1].BufferType = SECBUFFER_DATA;
|
||||||
|
wrapBufs[1].pvBuffer = msg + sizes.cbSecurityTrailer;
|
||||||
|
|
||||||
|
wrapBufs[2].cbBuffer = sizes.cbBlockSize;
|
||||||
|
wrapBufs[2].BufferType = SECBUFFER_PADDING;
|
||||||
|
wrapBufs[2].pvBuffer = msg + sizes.cbSecurityTrailer + msgSize;
|
||||||
|
|
||||||
|
status = call_sspi_encrypt_message(context, SECQOP_WRAP_NO_ENCRYPT, &wrapBufDesc, 0);
|
||||||
|
if (status != SEC_E_OK) {
|
||||||
|
free(msg);
|
||||||
|
return status;
|
||||||
|
}
|
||||||
|
|
||||||
|
*buffer_length = wrapBufs[0].cbBuffer + wrapBufs[1].cbBuffer + wrapBufs[2].cbBuffer;
|
||||||
|
*buffer = malloc(*buffer_length);
|
||||||
|
|
||||||
|
memcpy(*buffer, wrapBufs[0].pvBuffer, wrapBufs[0].cbBuffer);
|
||||||
|
memcpy(*buffer + wrapBufs[0].cbBuffer, wrapBufs[1].pvBuffer, wrapBufs[1].cbBuffer);
|
||||||
|
memcpy(*buffer + wrapBufs[0].cbBuffer + wrapBufs[1].cbBuffer, wrapBufs[2].pvBuffer, wrapBufs[2].cbBuffer);
|
||||||
|
|
||||||
|
free(msg);
|
||||||
|
return SEC_E_OK;
|
||||||
|
}
|
|
@ -0,0 +1,142 @@
|
||||||
|
package sasl
|
||||||
|
|
||||||
|
// #include "sasl_windows.h"
|
||||||
|
import "C"
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"unsafe"
|
||||||
|
)
|
||||||
|
|
||||||
|
type saslStepper interface {
|
||||||
|
Step(serverData []byte) (clientData []byte, done bool, err error)
|
||||||
|
Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
type saslSession struct {
|
||||||
|
// Credentials
|
||||||
|
mech string
|
||||||
|
service string
|
||||||
|
host string
|
||||||
|
userPlusRealm string
|
||||||
|
target string
|
||||||
|
domain string
|
||||||
|
|
||||||
|
// Internal state
|
||||||
|
authComplete bool
|
||||||
|
errored bool
|
||||||
|
step int
|
||||||
|
|
||||||
|
// C internal state
|
||||||
|
credHandle C.CredHandle
|
||||||
|
context C.CtxtHandle
|
||||||
|
hasContext C.int
|
||||||
|
|
||||||
|
// Keep track of pointers we need to explicitly free
|
||||||
|
stringsToFree []*C.char
|
||||||
|
}
|
||||||
|
|
||||||
|
var initError error
|
||||||
|
var initOnce sync.Once
|
||||||
|
|
||||||
|
func initSSPI() {
|
||||||
|
rc := C.load_secur32_dll()
|
||||||
|
if rc != 0 {
|
||||||
|
initError = fmt.Errorf("Error loading libraries: %v", rc)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func New(username, password, mechanism, service, host string) (saslStepper, error) {
|
||||||
|
initOnce.Do(initSSPI)
|
||||||
|
ss := &saslSession{mech: mechanism, hasContext: 0, userPlusRealm: username}
|
||||||
|
if service == "" {
|
||||||
|
service = "mongodb"
|
||||||
|
}
|
||||||
|
if i := strings.Index(host, ":"); i >= 0 {
|
||||||
|
host = host[:i]
|
||||||
|
}
|
||||||
|
ss.service = service
|
||||||
|
ss.host = host
|
||||||
|
|
||||||
|
usernameComponents := strings.Split(username, "@")
|
||||||
|
if len(usernameComponents) < 2 {
|
||||||
|
return nil, fmt.Errorf("Username '%v' doesn't contain a realm!", username)
|
||||||
|
}
|
||||||
|
user := usernameComponents[0]
|
||||||
|
ss.domain = usernameComponents[1]
|
||||||
|
ss.target = fmt.Sprintf("%s/%s", ss.service, ss.host)
|
||||||
|
|
||||||
|
var status C.SECURITY_STATUS
|
||||||
|
// Step 0: call AcquireCredentialsHandle to get a nice SSPI CredHandle
|
||||||
|
if len(password) > 0 {
|
||||||
|
status = C.sspi_acquire_credentials_handle(&ss.credHandle, ss.cstr(user), ss.cstr(password), ss.cstr(ss.domain))
|
||||||
|
} else {
|
||||||
|
status = C.sspi_acquire_credentials_handle(&ss.credHandle, ss.cstr(user), nil, ss.cstr(ss.domain))
|
||||||
|
}
|
||||||
|
if status != C.SEC_E_OK {
|
||||||
|
ss.errored = true
|
||||||
|
return nil, fmt.Errorf("Couldn't create new SSPI client, error code %v", status)
|
||||||
|
}
|
||||||
|
return ss, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ss *saslSession) cstr(s string) *C.char {
|
||||||
|
cstr := C.CString(s)
|
||||||
|
ss.stringsToFree = append(ss.stringsToFree, cstr)
|
||||||
|
return cstr
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ss *saslSession) Close() {
|
||||||
|
for _, cstr := range ss.stringsToFree {
|
||||||
|
C.free(unsafe.Pointer(cstr))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ss *saslSession) Step(serverData []byte) (clientData []byte, done bool, err error) {
|
||||||
|
ss.step++
|
||||||
|
if ss.step > 10 {
|
||||||
|
return nil, false, fmt.Errorf("too many SSPI steps without authentication")
|
||||||
|
}
|
||||||
|
var buffer C.PVOID
|
||||||
|
var bufferLength C.ULONG
|
||||||
|
var outBuffer C.PVOID
|
||||||
|
var outBufferLength C.ULONG
|
||||||
|
if len(serverData) > 0 {
|
||||||
|
buffer = (C.PVOID)(unsafe.Pointer(&serverData[0]))
|
||||||
|
bufferLength = C.ULONG(len(serverData))
|
||||||
|
}
|
||||||
|
var status C.int
|
||||||
|
if ss.authComplete {
|
||||||
|
// Step 3: last bit of magic to use the correct server credentials
|
||||||
|
status = C.sspi_send_client_authz_id(&ss.context, &outBuffer, &outBufferLength, ss.cstr(ss.userPlusRealm))
|
||||||
|
} else {
|
||||||
|
// Step 1 + Step 2: set up security context with the server and TGT
|
||||||
|
status = C.sspi_step(&ss.credHandle, ss.hasContext, &ss.context, buffer, bufferLength, &outBuffer, &outBufferLength, ss.cstr(ss.target))
|
||||||
|
}
|
||||||
|
if outBuffer != C.PVOID(nil) {
|
||||||
|
defer C.free(unsafe.Pointer(outBuffer))
|
||||||
|
}
|
||||||
|
if status != C.SEC_E_OK && status != C.SEC_I_CONTINUE_NEEDED {
|
||||||
|
ss.errored = true
|
||||||
|
return nil, false, ss.handleSSPIErrorCode(status)
|
||||||
|
}
|
||||||
|
|
||||||
|
clientData = C.GoBytes(unsafe.Pointer(outBuffer), C.int(outBufferLength))
|
||||||
|
if status == C.SEC_E_OK {
|
||||||
|
ss.authComplete = true
|
||||||
|
return clientData, true, nil
|
||||||
|
} else {
|
||||||
|
ss.hasContext = 1
|
||||||
|
return clientData, false, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ss *saslSession) handleSSPIErrorCode(code C.int) error {
|
||||||
|
switch {
|
||||||
|
case code == C.SEC_E_TARGET_UNKNOWN:
|
||||||
|
return fmt.Errorf("Target %v@%v not found", ss.target, ss.domain)
|
||||||
|
}
|
||||||
|
return fmt.Errorf("Unknown error doing step %v, error code %v", ss.step, code)
|
||||||
|
}
|
|
@ -0,0 +1,7 @@
|
||||||
|
#include <windows.h>
|
||||||
|
|
||||||
|
#include "sspi_windows.h"
|
||||||
|
|
||||||
|
SECURITY_STATUS SEC_ENTRY sspi_acquire_credentials_handle(CredHandle* cred_handle, char* username, char* password, char* domain);
|
||||||
|
int sspi_step(CredHandle* cred_handle, int has_context, CtxtHandle* context, PVOID buffer, ULONG buffer_length, PVOID* out_buffer, ULONG* out_buffer_length, char* target);
|
||||||
|
int sspi_send_client_authz_id(CtxtHandle* context, PVOID* buffer, ULONG* buffer_length, char* user_plus_realm);
|
|
@ -0,0 +1,96 @@
|
||||||
|
// Code adapted from the NodeJS kerberos library:
|
||||||
|
//
|
||||||
|
// https://github.com/christkv/kerberos/tree/master/lib/win32/kerberos_sspi.c
|
||||||
|
//
|
||||||
|
// Under the terms of the Apache License, Version 2.0:
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
#include <stdlib.h>
|
||||||
|
|
||||||
|
#include "sspi_windows.h"
|
||||||
|
|
||||||
|
static HINSTANCE sspi_secur32_dll = NULL;
|
||||||
|
|
||||||
|
int load_secur32_dll()
|
||||||
|
{
|
||||||
|
sspi_secur32_dll = LoadLibrary("secur32.dll");
|
||||||
|
if (sspi_secur32_dll == NULL) {
|
||||||
|
return GetLastError();
|
||||||
|
}
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
SECURITY_STATUS SEC_ENTRY call_sspi_encrypt_message(PCtxtHandle phContext, unsigned long fQOP, PSecBufferDesc pMessage, unsigned long MessageSeqNo)
|
||||||
|
{
|
||||||
|
if (sspi_secur32_dll == NULL) {
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
encryptMessage_fn pfn_encryptMessage = (encryptMessage_fn) GetProcAddress(sspi_secur32_dll, "EncryptMessage");
|
||||||
|
if (!pfn_encryptMessage) {
|
||||||
|
return -2;
|
||||||
|
}
|
||||||
|
return (*pfn_encryptMessage)(phContext, fQOP, pMessage, MessageSeqNo);
|
||||||
|
}
|
||||||
|
|
||||||
|
SECURITY_STATUS SEC_ENTRY call_sspi_acquire_credentials_handle(
|
||||||
|
LPSTR pszPrincipal, LPSTR pszPackage, unsigned long fCredentialUse,
|
||||||
|
void *pvLogonId, void *pAuthData, SEC_GET_KEY_FN pGetKeyFn, void *pvGetKeyArgument,
|
||||||
|
PCredHandle phCredential, PTimeStamp ptsExpiry)
|
||||||
|
{
|
||||||
|
if (sspi_secur32_dll == NULL) {
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
acquireCredentialsHandle_fn pfn_acquireCredentialsHandle;
|
||||||
|
#ifdef _UNICODE
|
||||||
|
pfn_acquireCredentialsHandle = (acquireCredentialsHandle_fn) GetProcAddress(sspi_secur32_dll, "AcquireCredentialsHandleW");
|
||||||
|
#else
|
||||||
|
pfn_acquireCredentialsHandle = (acquireCredentialsHandle_fn) GetProcAddress(sspi_secur32_dll, "AcquireCredentialsHandleA");
|
||||||
|
#endif
|
||||||
|
if (!pfn_acquireCredentialsHandle) {
|
||||||
|
return -2;
|
||||||
|
}
|
||||||
|
return (*pfn_acquireCredentialsHandle)(
|
||||||
|
pszPrincipal, pszPackage, fCredentialUse, pvLogonId, pAuthData,
|
||||||
|
pGetKeyFn, pvGetKeyArgument, phCredential, ptsExpiry);
|
||||||
|
}
|
||||||
|
|
||||||
|
SECURITY_STATUS SEC_ENTRY call_sspi_initialize_security_context(
|
||||||
|
PCredHandle phCredential, PCtxtHandle phContext, LPSTR pszTargetName,
|
||||||
|
unsigned long fContextReq, unsigned long Reserved1, unsigned long TargetDataRep,
|
||||||
|
PSecBufferDesc pInput, unsigned long Reserved2, PCtxtHandle phNewContext,
|
||||||
|
PSecBufferDesc pOutput, unsigned long *pfContextAttr, PTimeStamp ptsExpiry)
|
||||||
|
{
|
||||||
|
if (sspi_secur32_dll == NULL) {
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
initializeSecurityContext_fn pfn_initializeSecurityContext;
|
||||||
|
#ifdef _UNICODE
|
||||||
|
pfn_initializeSecurityContext = (initializeSecurityContext_fn) GetProcAddress(sspi_secur32_dll, "InitializeSecurityContextW");
|
||||||
|
#else
|
||||||
|
pfn_initializeSecurityContext = (initializeSecurityContext_fn) GetProcAddress(sspi_secur32_dll, "InitializeSecurityContextA");
|
||||||
|
#endif
|
||||||
|
if (!pfn_initializeSecurityContext) {
|
||||||
|
return -2;
|
||||||
|
}
|
||||||
|
return (*pfn_initializeSecurityContext)(
|
||||||
|
phCredential, phContext, pszTargetName, fContextReq, Reserved1, TargetDataRep,
|
||||||
|
pInput, Reserved2, phNewContext, pOutput, pfContextAttr, ptsExpiry);
|
||||||
|
}
|
||||||
|
|
||||||
|
SECURITY_STATUS SEC_ENTRY call_sspi_query_context_attributes(PCtxtHandle phContext, unsigned long ulAttribute, void *pBuffer)
|
||||||
|
{
|
||||||
|
if (sspi_secur32_dll == NULL) {
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
queryContextAttributes_fn pfn_queryContextAttributes;
|
||||||
|
#ifdef _UNICODE
|
||||||
|
pfn_queryContextAttributes = (queryContextAttributes_fn) GetProcAddress(sspi_secur32_dll, "QueryContextAttributesW");
|
||||||
|
#else
|
||||||
|
pfn_queryContextAttributes = (queryContextAttributes_fn) GetProcAddress(sspi_secur32_dll, "QueryContextAttributesA");
|
||||||
|
#endif
|
||||||
|
if (!pfn_queryContextAttributes) {
|
||||||
|
return -2;
|
||||||
|
}
|
||||||
|
return (*pfn_queryContextAttributes)(phContext, ulAttribute, pBuffer);
|
||||||
|
}
|
|
@ -0,0 +1,70 @@
|
||||||
|
// Code adapted from the NodeJS kerberos library:
|
||||||
|
//
|
||||||
|
// https://github.com/christkv/kerberos/tree/master/lib/win32/kerberos_sspi.h
|
||||||
|
//
|
||||||
|
// Under the terms of the Apache License, Version 2.0:
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
#ifndef SSPI_WINDOWS_H
|
||||||
|
#define SSPI_WINDOWS_H
|
||||||
|
|
||||||
|
#define SECURITY_WIN32 1
|
||||||
|
|
||||||
|
#include <windows.h>
|
||||||
|
#include <sspi.h>
|
||||||
|
|
||||||
|
int load_secur32_dll();
|
||||||
|
|
||||||
|
SECURITY_STATUS SEC_ENTRY call_sspi_encrypt_message(PCtxtHandle phContext, unsigned long fQOP, PSecBufferDesc pMessage, unsigned long MessageSeqNo);
|
||||||
|
|
||||||
|
typedef DWORD (WINAPI *encryptMessage_fn)(PCtxtHandle phContext, ULONG fQOP, PSecBufferDesc pMessage, ULONG MessageSeqNo);
|
||||||
|
|
||||||
|
SECURITY_STATUS SEC_ENTRY call_sspi_acquire_credentials_handle(
|
||||||
|
LPSTR pszPrincipal, // Name of principal
|
||||||
|
LPSTR pszPackage, // Name of package
|
||||||
|
unsigned long fCredentialUse, // Flags indicating use
|
||||||
|
void *pvLogonId, // Pointer to logon ID
|
||||||
|
void *pAuthData, // Package specific data
|
||||||
|
SEC_GET_KEY_FN pGetKeyFn, // Pointer to GetKey() func
|
||||||
|
void *pvGetKeyArgument, // Value to pass to GetKey()
|
||||||
|
PCredHandle phCredential, // (out) Cred Handle
|
||||||
|
PTimeStamp ptsExpiry // (out) Lifetime (optional)
|
||||||
|
);
|
||||||
|
|
||||||
|
typedef DWORD (WINAPI *acquireCredentialsHandle_fn)(
|
||||||
|
LPSTR pszPrincipal, LPSTR pszPackage, unsigned long fCredentialUse,
|
||||||
|
void *pvLogonId, void *pAuthData, SEC_GET_KEY_FN pGetKeyFn, void *pvGetKeyArgument,
|
||||||
|
PCredHandle phCredential, PTimeStamp ptsExpiry
|
||||||
|
);
|
||||||
|
|
||||||
|
SECURITY_STATUS SEC_ENTRY call_sspi_initialize_security_context(
|
||||||
|
PCredHandle phCredential, // Cred to base context
|
||||||
|
PCtxtHandle phContext, // Existing context (OPT)
|
||||||
|
LPSTR pszTargetName, // Name of target
|
||||||
|
unsigned long fContextReq, // Context Requirements
|
||||||
|
unsigned long Reserved1, // Reserved, MBZ
|
||||||
|
unsigned long TargetDataRep, // Data rep of target
|
||||||
|
PSecBufferDesc pInput, // Input Buffers
|
||||||
|
unsigned long Reserved2, // Reserved, MBZ
|
||||||
|
PCtxtHandle phNewContext, // (out) New Context handle
|
||||||
|
PSecBufferDesc pOutput, // (inout) Output Buffers
|
||||||
|
unsigned long *pfContextAttr, // (out) Context attrs
|
||||||
|
PTimeStamp ptsExpiry // (out) Life span (OPT)
|
||||||
|
);
|
||||||
|
|
||||||
|
typedef DWORD (WINAPI *initializeSecurityContext_fn)(
|
||||||
|
PCredHandle phCredential, PCtxtHandle phContext, LPSTR pszTargetName, unsigned long fContextReq,
|
||||||
|
unsigned long Reserved1, unsigned long TargetDataRep, PSecBufferDesc pInput, unsigned long Reserved2,
|
||||||
|
PCtxtHandle phNewContext, PSecBufferDesc pOutput, unsigned long *pfContextAttr, PTimeStamp ptsExpiry);
|
||||||
|
|
||||||
|
SECURITY_STATUS SEC_ENTRY call_sspi_query_context_attributes(
|
||||||
|
PCtxtHandle phContext, // Context to query
|
||||||
|
unsigned long ulAttribute, // Attribute to query
|
||||||
|
void *pBuffer // Buffer for attributes
|
||||||
|
);
|
||||||
|
|
||||||
|
typedef DWORD (WINAPI *queryContextAttributes_fn)(
|
||||||
|
PCtxtHandle phContext, unsigned long ulAttribute, void *pBuffer);
|
||||||
|
|
||||||
|
#endif // SSPI_WINDOWS_H
|
|
@ -0,0 +1,266 @@
|
||||||
|
// mgo - MongoDB driver for Go
|
||||||
|
//
|
||||||
|
// Copyright (c) 2014 - Gustavo Niemeyer <gustavo@niemeyer.net>
|
||||||
|
//
|
||||||
|
// All rights reserved.
|
||||||
|
//
|
||||||
|
// Redistribution and use in source and binary forms, with or without
|
||||||
|
// modification, are permitted provided that the following conditions are met:
|
||||||
|
//
|
||||||
|
// 1. Redistributions of source code must retain the above copyright notice, this
|
||||||
|
// list of conditions and the following disclaimer.
|
||||||
|
// 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||||
|
// this list of conditions and the following disclaimer in the documentation
|
||||||
|
// and/or other materials provided with the distribution.
|
||||||
|
//
|
||||||
|
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
|
||||||
|
// ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
|
||||||
|
// WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||||
|
// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
|
||||||
|
// ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
|
||||||
|
// (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
|
||||||
|
// LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
|
||||||
|
// ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||||
|
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||||
|
// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||||
|
|
||||||
|
// Pacakage scram implements a SCRAM-{SHA-1,etc} client per RFC5802.
|
||||||
|
//
|
||||||
|
// http://tools.ietf.org/html/rfc5802
|
||||||
|
//
|
||||||
|
package scram
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"crypto/hmac"
|
||||||
|
"crypto/rand"
|
||||||
|
"encoding/base64"
|
||||||
|
"fmt"
|
||||||
|
"hash"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Client implements a SCRAM-* client (SCRAM-SHA-1, SCRAM-SHA-256, etc).
|
||||||
|
//
|
||||||
|
// A Client may be used within a SASL conversation with logic resembling:
|
||||||
|
//
|
||||||
|
// var in []byte
|
||||||
|
// var client = scram.NewClient(sha1.New, user, pass)
|
||||||
|
// for client.Step(in) {
|
||||||
|
// out := client.Out()
|
||||||
|
// // send out to server
|
||||||
|
// in := serverOut
|
||||||
|
// }
|
||||||
|
// if client.Err() != nil {
|
||||||
|
// // auth failed
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
type Client struct {
|
||||||
|
newHash func() hash.Hash
|
||||||
|
|
||||||
|
user string
|
||||||
|
pass string
|
||||||
|
step int
|
||||||
|
out bytes.Buffer
|
||||||
|
err error
|
||||||
|
|
||||||
|
clientNonce []byte
|
||||||
|
serverNonce []byte
|
||||||
|
saltedPass []byte
|
||||||
|
authMsg bytes.Buffer
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewClient returns a new SCRAM-* client with the provided hash algorithm.
|
||||||
|
//
|
||||||
|
// For SCRAM-SHA-1, for example, use:
|
||||||
|
//
|
||||||
|
// client := scram.NewClient(sha1.New, user, pass)
|
||||||
|
//
|
||||||
|
func NewClient(newHash func() hash.Hash, user, pass string) *Client {
|
||||||
|
c := &Client{
|
||||||
|
newHash: newHash,
|
||||||
|
user: user,
|
||||||
|
pass: pass,
|
||||||
|
}
|
||||||
|
c.out.Grow(256)
|
||||||
|
c.authMsg.Grow(256)
|
||||||
|
return c
|
||||||
|
}
|
||||||
|
|
||||||
|
// Out returns the data to be sent to the server in the current step.
|
||||||
|
func (c *Client) Out() []byte {
|
||||||
|
if c.out.Len() == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return c.out.Bytes()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Err returns the error that ocurred, or nil if there were no errors.
|
||||||
|
func (c *Client) Err() error {
|
||||||
|
return c.err
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetNonce sets the client nonce to the provided value.
|
||||||
|
// If not set, the nonce is generated automatically out of crypto/rand on the first step.
|
||||||
|
func (c *Client) SetNonce(nonce []byte) {
|
||||||
|
c.clientNonce = nonce
|
||||||
|
}
|
||||||
|
|
||||||
|
var escaper = strings.NewReplacer("=", "=3D", ",", "=2C")
|
||||||
|
|
||||||
|
// Step processes the incoming data from the server and makes the
|
||||||
|
// next round of data for the server available via Client.Out.
|
||||||
|
// Step returns false if there are no errors and more data is
|
||||||
|
// still expected.
|
||||||
|
func (c *Client) Step(in []byte) bool {
|
||||||
|
c.out.Reset()
|
||||||
|
if c.step > 2 || c.err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
c.step++
|
||||||
|
switch c.step {
|
||||||
|
case 1:
|
||||||
|
c.err = c.step1(in)
|
||||||
|
case 2:
|
||||||
|
c.err = c.step2(in)
|
||||||
|
case 3:
|
||||||
|
c.err = c.step3(in)
|
||||||
|
}
|
||||||
|
return c.step > 2 || c.err != nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) step1(in []byte) error {
|
||||||
|
if len(c.clientNonce) == 0 {
|
||||||
|
const nonceLen = 6
|
||||||
|
buf := make([]byte, nonceLen + b64.EncodedLen(nonceLen))
|
||||||
|
if _, err := rand.Read(buf[:nonceLen]); err != nil {
|
||||||
|
return fmt.Errorf("cannot read random SCRAM-SHA-1 nonce from operating system: %v", err)
|
||||||
|
}
|
||||||
|
c.clientNonce = buf[nonceLen:]
|
||||||
|
b64.Encode(c.clientNonce, buf[:nonceLen])
|
||||||
|
}
|
||||||
|
c.authMsg.WriteString("n=")
|
||||||
|
escaper.WriteString(&c.authMsg, c.user)
|
||||||
|
c.authMsg.WriteString(",r=")
|
||||||
|
c.authMsg.Write(c.clientNonce)
|
||||||
|
|
||||||
|
c.out.WriteString("n,,")
|
||||||
|
c.out.Write(c.authMsg.Bytes())
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var b64 = base64.StdEncoding
|
||||||
|
|
||||||
|
func (c *Client) step2(in []byte) error {
|
||||||
|
c.authMsg.WriteByte(',')
|
||||||
|
c.authMsg.Write(in)
|
||||||
|
|
||||||
|
fields := bytes.Split(in, []byte(","))
|
||||||
|
if len(fields) != 3 {
|
||||||
|
return fmt.Errorf("expected 3 fields in first SCRAM-SHA-1 server message, got %d: %q", len(fields), in)
|
||||||
|
}
|
||||||
|
if !bytes.HasPrefix(fields[0], []byte("r=")) || len(fields[0]) < 2 {
|
||||||
|
return fmt.Errorf("server sent an invalid SCRAM-SHA-1 nonce: %q", fields[0])
|
||||||
|
}
|
||||||
|
if !bytes.HasPrefix(fields[1], []byte("s=")) || len(fields[1]) < 6 {
|
||||||
|
return fmt.Errorf("server sent an invalid SCRAM-SHA-1 salt: %q", fields[1])
|
||||||
|
}
|
||||||
|
if !bytes.HasPrefix(fields[2], []byte("i=")) || len(fields[2]) < 6 {
|
||||||
|
return fmt.Errorf("server sent an invalid SCRAM-SHA-1 iteration count: %q", fields[2])
|
||||||
|
}
|
||||||
|
|
||||||
|
c.serverNonce = fields[0][2:]
|
||||||
|
if !bytes.HasPrefix(c.serverNonce, c.clientNonce) {
|
||||||
|
return fmt.Errorf("server SCRAM-SHA-1 nonce is not prefixed by client nonce: got %q, want %q+\"...\"", c.serverNonce, c.clientNonce)
|
||||||
|
}
|
||||||
|
|
||||||
|
salt := make([]byte, b64.DecodedLen(len(fields[1][2:])))
|
||||||
|
n, err := b64.Decode(salt, fields[1][2:])
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("cannot decode SCRAM-SHA-1 salt sent by server: %q", fields[1])
|
||||||
|
}
|
||||||
|
salt = salt[:n]
|
||||||
|
iterCount, err := strconv.Atoi(string(fields[2][2:]))
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("server sent an invalid SCRAM-SHA-1 iteration count: %q", fields[2])
|
||||||
|
}
|
||||||
|
c.saltPassword(salt, iterCount)
|
||||||
|
|
||||||
|
c.authMsg.WriteString(",c=biws,r=")
|
||||||
|
c.authMsg.Write(c.serverNonce)
|
||||||
|
|
||||||
|
c.out.WriteString("c=biws,r=")
|
||||||
|
c.out.Write(c.serverNonce)
|
||||||
|
c.out.WriteString(",p=")
|
||||||
|
c.out.Write(c.clientProof())
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) step3(in []byte) error {
|
||||||
|
var isv, ise bool
|
||||||
|
var fields = bytes.Split(in, []byte(","))
|
||||||
|
if len(fields) == 1 {
|
||||||
|
isv = bytes.HasPrefix(fields[0], []byte("v="))
|
||||||
|
ise = bytes.HasPrefix(fields[0], []byte("e="))
|
||||||
|
}
|
||||||
|
if ise {
|
||||||
|
return fmt.Errorf("SCRAM-SHA-1 authentication error: %s", fields[0][2:])
|
||||||
|
} else if !isv {
|
||||||
|
return fmt.Errorf("unsupported SCRAM-SHA-1 final message from server: %q", in)
|
||||||
|
}
|
||||||
|
if !bytes.Equal(c.serverSignature(), fields[0][2:]) {
|
||||||
|
return fmt.Errorf("cannot authenticate SCRAM-SHA-1 server signature: %q", fields[0][2:])
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) saltPassword(salt []byte, iterCount int) {
|
||||||
|
mac := hmac.New(c.newHash, []byte(c.pass))
|
||||||
|
mac.Write(salt)
|
||||||
|
mac.Write([]byte{0, 0, 0, 1})
|
||||||
|
ui := mac.Sum(nil)
|
||||||
|
hi := make([]byte, len(ui))
|
||||||
|
copy(hi, ui)
|
||||||
|
for i := 1; i < iterCount; i++ {
|
||||||
|
mac.Reset()
|
||||||
|
mac.Write(ui)
|
||||||
|
mac.Sum(ui[:0])
|
||||||
|
for j, b := range ui {
|
||||||
|
hi[j] ^= b
|
||||||
|
}
|
||||||
|
}
|
||||||
|
c.saltedPass = hi
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) clientProof() []byte {
|
||||||
|
mac := hmac.New(c.newHash, c.saltedPass)
|
||||||
|
mac.Write([]byte("Client Key"))
|
||||||
|
clientKey := mac.Sum(nil)
|
||||||
|
hash := c.newHash()
|
||||||
|
hash.Write(clientKey)
|
||||||
|
storedKey := hash.Sum(nil)
|
||||||
|
mac = hmac.New(c.newHash, storedKey)
|
||||||
|
mac.Write(c.authMsg.Bytes())
|
||||||
|
clientProof := mac.Sum(nil)
|
||||||
|
for i, b := range clientKey {
|
||||||
|
clientProof[i] ^= b
|
||||||
|
}
|
||||||
|
clientProof64 := make([]byte, b64.EncodedLen(len(clientProof)))
|
||||||
|
b64.Encode(clientProof64, clientProof)
|
||||||
|
return clientProof64
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) serverSignature() []byte {
|
||||||
|
mac := hmac.New(c.newHash, c.saltedPass)
|
||||||
|
mac.Write([]byte("Server Key"))
|
||||||
|
serverKey := mac.Sum(nil)
|
||||||
|
|
||||||
|
mac = hmac.New(c.newHash, serverKey)
|
||||||
|
mac.Write(c.authMsg.Bytes())
|
||||||
|
serverSignature := mac.Sum(nil)
|
||||||
|
|
||||||
|
encoded := make([]byte, b64.EncodedLen(len(serverSignature)))
|
||||||
|
b64.Encode(encoded, serverSignature)
|
||||||
|
return encoded
|
||||||
|
}
|
|
@ -0,0 +1,133 @@
|
||||||
|
// mgo - MongoDB driver for Go
|
||||||
|
//
|
||||||
|
// Copyright (c) 2010-2012 - Gustavo Niemeyer <gustavo@niemeyer.net>
|
||||||
|
//
|
||||||
|
// All rights reserved.
|
||||||
|
//
|
||||||
|
// Redistribution and use in source and binary forms, with or without
|
||||||
|
// modification, are permitted provided that the following conditions are met:
|
||||||
|
//
|
||||||
|
// 1. Redistributions of source code must retain the above copyright notice, this
|
||||||
|
// list of conditions and the following disclaimer.
|
||||||
|
// 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||||
|
// this list of conditions and the following disclaimer in the documentation
|
||||||
|
// and/or other materials provided with the distribution.
|
||||||
|
//
|
||||||
|
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
|
||||||
|
// ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
|
||||||
|
// WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||||
|
// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
|
||||||
|
// ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
|
||||||
|
// (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
|
||||||
|
// LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
|
||||||
|
// ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||||
|
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||||
|
// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||||
|
|
||||||
|
package mgo
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"sync"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// Logging integration.
|
||||||
|
|
||||||
|
// Avoid importing the log type information unnecessarily. There's a small cost
|
||||||
|
// associated with using an interface rather than the type. Depending on how
|
||||||
|
// often the logger is plugged in, it would be worth using the type instead.
|
||||||
|
type log_Logger interface {
|
||||||
|
Output(calldepth int, s string) error
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
globalLogger log_Logger
|
||||||
|
globalDebug bool
|
||||||
|
globalMutex sync.Mutex
|
||||||
|
)
|
||||||
|
|
||||||
|
// RACE WARNING: There are known data races when logging, which are manually
|
||||||
|
// silenced when the race detector is in use. These data races won't be
|
||||||
|
// observed in typical use, because logging is supposed to be set up once when
|
||||||
|
// the application starts. Having raceDetector as a constant, the compiler
|
||||||
|
// should elide the locks altogether in actual use.
|
||||||
|
|
||||||
|
// Specify the *log.Logger object where log messages should be sent to.
|
||||||
|
func SetLogger(logger log_Logger) {
|
||||||
|
if raceDetector {
|
||||||
|
globalMutex.Lock()
|
||||||
|
defer globalMutex.Unlock()
|
||||||
|
}
|
||||||
|
globalLogger = logger
|
||||||
|
}
|
||||||
|
|
||||||
|
// Enable the delivery of debug messages to the logger. Only meaningful
|
||||||
|
// if a logger is also set.
|
||||||
|
func SetDebug(debug bool) {
|
||||||
|
if raceDetector {
|
||||||
|
globalMutex.Lock()
|
||||||
|
defer globalMutex.Unlock()
|
||||||
|
}
|
||||||
|
globalDebug = debug
|
||||||
|
}
|
||||||
|
|
||||||
|
func log(v ...interface{}) {
|
||||||
|
if raceDetector {
|
||||||
|
globalMutex.Lock()
|
||||||
|
defer globalMutex.Unlock()
|
||||||
|
}
|
||||||
|
if globalLogger != nil {
|
||||||
|
globalLogger.Output(2, fmt.Sprint(v...))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func logln(v ...interface{}) {
|
||||||
|
if raceDetector {
|
||||||
|
globalMutex.Lock()
|
||||||
|
defer globalMutex.Unlock()
|
||||||
|
}
|
||||||
|
if globalLogger != nil {
|
||||||
|
globalLogger.Output(2, fmt.Sprintln(v...))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func logf(format string, v ...interface{}) {
|
||||||
|
if raceDetector {
|
||||||
|
globalMutex.Lock()
|
||||||
|
defer globalMutex.Unlock()
|
||||||
|
}
|
||||||
|
if globalLogger != nil {
|
||||||
|
globalLogger.Output(2, fmt.Sprintf(format, v...))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func debug(v ...interface{}) {
|
||||||
|
if raceDetector {
|
||||||
|
globalMutex.Lock()
|
||||||
|
defer globalMutex.Unlock()
|
||||||
|
}
|
||||||
|
if globalDebug && globalLogger != nil {
|
||||||
|
globalLogger.Output(2, fmt.Sprint(v...))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func debugln(v ...interface{}) {
|
||||||
|
if raceDetector {
|
||||||
|
globalMutex.Lock()
|
||||||
|
defer globalMutex.Unlock()
|
||||||
|
}
|
||||||
|
if globalDebug && globalLogger != nil {
|
||||||
|
globalLogger.Output(2, fmt.Sprintln(v...))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func debugf(format string, v ...interface{}) {
|
||||||
|
if raceDetector {
|
||||||
|
globalMutex.Lock()
|
||||||
|
defer globalMutex.Unlock()
|
||||||
|
}
|
||||||
|
if globalDebug && globalLogger != nil {
|
||||||
|
globalLogger.Output(2, fmt.Sprintf(format, v...))
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,91 @@
|
||||||
|
// mgo - MongoDB driver for Go
|
||||||
|
//
|
||||||
|
// Copyright (c) 2010-2012 - Gustavo Niemeyer <gustavo@niemeyer.net>
|
||||||
|
//
|
||||||
|
// All rights reserved.
|
||||||
|
//
|
||||||
|
// Redistribution and use in source and binary forms, with or without
|
||||||
|
// modification, are permitted provided that the following conditions are met:
|
||||||
|
//
|
||||||
|
// 1. Redistributions of source code must retain the above copyright notice, this
|
||||||
|
// list of conditions and the following disclaimer.
|
||||||
|
// 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||||
|
// this list of conditions and the following disclaimer in the documentation
|
||||||
|
// and/or other materials provided with the distribution.
|
||||||
|
//
|
||||||
|
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
|
||||||
|
// ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
|
||||||
|
// WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||||
|
// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
|
||||||
|
// ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
|
||||||
|
// (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
|
||||||
|
// LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
|
||||||
|
// ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||||
|
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||||
|
// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||||
|
|
||||||
|
package mgo
|
||||||
|
|
||||||
|
type queue struct {
|
||||||
|
elems []interface{}
|
||||||
|
nelems, popi, pushi int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (q *queue) Len() int {
|
||||||
|
return q.nelems
|
||||||
|
}
|
||||||
|
|
||||||
|
func (q *queue) Push(elem interface{}) {
|
||||||
|
//debugf("Pushing(pushi=%d popi=%d cap=%d): %#v\n",
|
||||||
|
// q.pushi, q.popi, len(q.elems), elem)
|
||||||
|
if q.nelems == len(q.elems) {
|
||||||
|
q.expand()
|
||||||
|
}
|
||||||
|
q.elems[q.pushi] = elem
|
||||||
|
q.nelems++
|
||||||
|
q.pushi = (q.pushi + 1) % len(q.elems)
|
||||||
|
//debugf(" Pushed(pushi=%d popi=%d cap=%d): %#v\n",
|
||||||
|
// q.pushi, q.popi, len(q.elems), elem)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (q *queue) Pop() (elem interface{}) {
|
||||||
|
//debugf("Popping(pushi=%d popi=%d cap=%d)\n",
|
||||||
|
// q.pushi, q.popi, len(q.elems))
|
||||||
|
if q.nelems == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
elem = q.elems[q.popi]
|
||||||
|
q.elems[q.popi] = nil // Help GC.
|
||||||
|
q.nelems--
|
||||||
|
q.popi = (q.popi + 1) % len(q.elems)
|
||||||
|
//debugf(" Popped(pushi=%d popi=%d cap=%d): %#v\n",
|
||||||
|
// q.pushi, q.popi, len(q.elems), elem)
|
||||||
|
return elem
|
||||||
|
}
|
||||||
|
|
||||||
|
func (q *queue) expand() {
|
||||||
|
curcap := len(q.elems)
|
||||||
|
var newcap int
|
||||||
|
if curcap == 0 {
|
||||||
|
newcap = 8
|
||||||
|
} else if curcap < 1024 {
|
||||||
|
newcap = curcap * 2
|
||||||
|
} else {
|
||||||
|
newcap = curcap + (curcap / 4)
|
||||||
|
}
|
||||||
|
elems := make([]interface{}, newcap)
|
||||||
|
|
||||||
|
if q.popi == 0 {
|
||||||
|
copy(elems, q.elems)
|
||||||
|
q.pushi = curcap
|
||||||
|
} else {
|
||||||
|
newpopi := newcap - (curcap - q.popi)
|
||||||
|
copy(elems, q.elems[:q.popi])
|
||||||
|
copy(elems[newpopi:], q.elems[q.popi:])
|
||||||
|
q.popi = newpopi
|
||||||
|
}
|
||||||
|
for i := range q.elems {
|
||||||
|
q.elems[i] = nil // Help GC.
|
||||||
|
}
|
||||||
|
q.elems = elems
|
||||||
|
}
|
|
@ -0,0 +1,5 @@
|
||||||
|
// +build !race
|
||||||
|
|
||||||
|
package mgo
|
||||||
|
|
||||||
|
const raceDetector = false
|
|
@ -0,0 +1,5 @@
|
||||||
|
// +build race
|
||||||
|
|
||||||
|
package mgo
|
||||||
|
|
||||||
|
const raceDetector = true
|
|
@ -0,0 +1,11 @@
|
||||||
|
//+build sasl
|
||||||
|
|
||||||
|
package mgo
|
||||||
|
|
||||||
|
import (
|
||||||
|
"gopkg.in/mgo.v2/internal/sasl"
|
||||||
|
)
|
||||||
|
|
||||||
|
func saslNew(cred Credential, host string) (saslStepper, error) {
|
||||||
|
return sasl.New(cred.Username, cred.Password, cred.Mechanism, cred.Service, host)
|
||||||
|
}
|
|
@ -0,0 +1,11 @@
|
||||||
|
//+build !sasl
|
||||||
|
|
||||||
|
package mgo
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
)
|
||||||
|
|
||||||
|
func saslNew(cred Credential, host string) (saslStepper, error) {
|
||||||
|
return nil, fmt.Errorf("SASL support not enabled during build (-tags sasl)")
|
||||||
|
}
|
|
@ -0,0 +1,463 @@
|
||||||
|
// mgo - MongoDB driver for Go
|
||||||
|
//
|
||||||
|
// Copyright (c) 2010-2012 - Gustavo Niemeyer <gustavo@niemeyer.net>
|
||||||
|
//
|
||||||
|
// All rights reserved.
|
||||||
|
//
|
||||||
|
// Redistribution and use in source and binary forms, with or without
|
||||||
|
// modification, are permitted provided that the following conditions are met:
|
||||||
|
//
|
||||||
|
// 1. Redistributions of source code must retain the above copyright notice, this
|
||||||
|
// list of conditions and the following disclaimer.
|
||||||
|
// 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||||
|
// this list of conditions and the following disclaimer in the documentation
|
||||||
|
// and/or other materials provided with the distribution.
|
||||||
|
//
|
||||||
|
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
|
||||||
|
// ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
|
||||||
|
// WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||||
|
// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
|
||||||
|
// ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
|
||||||
|
// (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
|
||||||
|
// LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
|
||||||
|
// ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||||
|
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||||
|
// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||||
|
|
||||||
|
package mgo
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"net"
|
||||||
|
"sort"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"gopkg.in/mgo.v2/bson"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// Mongo server encapsulation.
|
||||||
|
|
||||||
|
type mongoServer struct {
|
||||||
|
sync.RWMutex
|
||||||
|
Addr string
|
||||||
|
ResolvedAddr string
|
||||||
|
tcpaddr *net.TCPAddr
|
||||||
|
unusedSockets []*mongoSocket
|
||||||
|
liveSockets []*mongoSocket
|
||||||
|
closed bool
|
||||||
|
abended bool
|
||||||
|
sync chan bool
|
||||||
|
dial dialer
|
||||||
|
pingValue time.Duration
|
||||||
|
pingIndex int
|
||||||
|
pingCount uint32
|
||||||
|
pingWindow [6]time.Duration
|
||||||
|
info *mongoServerInfo
|
||||||
|
}
|
||||||
|
|
||||||
|
type dialer struct {
|
||||||
|
old func(addr net.Addr) (net.Conn, error)
|
||||||
|
new func(addr *ServerAddr) (net.Conn, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (dial dialer) isSet() bool {
|
||||||
|
return dial.old != nil || dial.new != nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type mongoServerInfo struct {
|
||||||
|
Master bool
|
||||||
|
Mongos bool
|
||||||
|
Tags bson.D
|
||||||
|
MaxWireVersion int
|
||||||
|
SetName string
|
||||||
|
}
|
||||||
|
|
||||||
|
var defaultServerInfo mongoServerInfo
|
||||||
|
|
||||||
|
func newServer(addr string, tcpaddr *net.TCPAddr, sync chan bool, dial dialer) *mongoServer {
|
||||||
|
server := &mongoServer{
|
||||||
|
Addr: addr,
|
||||||
|
ResolvedAddr: tcpaddr.String(),
|
||||||
|
tcpaddr: tcpaddr,
|
||||||
|
sync: sync,
|
||||||
|
dial: dial,
|
||||||
|
info: &defaultServerInfo,
|
||||||
|
pingValue: time.Hour, // Push it back before an actual ping.
|
||||||
|
}
|
||||||
|
go server.pinger(true)
|
||||||
|
return server
|
||||||
|
}
|
||||||
|
|
||||||
|
var errPoolLimit = errors.New("per-server connection limit reached")
|
||||||
|
var errServerClosed = errors.New("server was closed")
|
||||||
|
|
||||||
|
// AcquireSocket returns a socket for communicating with the server.
|
||||||
|
// This will attempt to reuse an old connection, if one is available. Otherwise,
|
||||||
|
// it will establish a new one. The returned socket is owned by the call site,
|
||||||
|
// and will return to the cache when the socket has its Release method called
|
||||||
|
// the same number of times as AcquireSocket + Acquire were called for it.
|
||||||
|
// If the poolLimit argument is greater than zero and the number of sockets in
|
||||||
|
// use in this server is greater than the provided limit, errPoolLimit is
|
||||||
|
// returned.
|
||||||
|
func (server *mongoServer) AcquireSocket(poolLimit int, timeout time.Duration) (socket *mongoSocket, abended bool, err error) {
|
||||||
|
for {
|
||||||
|
server.Lock()
|
||||||
|
abended = server.abended
|
||||||
|
if server.closed {
|
||||||
|
server.Unlock()
|
||||||
|
return nil, abended, errServerClosed
|
||||||
|
}
|
||||||
|
n := len(server.unusedSockets)
|
||||||
|
if poolLimit > 0 && len(server.liveSockets)-n >= poolLimit {
|
||||||
|
server.Unlock()
|
||||||
|
return nil, false, errPoolLimit
|
||||||
|
}
|
||||||
|
if n > 0 {
|
||||||
|
socket = server.unusedSockets[n-1]
|
||||||
|
server.unusedSockets[n-1] = nil // Help GC.
|
||||||
|
server.unusedSockets = server.unusedSockets[:n-1]
|
||||||
|
info := server.info
|
||||||
|
server.Unlock()
|
||||||
|
err = socket.InitialAcquire(info, timeout)
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
server.Unlock()
|
||||||
|
socket, err = server.Connect(timeout)
|
||||||
|
if err == nil {
|
||||||
|
server.Lock()
|
||||||
|
// We've waited for the Connect, see if we got
|
||||||
|
// closed in the meantime
|
||||||
|
if server.closed {
|
||||||
|
server.Unlock()
|
||||||
|
socket.Release()
|
||||||
|
socket.Close()
|
||||||
|
return nil, abended, errServerClosed
|
||||||
|
}
|
||||||
|
server.liveSockets = append(server.liveSockets, socket)
|
||||||
|
server.Unlock()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
panic("unreachable")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Connect establishes a new connection to the server. This should
|
||||||
|
// generally be done through server.AcquireSocket().
|
||||||
|
func (server *mongoServer) Connect(timeout time.Duration) (*mongoSocket, error) {
|
||||||
|
server.RLock()
|
||||||
|
master := server.info.Master
|
||||||
|
dial := server.dial
|
||||||
|
server.RUnlock()
|
||||||
|
|
||||||
|
logf("Establishing new connection to %s (timeout=%s)...", server.Addr, timeout)
|
||||||
|
var conn net.Conn
|
||||||
|
var err error
|
||||||
|
switch {
|
||||||
|
case !dial.isSet():
|
||||||
|
// Cannot do this because it lacks timeout support. :-(
|
||||||
|
//conn, err = net.DialTCP("tcp", nil, server.tcpaddr)
|
||||||
|
conn, err = net.DialTimeout("tcp", server.ResolvedAddr, timeout)
|
||||||
|
if tcpconn, ok := conn.(*net.TCPConn); ok {
|
||||||
|
tcpconn.SetKeepAlive(true)
|
||||||
|
} else if err == nil {
|
||||||
|
panic("internal error: obtained TCP connection is not a *net.TCPConn!?")
|
||||||
|
}
|
||||||
|
case dial.old != nil:
|
||||||
|
conn, err = dial.old(server.tcpaddr)
|
||||||
|
case dial.new != nil:
|
||||||
|
conn, err = dial.new(&ServerAddr{server.Addr, server.tcpaddr})
|
||||||
|
default:
|
||||||
|
panic("dialer is set, but both dial.old and dial.new are nil")
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
logf("Connection to %s failed: %v", server.Addr, err.Error())
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
logf("Connection to %s established.", server.Addr)
|
||||||
|
|
||||||
|
stats.conn(+1, master)
|
||||||
|
return newSocket(server, conn, timeout), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close forces closing all sockets that are alive, whether
|
||||||
|
// they're currently in use or not.
|
||||||
|
func (server *mongoServer) Close() {
|
||||||
|
server.Lock()
|
||||||
|
server.closed = true
|
||||||
|
liveSockets := server.liveSockets
|
||||||
|
unusedSockets := server.unusedSockets
|
||||||
|
server.liveSockets = nil
|
||||||
|
server.unusedSockets = nil
|
||||||
|
server.Unlock()
|
||||||
|
logf("Connections to %s closing (%d live sockets).", server.Addr, len(liveSockets))
|
||||||
|
for i, s := range liveSockets {
|
||||||
|
s.Close()
|
||||||
|
liveSockets[i] = nil
|
||||||
|
}
|
||||||
|
for i := range unusedSockets {
|
||||||
|
unusedSockets[i] = nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// RecycleSocket puts socket back into the unused cache.
|
||||||
|
func (server *mongoServer) RecycleSocket(socket *mongoSocket) {
|
||||||
|
server.Lock()
|
||||||
|
if !server.closed {
|
||||||
|
server.unusedSockets = append(server.unusedSockets, socket)
|
||||||
|
}
|
||||||
|
server.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
func removeSocket(sockets []*mongoSocket, socket *mongoSocket) []*mongoSocket {
|
||||||
|
for i, s := range sockets {
|
||||||
|
if s == socket {
|
||||||
|
copy(sockets[i:], sockets[i+1:])
|
||||||
|
n := len(sockets) - 1
|
||||||
|
sockets[n] = nil
|
||||||
|
sockets = sockets[:n]
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return sockets
|
||||||
|
}
|
||||||
|
|
||||||
|
// AbendSocket notifies the server that the given socket has terminated
|
||||||
|
// abnormally, and thus should be discarded rather than cached.
|
||||||
|
func (server *mongoServer) AbendSocket(socket *mongoSocket) {
|
||||||
|
server.Lock()
|
||||||
|
server.abended = true
|
||||||
|
if server.closed {
|
||||||
|
server.Unlock()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
server.liveSockets = removeSocket(server.liveSockets, socket)
|
||||||
|
server.unusedSockets = removeSocket(server.unusedSockets, socket)
|
||||||
|
server.Unlock()
|
||||||
|
// Maybe just a timeout, but suggest a cluster sync up just in case.
|
||||||
|
select {
|
||||||
|
case server.sync <- true:
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (server *mongoServer) SetInfo(info *mongoServerInfo) {
|
||||||
|
server.Lock()
|
||||||
|
server.info = info
|
||||||
|
server.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (server *mongoServer) Info() *mongoServerInfo {
|
||||||
|
server.Lock()
|
||||||
|
info := server.info
|
||||||
|
server.Unlock()
|
||||||
|
return info
|
||||||
|
}
|
||||||
|
|
||||||
|
func (server *mongoServer) hasTags(serverTags []bson.D) bool {
|
||||||
|
NextTagSet:
|
||||||
|
for _, tags := range serverTags {
|
||||||
|
NextReqTag:
|
||||||
|
for _, req := range tags {
|
||||||
|
for _, has := range server.info.Tags {
|
||||||
|
if req.Name == has.Name {
|
||||||
|
if req.Value == has.Value {
|
||||||
|
continue NextReqTag
|
||||||
|
}
|
||||||
|
continue NextTagSet
|
||||||
|
}
|
||||||
|
}
|
||||||
|
continue NextTagSet
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
var pingDelay = 15 * time.Second
|
||||||
|
|
||||||
|
func (server *mongoServer) pinger(loop bool) {
|
||||||
|
var delay time.Duration
|
||||||
|
if raceDetector {
|
||||||
|
// This variable is only ever touched by tests.
|
||||||
|
globalMutex.Lock()
|
||||||
|
delay = pingDelay
|
||||||
|
globalMutex.Unlock()
|
||||||
|
} else {
|
||||||
|
delay = pingDelay
|
||||||
|
}
|
||||||
|
op := queryOp{
|
||||||
|
collection: "admin.$cmd",
|
||||||
|
query: bson.D{{"ping", 1}},
|
||||||
|
flags: flagSlaveOk,
|
||||||
|
limit: -1,
|
||||||
|
}
|
||||||
|
for {
|
||||||
|
if loop {
|
||||||
|
time.Sleep(delay)
|
||||||
|
}
|
||||||
|
op := op
|
||||||
|
socket, _, err := server.AcquireSocket(0, delay)
|
||||||
|
if err == nil {
|
||||||
|
start := time.Now()
|
||||||
|
_, _ = socket.SimpleQuery(&op)
|
||||||
|
delay := time.Now().Sub(start)
|
||||||
|
|
||||||
|
server.pingWindow[server.pingIndex] = delay
|
||||||
|
server.pingIndex = (server.pingIndex + 1) % len(server.pingWindow)
|
||||||
|
server.pingCount++
|
||||||
|
var max time.Duration
|
||||||
|
for i := 0; i < len(server.pingWindow) && uint32(i) < server.pingCount; i++ {
|
||||||
|
if server.pingWindow[i] > max {
|
||||||
|
max = server.pingWindow[i]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
socket.Release()
|
||||||
|
server.Lock()
|
||||||
|
if server.closed {
|
||||||
|
loop = false
|
||||||
|
}
|
||||||
|
server.pingValue = max
|
||||||
|
server.Unlock()
|
||||||
|
logf("Ping for %s is %d ms", server.Addr, max/time.Millisecond)
|
||||||
|
} else if err == errServerClosed {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !loop {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type mongoServerSlice []*mongoServer
|
||||||
|
|
||||||
|
func (s mongoServerSlice) Len() int {
|
||||||
|
return len(s)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s mongoServerSlice) Less(i, j int) bool {
|
||||||
|
return s[i].ResolvedAddr < s[j].ResolvedAddr
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s mongoServerSlice) Swap(i, j int) {
|
||||||
|
s[i], s[j] = s[j], s[i]
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s mongoServerSlice) Sort() {
|
||||||
|
sort.Sort(s)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s mongoServerSlice) Search(resolvedAddr string) (i int, ok bool) {
|
||||||
|
n := len(s)
|
||||||
|
i = sort.Search(n, func(i int) bool {
|
||||||
|
return s[i].ResolvedAddr >= resolvedAddr
|
||||||
|
})
|
||||||
|
return i, i != n && s[i].ResolvedAddr == resolvedAddr
|
||||||
|
}
|
||||||
|
|
||||||
|
type mongoServers struct {
|
||||||
|
slice mongoServerSlice
|
||||||
|
}
|
||||||
|
|
||||||
|
func (servers *mongoServers) Search(resolvedAddr string) (server *mongoServer) {
|
||||||
|
if i, ok := servers.slice.Search(resolvedAddr); ok {
|
||||||
|
return servers.slice[i]
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (servers *mongoServers) Add(server *mongoServer) {
|
||||||
|
servers.slice = append(servers.slice, server)
|
||||||
|
servers.slice.Sort()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (servers *mongoServers) Remove(other *mongoServer) (server *mongoServer) {
|
||||||
|
if i, found := servers.slice.Search(other.ResolvedAddr); found {
|
||||||
|
server = servers.slice[i]
|
||||||
|
copy(servers.slice[i:], servers.slice[i+1:])
|
||||||
|
n := len(servers.slice) - 1
|
||||||
|
servers.slice[n] = nil // Help GC.
|
||||||
|
servers.slice = servers.slice[:n]
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (servers *mongoServers) Slice() []*mongoServer {
|
||||||
|
return ([]*mongoServer)(servers.slice)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (servers *mongoServers) Get(i int) *mongoServer {
|
||||||
|
return servers.slice[i]
|
||||||
|
}
|
||||||
|
|
||||||
|
func (servers *mongoServers) Len() int {
|
||||||
|
return len(servers.slice)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (servers *mongoServers) Empty() bool {
|
||||||
|
return len(servers.slice) == 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func (servers *mongoServers) HasMongos() bool {
|
||||||
|
for _, s := range servers.slice {
|
||||||
|
if s.Info().Mongos {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// BestFit returns the best guess of what would be the most interesting
|
||||||
|
// server to perform operations on at this point in time.
|
||||||
|
func (servers *mongoServers) BestFit(mode Mode, serverTags []bson.D) *mongoServer {
|
||||||
|
var best *mongoServer
|
||||||
|
for _, next := range servers.slice {
|
||||||
|
if best == nil {
|
||||||
|
best = next
|
||||||
|
best.RLock()
|
||||||
|
if serverTags != nil && !next.info.Mongos && !best.hasTags(serverTags) {
|
||||||
|
best.RUnlock()
|
||||||
|
best = nil
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
next.RLock()
|
||||||
|
swap := false
|
||||||
|
switch {
|
||||||
|
case serverTags != nil && !next.info.Mongos && !next.hasTags(serverTags):
|
||||||
|
// Must have requested tags.
|
||||||
|
case mode == Secondary && next.info.Master && !next.info.Mongos:
|
||||||
|
// Must be a secondary or mongos.
|
||||||
|
case next.info.Master != best.info.Master && mode != Nearest:
|
||||||
|
// Prefer slaves, unless the mode is PrimaryPreferred.
|
||||||
|
swap = (mode == PrimaryPreferred) != best.info.Master
|
||||||
|
case absDuration(next.pingValue-best.pingValue) > 15*time.Millisecond:
|
||||||
|
// Prefer nearest server.
|
||||||
|
swap = next.pingValue < best.pingValue
|
||||||
|
case len(next.liveSockets)-len(next.unusedSockets) < len(best.liveSockets)-len(best.unusedSockets):
|
||||||
|
// Prefer servers with less connections.
|
||||||
|
swap = true
|
||||||
|
}
|
||||||
|
if swap {
|
||||||
|
best.RUnlock()
|
||||||
|
best = next
|
||||||
|
} else {
|
||||||
|
next.RUnlock()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if best != nil {
|
||||||
|
best.RUnlock()
|
||||||
|
}
|
||||||
|
return best
|
||||||
|
}
|
||||||
|
|
||||||
|
func absDuration(d time.Duration) time.Duration {
|
||||||
|
if d < 0 {
|
||||||
|
return -d
|
||||||
|
}
|
||||||
|
return d
|
||||||
|
}
|
Разница между файлами не показана из-за своего большого размера
Загрузить разницу
|
@ -0,0 +1,707 @@
|
||||||
|
// mgo - MongoDB driver for Go
|
||||||
|
//
|
||||||
|
// Copyright (c) 2010-2012 - Gustavo Niemeyer <gustavo@niemeyer.net>
|
||||||
|
//
|
||||||
|
// All rights reserved.
|
||||||
|
//
|
||||||
|
// Redistribution and use in source and binary forms, with or without
|
||||||
|
// modification, are permitted provided that the following conditions are met:
|
||||||
|
//
|
||||||
|
// 1. Redistributions of source code must retain the above copyright notice, this
|
||||||
|
// list of conditions and the following disclaimer.
|
||||||
|
// 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||||
|
// this list of conditions and the following disclaimer in the documentation
|
||||||
|
// and/or other materials provided with the distribution.
|
||||||
|
//
|
||||||
|
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
|
||||||
|
// ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
|
||||||
|
// WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||||
|
// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
|
||||||
|
// ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
|
||||||
|
// (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
|
||||||
|
// LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
|
||||||
|
// ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||||
|
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||||
|
// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||||
|
|
||||||
|
package mgo
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"gopkg.in/mgo.v2/bson"
|
||||||
|
)
|
||||||
|
|
||||||
|
type replyFunc func(err error, reply *replyOp, docNum int, docData []byte)
|
||||||
|
|
||||||
|
type mongoSocket struct {
|
||||||
|
sync.Mutex
|
||||||
|
server *mongoServer // nil when cached
|
||||||
|
conn net.Conn
|
||||||
|
timeout time.Duration
|
||||||
|
addr string // For debugging only.
|
||||||
|
nextRequestId uint32
|
||||||
|
replyFuncs map[uint32]replyFunc
|
||||||
|
references int
|
||||||
|
creds []Credential
|
||||||
|
logout []Credential
|
||||||
|
cachedNonce string
|
||||||
|
gotNonce sync.Cond
|
||||||
|
dead error
|
||||||
|
serverInfo *mongoServerInfo
|
||||||
|
}
|
||||||
|
|
||||||
|
type queryOpFlags uint32
|
||||||
|
|
||||||
|
const (
|
||||||
|
_ queryOpFlags = 1 << iota
|
||||||
|
flagTailable
|
||||||
|
flagSlaveOk
|
||||||
|
flagLogReplay
|
||||||
|
flagNoCursorTimeout
|
||||||
|
flagAwaitData
|
||||||
|
)
|
||||||
|
|
||||||
|
type queryOp struct {
|
||||||
|
collection string
|
||||||
|
query interface{}
|
||||||
|
skip int32
|
||||||
|
limit int32
|
||||||
|
selector interface{}
|
||||||
|
flags queryOpFlags
|
||||||
|
replyFunc replyFunc
|
||||||
|
|
||||||
|
mode Mode
|
||||||
|
options queryWrapper
|
||||||
|
hasOptions bool
|
||||||
|
serverTags []bson.D
|
||||||
|
}
|
||||||
|
|
||||||
|
type queryWrapper struct {
|
||||||
|
Query interface{} "$query"
|
||||||
|
OrderBy interface{} "$orderby,omitempty"
|
||||||
|
Hint interface{} "$hint,omitempty"
|
||||||
|
Explain bool "$explain,omitempty"
|
||||||
|
Snapshot bool "$snapshot,omitempty"
|
||||||
|
ReadPreference bson.D "$readPreference,omitempty"
|
||||||
|
MaxScan int "$maxScan,omitempty"
|
||||||
|
MaxTimeMS int "$maxTimeMS,omitempty"
|
||||||
|
Comment string "$comment,omitempty"
|
||||||
|
}
|
||||||
|
|
||||||
|
func (op *queryOp) finalQuery(socket *mongoSocket) interface{} {
|
||||||
|
if op.flags&flagSlaveOk != 0 && socket.ServerInfo().Mongos {
|
||||||
|
var modeName string
|
||||||
|
switch op.mode {
|
||||||
|
case Strong:
|
||||||
|
modeName = "primary"
|
||||||
|
case Monotonic, Eventual:
|
||||||
|
modeName = "secondaryPreferred"
|
||||||
|
case PrimaryPreferred:
|
||||||
|
modeName = "primaryPreferred"
|
||||||
|
case Secondary:
|
||||||
|
modeName = "secondary"
|
||||||
|
case SecondaryPreferred:
|
||||||
|
modeName = "secondaryPreferred"
|
||||||
|
case Nearest:
|
||||||
|
modeName = "nearest"
|
||||||
|
default:
|
||||||
|
panic(fmt.Sprintf("unsupported read mode: %d", op.mode))
|
||||||
|
}
|
||||||
|
op.hasOptions = true
|
||||||
|
op.options.ReadPreference = make(bson.D, 0, 2)
|
||||||
|
op.options.ReadPreference = append(op.options.ReadPreference, bson.DocElem{"mode", modeName})
|
||||||
|
if len(op.serverTags) > 0 {
|
||||||
|
op.options.ReadPreference = append(op.options.ReadPreference, bson.DocElem{"tags", op.serverTags})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if op.hasOptions {
|
||||||
|
if op.query == nil {
|
||||||
|
var empty bson.D
|
||||||
|
op.options.Query = empty
|
||||||
|
} else {
|
||||||
|
op.options.Query = op.query
|
||||||
|
}
|
||||||
|
debugf("final query is %#v\n", &op.options)
|
||||||
|
return &op.options
|
||||||
|
}
|
||||||
|
return op.query
|
||||||
|
}
|
||||||
|
|
||||||
|
type getMoreOp struct {
|
||||||
|
collection string
|
||||||
|
limit int32
|
||||||
|
cursorId int64
|
||||||
|
replyFunc replyFunc
|
||||||
|
}
|
||||||
|
|
||||||
|
type replyOp struct {
|
||||||
|
flags uint32
|
||||||
|
cursorId int64
|
||||||
|
firstDoc int32
|
||||||
|
replyDocs int32
|
||||||
|
}
|
||||||
|
|
||||||
|
type insertOp struct {
|
||||||
|
collection string // "database.collection"
|
||||||
|
documents []interface{} // One or more documents to insert
|
||||||
|
flags uint32
|
||||||
|
}
|
||||||
|
|
||||||
|
type updateOp struct {
|
||||||
|
Collection string `bson:"-"` // "database.collection"
|
||||||
|
Selector interface{} `bson:"q"`
|
||||||
|
Update interface{} `bson:"u"`
|
||||||
|
Flags uint32 `bson:"-"`
|
||||||
|
Multi bool `bson:"multi,omitempty"`
|
||||||
|
Upsert bool `bson:"upsert,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type deleteOp struct {
|
||||||
|
Collection string `bson:"-"` // "database.collection"
|
||||||
|
Selector interface{} `bson:"q"`
|
||||||
|
Flags uint32 `bson:"-"`
|
||||||
|
Limit int `bson:"limit"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type killCursorsOp struct {
|
||||||
|
cursorIds []int64
|
||||||
|
}
|
||||||
|
|
||||||
|
type requestInfo struct {
|
||||||
|
bufferPos int
|
||||||
|
replyFunc replyFunc
|
||||||
|
}
|
||||||
|
|
||||||
|
func newSocket(server *mongoServer, conn net.Conn, timeout time.Duration) *mongoSocket {
|
||||||
|
socket := &mongoSocket{
|
||||||
|
conn: conn,
|
||||||
|
addr: server.Addr,
|
||||||
|
server: server,
|
||||||
|
replyFuncs: make(map[uint32]replyFunc),
|
||||||
|
}
|
||||||
|
socket.gotNonce.L = &socket.Mutex
|
||||||
|
if err := socket.InitialAcquire(server.Info(), timeout); err != nil {
|
||||||
|
panic("newSocket: InitialAcquire returned error: " + err.Error())
|
||||||
|
}
|
||||||
|
stats.socketsAlive(+1)
|
||||||
|
debugf("Socket %p to %s: initialized", socket, socket.addr)
|
||||||
|
socket.resetNonce()
|
||||||
|
go socket.readLoop()
|
||||||
|
return socket
|
||||||
|
}
|
||||||
|
|
||||||
|
// Server returns the server that the socket is associated with.
|
||||||
|
// It returns nil while the socket is cached in its respective server.
|
||||||
|
func (socket *mongoSocket) Server() *mongoServer {
|
||||||
|
socket.Lock()
|
||||||
|
server := socket.server
|
||||||
|
socket.Unlock()
|
||||||
|
return server
|
||||||
|
}
|
||||||
|
|
||||||
|
// ServerInfo returns details for the server at the time the socket
|
||||||
|
// was initially acquired.
|
||||||
|
func (socket *mongoSocket) ServerInfo() *mongoServerInfo {
|
||||||
|
socket.Lock()
|
||||||
|
serverInfo := socket.serverInfo
|
||||||
|
socket.Unlock()
|
||||||
|
return serverInfo
|
||||||
|
}
|
||||||
|
|
||||||
|
// InitialAcquire obtains the first reference to the socket, either
|
||||||
|
// right after the connection is made or once a recycled socket is
|
||||||
|
// being put back in use.
|
||||||
|
func (socket *mongoSocket) InitialAcquire(serverInfo *mongoServerInfo, timeout time.Duration) error {
|
||||||
|
socket.Lock()
|
||||||
|
if socket.references > 0 {
|
||||||
|
panic("Socket acquired out of cache with references")
|
||||||
|
}
|
||||||
|
if socket.dead != nil {
|
||||||
|
dead := socket.dead
|
||||||
|
socket.Unlock()
|
||||||
|
return dead
|
||||||
|
}
|
||||||
|
socket.references++
|
||||||
|
socket.serverInfo = serverInfo
|
||||||
|
socket.timeout = timeout
|
||||||
|
stats.socketsInUse(+1)
|
||||||
|
stats.socketRefs(+1)
|
||||||
|
socket.Unlock()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Acquire obtains an additional reference to the socket.
|
||||||
|
// The socket will only be recycled when it's released as many
|
||||||
|
// times as it's been acquired.
|
||||||
|
func (socket *mongoSocket) Acquire() (info *mongoServerInfo) {
|
||||||
|
socket.Lock()
|
||||||
|
if socket.references == 0 {
|
||||||
|
panic("Socket got non-initial acquire with references == 0")
|
||||||
|
}
|
||||||
|
// We'll track references to dead sockets as well.
|
||||||
|
// Caller is still supposed to release the socket.
|
||||||
|
socket.references++
|
||||||
|
stats.socketRefs(+1)
|
||||||
|
serverInfo := socket.serverInfo
|
||||||
|
socket.Unlock()
|
||||||
|
return serverInfo
|
||||||
|
}
|
||||||
|
|
||||||
|
// Release decrements a socket reference. The socket will be
|
||||||
|
// recycled once its released as many times as it's been acquired.
|
||||||
|
func (socket *mongoSocket) Release() {
|
||||||
|
socket.Lock()
|
||||||
|
if socket.references == 0 {
|
||||||
|
panic("socket.Release() with references == 0")
|
||||||
|
}
|
||||||
|
socket.references--
|
||||||
|
stats.socketRefs(-1)
|
||||||
|
if socket.references == 0 {
|
||||||
|
stats.socketsInUse(-1)
|
||||||
|
server := socket.server
|
||||||
|
socket.Unlock()
|
||||||
|
socket.LogoutAll()
|
||||||
|
// If the socket is dead server is nil.
|
||||||
|
if server != nil {
|
||||||
|
server.RecycleSocket(socket)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
socket.Unlock()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetTimeout changes the timeout used on socket operations.
|
||||||
|
func (socket *mongoSocket) SetTimeout(d time.Duration) {
|
||||||
|
socket.Lock()
|
||||||
|
socket.timeout = d
|
||||||
|
socket.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
type deadlineType int
|
||||||
|
|
||||||
|
const (
|
||||||
|
readDeadline deadlineType = 1
|
||||||
|
writeDeadline deadlineType = 2
|
||||||
|
)
|
||||||
|
|
||||||
|
func (socket *mongoSocket) updateDeadline(which deadlineType) {
|
||||||
|
var when time.Time
|
||||||
|
if socket.timeout > 0 {
|
||||||
|
when = time.Now().Add(socket.timeout)
|
||||||
|
}
|
||||||
|
whichstr := ""
|
||||||
|
switch which {
|
||||||
|
case readDeadline | writeDeadline:
|
||||||
|
whichstr = "read/write"
|
||||||
|
socket.conn.SetDeadline(when)
|
||||||
|
case readDeadline:
|
||||||
|
whichstr = "read"
|
||||||
|
socket.conn.SetReadDeadline(when)
|
||||||
|
case writeDeadline:
|
||||||
|
whichstr = "write"
|
||||||
|
socket.conn.SetWriteDeadline(when)
|
||||||
|
default:
|
||||||
|
panic("invalid parameter to updateDeadline")
|
||||||
|
}
|
||||||
|
debugf("Socket %p to %s: updated %s deadline to %s ahead (%s)", socket, socket.addr, whichstr, socket.timeout, when)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close terminates the socket use.
|
||||||
|
func (socket *mongoSocket) Close() {
|
||||||
|
socket.kill(errors.New("Closed explicitly"), false)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (socket *mongoSocket) kill(err error, abend bool) {
|
||||||
|
socket.Lock()
|
||||||
|
if socket.dead != nil {
|
||||||
|
debugf("Socket %p to %s: killed again: %s (previously: %s)", socket, socket.addr, err.Error(), socket.dead.Error())
|
||||||
|
socket.Unlock()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
logf("Socket %p to %s: closing: %s (abend=%v)", socket, socket.addr, err.Error(), abend)
|
||||||
|
socket.dead = err
|
||||||
|
socket.conn.Close()
|
||||||
|
stats.socketsAlive(-1)
|
||||||
|
replyFuncs := socket.replyFuncs
|
||||||
|
socket.replyFuncs = make(map[uint32]replyFunc)
|
||||||
|
server := socket.server
|
||||||
|
socket.server = nil
|
||||||
|
socket.gotNonce.Broadcast()
|
||||||
|
socket.Unlock()
|
||||||
|
for _, replyFunc := range replyFuncs {
|
||||||
|
logf("Socket %p to %s: notifying replyFunc of closed socket: %s", socket, socket.addr, err.Error())
|
||||||
|
replyFunc(err, nil, -1, nil)
|
||||||
|
}
|
||||||
|
if abend {
|
||||||
|
server.AbendSocket(socket)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (socket *mongoSocket) SimpleQuery(op *queryOp) (data []byte, err error) {
|
||||||
|
var wait, change sync.Mutex
|
||||||
|
var replyDone bool
|
||||||
|
var replyData []byte
|
||||||
|
var replyErr error
|
||||||
|
wait.Lock()
|
||||||
|
op.replyFunc = func(err error, reply *replyOp, docNum int, docData []byte) {
|
||||||
|
change.Lock()
|
||||||
|
if !replyDone {
|
||||||
|
replyDone = true
|
||||||
|
replyErr = err
|
||||||
|
if err == nil {
|
||||||
|
replyData = docData
|
||||||
|
}
|
||||||
|
}
|
||||||
|
change.Unlock()
|
||||||
|
wait.Unlock()
|
||||||
|
}
|
||||||
|
err = socket.Query(op)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
wait.Lock()
|
||||||
|
change.Lock()
|
||||||
|
data = replyData
|
||||||
|
err = replyErr
|
||||||
|
change.Unlock()
|
||||||
|
return data, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (socket *mongoSocket) Query(ops ...interface{}) (err error) {
|
||||||
|
|
||||||
|
if lops := socket.flushLogout(); len(lops) > 0 {
|
||||||
|
ops = append(lops, ops...)
|
||||||
|
}
|
||||||
|
|
||||||
|
buf := make([]byte, 0, 256)
|
||||||
|
|
||||||
|
// Serialize operations synchronously to avoid interrupting
|
||||||
|
// other goroutines while we can't really be sending data.
|
||||||
|
// Also, record id positions so that we can compute request
|
||||||
|
// ids at once later with the lock already held.
|
||||||
|
requests := make([]requestInfo, len(ops))
|
||||||
|
requestCount := 0
|
||||||
|
|
||||||
|
for _, op := range ops {
|
||||||
|
debugf("Socket %p to %s: serializing op: %#v", socket, socket.addr, op)
|
||||||
|
if qop, ok := op.(*queryOp); ok {
|
||||||
|
if cmd, ok := qop.query.(*findCmd); ok {
|
||||||
|
debugf("Socket %p to %s: find command: %#v", socket, socket.addr, cmd)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
start := len(buf)
|
||||||
|
var replyFunc replyFunc
|
||||||
|
switch op := op.(type) {
|
||||||
|
|
||||||
|
case *updateOp:
|
||||||
|
buf = addHeader(buf, 2001)
|
||||||
|
buf = addInt32(buf, 0) // Reserved
|
||||||
|
buf = addCString(buf, op.Collection)
|
||||||
|
buf = addInt32(buf, int32(op.Flags))
|
||||||
|
debugf("Socket %p to %s: serializing selector document: %#v", socket, socket.addr, op.Selector)
|
||||||
|
buf, err = addBSON(buf, op.Selector)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
debugf("Socket %p to %s: serializing update document: %#v", socket, socket.addr, op.Update)
|
||||||
|
buf, err = addBSON(buf, op.Update)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
case *insertOp:
|
||||||
|
buf = addHeader(buf, 2002)
|
||||||
|
buf = addInt32(buf, int32(op.flags))
|
||||||
|
buf = addCString(buf, op.collection)
|
||||||
|
for _, doc := range op.documents {
|
||||||
|
debugf("Socket %p to %s: serializing document for insertion: %#v", socket, socket.addr, doc)
|
||||||
|
buf, err = addBSON(buf, doc)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
case *queryOp:
|
||||||
|
buf = addHeader(buf, 2004)
|
||||||
|
buf = addInt32(buf, int32(op.flags))
|
||||||
|
buf = addCString(buf, op.collection)
|
||||||
|
buf = addInt32(buf, op.skip)
|
||||||
|
buf = addInt32(buf, op.limit)
|
||||||
|
buf, err = addBSON(buf, op.finalQuery(socket))
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if op.selector != nil {
|
||||||
|
buf, err = addBSON(buf, op.selector)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
replyFunc = op.replyFunc
|
||||||
|
|
||||||
|
case *getMoreOp:
|
||||||
|
buf = addHeader(buf, 2005)
|
||||||
|
buf = addInt32(buf, 0) // Reserved
|
||||||
|
buf = addCString(buf, op.collection)
|
||||||
|
buf = addInt32(buf, op.limit)
|
||||||
|
buf = addInt64(buf, op.cursorId)
|
||||||
|
replyFunc = op.replyFunc
|
||||||
|
|
||||||
|
case *deleteOp:
|
||||||
|
buf = addHeader(buf, 2006)
|
||||||
|
buf = addInt32(buf, 0) // Reserved
|
||||||
|
buf = addCString(buf, op.Collection)
|
||||||
|
buf = addInt32(buf, int32(op.Flags))
|
||||||
|
debugf("Socket %p to %s: serializing selector document: %#v", socket, socket.addr, op.Selector)
|
||||||
|
buf, err = addBSON(buf, op.Selector)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
case *killCursorsOp:
|
||||||
|
buf = addHeader(buf, 2007)
|
||||||
|
buf = addInt32(buf, 0) // Reserved
|
||||||
|
buf = addInt32(buf, int32(len(op.cursorIds)))
|
||||||
|
for _, cursorId := range op.cursorIds {
|
||||||
|
buf = addInt64(buf, cursorId)
|
||||||
|
}
|
||||||
|
|
||||||
|
default:
|
||||||
|
panic("internal error: unknown operation type")
|
||||||
|
}
|
||||||
|
|
||||||
|
setInt32(buf, start, int32(len(buf)-start))
|
||||||
|
|
||||||
|
if replyFunc != nil {
|
||||||
|
request := &requests[requestCount]
|
||||||
|
request.replyFunc = replyFunc
|
||||||
|
request.bufferPos = start
|
||||||
|
requestCount++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Buffer is ready for the pipe. Lock, allocate ids, and enqueue.
|
||||||
|
|
||||||
|
socket.Lock()
|
||||||
|
if socket.dead != nil {
|
||||||
|
dead := socket.dead
|
||||||
|
socket.Unlock()
|
||||||
|
debugf("Socket %p to %s: failing query, already closed: %s", socket, socket.addr, socket.dead.Error())
|
||||||
|
// XXX This seems necessary in case the session is closed concurrently
|
||||||
|
// with a query being performed, but it's not yet tested:
|
||||||
|
for i := 0; i != requestCount; i++ {
|
||||||
|
request := &requests[i]
|
||||||
|
if request.replyFunc != nil {
|
||||||
|
request.replyFunc(dead, nil, -1, nil)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return dead
|
||||||
|
}
|
||||||
|
|
||||||
|
wasWaiting := len(socket.replyFuncs) > 0
|
||||||
|
|
||||||
|
// Reserve id 0 for requests which should have no responses.
|
||||||
|
requestId := socket.nextRequestId + 1
|
||||||
|
if requestId == 0 {
|
||||||
|
requestId++
|
||||||
|
}
|
||||||
|
socket.nextRequestId = requestId + uint32(requestCount)
|
||||||
|
for i := 0; i != requestCount; i++ {
|
||||||
|
request := &requests[i]
|
||||||
|
setInt32(buf, request.bufferPos+4, int32(requestId))
|
||||||
|
socket.replyFuncs[requestId] = request.replyFunc
|
||||||
|
requestId++
|
||||||
|
}
|
||||||
|
|
||||||
|
debugf("Socket %p to %s: sending %d op(s) (%d bytes)", socket, socket.addr, len(ops), len(buf))
|
||||||
|
stats.sentOps(len(ops))
|
||||||
|
|
||||||
|
socket.updateDeadline(writeDeadline)
|
||||||
|
_, err = socket.conn.Write(buf)
|
||||||
|
if !wasWaiting && requestCount > 0 {
|
||||||
|
socket.updateDeadline(readDeadline)
|
||||||
|
}
|
||||||
|
socket.Unlock()
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func fill(r net.Conn, b []byte) error {
|
||||||
|
l := len(b)
|
||||||
|
n, err := r.Read(b)
|
||||||
|
for n != l && err == nil {
|
||||||
|
var ni int
|
||||||
|
ni, err = r.Read(b[n:])
|
||||||
|
n += ni
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Estimated minimum cost per socket: 1 goroutine + memory for the largest
|
||||||
|
// document ever seen.
|
||||||
|
func (socket *mongoSocket) readLoop() {
|
||||||
|
p := make([]byte, 36) // 16 from header + 20 from OP_REPLY fixed fields
|
||||||
|
s := make([]byte, 4)
|
||||||
|
conn := socket.conn // No locking, conn never changes.
|
||||||
|
for {
|
||||||
|
err := fill(conn, p)
|
||||||
|
if err != nil {
|
||||||
|
socket.kill(err, true)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
totalLen := getInt32(p, 0)
|
||||||
|
responseTo := getInt32(p, 8)
|
||||||
|
opCode := getInt32(p, 12)
|
||||||
|
|
||||||
|
// Don't use socket.server.Addr here. socket is not
|
||||||
|
// locked and socket.server may go away.
|
||||||
|
debugf("Socket %p to %s: got reply (%d bytes)", socket, socket.addr, totalLen)
|
||||||
|
|
||||||
|
_ = totalLen
|
||||||
|
|
||||||
|
if opCode != 1 {
|
||||||
|
socket.kill(errors.New("opcode != 1, corrupted data?"), true)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
reply := replyOp{
|
||||||
|
flags: uint32(getInt32(p, 16)),
|
||||||
|
cursorId: getInt64(p, 20),
|
||||||
|
firstDoc: getInt32(p, 28),
|
||||||
|
replyDocs: getInt32(p, 32),
|
||||||
|
}
|
||||||
|
|
||||||
|
stats.receivedOps(+1)
|
||||||
|
stats.receivedDocs(int(reply.replyDocs))
|
||||||
|
|
||||||
|
socket.Lock()
|
||||||
|
replyFunc, ok := socket.replyFuncs[uint32(responseTo)]
|
||||||
|
if ok {
|
||||||
|
delete(socket.replyFuncs, uint32(responseTo))
|
||||||
|
}
|
||||||
|
socket.Unlock()
|
||||||
|
|
||||||
|
if replyFunc != nil && reply.replyDocs == 0 {
|
||||||
|
replyFunc(nil, &reply, -1, nil)
|
||||||
|
} else {
|
||||||
|
for i := 0; i != int(reply.replyDocs); i++ {
|
||||||
|
err := fill(conn, s)
|
||||||
|
if err != nil {
|
||||||
|
if replyFunc != nil {
|
||||||
|
replyFunc(err, nil, -1, nil)
|
||||||
|
}
|
||||||
|
socket.kill(err, true)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
b := make([]byte, int(getInt32(s, 0)))
|
||||||
|
|
||||||
|
// copy(b, s) in an efficient way.
|
||||||
|
b[0] = s[0]
|
||||||
|
b[1] = s[1]
|
||||||
|
b[2] = s[2]
|
||||||
|
b[3] = s[3]
|
||||||
|
|
||||||
|
err = fill(conn, b[4:])
|
||||||
|
if err != nil {
|
||||||
|
if replyFunc != nil {
|
||||||
|
replyFunc(err, nil, -1, nil)
|
||||||
|
}
|
||||||
|
socket.kill(err, true)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if globalDebug && globalLogger != nil {
|
||||||
|
m := bson.M{}
|
||||||
|
if err := bson.Unmarshal(b, m); err == nil {
|
||||||
|
debugf("Socket %p to %s: received document: %#v", socket, socket.addr, m)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if replyFunc != nil {
|
||||||
|
replyFunc(nil, &reply, i, b)
|
||||||
|
}
|
||||||
|
|
||||||
|
// XXX Do bound checking against totalLen.
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
socket.Lock()
|
||||||
|
if len(socket.replyFuncs) == 0 {
|
||||||
|
// Nothing else to read for now. Disable deadline.
|
||||||
|
socket.conn.SetReadDeadline(time.Time{})
|
||||||
|
} else {
|
||||||
|
socket.updateDeadline(readDeadline)
|
||||||
|
}
|
||||||
|
socket.Unlock()
|
||||||
|
|
||||||
|
// XXX Do bound checking against totalLen.
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var emptyHeader = []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}
|
||||||
|
|
||||||
|
func addHeader(b []byte, opcode int) []byte {
|
||||||
|
i := len(b)
|
||||||
|
b = append(b, emptyHeader...)
|
||||||
|
// Enough for current opcodes.
|
||||||
|
b[i+12] = byte(opcode)
|
||||||
|
b[i+13] = byte(opcode >> 8)
|
||||||
|
return b
|
||||||
|
}
|
||||||
|
|
||||||
|
func addInt32(b []byte, i int32) []byte {
|
||||||
|
return append(b, byte(i), byte(i>>8), byte(i>>16), byte(i>>24))
|
||||||
|
}
|
||||||
|
|
||||||
|
func addInt64(b []byte, i int64) []byte {
|
||||||
|
return append(b, byte(i), byte(i>>8), byte(i>>16), byte(i>>24),
|
||||||
|
byte(i>>32), byte(i>>40), byte(i>>48), byte(i>>56))
|
||||||
|
}
|
||||||
|
|
||||||
|
func addCString(b []byte, s string) []byte {
|
||||||
|
b = append(b, []byte(s)...)
|
||||||
|
b = append(b, 0)
|
||||||
|
return b
|
||||||
|
}
|
||||||
|
|
||||||
|
func addBSON(b []byte, doc interface{}) ([]byte, error) {
|
||||||
|
if doc == nil {
|
||||||
|
return append(b, 5, 0, 0, 0, 0), nil
|
||||||
|
}
|
||||||
|
data, err := bson.Marshal(doc)
|
||||||
|
if err != nil {
|
||||||
|
return b, err
|
||||||
|
}
|
||||||
|
return append(b, data...), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func setInt32(b []byte, pos int, i int32) {
|
||||||
|
b[pos] = byte(i)
|
||||||
|
b[pos+1] = byte(i >> 8)
|
||||||
|
b[pos+2] = byte(i >> 16)
|
||||||
|
b[pos+3] = byte(i >> 24)
|
||||||
|
}
|
||||||
|
|
||||||
|
func getInt32(b []byte, pos int) int32 {
|
||||||
|
return (int32(b[pos+0])) |
|
||||||
|
(int32(b[pos+1]) << 8) |
|
||||||
|
(int32(b[pos+2]) << 16) |
|
||||||
|
(int32(b[pos+3]) << 24)
|
||||||
|
}
|
||||||
|
|
||||||
|
func getInt64(b []byte, pos int) int64 {
|
||||||
|
return (int64(b[pos+0])) |
|
||||||
|
(int64(b[pos+1]) << 8) |
|
||||||
|
(int64(b[pos+2]) << 16) |
|
||||||
|
(int64(b[pos+3]) << 24) |
|
||||||
|
(int64(b[pos+4]) << 32) |
|
||||||
|
(int64(b[pos+5]) << 40) |
|
||||||
|
(int64(b[pos+6]) << 48) |
|
||||||
|
(int64(b[pos+7]) << 56)
|
||||||
|
}
|
|
@ -0,0 +1,147 @@
|
||||||
|
// mgo - MongoDB driver for Go
|
||||||
|
//
|
||||||
|
// Copyright (c) 2010-2012 - Gustavo Niemeyer <gustavo@niemeyer.net>
|
||||||
|
//
|
||||||
|
// All rights reserved.
|
||||||
|
//
|
||||||
|
// Redistribution and use in source and binary forms, with or without
|
||||||
|
// modification, are permitted provided that the following conditions are met:
|
||||||
|
//
|
||||||
|
// 1. Redistributions of source code must retain the above copyright notice, this
|
||||||
|
// list of conditions and the following disclaimer.
|
||||||
|
// 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||||
|
// this list of conditions and the following disclaimer in the documentation
|
||||||
|
// and/or other materials provided with the distribution.
|
||||||
|
//
|
||||||
|
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
|
||||||
|
// ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
|
||||||
|
// WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||||
|
// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
|
||||||
|
// ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
|
||||||
|
// (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
|
||||||
|
// LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
|
||||||
|
// ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||||
|
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||||
|
// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||||
|
|
||||||
|
package mgo
|
||||||
|
|
||||||
|
import (
|
||||||
|
"sync"
|
||||||
|
)
|
||||||
|
|
||||||
|
var stats *Stats
|
||||||
|
var statsMutex sync.Mutex
|
||||||
|
|
||||||
|
func SetStats(enabled bool) {
|
||||||
|
statsMutex.Lock()
|
||||||
|
if enabled {
|
||||||
|
if stats == nil {
|
||||||
|
stats = &Stats{}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
stats = nil
|
||||||
|
}
|
||||||
|
statsMutex.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetStats() (snapshot Stats) {
|
||||||
|
statsMutex.Lock()
|
||||||
|
snapshot = *stats
|
||||||
|
statsMutex.Unlock()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func ResetStats() {
|
||||||
|
statsMutex.Lock()
|
||||||
|
debug("Resetting stats")
|
||||||
|
old := stats
|
||||||
|
stats = &Stats{}
|
||||||
|
// These are absolute values:
|
||||||
|
stats.Clusters = old.Clusters
|
||||||
|
stats.SocketsInUse = old.SocketsInUse
|
||||||
|
stats.SocketsAlive = old.SocketsAlive
|
||||||
|
stats.SocketRefs = old.SocketRefs
|
||||||
|
statsMutex.Unlock()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
type Stats struct {
|
||||||
|
Clusters int
|
||||||
|
MasterConns int
|
||||||
|
SlaveConns int
|
||||||
|
SentOps int
|
||||||
|
ReceivedOps int
|
||||||
|
ReceivedDocs int
|
||||||
|
SocketsAlive int
|
||||||
|
SocketsInUse int
|
||||||
|
SocketRefs int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (stats *Stats) cluster(delta int) {
|
||||||
|
if stats != nil {
|
||||||
|
statsMutex.Lock()
|
||||||
|
stats.Clusters += delta
|
||||||
|
statsMutex.Unlock()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (stats *Stats) conn(delta int, master bool) {
|
||||||
|
if stats != nil {
|
||||||
|
statsMutex.Lock()
|
||||||
|
if master {
|
||||||
|
stats.MasterConns += delta
|
||||||
|
} else {
|
||||||
|
stats.SlaveConns += delta
|
||||||
|
}
|
||||||
|
statsMutex.Unlock()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (stats *Stats) sentOps(delta int) {
|
||||||
|
if stats != nil {
|
||||||
|
statsMutex.Lock()
|
||||||
|
stats.SentOps += delta
|
||||||
|
statsMutex.Unlock()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (stats *Stats) receivedOps(delta int) {
|
||||||
|
if stats != nil {
|
||||||
|
statsMutex.Lock()
|
||||||
|
stats.ReceivedOps += delta
|
||||||
|
statsMutex.Unlock()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (stats *Stats) receivedDocs(delta int) {
|
||||||
|
if stats != nil {
|
||||||
|
statsMutex.Lock()
|
||||||
|
stats.ReceivedDocs += delta
|
||||||
|
statsMutex.Unlock()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (stats *Stats) socketsInUse(delta int) {
|
||||||
|
if stats != nil {
|
||||||
|
statsMutex.Lock()
|
||||||
|
stats.SocketsInUse += delta
|
||||||
|
statsMutex.Unlock()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (stats *Stats) socketsAlive(delta int) {
|
||||||
|
if stats != nil {
|
||||||
|
statsMutex.Lock()
|
||||||
|
stats.SocketsAlive += delta
|
||||||
|
statsMutex.Unlock()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (stats *Stats) socketRefs(delta int) {
|
||||||
|
if stats != nil {
|
||||||
|
statsMutex.Lock()
|
||||||
|
stats.SocketRefs += delta
|
||||||
|
statsMutex.Unlock()
|
||||||
|
}
|
||||||
|
}
|
|
@ -54,7 +54,7 @@ github.com/joho/godotenv/autoload
|
||||||
# github.com/json-iterator/go v1.1.10
|
# github.com/json-iterator/go v1.1.10
|
||||||
## explicit
|
## explicit
|
||||||
github.com/json-iterator/go
|
github.com/json-iterator/go
|
||||||
# github.com/koding/cache v0.0.0-20140912085602-487fc0ca06f9
|
# github.com/koding/cache v0.0.0-20161222233015-e8a81b0b3f20
|
||||||
## explicit
|
## explicit
|
||||||
github.com/koding/cache
|
github.com/koding/cache
|
||||||
# github.com/leodido/go-urn v1.2.0
|
# github.com/leodido/go-urn v1.2.0
|
||||||
|
@ -156,6 +156,13 @@ google.golang.org/protobuf/runtime/protoimpl
|
||||||
# gopkg.in/gorp.v1 v1.7.1
|
# gopkg.in/gorp.v1 v1.7.1
|
||||||
## explicit
|
## explicit
|
||||||
gopkg.in/gorp.v1
|
gopkg.in/gorp.v1
|
||||||
|
# gopkg.in/mgo.v2 v2.0.0-20190816093944-a6b53ec6cb22
|
||||||
|
## explicit
|
||||||
|
gopkg.in/mgo.v2
|
||||||
|
gopkg.in/mgo.v2/bson
|
||||||
|
gopkg.in/mgo.v2/internal/json
|
||||||
|
gopkg.in/mgo.v2/internal/sasl
|
||||||
|
gopkg.in/mgo.v2/internal/scram
|
||||||
# gopkg.in/yaml.v2 v2.4.0
|
# gopkg.in/yaml.v2 v2.4.0
|
||||||
## explicit
|
## explicit
|
||||||
gopkg.in/yaml.v2
|
gopkg.in/yaml.v2
|
||||||
|
|
Загрузка…
Ссылка в новой задаче