luwak/handler/handler.go
sjcsjc123 b6ace815d2 v1
2023-07-21 11:33:36 +08:00

354 lines
8.7 KiB
Go

package handler
import (
"database/sql"
"encoding/json"
"fmt"
_ "github.com/go-sql-driver/mysql"
"luwak/model"
"net/http"
)
// 存储所有的handler函数
var handlers map[string]func(http.ResponseWriter, *http.Request)
// 存储所有的数据库连接
var connection map[string]*sql.DB
func GetTableAndColumns() []model.Table {
var rows *sql.Rows
tables := make(map[string]map[string][]string, 0)
defer func() {
if rows == nil {
return
}
if err := rows.Close(); err != nil {
panic(err)
}
}()
for dbName, db := range connection {
tables[dbName] = make(map[string][]string, 0)
var err error
var query string
// 查询数据库下的所有表名和字段名
query = "SELECT table_name, column_name FROM information_schema.columns WHERE table_schema = ?"
rows, err = db.Query(query, dbName)
if err != nil {
panic(err)
}
for rows.Next() {
var tableName, fieldName string
err := rows.Scan(&tableName, &fieldName)
if err != nil {
panic(err)
}
tables[dbName][tableName] = append(tables[dbName][tableName], fieldName)
}
rows.Close()
}
return convertMapToTable(tables)
}
func convertMapToTable(table map[string]map[string][]string) []model.Table {
var tables []model.Table
for dbName, v := range table {
for tableName, columns := range v {
tables = append(tables, model.Table{
DbName: dbName,
TableName: tableName,
Columns: columns,
})
}
}
return tables
}
// RegisterAll 注册所有的handler函数和数据库连接
func RegisterAll() {
handlers = make(map[string]func(http.ResponseWriter, *http.Request))
connection = make(map[string]*sql.DB)
//模拟从redis获取表结构
tables := []model.Table{
{
TableName: "user",
Columns: []string{"id", "name", "age"},
},
}
for _, table := range tables {
db, err := sql.Open("mysql", "root:root@tcp(127.0.0.1:3306)/"+table.TableName)
if err != nil {
panic(err)
}
connection[table.TableName] = db
registerInsertHandler(table)
registerQueryHandler(table)
registerUpdateHandler(table)
registerDeleteHandler(table)
}
tableAndColumns := GetTableAndColumns()
fmt.Println(tableAndColumns)
}
// registerInsertHandler 注册一系列插入数据的处理函数
func registerInsertHandler(table model.Table) {
handlers["/insert/"+table.TableName] = func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
w.WriteHeader(http.StatusMethodNotAllowed)
return
}
//获取请求体
body := r.Body
defer body.Close()
param := make([]map[string]interface{}, 0)
//解析请求体
err := json.NewDecoder(body).Decode(&param)
if err != nil {
w.WriteHeader(http.StatusBadRequest)
return
}
//校验请求体与表结构是否匹配
for _, p := range param {
if !checkParam(p, table.Columns) {
w.WriteHeader(http.StatusBadRequest)
return
}
}
for _, p := range param {
//按照顺序构造column和value
columns := ""
values := ""
for field, value := range p {
columns += field + ","
values += fmt.Sprintf("'%v',", value)
}
columns = columns[:len(columns)-1]
values = values[:len(values)-1]
//构造sql语句
execSql := fmt.Sprintf("insert into %s (%s) values (%s)", table.TableName, columns, values)
//执行sql语句
_, err := connection[table.TableName].Exec(execSql)
if err != nil {
w.WriteHeader(http.StatusInternalServerError)
return
}
}
}
}
// registerQueryHandler 注册一系列查询数据的处理函数
func registerQueryHandler(table model.Table) {
handlers["/query/"+table.TableName] = func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
w.WriteHeader(http.StatusMethodNotAllowed)
return
}
//获取请求体
body := r.Body
defer body.Close()
param := make(map[string]interface{})
//解析请求体
err := json.NewDecoder(body).Decode(&param)
if err != nil {
w.WriteHeader(http.StatusBadRequest)
return
}
//校验请求体与表结构是否匹配
if !checkParam(param, table.Columns) {
w.WriteHeader(http.StatusBadRequest)
return
}
//按照顺序构造column和value
where := "where"
if len(param) == 0 {
where = ""
} else {
for field, value := range param {
where += fmt.Sprintf("%s = '%v' and", field, value)
}
//去掉最后一个and
where = where[:len(where)-3]
}
//查询数据
db := connection[table.TableName]
//拼接sql语句
querySql := fmt.Sprintf("select * from %s %s", table.TableName, where)
rows, err := db.Query(querySql)
if err != nil {
w.WriteHeader(http.StatusInternalServerError)
return
}
//返回数据
var result []map[string]string
for rows.Next() {
columns, err := rows.Columns()
if err != nil {
w.WriteHeader(http.StatusInternalServerError)
return
}
values := make([]sql.RawBytes, len(columns))
scanArgs := make([]interface{}, len(values))
for i := range values {
scanArgs[i] = &values[i]
}
err = rows.Scan(scanArgs...)
if err != nil {
w.WriteHeader(http.StatusInternalServerError)
return
}
row := make(map[string]string)
for i, col := range values {
if col == nil {
row[columns[i]] = ""
} else {
row[columns[i]] = string(col)
}
}
result = append(result, row)
}
w.Header().Set("Content-Type", "application/json; charset=utf-8")
bytes, err := json.Marshal(result)
if err != nil {
w.WriteHeader(http.StatusInternalServerError)
return
}
_, err = w.Write(bytes)
if err != nil {
w.WriteHeader(http.StatusInternalServerError)
return
}
}
}
// registerUpdateHandler 注册一系列更新数据的处理函数
func registerUpdateHandler(table model.Table) {
handlers["/update/"+table.TableName] = func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
w.WriteHeader(http.StatusMethodNotAllowed)
return
}
//获取请求体
body := r.Body
defer body.Close()
param := make(map[string]interface{})
//解析请求体
err := json.NewDecoder(body).Decode(&param)
if err != nil {
w.WriteHeader(http.StatusBadRequest)
return
}
//校验请求体与表结构是否匹配
if !checkParam(param, table.Columns) {
w.WriteHeader(http.StatusBadRequest)
return
}
if param["where"] == nil {
w.WriteHeader(http.StatusBadRequest)
return
}
//按照顺序构造column和value
set := ""
for field, value := range param {
if field == "where" {
continue
}
set += fmt.Sprintf("%s = '%v',", field, value)
}
//去掉最后一个逗号
set = set[:len(set)-1]
//构造where
m, ok := param["where"].(map[string]interface{})
if !ok {
w.WriteHeader(http.StatusBadRequest)
return
}
where := "where "
for field, value := range m {
where += fmt.Sprintf("%s = '%v' and", field, value)
}
//去掉最后一个and
where = where[:len(where)-3]
//更新数据
db := connection[table.TableName]
//拼接sql语句
updateSql := fmt.Sprintf("update %s set %s %s", table.TableName, set, where)
_, err = db.Exec(updateSql)
if err != nil {
w.WriteHeader(http.StatusInternalServerError)
return
}
w.WriteHeader(http.StatusOK)
}
}
// registerDeleteHandler 注册一系列删除数据的处理函数
func registerDeleteHandler(table model.Table) {
handlers["/delete/"+table.TableName] = func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
w.WriteHeader(http.StatusMethodNotAllowed)
return
}
//获取请求体
body := r.Body
defer body.Close()
param := make(map[string]interface{})
//解析请求体
err := json.NewDecoder(body).Decode(&param)
if err != nil {
w.WriteHeader(http.StatusBadRequest)
return
}
//校验请求体与表结构是否匹配
if !checkParam(param, table.Columns) {
w.WriteHeader(http.StatusBadRequest)
return
}
//构造where
where := "where "
if len(param) == 0 {
w.WriteHeader(http.StatusBadRequest)
return
} else {
for field, value := range param {
where += fmt.Sprintf("%s = '%v' and", field, value)
}
//去掉最后一个and
where = where[:len(where)-3]
}
//删除数据
db := connection[table.TableName]
//拼接sql语句
deleteSql := fmt.Sprintf("delete from %s %s", table.TableName, where)
_, err = db.Exec(deleteSql)
if err != nil {
w.WriteHeader(http.StatusInternalServerError)
return
}
w.WriteHeader(http.StatusOK)
}
}
func GetHandlers() map[string]func(http.ResponseWriter, *http.Request) {
return handlers
}
func checkParam(param map[string]interface{}, cols []string) bool {
for k := range param {
//跳过where字段
if k == "where" || k == "limit" || k == "offset" || k == "order" || k == "having" || k == "group" {
continue
}
ok := false
for _, col := range cols {
if k == col {
ok = true
continue
}
}
if !ok {
return false
}
}
return true
}