package handler import ( "database/sql" "encoding/json" "fmt" _ "github.com/go-sql-driver/mysql" "luwak/model" "net/http" ) // 存储所有的handler函数 var handlers map[string]func(http.ResponseWriter, *http.Request) // 存储所有的数据库连接 var connection map[string]*sql.DB func GetTableAndColumns() []model.Table { var rows *sql.Rows tables := make(map[string]map[string][]string, 0) defer func() { if rows == nil { return } if err := rows.Close(); err != nil { panic(err) } }() for dbName, db := range connection { tables[dbName] = make(map[string][]string, 0) var err error var query string // 查询数据库下的所有表名和字段名 query = "SELECT table_name, column_name FROM information_schema.columns WHERE table_schema = ?" rows, err = db.Query(query, dbName) if err != nil { panic(err) } for rows.Next() { var tableName, fieldName string err := rows.Scan(&tableName, &fieldName) if err != nil { panic(err) } tables[dbName][tableName] = append(tables[dbName][tableName], fieldName) } rows.Close() } return convertMapToTable(tables) } func convertMapToTable(table map[string]map[string][]string) []model.Table { var tables []model.Table for dbName, v := range table { for tableName, columns := range v { tables = append(tables, model.Table{ DbName: dbName, TableName: tableName, Columns: columns, }) } } return tables } // RegisterAll 注册所有的handler函数和数据库连接 func RegisterAll() { handlers = make(map[string]func(http.ResponseWriter, *http.Request)) connection = make(map[string]*sql.DB) //模拟从redis获取表结构 tables := []model.Table{ { TableName: "user", Columns: []string{"id", "name", "age"}, }, } for _, table := range tables { db, err := sql.Open("mysql", "root:root@tcp(127.0.0.1:3306)/"+table.TableName) if err != nil { panic(err) } connection[table.TableName] = db registerInsertHandler(table) registerQueryHandler(table) registerUpdateHandler(table) registerDeleteHandler(table) } tableAndColumns := GetTableAndColumns() fmt.Println(tableAndColumns) } // registerInsertHandler 注册一系列插入数据的处理函数 func registerInsertHandler(table model.Table) { handlers["/insert/"+table.TableName] = func(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodPost { w.WriteHeader(http.StatusMethodNotAllowed) return } //获取请求体 body := r.Body defer body.Close() param := make([]map[string]interface{}, 0) //解析请求体 err := json.NewDecoder(body).Decode(¶m) if err != nil { w.WriteHeader(http.StatusBadRequest) return } //校验请求体与表结构是否匹配 for _, p := range param { if !checkParam(p, table.Columns) { w.WriteHeader(http.StatusBadRequest) return } } for _, p := range param { //按照顺序构造column和value columns := "" values := "" for field, value := range p { columns += field + "," values += fmt.Sprintf("'%v',", value) } columns = columns[:len(columns)-1] values = values[:len(values)-1] //构造sql语句 execSql := fmt.Sprintf("insert into %s (%s) values (%s)", table.TableName, columns, values) //执行sql语句 _, err := connection[table.TableName].Exec(execSql) if err != nil { w.WriteHeader(http.StatusInternalServerError) return } } } } // registerQueryHandler 注册一系列查询数据的处理函数 func registerQueryHandler(table model.Table) { handlers["/query/"+table.TableName] = func(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodPost { w.WriteHeader(http.StatusMethodNotAllowed) return } //获取请求体 body := r.Body defer body.Close() param := make(map[string]interface{}) //解析请求体 err := json.NewDecoder(body).Decode(¶m) if err != nil { w.WriteHeader(http.StatusBadRequest) return } //校验请求体与表结构是否匹配 if !checkParam(param, table.Columns) { w.WriteHeader(http.StatusBadRequest) return } //按照顺序构造column和value where := "where" if len(param) == 0 { where = "" } else { for field, value := range param { where += fmt.Sprintf("%s = '%v' and", field, value) } //去掉最后一个and where = where[:len(where)-3] } //查询数据 db := connection[table.TableName] //拼接sql语句 querySql := fmt.Sprintf("select * from %s %s", table.TableName, where) rows, err := db.Query(querySql) if err != nil { w.WriteHeader(http.StatusInternalServerError) return } //返回数据 var result []map[string]string for rows.Next() { columns, err := rows.Columns() if err != nil { w.WriteHeader(http.StatusInternalServerError) return } values := make([]sql.RawBytes, len(columns)) scanArgs := make([]interface{}, len(values)) for i := range values { scanArgs[i] = &values[i] } err = rows.Scan(scanArgs...) if err != nil { w.WriteHeader(http.StatusInternalServerError) return } row := make(map[string]string) for i, col := range values { if col == nil { row[columns[i]] = "" } else { row[columns[i]] = string(col) } } result = append(result, row) } w.Header().Set("Content-Type", "application/json; charset=utf-8") bytes, err := json.Marshal(result) if err != nil { w.WriteHeader(http.StatusInternalServerError) return } _, err = w.Write(bytes) if err != nil { w.WriteHeader(http.StatusInternalServerError) return } } } // registerUpdateHandler 注册一系列更新数据的处理函数 func registerUpdateHandler(table model.Table) { handlers["/update/"+table.TableName] = func(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodPost { w.WriteHeader(http.StatusMethodNotAllowed) return } //获取请求体 body := r.Body defer body.Close() param := make(map[string]interface{}) //解析请求体 err := json.NewDecoder(body).Decode(¶m) if err != nil { w.WriteHeader(http.StatusBadRequest) return } //校验请求体与表结构是否匹配 if !checkParam(param, table.Columns) { w.WriteHeader(http.StatusBadRequest) return } if param["where"] == nil { w.WriteHeader(http.StatusBadRequest) return } //按照顺序构造column和value set := "" for field, value := range param { if field == "where" { continue } set += fmt.Sprintf("%s = '%v',", field, value) } //去掉最后一个逗号 set = set[:len(set)-1] //构造where m, ok := param["where"].(map[string]interface{}) if !ok { w.WriteHeader(http.StatusBadRequest) return } where := "where " for field, value := range m { where += fmt.Sprintf("%s = '%v' and", field, value) } //去掉最后一个and where = where[:len(where)-3] //更新数据 db := connection[table.TableName] //拼接sql语句 updateSql := fmt.Sprintf("update %s set %s %s", table.TableName, set, where) _, err = db.Exec(updateSql) if err != nil { w.WriteHeader(http.StatusInternalServerError) return } w.WriteHeader(http.StatusOK) } } // registerDeleteHandler 注册一系列删除数据的处理函数 func registerDeleteHandler(table model.Table) { handlers["/delete/"+table.TableName] = func(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodPost { w.WriteHeader(http.StatusMethodNotAllowed) return } //获取请求体 body := r.Body defer body.Close() param := make(map[string]interface{}) //解析请求体 err := json.NewDecoder(body).Decode(¶m) if err != nil { w.WriteHeader(http.StatusBadRequest) return } //校验请求体与表结构是否匹配 if !checkParam(param, table.Columns) { w.WriteHeader(http.StatusBadRequest) return } //构造where where := "where " if len(param) == 0 { w.WriteHeader(http.StatusBadRequest) return } else { for field, value := range param { where += fmt.Sprintf("%s = '%v' and", field, value) } //去掉最后一个and where = where[:len(where)-3] } //删除数据 db := connection[table.TableName] //拼接sql语句 deleteSql := fmt.Sprintf("delete from %s %s", table.TableName, where) _, err = db.Exec(deleteSql) if err != nil { w.WriteHeader(http.StatusInternalServerError) return } w.WriteHeader(http.StatusOK) } } func GetHandlers() map[string]func(http.ResponseWriter, *http.Request) { return handlers } func checkParam(param map[string]interface{}, cols []string) bool { for k := range param { //跳过where字段 if k == "where" || k == "limit" || k == "offset" || k == "order" || k == "having" || k == "group" { continue } ok := false for _, col := range cols { if k == col { ok = true continue } } if !ok { return false } } return true }