Skip to content

Commit cc39db7

Browse files
authored
Merge pull request #1382 from mattn/codex/sqlite3-bind-fastpath
[codex] optimize sqlite bind fast path
2 parents edadafa + 9a908a9 commit cc39db7

1 file changed

Lines changed: 96 additions & 39 deletions

File tree

sqlite3.go

Lines changed: 96 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -866,26 +866,12 @@ func (c *SQLiteConn) exec(ctx context.Context, query string, args []driver.Named
866866
}
867867
var res driver.Result
868868
if s.(*SQLiteStmt).s != nil {
869-
stmtArgs := make([]driver.NamedValue, 0, len(args))
870869
na := s.NumInput()
871870
if len(args)-start < na {
872871
s.Close()
873872
return nil, fmt.Errorf("not enough args to execute query: want %d got %d", na, len(args))
874873
}
875-
// consume the number of arguments used in the current
876-
// statement and append all named arguments not
877-
// contained therein
878-
if na > 0 {
879-
stmtArgs = append(stmtArgs, args[start:start+na]...)
880-
for i := range args {
881-
if (i < start || i >= na) && args[i].Name != "" {
882-
stmtArgs = append(stmtArgs, args[i])
883-
}
884-
}
885-
for i := range stmtArgs {
886-
stmtArgs[i].Ordinal = i + 1
887-
}
888-
}
874+
stmtArgs := stmtArgs(args, start, na)
889875
res, err = s.(*SQLiteStmt).exec(ctx, stmtArgs)
890876
if err != nil && err != driver.ErrSkip {
891877
s.Close()
@@ -921,7 +907,6 @@ func (c *SQLiteConn) Query(query string, args []driver.Value) (driver.Rows, erro
921907
func (c *SQLiteConn) query(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) {
922908
start := 0
923909
for {
924-
stmtArgs := make([]driver.NamedValue, 0, len(args))
925910
s, err := c.prepare(ctx, query)
926911
if err != nil {
927912
return nil, err
@@ -932,18 +917,7 @@ func (c *SQLiteConn) query(ctx context.Context, query string, args []driver.Name
932917
s.Close()
933918
return nil, fmt.Errorf("not enough args to execute query: want %d got %d", na, len(args)-start)
934919
}
935-
// consume the number of arguments used in the current
936-
// statement and append all named arguments not contained
937-
// therein
938-
stmtArgs = append(stmtArgs, args[start:start+na]...)
939-
for i := range args {
940-
if (i < start || i >= na) && args[i].Name != "" {
941-
stmtArgs = append(stmtArgs, args[i])
942-
}
943-
}
944-
for i := range stmtArgs {
945-
stmtArgs[i].Ordinal = i + 1
946-
}
920+
stmtArgs := stmtArgs(args, start, na)
947921
rows, err := s.(*SQLiteStmt).query(ctx, stmtArgs)
948922
if err != nil && err != driver.ErrSkip {
949923
s.Close()
@@ -1957,6 +1931,36 @@ func (s *SQLiteStmt) NumInput() int {
19571931

19581932
var placeHolder = []byte{0}
19591933

1934+
func stmtArgs(args []driver.NamedValue, start, na int) []driver.NamedValue {
1935+
if na == 0 {
1936+
return nil
1937+
}
1938+
1939+
end := start + na
1940+
hasNamedOutside := false
1941+
for i := range args {
1942+
if args[i].Name != "" && (i < start || i >= end) {
1943+
hasNamedOutside = true
1944+
break
1945+
}
1946+
}
1947+
if start == 0 && !hasNamedOutside {
1948+
return args[start:end]
1949+
}
1950+
1951+
stmtArgs := make([]driver.NamedValue, 0, len(args))
1952+
stmtArgs = append(stmtArgs, args[start:end]...)
1953+
for i := range args {
1954+
if args[i].Name != "" && (i < start || i >= end) {
1955+
stmtArgs = append(stmtArgs, args[i])
1956+
}
1957+
}
1958+
for i := range stmtArgs {
1959+
stmtArgs[i].Ordinal = i + 1
1960+
}
1961+
return stmtArgs
1962+
}
1963+
19601964
func (s *SQLiteStmt) bind(args []driver.NamedValue) error {
19611965
rv := C.sqlite3_reset(s.s)
19621966
if rv != C.SQLITE_ROW && rv != C.SQLITE_OK && rv != C.SQLITE_DONE {
@@ -1965,26 +1969,79 @@ func (s *SQLiteStmt) bind(args []driver.NamedValue) error {
19651969

19661970
C.sqlite3_clear_bindings(s.s)
19671971

1972+
hasNamed := false
1973+
for i := range args {
1974+
if args[i].Name != "" {
1975+
hasNamed = true
1976+
break
1977+
}
1978+
}
1979+
1980+
if !hasNamed {
1981+
for _, arg := range args {
1982+
n := C.int(arg.Ordinal)
1983+
switch v := arg.Value.(type) {
1984+
case nil:
1985+
rv = C.sqlite3_bind_null(s.s, n)
1986+
case string:
1987+
if len(v) == 0 {
1988+
rv = C._sqlite3_bind_text(s.s, n, (*C.char)(unsafe.Pointer(&placeHolder[0])), C.int(0))
1989+
} else {
1990+
b := []byte(v)
1991+
rv = C._sqlite3_bind_text(s.s, n, (*C.char)(unsafe.Pointer(&b[0])), C.int(len(b)))
1992+
}
1993+
case int64:
1994+
rv = C.sqlite3_bind_int64(s.s, n, C.sqlite3_int64(v))
1995+
case bool:
1996+
if v {
1997+
rv = C.sqlite3_bind_int(s.s, n, 1)
1998+
} else {
1999+
rv = C.sqlite3_bind_int(s.s, n, 0)
2000+
}
2001+
case float64:
2002+
rv = C.sqlite3_bind_double(s.s, n, C.double(v))
2003+
case []byte:
2004+
if v == nil {
2005+
rv = C.sqlite3_bind_null(s.s, n)
2006+
} else {
2007+
ln := len(v)
2008+
if ln == 0 {
2009+
v = placeHolder
2010+
}
2011+
rv = C._sqlite3_bind_blob(s.s, n, unsafe.Pointer(&v[0]), C.int(ln))
2012+
}
2013+
case time.Time:
2014+
b := []byte(v.Format(SQLiteTimestampFormats[0]))
2015+
rv = C._sqlite3_bind_text(s.s, n, (*C.char)(unsafe.Pointer(&b[0])), C.int(len(b)))
2016+
}
2017+
if rv != C.SQLITE_OK {
2018+
return s.c.lastError()
2019+
}
2020+
}
2021+
return nil
2022+
}
2023+
19682024
bindIndices := make([][3]int, len(args))
1969-
prefixes := []string{":", "@", "$"}
2025+
prefixes := [3]string{":", "@", "$"}
19702026
for i, v := range args {
19712027
bindIndices[i][0] = v.Ordinal
1972-
if v.Name != "" {
1973-
for j := range prefixes {
1974-
cname := C.CString(prefixes[j] + v.Name)
1975-
bindIndices[i][j] = int(C.sqlite3_bind_parameter_index(s.s, cname))
1976-
C.free(unsafe.Pointer(cname))
1977-
}
1978-
args[i].Ordinal = bindIndices[i][0]
2028+
if v.Name == "" {
2029+
continue
2030+
}
2031+
for j := range prefixes {
2032+
cname := C.CString(prefixes[j] + v.Name)
2033+
bindIndices[i][j] = int(C.sqlite3_bind_parameter_index(s.s, cname))
2034+
C.free(unsafe.Pointer(cname))
19792035
}
2036+
args[i].Ordinal = bindIndices[i][0]
19802037
}
19812038

19822039
for i, arg := range args {
1983-
for j := range bindIndices[i] {
1984-
if bindIndices[i][j] == 0 {
2040+
for _, idx := range bindIndices[i] {
2041+
if idx == 0 {
19852042
continue
19862043
}
1987-
n := C.int(bindIndices[i][j])
2044+
n := C.int(idx)
19882045
switch v := arg.Value.(type) {
19892046
case nil:
19902047
rv = C.sqlite3_bind_null(s.s, n)

0 commit comments

Comments
 (0)