-
Notifications
You must be signed in to change notification settings - Fork 2
/
DB.go
541 lines (474 loc) · 14.9 KB
/
DB.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
package drugdose
import (
"context"
"errors"
"fmt"
"os"
"strconv"
"strings"
"sync"
"time"
"github.com/hasura/go-graphql-client"
"database/sql"
// MySQL driver needed for sql module
_ "github.com/go-sql-driver/mysql"
// SQLite driver needed for sql module
_ "modernc.org/sqlite"
)
const SqliteDriver string = "sqlite"
const MysqlDriver string = "mysql"
const loggingTableName string = "userLogs"
const userSetTableName string = "userSettings"
// When this number is set as the reference ID for remembering
// a particular input, it means that it's now "forgotten"
// and there should be no attempts to "remember" any inputs.
// This is related to the RememberConfig() and ForgetConfig() functions.
const ForgetInputConfigMagicNumber string = "0"
const ActionFetchFromSource string = "fetching from source completed"
const ActionChangeUserLog string = "changing user log completed"
const ActionAddToInfoTable string = "adding to info table completed"
const ActionFetchFromPsychonautWiki string = "fetching from psychonautwiki completed"
const ActionAddToDoseTable string = "adding to dose table completed"
const ActionRemoveLogs string = "removing logs from dose table completed"
const ActionRemoveSingleDrugInfo string = "removing single drug info completed"
const ActionSetUserSettings string = "user settings change completed"
const ActionRememberDosing string = "dosing remember completed"
const ActionForgetDosing string = "dosing forgetting completed"
var NoValidSourceSel error = errors.New("no valid source selected")
var TimeoutValueEmptyError error = errors.New("timeout value is empty")
func exitProgram(printN string) {
printName(printN, "exitProgram(): Exiting")
os.Exit(1)
}
func errorCantOpenDB(path string, err error, printN string) {
printName(printN, "errorCantOpenDB(): Error opening DB:", path, ":", err)
exitProgram(printN)
}
// If err is not nil, starts a transaction rollback and returns the error
// through errChannel.
//
// This function is meant to be used within concurrently ran functions.
//
// Returns true if there's an error, false otherwise.
func handleErrRollback(err error, tx *sql.Tx, errChannel chan<- ErrorInfo,
errInfo *ErrorInfo, printN string, xtraMsg string) bool {
if err != nil {
err2 := tx.Rollback()
if err2 != nil {
errInfo.Err = fmt.Errorf("error when attempting to roll back: %s%w",
sprintName(printN, xtraMsg), err2)
if errChannel != nil {
errChannel <- *errInfo
}
return true
}
errInfo.Err = fmt.Errorf("rolling back: %s%w", sprintName(printN, xtraMsg), err)
if errChannel != nil {
errChannel <- *errInfo
}
return true
}
return false
}
// If err is not nil, starts a transaction rollback and returns the error.
//
// This function is meant to be used within sequentially ran functions.
//
// Returns the error if there's one, nil otherwise.
func handleErrRollbackSeq(err error, tx *sql.Tx, printN string, xtraMsg string) error {
if err != nil {
err2 := tx.Rollback()
if err2 != nil {
return fmt.Errorf("error when attempting to roll back: %s%w",
sprintName(printN, xtraMsg), err2)
}
return fmt.Errorf("rolling back: %s%w", sprintName(printN, xtraMsg), err)
}
return nil
}
// Make sure the input column name matches exactly with the proper names.
func checkColIsInvalid(validCols []string, gotCol string, printN string) error {
validCol := false
if gotCol != "" && gotCol != "none" && len(validCols) != 0 {
validCols := validLogCols()
for i := 0; i < len(validCols); i++ {
if gotCol == validCols[i] {
validCol = true
break
}
}
if validCol == false {
return fmt.Errorf("%s%w", sprintName(printN), InvalidColInput)
}
} else if gotCol == "" || gotCol == "none" {
return errors.New(sprintName(printN, "Empty column given."))
} else if len(validCols) == 0 {
return errors.New(sprintName(printN, "Invalid parameters when checking if column is invalid."))
}
return nil
}
type UserLog struct {
StartTime int64
Username string
EndTime int64
DrugName string
Dose float32
DoseUnits string
DrugRoute string
Cost float32
CostCurrency string
}
type UserLogsError struct {
UserLogs []UserLog
Username string
Err error
}
type DrugNamesError struct {
DrugNames []string
Username string
Err error
}
type DrugInfoError struct {
DrugI []DrugInfo
Username string
Err error
}
type UserSettingError struct {
UserSetting string
Username string
Err error
}
type LogCountError struct {
LogCount uint32
Username string
Err error
}
type AllUsersError struct {
AllUsers []string
Username string
Err error
}
type ErrorInfo struct {
Err error
Action string
Username string
}
type DrugInfo struct {
DrugName string
DrugRoute string
Threshold float32
LowDoseMin float32
LowDoseMax float32
MediumDoseMin float32
MediumDoseMax float32
HighDoseMin float32
HighDoseMax float32
DoseUnits string
OnsetMin float32
OnsetMax float32
OnsetUnits string
ComeUpMin float32
ComeUpMax float32
ComeUpUnits string
PeakMin float32
PeakMax float32
PeakUnits string
OffsetMin float32
OffsetMax float32
OffsetUnits string
TotalDurMin float32
TotalDurMax float32
TotalDurUnits string
TimeOfFetch int64
}
type SyncTimestamps struct {
LastTimestamp int64
LastUser string
Lock sync.Mutex
}
func xtrastmt(col string, logical string) string {
return logical + " " + col + " = ?"
}
func checkIfExistsDB(db *sql.DB, ctx context.Context,
col string, table string, driver string,
path string, xtrastmt []string, values ...interface{}) bool {
const printN string = "checkIfExistsDB()"
// NOTE: this doesn't cause an SQL injection, because we're not taking
// 'col' and 'table' from an user input.
stmtstr := "select " + col + " from " + table + " where " + col + " = ?"
if xtrastmt != nil {
for i := 0; i < len(xtrastmt); i++ {
stmtstr = stmtstr + " " + xtrastmt[i]
}
}
stmt, err := db.PrepareContext(ctx, stmtstr)
if err != nil {
printName(printN, "SQL error in prepare for check if exists:", err)
return false
}
defer stmt.Close()
var got string
err = stmt.QueryRowContext(ctx, values...).Scan(&got)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return false
}
printName(printN, "received weird error:", err)
return false
}
return true
}
// Uses the value of Timeout from the settings file to create a WithTimeout
// context. If no errors are found, it then returns the context to be used
// where it's needed.
func (cfg *Config) UseConfigTimeout() (context.Context, context.CancelFunc, error) {
const printN string = "UseConfigTimeout()"
if cfg.Timeout == "" || cfg.Timeout == "none" {
return nil, nil, fmt.Errorf("%s%w", sprintName(printN), TimeoutValueEmptyError)
}
gotDuration, err := time.ParseDuration(cfg.Timeout)
if err != nil {
return nil, nil, fmt.Errorf("%s%w", sprintName(printN, "time.ParseDuration(): "), err)
}
ctx, cancel := context.WithTimeout(context.Background(), gotDuration)
return ctx, cancel, nil
}
// Open a database connection using the Config struct.
//
// After calling this function, don't forget to run: defer db.Close()
//
// db being the name of the returned *sql.DB variable
//
// ctx - context to be passed to PingDB(), first passing through WithTimeout()
func (cfg *Config) OpenDBConnection(ctx context.Context) *sql.DB {
const printN string = "OpenDBConnection()"
finalPath := cfg.DBSettings[cfg.DBDriver].Path
if cfg.DBDriver == SqliteDriver {
finalPath = finalPath + cfg.DBSettings[cfg.DBDriver].Parameters
}
db, err := sql.Open(cfg.DBDriver, finalPath)
if err != nil {
errorCantOpenDB(finalPath, err, printN)
}
cfg.PingDB(db, ctx)
return db
}
// Ping verifies a connection to the database is still alive,
// establishing a connection if necessary.
//
// db - open database connection
//
// ctx - context to be passed to PingContext()
func (cfg *Config) PingDB(db *sql.DB, ctx context.Context) {
const printN string = "PingDB()"
err := db.PingContext(ctx)
if err != nil {
errorCantOpenDB(cfg.DBSettings[cfg.DBDriver].Path, err, printN)
}
}
// Function which generates and returns a query for looking up table
// names in the database.
// If tableName is empty, query returns all tables in the database.
// If tableName is not empty, query returns a specific table if it exists.
func (cfg *Config) getTableNamesQuery(tableName string) string {
var queryStr string
andTable := ""
if cfg.DBDriver == SqliteDriver {
if tableName != "" {
andTable = " AND name = '" + tableName + "'"
}
queryStr = "SELECT name FROM sqlite_schema WHERE type='table'" + andTable
} else if cfg.DBDriver == MysqlDriver {
if tableName != "" {
andTable = " AND table_name = '" + tableName + "'"
}
dbName := strings.Split(cfg.DBSettings[cfg.DBDriver].Path, "/")
queryStr = "SELECT table_name FROM information_schema.tables WHERE table_schema = '" +
dbName[1] + "'" + andTable
}
return queryStr
}
// CheckTables returns true if a table exists in the database with the name
// tableName. It returns false in case of error or if the table isn't found.
// If tableName is empty it will search for all tables in the database.
//
// db - open database connection
//
// ctx - context to be passed to sql queries
//
// tableName - name of table to check if it exists
func (cfg *Config) CheckTables(db *sql.DB, ctx context.Context, tableName string) bool {
const printN string = "CheckTables()"
queryStr := cfg.getTableNamesQuery(tableName)
rows, err := db.QueryContext(ctx, queryStr)
if err != nil {
printName(printN, err)
return false
}
defer rows.Close()
var tableList []string
for rows.Next() {
var name string
err = rows.Scan(&name)
if err != nil {
printName(printN, err)
return false
}
tableList = append(tableList, name)
}
return len(tableList) != 0
}
// FetchFromSource goes through all source names and picks the proper
// function for fetching drug information. The information is automatically
// added to the proper info table depending on the Config struct.
//
// db - open database connection
//
// ctx - context to be passed to sql queries
//
// errChannel - the gorouting channel which returns the errors
// (set to nil if function doesn't need to be concurrent)
//
// drugname - the name of the substance to fetch information for
//
// username - the user requesting the fetch
//
// xtraNeeded - these are values dependent on the configured source, the order
// in which they're given also matters, the order in which every source is
// described, is the one in which the values should be given
//
// for psychonautwiki: the initialised structure for the graphql client,
// best done using InitGraphqlClient(), but can be done manually if needed
func (cfg *Config) FetchFromSource(db *sql.DB, ctx context.Context,
errChannel chan<- ErrorInfo, drugname string, username string,
xtraNeeded ...any) ErrorInfo {
const printN string = "FetchFromSource()"
tempErrInfo := ErrorInfo{
Err: nil,
Action: ActionFetchFromSource,
Username: username,
}
gotsrcData := GetSourceData()
printNameVerbose(cfg.VerbosePrinting,
printN, "Using API from settings.toml:", cfg.UseSource)
printNameVerbose(cfg.VerbosePrinting,
printN, "Got API URL from sources.toml:", gotsrcData[cfg.UseSource].API_ADDRESS)
if cfg.UseSource == "psychonautwiki" {
gotErrInfo := cfg.FetchPsyWiki(db, ctx, nil, drugname, xtraNeeded[0].(graphql.Client), username)
if gotErrInfo.Err != nil {
tempErrInfo.Err = fmt.Errorf("%s%w",
sprintName(printN, "While fetching from: ", cfg.UseSource, " ; error: "),
gotErrInfo.Err)
if errChannel != nil {
errChannel <- tempErrInfo
}
return tempErrInfo
}
} else {
tempErrInfo.Err = fmt.Errorf("%s%w: %s", sprintName(printN), NoValidSourceSel, cfg.UseSource)
if errChannel != nil {
errChannel <- tempErrInfo
}
return tempErrInfo
}
if errChannel != nil {
errChannel <- tempErrInfo
}
return tempErrInfo
}
// ChangeUserLog can be used to modify log data of a single log.
//
// db - open database connection
//
// ctx - context to be passed to sql queries
//
// errChannel - the gorouting channel which returns the errors
// (set to nil if function doesn't need to be concurrent)
//
// set - what log data to change, if name is invalid, InvalidColInput
// error will be send through userLogsErrorChannel or returned
//
// id - if 0 will change the newest log, else it will change the log with
// the given id
//
// username - the user who's log we're changing
//
// setValue - the new value to set
func (cfg *Config) ChangeUserLog(db *sql.DB, ctx context.Context, errChannel chan<- ErrorInfo,
set string, id int64, username string, setValue string) ErrorInfo {
const printN string = "ChangeUserLog()"
tempErrInfo := ErrorInfo{
Err: nil,
Action: ActionChangeUserLog,
Username: username,
}
err := checkColIsInvalid(validLogCols(), set, printN)
if err != nil {
tempErrInfo.Err = err
if errChannel != nil {
errChannel <- tempErrInfo
}
return tempErrInfo
}
if setValue == "now" && set == LogStartTimeCol || setValue == "now" && set == LogEndTimeCol {
setValue = strconv.FormatInt(time.Now().Unix(), 10)
}
if set == LogStartTimeCol || set == LogEndTimeCol {
if _, err := strconv.ParseInt(setValue, 10, 64); err != nil {
tempErrInfo.Err = fmt.Errorf("%s%w", sprintName(printN, "strconv.ParseInt(): "), err)
if errChannel != nil {
errChannel <- tempErrInfo
}
return tempErrInfo
}
}
if set == "dose" {
if _, err := strconv.ParseFloat(setValue, 64); err != nil {
tempErrInfo.Err = fmt.Errorf("%s%w", sprintName(printN, "strconv.ParseFloat(): "), err)
if errChannel != nil {
errChannel <- tempErrInfo
}
return tempErrInfo
}
}
var gotLogs []UserLog
var gotErr error
gotUserLogsErr := cfg.GetLogs(db, ctx, nil, 1, id, username, true, "", "")
gotErr = gotUserLogsErr.Err
if gotErr != nil {
tempErrInfo.Err = fmt.Errorf("%s%w", sprintName(printN), gotErr)
if errChannel != nil {
errChannel <- tempErrInfo
}
return tempErrInfo
}
gotLogs = gotUserLogsErr.UserLogs
id = gotLogs[0].StartTime
stmtStr := fmt.Sprintf("update "+loggingTableName+" set %s = ? where timeOfDoseStart = ?", set)
tx, err := db.BeginTx(ctx, nil)
if err != nil {
tempErrInfo.Err = fmt.Errorf("%s%w", sprintName(printN, "db.Begin(): "), err)
if errChannel != nil {
errChannel <- tempErrInfo
}
return tempErrInfo
}
stmt, err := tx.Prepare(stmtStr)
if handleErrRollback(err, tx, errChannel, &tempErrInfo, printN, "tx.Prepare(): ") {
return tempErrInfo
}
defer stmt.Close()
_, err = stmt.Exec(setValue, id)
if handleErrRollback(err, tx, errChannel, &tempErrInfo, printN, "stmt.Exec(): ") {
return tempErrInfo
}
err = tx.Commit()
if handleErrRollback(err, tx, errChannel, &tempErrInfo, printN, "tx.Commit(): ") {
return tempErrInfo
}
printName(printN, "entry:", id, "; changed:", set, "; to value:", setValue, "; for user:", username)
if errChannel != nil {
errChannel <- tempErrInfo
}
return tempErrInfo
}