From 2301e09df4db963cafa45f5cd68fc5d0a422afcf Mon Sep 17 00:00:00 2001 From: Michael Chinn Date: Tue, 29 Aug 2023 20:10:39 -0400 Subject: [PATCH] workaround github.com/go-sql-driver/mysql/pull/1424 not released --- dump.go | 19 ++++++++++++++++++- mysqldump_test.go | 20 ++++++++++---------- 2 files changed, 28 insertions(+), 11 deletions(-) diff --git a/dump.go b/dump.go index 5109f50..c0e1122 100644 --- a/dump.go +++ b/dump.go @@ -383,11 +383,28 @@ func (table *table) Init() error { table.values = make([]interface{}, len(tt)) for i, tp := range tt { - table.values[i] = reflect.New(tp.ScanType()).Interface() + table.values[i] = reflect.New(reflectColumnType(tp)).Interface() } return nil } +func reflectColumnType(tp *sql.ColumnType) reflect.Type { + // workaround https://github.com/go-sql-driver/mysql/pull/1424 till it's released + nullable, _ := tp.Nullable() + switch tp.DatabaseTypeName() { + case "TINYBLOB", "MEDIUMBLOB", "LONGBLOB", "BLOB", + "VARBINARY", "BINARY", "BIT", "GEOMETRY": + return reflect.TypeOf([]byte{}) + case "TINYTEXT", "MEDIUMTEXT", "LONGTEXT", "TEXT", + "VARCHAR", "CHAR", "DECIMAL", "ENUM", "SET", "JSON", "TIME": + if nullable { + return reflect.TypeOf(sql.NullString{}) + } + return reflect.TypeOf("") + } + return tp.ScanType() +} + func (table *table) Next() bool { if table.rows == nil { if err := table.Init(); err != nil { diff --git a/mysqldump_test.go b/mysqldump_test.go index ad62f59..4d6fce7 100644 --- a/mysqldump_test.go +++ b/mysqldump_test.go @@ -65,8 +65,8 @@ CREATE TABLE 'Test_Table' ( ~NullFloat64~ DOUBLE, ~bool~ TINYINT(1) NOT NULL, ~NullBool~ TINYINT(1), - ~time~ TIME NOT NULL, - ~NullTime~ TIME, + ~time~ DATETIME NOT NULL, + ~NullTime~ DATETIME, ~varbinary~ VARBINARY, ~rawbytes~ BLOB, PRIMARY KEY (~id~) @@ -128,8 +128,8 @@ func mockColumnRows() *sqlmock.Rows { AddRow("NullFloat64", "DOUBLE", "YES", "", nil, ""). AddRow("bool", "BOOL", "NO", "", nil, ""). AddRow("NullBool", "BOOL", "YES", "", nil, ""). - AddRow("time", "TIME", "NO", "", nil, ""). - AddRow("NullTime", "TIME", "YES", "", nil, ""). + AddRow("time", "DATETIME", "NO", "", nil, ""). + AddRow("NullTime", "DATETIME", "YES", "", nil, ""). AddRow("varbinary", "VARBINARY", "YES", "", nil, ""). AddRow("rawbytes", "BLOB", "YES", "", nil, "") } @@ -185,10 +185,10 @@ func c(name string, v interface{}) *sqlmock.Column { nullable = true t = "BOOL" case time.Time: - t = "TIME" + t = "DATETIME" case sql.NullTime: nullable = true - t = "TIME" + t = "DATETIME" case []byte: nullable = true t = "VARBINARY" @@ -243,8 +243,8 @@ func RunDump(t testing.TB, data *mysqldump.Data) { ~NullFloat64~ DOUBLE, ~bool~ TINYINT(1) NOT NULL, ~NullBool~ TINYINT(1), - ~time~ TIME NOT NULL, - ~NullTime~ TIME, + ~time~ DATETIME NOT NULL, + ~NullTime~ DATETIME, ~varbinary~ VARBINARY, ~rawbytes~ BLOB, PRIMARY KEY (~id~) @@ -421,8 +421,8 @@ func TestNoLockOk(t *testing.T) { ~NullFloat64~ DOUBLE, ~bool~ TINYINT(1) NOT NULL, ~NullBool~ TINYINT(1), - ~time~ TIME NOT NULL, - ~NullTime~ TIME, + ~time~ DATETIME NOT NULL, + ~NullTime~ DATETIME, ~varbinary~ VARBINARY, ~rawbytes~ BLOB, PRIMARY KEY (~id~)