mirror of
https://codeberg.org/forgejo/forgejo.git
synced 2025-01-12 15:49:28 -05:00
792b4dba2c
* update github.com/blevesearch/bleve v2.0.2 -> v2.0.3 * github.com/denisenkom/go-mssqldb v0.9.0 -> v0.10.0 * github.com/editorconfig/editorconfig-core-go v2.4.1 -> v2.4.2 * github.com/go-chi/cors v1.1.1 -> v1.2.0 * github.com/go-git/go-billy v5.0.0 -> v5.1.0 * github.com/go-git/go-git v5.2.0 -> v5.3.0 * github.com/go-ldap/ldap v3.2.4 -> v3.3.0 * github.com/go-redis/redis v8.6.0 -> v8.8.2 * github.com/go-sql-driver/mysql v1.5.0 -> v1.6.0 * github.com/go-swagger/go-swagger v0.26.1 -> v0.27.0 * github.com/lib/pq v1.9.0 -> v1.10.1 * github.com/mattn/go-sqlite3 v1.14.6 -> v1.14.7 * github.com/go-testfixtures/testfixtures v3.5.0 -> v3.6.0 * github.com/issue9/identicon v1.0.1 -> v1.2.0 * github.com/klauspost/compress v1.11.8 -> v1.12.1 * github.com/mgechev/revive v1.0.3 -> v1.0.6 * github.com/microcosm-cc/bluemonday v1.0.7 -> v1.0.8 * github.com/niklasfasching/go-org v1.4.0 -> v1.5.0 * github.com/olivere/elastic v7.0.22 -> v7.0.24 * github.com/pelletier/go-toml v1.8.1 -> v1.9.0 * github.com/prometheus/client_golang v1.9.0 -> v1.10.0 * github.com/xanzy/go-gitlab v0.44.0 -> v0.48.0 * github.com/yuin/goldmark v1.3.3 -> v1.3.5 * github.com/6543/go-version v1.2.4 -> v1.3.1 * do github.com/lib/pq v1.10.0 -> v1.10.1 again ...
500 lines
12 KiB
Go
Vendored
500 lines
12 KiB
Go
Vendored
package mssql
|
|
|
|
import (
|
|
"fmt"
|
|
"net"
|
|
"net/url"
|
|
"os"
|
|
"strconv"
|
|
"strings"
|
|
"time"
|
|
"unicode"
|
|
)
|
|
|
|
const defaultServerPort = 1433
|
|
|
|
type connectParams struct {
|
|
logFlags uint64
|
|
port uint64
|
|
host string
|
|
instance string
|
|
database string
|
|
user string
|
|
password string
|
|
dial_timeout time.Duration
|
|
conn_timeout time.Duration
|
|
keepAlive time.Duration
|
|
encrypt bool
|
|
disableEncryption bool
|
|
trustServerCertificate bool
|
|
certificate string
|
|
hostInCertificate string
|
|
hostInCertificateProvided bool
|
|
serverSPN string
|
|
workstation string
|
|
appname string
|
|
typeFlags uint8
|
|
failOverPartner string
|
|
failOverPort uint64
|
|
packetSize uint16
|
|
fedAuthLibrary int
|
|
fedAuthADALWorkflow byte
|
|
}
|
|
|
|
// default packet size for TDS buffer
|
|
const defaultPacketSize = 4096
|
|
|
|
func parseConnectParams(dsn string) (connectParams, error) {
|
|
p := connectParams{
|
|
fedAuthLibrary: fedAuthLibraryReserved,
|
|
}
|
|
|
|
var params map[string]string
|
|
if strings.HasPrefix(dsn, "odbc:") {
|
|
parameters, err := splitConnectionStringOdbc(dsn[len("odbc:"):])
|
|
if err != nil {
|
|
return p, err
|
|
}
|
|
params = parameters
|
|
} else if strings.HasPrefix(dsn, "sqlserver://") {
|
|
parameters, err := splitConnectionStringURL(dsn)
|
|
if err != nil {
|
|
return p, err
|
|
}
|
|
params = parameters
|
|
} else {
|
|
params = splitConnectionString(dsn)
|
|
}
|
|
|
|
strlog, ok := params["log"]
|
|
if ok {
|
|
var err error
|
|
p.logFlags, err = strconv.ParseUint(strlog, 10, 64)
|
|
if err != nil {
|
|
return p, fmt.Errorf("invalid log parameter '%s': %s", strlog, err.Error())
|
|
}
|
|
}
|
|
server := params["server"]
|
|
parts := strings.SplitN(server, `\`, 2)
|
|
p.host = parts[0]
|
|
if p.host == "." || strings.ToUpper(p.host) == "(LOCAL)" || p.host == "" {
|
|
p.host = "localhost"
|
|
}
|
|
if len(parts) > 1 {
|
|
p.instance = parts[1]
|
|
}
|
|
p.database = params["database"]
|
|
p.user = params["user id"]
|
|
p.password = params["password"]
|
|
|
|
p.port = 0
|
|
strport, ok := params["port"]
|
|
if ok {
|
|
var err error
|
|
p.port, err = strconv.ParseUint(strport, 10, 16)
|
|
if err != nil {
|
|
f := "invalid tcp port '%v': %v"
|
|
return p, fmt.Errorf(f, strport, err.Error())
|
|
}
|
|
}
|
|
|
|
// https://docs.microsoft.com/en-us/sql/database-engine/configure-windows/configure-the-network-packet-size-server-configuration-option
|
|
p.packetSize = defaultPacketSize
|
|
strpsize, ok := params["packet size"]
|
|
if ok {
|
|
var err error
|
|
psize, err := strconv.ParseUint(strpsize, 0, 16)
|
|
if err != nil {
|
|
f := "invalid packet size '%v': %v"
|
|
return p, fmt.Errorf(f, strpsize, err.Error())
|
|
}
|
|
|
|
// Ensure packet size falls within the TDS protocol range of 512 to 32767 bytes
|
|
// NOTE: Encrypted connections have a maximum size of 16383 bytes. If you request
|
|
// a higher packet size, the server will respond with an ENVCHANGE request to
|
|
// alter the packet size to 16383 bytes.
|
|
p.packetSize = uint16(psize)
|
|
if p.packetSize < 512 {
|
|
p.packetSize = 512
|
|
} else if p.packetSize > 32767 {
|
|
p.packetSize = 32767
|
|
}
|
|
}
|
|
|
|
// https://msdn.microsoft.com/en-us/library/dd341108.aspx
|
|
//
|
|
// Do not set a connection timeout. Use Context to manage such things.
|
|
// Default to zero, but still allow it to be set.
|
|
if strconntimeout, ok := params["connection timeout"]; ok {
|
|
timeout, err := strconv.ParseUint(strconntimeout, 10, 64)
|
|
if err != nil {
|
|
f := "invalid connection timeout '%v': %v"
|
|
return p, fmt.Errorf(f, strconntimeout, err.Error())
|
|
}
|
|
p.conn_timeout = time.Duration(timeout) * time.Second
|
|
}
|
|
p.dial_timeout = 15 * time.Second
|
|
if strdialtimeout, ok := params["dial timeout"]; ok {
|
|
timeout, err := strconv.ParseUint(strdialtimeout, 10, 64)
|
|
if err != nil {
|
|
f := "invalid dial timeout '%v': %v"
|
|
return p, fmt.Errorf(f, strdialtimeout, err.Error())
|
|
}
|
|
p.dial_timeout = time.Duration(timeout) * time.Second
|
|
}
|
|
|
|
// default keep alive should be 30 seconds according to spec:
|
|
// https://msdn.microsoft.com/en-us/library/dd341108.aspx
|
|
p.keepAlive = 30 * time.Second
|
|
if keepAlive, ok := params["keepalive"]; ok {
|
|
timeout, err := strconv.ParseUint(keepAlive, 10, 64)
|
|
if err != nil {
|
|
f := "invalid keepAlive value '%s': %s"
|
|
return p, fmt.Errorf(f, keepAlive, err.Error())
|
|
}
|
|
p.keepAlive = time.Duration(timeout) * time.Second
|
|
}
|
|
encrypt, ok := params["encrypt"]
|
|
if ok {
|
|
if strings.EqualFold(encrypt, "DISABLE") {
|
|
p.disableEncryption = true
|
|
} else {
|
|
var err error
|
|
p.encrypt, err = strconv.ParseBool(encrypt)
|
|
if err != nil {
|
|
f := "invalid encrypt '%s': %s"
|
|
return p, fmt.Errorf(f, encrypt, err.Error())
|
|
}
|
|
}
|
|
} else {
|
|
p.trustServerCertificate = true
|
|
}
|
|
trust, ok := params["trustservercertificate"]
|
|
if ok {
|
|
var err error
|
|
p.trustServerCertificate, err = strconv.ParseBool(trust)
|
|
if err != nil {
|
|
f := "invalid trust server certificate '%s': %s"
|
|
return p, fmt.Errorf(f, trust, err.Error())
|
|
}
|
|
}
|
|
p.certificate = params["certificate"]
|
|
p.hostInCertificate, ok = params["hostnameincertificate"]
|
|
if ok {
|
|
p.hostInCertificateProvided = true
|
|
} else {
|
|
p.hostInCertificate = p.host
|
|
p.hostInCertificateProvided = false
|
|
}
|
|
|
|
serverSPN, ok := params["serverspn"]
|
|
if ok {
|
|
p.serverSPN = serverSPN
|
|
} else {
|
|
p.serverSPN = generateSpn(p.host, resolveServerPort(p.port))
|
|
}
|
|
|
|
workstation, ok := params["workstation id"]
|
|
if ok {
|
|
p.workstation = workstation
|
|
} else {
|
|
workstation, err := os.Hostname()
|
|
if err == nil {
|
|
p.workstation = workstation
|
|
}
|
|
}
|
|
|
|
appname, ok := params["app name"]
|
|
if !ok {
|
|
appname = "go-mssqldb"
|
|
}
|
|
p.appname = appname
|
|
|
|
appintent, ok := params["applicationintent"]
|
|
if ok {
|
|
if appintent == "ReadOnly" {
|
|
if p.database == "" {
|
|
return p, fmt.Errorf("database must be specified when ApplicationIntent is ReadOnly")
|
|
}
|
|
p.typeFlags |= fReadOnlyIntent
|
|
}
|
|
}
|
|
|
|
failOverPartner, ok := params["failoverpartner"]
|
|
if ok {
|
|
p.failOverPartner = failOverPartner
|
|
}
|
|
|
|
failOverPort, ok := params["failoverport"]
|
|
if ok {
|
|
var err error
|
|
p.failOverPort, err = strconv.ParseUint(failOverPort, 0, 16)
|
|
if err != nil {
|
|
f := "invalid tcp port '%v': %v"
|
|
return p, fmt.Errorf(f, failOverPort, err.Error())
|
|
}
|
|
}
|
|
|
|
return p, nil
|
|
}
|
|
|
|
// convert connectionParams to url style connection string
|
|
// used mostly for testing
|
|
func (p connectParams) toUrl() *url.URL {
|
|
q := url.Values{}
|
|
if p.database != "" {
|
|
q.Add("database", p.database)
|
|
}
|
|
if p.logFlags != 0 {
|
|
q.Add("log", strconv.FormatUint(p.logFlags, 10))
|
|
}
|
|
res := url.URL{
|
|
Scheme: "sqlserver",
|
|
Host: p.host,
|
|
User: url.UserPassword(p.user, p.password),
|
|
}
|
|
if p.instance != "" {
|
|
res.Path = p.instance
|
|
}
|
|
if len(q) > 0 {
|
|
res.RawQuery = q.Encode()
|
|
}
|
|
return &res
|
|
}
|
|
|
|
func splitConnectionString(dsn string) (res map[string]string) {
|
|
res = map[string]string{}
|
|
parts := strings.Split(dsn, ";")
|
|
for _, part := range parts {
|
|
if len(part) == 0 {
|
|
continue
|
|
}
|
|
lst := strings.SplitN(part, "=", 2)
|
|
name := strings.TrimSpace(strings.ToLower(lst[0]))
|
|
if len(name) == 0 {
|
|
continue
|
|
}
|
|
var value string = ""
|
|
if len(lst) > 1 {
|
|
value = strings.TrimSpace(lst[1])
|
|
}
|
|
res[name] = value
|
|
}
|
|
return res
|
|
}
|
|
|
|
// Splits a URL of the form sqlserver://username:password@host/instance?param1=value¶m2=value
|
|
func splitConnectionStringURL(dsn string) (map[string]string, error) {
|
|
res := map[string]string{}
|
|
|
|
u, err := url.Parse(dsn)
|
|
if err != nil {
|
|
return res, err
|
|
}
|
|
|
|
if u.Scheme != "sqlserver" {
|
|
return res, fmt.Errorf("scheme %s is not recognized", u.Scheme)
|
|
}
|
|
|
|
if u.User != nil {
|
|
res["user id"] = u.User.Username()
|
|
p, exists := u.User.Password()
|
|
if exists {
|
|
res["password"] = p
|
|
}
|
|
}
|
|
|
|
host, port, err := net.SplitHostPort(u.Host)
|
|
if err != nil {
|
|
host = u.Host
|
|
}
|
|
|
|
if len(u.Path) > 0 {
|
|
res["server"] = host + "\\" + u.Path[1:]
|
|
} else {
|
|
res["server"] = host
|
|
}
|
|
|
|
if len(port) > 0 {
|
|
res["port"] = port
|
|
}
|
|
|
|
query := u.Query()
|
|
for k, v := range query {
|
|
if len(v) > 1 {
|
|
return res, fmt.Errorf("key %s provided more than once", k)
|
|
}
|
|
res[strings.ToLower(k)] = v[0]
|
|
}
|
|
|
|
return res, nil
|
|
}
|
|
|
|
// Splits a URL in the ODBC format
|
|
func splitConnectionStringOdbc(dsn string) (map[string]string, error) {
|
|
res := map[string]string{}
|
|
|
|
type parserState int
|
|
const (
|
|
// Before the start of a key
|
|
parserStateBeforeKey parserState = iota
|
|
|
|
// Inside a key
|
|
parserStateKey
|
|
|
|
// Beginning of a value. May be bare or braced
|
|
parserStateBeginValue
|
|
|
|
// Inside a bare value
|
|
parserStateBareValue
|
|
|
|
// Inside a braced value
|
|
parserStateBracedValue
|
|
|
|
// A closing brace inside a braced value.
|
|
// May be the end of the value or an escaped closing brace, depending on the next character
|
|
parserStateBracedValueClosingBrace
|
|
|
|
// After a value. Next character should be a semicolon or whitespace.
|
|
parserStateEndValue
|
|
)
|
|
|
|
var state = parserStateBeforeKey
|
|
|
|
var key string
|
|
var value string
|
|
|
|
for i, c := range dsn {
|
|
switch state {
|
|
case parserStateBeforeKey:
|
|
switch {
|
|
case c == '=':
|
|
return res, fmt.Errorf("unexpected character = at index %d. Expected start of key or semi-colon or whitespace", i)
|
|
case !unicode.IsSpace(c) && c != ';':
|
|
state = parserStateKey
|
|
key += string(c)
|
|
}
|
|
|
|
case parserStateKey:
|
|
switch c {
|
|
case '=':
|
|
key = normalizeOdbcKey(key)
|
|
state = parserStateBeginValue
|
|
|
|
case ';':
|
|
// Key without value
|
|
key = normalizeOdbcKey(key)
|
|
res[key] = value
|
|
key = ""
|
|
value = ""
|
|
state = parserStateBeforeKey
|
|
|
|
default:
|
|
key += string(c)
|
|
}
|
|
|
|
case parserStateBeginValue:
|
|
switch {
|
|
case c == '{':
|
|
state = parserStateBracedValue
|
|
case c == ';':
|
|
// Empty value
|
|
res[key] = value
|
|
key = ""
|
|
state = parserStateBeforeKey
|
|
case unicode.IsSpace(c):
|
|
// Ignore whitespace
|
|
default:
|
|
state = parserStateBareValue
|
|
value += string(c)
|
|
}
|
|
|
|
case parserStateBareValue:
|
|
if c == ';' {
|
|
res[key] = strings.TrimRightFunc(value, unicode.IsSpace)
|
|
key = ""
|
|
value = ""
|
|
state = parserStateBeforeKey
|
|
} else {
|
|
value += string(c)
|
|
}
|
|
|
|
case parserStateBracedValue:
|
|
if c == '}' {
|
|
state = parserStateBracedValueClosingBrace
|
|
} else {
|
|
value += string(c)
|
|
}
|
|
|
|
case parserStateBracedValueClosingBrace:
|
|
if c == '}' {
|
|
// Escaped closing brace
|
|
value += string(c)
|
|
state = parserStateBracedValue
|
|
continue
|
|
}
|
|
|
|
// End of braced value
|
|
res[key] = value
|
|
key = ""
|
|
value = ""
|
|
|
|
// This character is the first character past the end,
|
|
// so it needs to be parsed like the parserStateEndValue state.
|
|
state = parserStateEndValue
|
|
switch {
|
|
case c == ';':
|
|
state = parserStateBeforeKey
|
|
case unicode.IsSpace(c):
|
|
// Ignore whitespace
|
|
default:
|
|
return res, fmt.Errorf("unexpected character %c at index %d. Expected semi-colon or whitespace", c, i)
|
|
}
|
|
|
|
case parserStateEndValue:
|
|
switch {
|
|
case c == ';':
|
|
state = parserStateBeforeKey
|
|
case unicode.IsSpace(c):
|
|
// Ignore whitespace
|
|
default:
|
|
return res, fmt.Errorf("unexpected character %c at index %d. Expected semi-colon or whitespace", c, i)
|
|
}
|
|
}
|
|
}
|
|
|
|
switch state {
|
|
case parserStateBeforeKey: // Okay
|
|
case parserStateKey: // Unfinished key. Treat as key without value.
|
|
key = normalizeOdbcKey(key)
|
|
res[key] = value
|
|
case parserStateBeginValue: // Empty value
|
|
res[key] = value
|
|
case parserStateBareValue:
|
|
res[key] = strings.TrimRightFunc(value, unicode.IsSpace)
|
|
case parserStateBracedValue:
|
|
return res, fmt.Errorf("unexpected end of braced value at index %d", len(dsn))
|
|
case parserStateBracedValueClosingBrace: // End of braced value
|
|
res[key] = value
|
|
case parserStateEndValue: // Okay
|
|
}
|
|
|
|
return res, nil
|
|
}
|
|
|
|
// Normalizes the given string as an ODBC-format key
|
|
func normalizeOdbcKey(s string) string {
|
|
return strings.ToLower(strings.TrimRightFunc(s, unicode.IsSpace))
|
|
}
|
|
|
|
func resolveServerPort(port uint64) uint64 {
|
|
if port == 0 {
|
|
return defaultServerPort
|
|
}
|
|
|
|
return port
|
|
}
|
|
|
|
func generateSpn(host string, port uint64) string {
|
|
return fmt.Sprintf("MSSQLSvc/%s:%d", host, port)
|
|
}
|