354 lines
8.7 KiB
Go
354 lines
8.7 KiB
Go
|
package handler
|
||
|
|
||
|
import (
|
||
|
"database/sql"
|
||
|
"encoding/json"
|
||
|
"fmt"
|
||
|
_ "github.com/go-sql-driver/mysql"
|
||
|
"luwak/model"
|
||
|
"net/http"
|
||
|
)
|
||
|
|
||
|
// 存储所有的handler函数
|
||
|
var handlers map[string]func(http.ResponseWriter, *http.Request)
|
||
|
|
||
|
// 存储所有的数据库连接
|
||
|
var connection map[string]*sql.DB
|
||
|
|
||
|
func GetTableAndColumns() []model.Table {
|
||
|
var rows *sql.Rows
|
||
|
tables := make(map[string]map[string][]string, 0)
|
||
|
defer func() {
|
||
|
if rows == nil {
|
||
|
return
|
||
|
}
|
||
|
if err := rows.Close(); err != nil {
|
||
|
panic(err)
|
||
|
}
|
||
|
}()
|
||
|
for dbName, db := range connection {
|
||
|
tables[dbName] = make(map[string][]string, 0)
|
||
|
var err error
|
||
|
var query string
|
||
|
// 查询数据库下的所有表名和字段名
|
||
|
query = "SELECT table_name, column_name FROM information_schema.columns WHERE table_schema = ?"
|
||
|
rows, err = db.Query(query, dbName)
|
||
|
if err != nil {
|
||
|
panic(err)
|
||
|
}
|
||
|
for rows.Next() {
|
||
|
var tableName, fieldName string
|
||
|
err := rows.Scan(&tableName, &fieldName)
|
||
|
if err != nil {
|
||
|
panic(err)
|
||
|
}
|
||
|
tables[dbName][tableName] = append(tables[dbName][tableName], fieldName)
|
||
|
}
|
||
|
|
||
|
rows.Close()
|
||
|
}
|
||
|
return convertMapToTable(tables)
|
||
|
}
|
||
|
|
||
|
func convertMapToTable(table map[string]map[string][]string) []model.Table {
|
||
|
var tables []model.Table
|
||
|
for dbName, v := range table {
|
||
|
for tableName, columns := range v {
|
||
|
tables = append(tables, model.Table{
|
||
|
DbName: dbName,
|
||
|
TableName: tableName,
|
||
|
Columns: columns,
|
||
|
})
|
||
|
}
|
||
|
}
|
||
|
return tables
|
||
|
}
|
||
|
|
||
|
// RegisterAll 注册所有的handler函数和数据库连接
|
||
|
func RegisterAll() {
|
||
|
handlers = make(map[string]func(http.ResponseWriter, *http.Request))
|
||
|
connection = make(map[string]*sql.DB)
|
||
|
//模拟从redis获取表结构
|
||
|
tables := []model.Table{
|
||
|
{
|
||
|
TableName: "user",
|
||
|
Columns: []string{"id", "name", "age"},
|
||
|
},
|
||
|
}
|
||
|
for _, table := range tables {
|
||
|
db, err := sql.Open("mysql", "root:root@tcp(127.0.0.1:3306)/"+table.TableName)
|
||
|
if err != nil {
|
||
|
panic(err)
|
||
|
}
|
||
|
connection[table.TableName] = db
|
||
|
registerInsertHandler(table)
|
||
|
registerQueryHandler(table)
|
||
|
registerUpdateHandler(table)
|
||
|
registerDeleteHandler(table)
|
||
|
}
|
||
|
tableAndColumns := GetTableAndColumns()
|
||
|
fmt.Println(tableAndColumns)
|
||
|
}
|
||
|
|
||
|
// registerInsertHandler 注册一系列插入数据的处理函数
|
||
|
func registerInsertHandler(table model.Table) {
|
||
|
handlers["/insert/"+table.TableName] = func(w http.ResponseWriter, r *http.Request) {
|
||
|
if r.Method != http.MethodPost {
|
||
|
w.WriteHeader(http.StatusMethodNotAllowed)
|
||
|
return
|
||
|
}
|
||
|
//获取请求体
|
||
|
body := r.Body
|
||
|
defer body.Close()
|
||
|
param := make([]map[string]interface{}, 0)
|
||
|
//解析请求体
|
||
|
err := json.NewDecoder(body).Decode(¶m)
|
||
|
if err != nil {
|
||
|
w.WriteHeader(http.StatusBadRequest)
|
||
|
return
|
||
|
}
|
||
|
//校验请求体与表结构是否匹配
|
||
|
for _, p := range param {
|
||
|
if !checkParam(p, table.Columns) {
|
||
|
w.WriteHeader(http.StatusBadRequest)
|
||
|
return
|
||
|
}
|
||
|
}
|
||
|
for _, p := range param {
|
||
|
//按照顺序构造column和value
|
||
|
columns := ""
|
||
|
values := ""
|
||
|
for field, value := range p {
|
||
|
columns += field + ","
|
||
|
values += fmt.Sprintf("'%v',", value)
|
||
|
}
|
||
|
columns = columns[:len(columns)-1]
|
||
|
values = values[:len(values)-1]
|
||
|
//构造sql语句
|
||
|
execSql := fmt.Sprintf("insert into %s (%s) values (%s)", table.TableName, columns, values)
|
||
|
//执行sql语句
|
||
|
_, err := connection[table.TableName].Exec(execSql)
|
||
|
if err != nil {
|
||
|
w.WriteHeader(http.StatusInternalServerError)
|
||
|
return
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// registerQueryHandler 注册一系列查询数据的处理函数
|
||
|
func registerQueryHandler(table model.Table) {
|
||
|
handlers["/query/"+table.TableName] = func(w http.ResponseWriter, r *http.Request) {
|
||
|
if r.Method != http.MethodPost {
|
||
|
w.WriteHeader(http.StatusMethodNotAllowed)
|
||
|
return
|
||
|
}
|
||
|
//获取请求体
|
||
|
body := r.Body
|
||
|
defer body.Close()
|
||
|
param := make(map[string]interface{})
|
||
|
//解析请求体
|
||
|
err := json.NewDecoder(body).Decode(¶m)
|
||
|
if err != nil {
|
||
|
w.WriteHeader(http.StatusBadRequest)
|
||
|
return
|
||
|
}
|
||
|
//校验请求体与表结构是否匹配
|
||
|
if !checkParam(param, table.Columns) {
|
||
|
w.WriteHeader(http.StatusBadRequest)
|
||
|
return
|
||
|
}
|
||
|
//按照顺序构造column和value
|
||
|
where := "where"
|
||
|
if len(param) == 0 {
|
||
|
where = ""
|
||
|
} else {
|
||
|
for field, value := range param {
|
||
|
where += fmt.Sprintf("%s = '%v' and", field, value)
|
||
|
}
|
||
|
//去掉最后一个and
|
||
|
where = where[:len(where)-3]
|
||
|
}
|
||
|
//查询数据
|
||
|
db := connection[table.TableName]
|
||
|
//拼接sql语句
|
||
|
querySql := fmt.Sprintf("select * from %s %s", table.TableName, where)
|
||
|
rows, err := db.Query(querySql)
|
||
|
if err != nil {
|
||
|
w.WriteHeader(http.StatusInternalServerError)
|
||
|
return
|
||
|
}
|
||
|
//返回数据
|
||
|
var result []map[string]string
|
||
|
for rows.Next() {
|
||
|
columns, err := rows.Columns()
|
||
|
if err != nil {
|
||
|
w.WriteHeader(http.StatusInternalServerError)
|
||
|
return
|
||
|
}
|
||
|
values := make([]sql.RawBytes, len(columns))
|
||
|
scanArgs := make([]interface{}, len(values))
|
||
|
for i := range values {
|
||
|
scanArgs[i] = &values[i]
|
||
|
}
|
||
|
err = rows.Scan(scanArgs...)
|
||
|
if err != nil {
|
||
|
w.WriteHeader(http.StatusInternalServerError)
|
||
|
return
|
||
|
}
|
||
|
row := make(map[string]string)
|
||
|
for i, col := range values {
|
||
|
if col == nil {
|
||
|
row[columns[i]] = ""
|
||
|
} else {
|
||
|
row[columns[i]] = string(col)
|
||
|
}
|
||
|
}
|
||
|
result = append(result, row)
|
||
|
}
|
||
|
w.Header().Set("Content-Type", "application/json; charset=utf-8")
|
||
|
bytes, err := json.Marshal(result)
|
||
|
if err != nil {
|
||
|
w.WriteHeader(http.StatusInternalServerError)
|
||
|
return
|
||
|
}
|
||
|
_, err = w.Write(bytes)
|
||
|
if err != nil {
|
||
|
w.WriteHeader(http.StatusInternalServerError)
|
||
|
return
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// registerUpdateHandler 注册一系列更新数据的处理函数
|
||
|
func registerUpdateHandler(table model.Table) {
|
||
|
handlers["/update/"+table.TableName] = func(w http.ResponseWriter, r *http.Request) {
|
||
|
if r.Method != http.MethodPost {
|
||
|
w.WriteHeader(http.StatusMethodNotAllowed)
|
||
|
return
|
||
|
}
|
||
|
//获取请求体
|
||
|
body := r.Body
|
||
|
defer body.Close()
|
||
|
param := make(map[string]interface{})
|
||
|
//解析请求体
|
||
|
err := json.NewDecoder(body).Decode(¶m)
|
||
|
if err != nil {
|
||
|
w.WriteHeader(http.StatusBadRequest)
|
||
|
return
|
||
|
}
|
||
|
//校验请求体与表结构是否匹配
|
||
|
if !checkParam(param, table.Columns) {
|
||
|
w.WriteHeader(http.StatusBadRequest)
|
||
|
return
|
||
|
}
|
||
|
if param["where"] == nil {
|
||
|
w.WriteHeader(http.StatusBadRequest)
|
||
|
return
|
||
|
}
|
||
|
//按照顺序构造column和value
|
||
|
set := ""
|
||
|
for field, value := range param {
|
||
|
if field == "where" {
|
||
|
continue
|
||
|
}
|
||
|
set += fmt.Sprintf("%s = '%v',", field, value)
|
||
|
}
|
||
|
//去掉最后一个逗号
|
||
|
set = set[:len(set)-1]
|
||
|
//构造where
|
||
|
m, ok := param["where"].(map[string]interface{})
|
||
|
if !ok {
|
||
|
w.WriteHeader(http.StatusBadRequest)
|
||
|
return
|
||
|
}
|
||
|
where := "where "
|
||
|
for field, value := range m {
|
||
|
where += fmt.Sprintf("%s = '%v' and", field, value)
|
||
|
}
|
||
|
//去掉最后一个and
|
||
|
where = where[:len(where)-3]
|
||
|
//更新数据
|
||
|
db := connection[table.TableName]
|
||
|
//拼接sql语句
|
||
|
updateSql := fmt.Sprintf("update %s set %s %s", table.TableName, set, where)
|
||
|
_, err = db.Exec(updateSql)
|
||
|
if err != nil {
|
||
|
w.WriteHeader(http.StatusInternalServerError)
|
||
|
return
|
||
|
}
|
||
|
w.WriteHeader(http.StatusOK)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// registerDeleteHandler 注册一系列删除数据的处理函数
|
||
|
func registerDeleteHandler(table model.Table) {
|
||
|
handlers["/delete/"+table.TableName] = func(w http.ResponseWriter, r *http.Request) {
|
||
|
if r.Method != http.MethodPost {
|
||
|
w.WriteHeader(http.StatusMethodNotAllowed)
|
||
|
return
|
||
|
}
|
||
|
//获取请求体
|
||
|
body := r.Body
|
||
|
defer body.Close()
|
||
|
param := make(map[string]interface{})
|
||
|
//解析请求体
|
||
|
err := json.NewDecoder(body).Decode(¶m)
|
||
|
if err != nil {
|
||
|
w.WriteHeader(http.StatusBadRequest)
|
||
|
return
|
||
|
}
|
||
|
//校验请求体与表结构是否匹配
|
||
|
if !checkParam(param, table.Columns) {
|
||
|
w.WriteHeader(http.StatusBadRequest)
|
||
|
return
|
||
|
}
|
||
|
//构造where
|
||
|
where := "where "
|
||
|
if len(param) == 0 {
|
||
|
w.WriteHeader(http.StatusBadRequest)
|
||
|
return
|
||
|
} else {
|
||
|
for field, value := range param {
|
||
|
where += fmt.Sprintf("%s = '%v' and", field, value)
|
||
|
}
|
||
|
//去掉最后一个and
|
||
|
where = where[:len(where)-3]
|
||
|
}
|
||
|
//删除数据
|
||
|
db := connection[table.TableName]
|
||
|
//拼接sql语句
|
||
|
deleteSql := fmt.Sprintf("delete from %s %s", table.TableName, where)
|
||
|
_, err = db.Exec(deleteSql)
|
||
|
if err != nil {
|
||
|
w.WriteHeader(http.StatusInternalServerError)
|
||
|
return
|
||
|
}
|
||
|
w.WriteHeader(http.StatusOK)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func GetHandlers() map[string]func(http.ResponseWriter, *http.Request) {
|
||
|
return handlers
|
||
|
}
|
||
|
|
||
|
func checkParam(param map[string]interface{}, cols []string) bool {
|
||
|
for k := range param {
|
||
|
//跳过where字段
|
||
|
if k == "where" || k == "limit" || k == "offset" || k == "order" || k == "having" || k == "group" {
|
||
|
continue
|
||
|
}
|
||
|
ok := false
|
||
|
for _, col := range cols {
|
||
|
if k == col {
|
||
|
ok = true
|
||
|
continue
|
||
|
}
|
||
|
}
|
||
|
if !ok {
|
||
|
return false
|
||
|
}
|
||
|
}
|
||
|
return true
|
||
|
}
|