database.go 2.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108
  1. // Package database 数据库连接
  2. package database
  3. import (
  4. "database/sql"
  5. "fmt"
  6. "gorm.io/gorm"
  7. gormlgger "gorm.io/gorm/logger"
  8. "github.com/runningwater/gohub/pkg/config"
  9. )
  10. // DB 数据库连接实例
  11. var DB *gorm.DB
  12. var SQLDB *sql.DB
  13. func Connect(dbConfig gorm.Dialector, _logger gormlgger.Interface) {
  14. // 使用 gorm.Open 连接数据库
  15. var err error
  16. DB, err = gorm.Open(dbConfig, &gorm.Config{
  17. Logger: _logger,
  18. })
  19. // 处理错误
  20. if err != nil {
  21. fmt.Println("数据库连接失败", err.Error())
  22. }
  23. // 获取底层的 sql.DB 实例
  24. SQLDB, err = DB.DB()
  25. if err != nil {
  26. fmt.Println("数据库连接失败", err.Error())
  27. }
  28. }
  29. func CurrentDatabase() string {
  30. return DB.Migrator().CurrentDatabase()
  31. }
  32. func DeleteAllTables() error {
  33. var err error
  34. switch config.Get("database.connection") {
  35. case "mysql":
  36. err = deleteMySQLTables()
  37. case "sqlite":
  38. err = deleteAllSQLiteTables()
  39. default:
  40. err = fmt.Errorf("不支持的数据库类型: %s", config.Get("database.connection"))
  41. }
  42. return err
  43. }
  44. func deleteMySQLTables() error {
  45. dbname := CurrentDatabase()
  46. var tables []string
  47. err := DB.Table("information_schema.tables").
  48. Where("table_schema = ?", dbname).
  49. Pluck("table_name", &tables).Error
  50. if err != nil {
  51. return err
  52. }
  53. // 暂时关闭外键检测
  54. DB.Exec("SET FOREIGN_KEY_CHECKS = 0;")
  55. // 删除所有表
  56. for _, table := range tables {
  57. if err := DB.Migrator().DropTable(table); err != nil {
  58. return err
  59. }
  60. }
  61. // 重新启用外键检测
  62. DB.Exec("SET FOREIGN_KEY_CHECKS = 1;")
  63. return nil
  64. }
  65. func deleteAllSQLiteTables() error {
  66. var tables []string
  67. err := DB.Raw("SELECT name FROM sqlite_master WHERE type='table'").
  68. Pluck("name", &tables).Error
  69. if err != nil {
  70. return err
  71. }
  72. // 删除所有表
  73. for _, table := range tables {
  74. if err := DB.Migrator().DropTable(table); err != nil {
  75. return err
  76. }
  77. }
  78. return nil
  79. }
  80. // TableName 获取表名
  81. func TableName(obj any) string {
  82. stmt := &gorm.Statement{DB: DB}
  83. err := stmt.Parse(obj)
  84. if err != nil {
  85. return ""
  86. }
  87. return stmt.Schema.Table
  88. }