Skip to content

Commit

Permalink
fix: use native ScanType from driver and enhance RowBuffer to underst…
Browse files Browse the repository at this point in the history
…and more types
  • Loading branch information
mechinn committed Aug 22, 2023
1 parent 3254d43 commit 31781fc
Show file tree
Hide file tree
Showing 4 changed files with 488 additions and 76 deletions.
129 changes: 89 additions & 40 deletions dump.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,11 @@ import (
/*
Data struct to configure dump behavior
Out: Stream to wite to
Connection: Database connection to dump
IgnoreTables: Mark sensitive tables to ignore
MaxAllowedPacket: Sets the largest packet size to use in backups
LockTables: Lock all tables for the duration of the dump
Out: Stream to wite to
Connection: Database connection to dump
IgnoreTables: Mark sensitive tables to ignore
MaxAllowedPacket: Sets the largest packet size to use in backups
LockTables: Lock all tables for the duration of the dump
*/
type Data struct {
Out io.Writer
Expand Down Expand Up @@ -68,7 +68,7 @@ const headerTmpl = `-- Go SQL Dump {{ .DumpVersion }}
/*!40101 SET @OLD_CHARACTER_SET_CLIENT=@@CHARACTER_SET_CLIENT */;
/*!40101 SET @OLD_CHARACTER_SET_RESULTS=@@CHARACTER_SET_RESULTS */;
/*!40101 SET @OLD_COLLATION_CONNECTION=@@COLLATION_CONNECTION */;
SET NAMES utf8mb4 ;
/*!50503 SET NAMES UTF8 */;
/*!40103 SET @OLD_TIME_ZONE=@@TIME_ZONE */;
/*!40103 SET TIME_ZONE='+00:00' */;
/*!40014 SET @OLD_UNIQUE_CHECKS=@@UNIQUE_CHECKS, UNIQUE_CHECKS=0 */;
Expand Down Expand Up @@ -99,7 +99,7 @@ const tableTmpl = `
DROP TABLE IF EXISTS {{ .NameEsc }};
/*!40101 SET @saved_cs_client = @@character_set_client */;
SET character_set_client = utf8mb4 ;
/*!50503 SET character_set_client = utf8mb4 */;
{{ .CreateSQL }};
/*!40101 SET character_set_client = @saved_cs_client */;
Expand Down Expand Up @@ -296,7 +296,7 @@ func (table *table) CreateSQL() (string, error) {
}

if tableReturn.String != table.Name {
return "", errors.New("Returned table is not the same as requested table")
return "", errors.New("returned table is not the same as requested table")
}

return tableSQL.String, nil
Expand Down Expand Up @@ -383,38 +383,11 @@ func (table *table) Init() error {

table.values = make([]interface{}, len(tt))
for i, tp := range tt {
table.values[i] = reflect.New(reflectColumnType(tp)).Interface()
table.values[i] = reflect.New(tp.ScanType()).Interface()
}
return nil
}

func reflectColumnType(tp *sql.ColumnType) reflect.Type {
// reflect for scanable
switch tp.ScanType().Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return reflect.TypeOf(sql.NullInt64{})
case reflect.Float32, reflect.Float64:
return reflect.TypeOf(sql.NullFloat64{})
case reflect.String:
return reflect.TypeOf(sql.NullString{})
}

// determine by name
switch tp.DatabaseTypeName() {
case "BLOB", "BINARY":
return reflect.TypeOf(sql.RawBytes{})
case "VARCHAR", "TEXT", "DECIMAL", "JSON":
return reflect.TypeOf(sql.NullString{})
case "BIGINT", "TINYINT", "INT":
return reflect.TypeOf(sql.NullInt64{})
case "DOUBLE":
return reflect.TypeOf(sql.NullFloat64{})
}

// unknown datatype
return tp.ScanType()
}

func (table *table) Next() bool {
if table.rows == nil {
if err := table.Init(); err != nil {
Expand Down Expand Up @@ -443,6 +416,30 @@ func (table *table) RowValues() string {
return table.RowBuffer().String()
}

func writeString(b *bytes.Buffer, s string) {
fmt.Fprintf(b, "'%s'", sanitize(s))
}

func writeBool(b *bytes.Buffer, s bool) {
if s {
fmt.Fprintf(b, "1")
} else {
fmt.Fprintf(b, "0")
}
}

func writeBinary(b *bytes.Buffer, s []byte) {
if len(s) == 0 {
b.WriteString(nullType)
} else {
fmt.Fprintf(b, "_binary '%s'", sanitize(string(s)))
}
}

func writeTime(b *bytes.Buffer, s time.Time) {
fmt.Fprintf(b, "'%s'", sanitize(s.UTC().Format(time.DateTime)))
}

func (table *table) RowBuffer() *bytes.Buffer {
var b bytes.Buffer
b.WriteString("(")
Expand All @@ -454,9 +451,51 @@ func (table *table) RowBuffer() *bytes.Buffer {
switch s := value.(type) {
case nil:
b.WriteString(nullType)
case *string:
writeString(&b, *s)
case *sql.NullString:
if s.Valid {
fmt.Fprintf(&b, "'%s'", sanitize(s.String))
writeString(&b, s.String)
} else {
b.WriteString(nullType)
}
case *bool:
writeBool(&b, *s)
case *sql.NullBool:
if s.Valid {
writeBool(&b, s.Bool)
} else {
b.WriteString(nullType)
}
case *uint:
fmt.Fprintf(&b, "%d", *s)
case *uint8:
fmt.Fprintf(&b, "%d", *s)
case *uint16:
fmt.Fprintf(&b, "%d", *s)
case *uint32:
fmt.Fprintf(&b, "%d", *s)
case *uint64:
fmt.Fprintf(&b, "%d", *s)
case *int:
fmt.Fprintf(&b, "%d", *s)
case *int8:
fmt.Fprintf(&b, "%d", *s)
case *int16:
fmt.Fprintf(&b, "%d", *s)
case *int32:
fmt.Fprintf(&b, "%d", *s)
case *int64:
fmt.Fprintf(&b, "%d", *s)
case *sql.NullInt16:
if s.Valid {
fmt.Fprintf(&b, "%d", s.Int16)
} else {
b.WriteString(nullType)
}
case *sql.NullInt32:
if s.Valid {
fmt.Fprintf(&b, "%d", s.Int32)
} else {
b.WriteString(nullType)
}
Expand All @@ -466,17 +505,27 @@ func (table *table) RowBuffer() *bytes.Buffer {
} else {
b.WriteString(nullType)
}
case *float32:
fmt.Fprintf(&b, "%f", *s)
case *float64:
fmt.Fprintf(&b, "%f", *s)
case *sql.NullFloat64:
if s.Valid {
fmt.Fprintf(&b, "%f", s.Float64)
} else {
b.WriteString(nullType)
}
case *[]byte:
writeBinary(&b, *s)
case *sql.RawBytes:
if len(*s) == 0 {
b.WriteString(nullType)
writeBinary(&b, *s)
case *time.Time:
writeTime(&b, *s)
case *sql.NullTime:
if s.Valid {
writeTime(&b, s.Time)
} else {
fmt.Fprintf(&b, "_binary '%s'", sanitize(string(*s)))
b.WriteString(nullType)
}
default:
fmt.Fprintf(&b, "'%s'", value)
Expand Down
10 changes: 5 additions & 5 deletions dump_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ func TestCreateTableAllValuesWithNil(t *testing.T) {
AddRow("email", "").
AddRow("name", "")

rows := sqlmock.NewRowsWithColumnDefinition(c("id", 0), c("email", ""), c("name", "")).
rows := sqlmock.NewRowsWithColumnDefinition(c("id", 0), c("email", sql.NullString{}), c("name", "")).
AddRow(1, nil, "Test Name 1").
AddRow(2, "[email protected]", "Test Name 2").
AddRow(3, "", "Test Name 3")
Expand Down Expand Up @@ -266,7 +266,7 @@ func TestCreateTableOk(t *testing.T) {
AddRow("email", "").
AddRow("name", "")

createTableValueRows := sqlmock.NewRowsWithColumnDefinition(c("id", 0), c("email", ""), c("name", "")).
createTableValueRows := sqlmock.NewRowsWithColumnDefinition(c("id", 0), c("email", sql.NullString{}), c("name", "")).
AddRow(1, nil, "Test Name 1").
AddRow(2, "[email protected]", "Test Name 2")

Expand Down Expand Up @@ -294,7 +294,7 @@ func TestCreateTableOk(t *testing.T) {
DROP TABLE IF EXISTS ~Test_Table~;
/*!40101 SET @saved_cs_client = @@character_set_client */;
SET character_set_client = utf8mb4 ;
/*!50503 SET character_set_client = utf8mb4 */;
CREATE TABLE 'Test_Table' (~id~ int(11) NOT NULL AUTO_INCREMENT,~s~ char(60) DEFAULT NULL, PRIMARY KEY (~id~))ENGINE=InnoDB DEFAULT CHARSET=latin1;
/*!40101 SET character_set_client = @saved_cs_client */;
Expand Down Expand Up @@ -325,7 +325,7 @@ func TestCreateTableOkSmallPackets(t *testing.T) {
AddRow("email", "").
AddRow("name", "")

createTableValueRows := sqlmock.NewRowsWithColumnDefinition(c("id", 0), c("email", ""), c("name", "")).
createTableValueRows := sqlmock.NewRowsWithColumnDefinition(c("id", 0), c("email", sql.NullString{}), c("name", "")).
AddRow(1, nil, "Test Name 1").
AddRow(2, "[email protected]", "Test Name 2")

Expand Down Expand Up @@ -353,7 +353,7 @@ func TestCreateTableOkSmallPackets(t *testing.T) {
DROP TABLE IF EXISTS ~Test_Table~;
/*!40101 SET @saved_cs_client = @@character_set_client */;
SET character_set_client = utf8mb4 ;
/*!50503 SET character_set_client = utf8mb4 */;
CREATE TABLE 'Test_Table' (~id~ int(11) NOT NULL AUTO_INCREMENT,~s~ char(60) DEFAULT NULL, PRIMARY KEY (~id~))ENGINE=InnoDB DEFAULT CHARSET=latin1;
/*!40101 SET character_set_client = @saved_cs_client */;
Expand Down
2 changes: 1 addition & 1 deletion mysqldump.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ Register a new dumper.
*/
func Register(db *sql.DB, dir, format string) (*Data, error) {
if !isDir(dir) {
return nil, errors.New("Invalid directory")
return nil, errors.New("invalid directory")
}

name := time.Now().Format(format)
Expand Down
Loading

0 comments on commit 31781fc

Please sign in to comment.