From cc9f29b71c8d78180aa5815c9fa19bd262421092 Mon Sep 17 00:00:00 2001 From: kelvin Date: Sun, 23 Jul 2023 17:18:40 +0800 Subject: [PATCH] update --- main.go | 85 ++++++++++++++++++++++++++++-------------------------- service.go | 72 +++++++++++++++++++-------------------------- 2 files changed, 74 insertions(+), 83 deletions(-) diff --git a/main.go b/main.go index 997a2eb..6e32c89 100644 --- a/main.go +++ b/main.go @@ -134,6 +134,7 @@ type ModelInfo struct { ModelPrice float64 `json:"model_price"` ModelPrice2 float64 `json:"model_price2"` ModelPrepayment int `json:"model_prepayment"` + ServerId int `json:"server_id"` } var ( @@ -705,39 +706,8 @@ func Cors() gin.HandlerFunc { } } -func main() { - - // 禁用控制台颜色,将日志写入文件时不需要控制台颜色。 - gin.DisableConsoleColor() - - // 记录到文件。 - filename := time.Now().Format("20060102150405") + ".log" - f, _ := os.Create("logs/gin" + filename) - //gin.DefaultWriter = io.MultiWriter(f) - // 如果需要同时将日志写入文件和控制台,请使用以下代码。 - gin.DefaultWriter = io.MultiWriter(f, os.Stdout) - - log.SetOutput(gin.DefaultWriter) - - r := gin.Default() - - r.Use(Cors()) - - r.GET("/dashboard/billing/credit_grants", checkKeyMid(), balance) - - r.GET("/v1/models", handleGetModels) - r.OPTIONS("/v1/*path", handleOptions) - r.POST("/v1/chat/completions", checkKeyMid(), completions) - r.POST("/v1/completions", checkKeyMid(), completions) - r.POST("/v1/embeddings", checkKeyMid(), embeddings) - r.POST("/v1/edits", checkKeyMid(), edit) - r.POST("/v1/images/generations", checkKeyMid(), images) - - r.POST("/mock1", mockBalanceInquiry) - r.POST("/mock2", mockBalanceConsumption) - +func test_redis() { Redis = InitRedis() - //添加reids测试数据 //var serverInfo ServerInfo = ServerInfo{ // ServerAddress: "https://gptp.any-door.cn", @@ -797,24 +767,54 @@ func main() { // modelInfoStr2, _ := json.Marshal(&modelInfo2) // Redis.Set(context.Background(), "model:images-generations", modelInfoStr2, 0) - // var userInfo UserInfo = UserInfo{ - // UID: "1", - // SID: "1", - // } + var userInfo UserInfo = UserInfo{ + UID: "0", + SID: "1", + } // var userInfo2 UserInfo = UserInfo{ // UID: "2", // SID: "2", // } - // userInfoStr, _ := json.Marshal(&userInfo) + userInfoStr, _ := json.Marshal(&userInfo) // userInfoStr2, _ := json.Marshal(&userInfo2) - // Redis.Set(context.Background(), "user:8aeb3747-715c-48e8-8b80-aec815949f22", userInfoStr, 0) + Redis.Set(context.Background(), "user:key0", userInfoStr, 0) // Redis.Set(context.Background(), "user:AK-7d8ab782-a152-4cc1-9972-568713465c96", userInfoStr2, 0) - // Redis.IncrByFloat(context.Background(), "user:1:balance", 1000).Result() + Redis.IncrByFloat(context.Background(), "user:0:balance", 1000).Result() // Redis.IncrByFloat(context.Background(), "user:2:balance", 1000).Result() +} + +func main() { + + // 禁用控制台颜色,将日志写入文件时不需要控制台颜色。 + gin.DisableConsoleColor() + + // 记录到文件。 + filename := time.Now().Format("20060102150405") + ".log" + f, _ := os.Create("logs/gin" + filename) + //gin.DefaultWriter = io.MultiWriter(f) + // 如果需要同时将日志写入文件和控制台,请使用以下代码。 + gin.DefaultWriter = io.MultiWriter(f, os.Stdout) - //r.Run("127.0.0.1:8080") - //docker下使用 + log.SetOutput(gin.DefaultWriter) + + r := gin.Default() + + //添加跨域支持 + r.Use(Cors()) + + r.GET("/dashboard/billing/credit_grants", checkKeyMid(), balance) + + r.GET("/v1/models", handleGetModels) + r.OPTIONS("/v1/*path", handleOptions) + r.POST("/v1/chat/completions", checkKeyMid(), completions) + r.POST("/v1/completions", checkKeyMid(), completions) + r.POST("/v1/embeddings", checkKeyMid(), embeddings) + r.POST("/v1/edits", checkKeyMid(), edit) + r.POST("/v1/images/generations", checkKeyMid(), images) + + //r.POST("/mock1", mockBalanceInquiry) + //r.POST("/mock2", mockBalanceConsumption) // 定义一个GET请求测试接口 r.GET("/ping", func(c *gin.Context) { @@ -823,5 +823,8 @@ func main() { }) }) + //添加测试数据 + //test_redis() + r.Run("0.0.0.0:8080") } diff --git a/service.go b/service.go index 2683d0e..8b82a09 100644 --- a/service.go +++ b/service.go @@ -5,6 +5,7 @@ import ( "encoding/json" "errors" "log" + "strconv" "time" ) @@ -59,12 +60,24 @@ func queryBlance(key string) (float64, error) { // 余额查询 func checkBlance(key string, model string) (ServerInfo, error) { var serverInfo ServerInfo + + //获取模型价格 + modelPriceStr, err := Redis.Get(context.Background(), "model:"+model).Result() + if err != nil { + return serverInfo, errors.New("模型信息不存在") + } + var modelInfo ModelInfo + err = json.Unmarshal([]byte(modelPriceStr), &modelInfo) + if err != nil { + return serverInfo, errors.New("模型信息解析失败") + } + //获取用户信息 userInfoStr, err := Redis.Get(context.Background(), "user:"+key).Result() var userInfo UserInfo err = json.Unmarshal([]byte(userInfoStr), &userInfo) //获取服务器信息 - serverInfoStr, err := Redis.Get(context.Background(), "server:"+userInfo.SID).Result() + serverInfoStr, err := Redis.Get(context.Background(), "server:"+strconv.Itoa(modelInfo.ServerId)).Result() if err != nil { return serverInfo, errors.New("服务器信息不存在") } @@ -72,16 +85,7 @@ func checkBlance(key string, model string) (ServerInfo, error) { if err != nil { return serverInfo, errors.New("服务器信息解析失败") } - //获取模型价格 - modelPriceStr, err := Redis.Get(context.Background(), "model:"+model).Result() - if err != nil { - return serverInfo, errors.New("模型信息不存在") - } - var modelInfo ModelInfo - err = json.Unmarshal([]byte(modelPriceStr), &modelInfo) - if err != nil { - return serverInfo, errors.New("模型信息解析失败") - } + //计算余额-先扣除指定金额 balance, err := Redis.IncrByFloat(context.Background(), "user:"+userInfo.UID+":balance", -(float64(modelInfo.ModelPrepayment) * modelInfo.ModelPrice)).Result() if err != nil { @@ -99,12 +103,24 @@ func checkBlance(key string, model string) (ServerInfo, error) { // 余额查询 for images func checkBlanceForImages(key string, model string, n int) (ServerInfo, error) { var serverInfo ServerInfo + + //获取模型价格 + modelPriceStr, err := Redis.Get(context.Background(), "model:"+model).Result() + if err != nil { + return serverInfo, errors.New("模型信息不存在") + } + var modelInfo ModelInfo + err = json.Unmarshal([]byte(modelPriceStr), &modelInfo) + if err != nil { + return serverInfo, errors.New("模型信息解析失败") + } + //获取用户信息 userInfoStr, err := Redis.Get(context.Background(), "user:"+key).Result() var userInfo UserInfo err = json.Unmarshal([]byte(userInfoStr), &userInfo) //获取服务器信息 - serverInfoStr, err := Redis.Get(context.Background(), "server:"+userInfo.SID).Result() + serverInfoStr, err := Redis.Get(context.Background(), "server:"+strconv.Itoa(modelInfo.ServerId)).Result() if err != nil { return serverInfo, errors.New("服务器信息不存在") } @@ -112,16 +128,7 @@ func checkBlanceForImages(key string, model string, n int) (ServerInfo, error) { if err != nil { return serverInfo, errors.New("服务器信息解析失败") } - //获取模型价格 - modelPriceStr, err := Redis.Get(context.Background(), "model:"+model).Result() - if err != nil { - return serverInfo, errors.New("模型信息不存在") - } - var modelInfo ModelInfo - err = json.Unmarshal([]byte(modelPriceStr), &modelInfo) - if err != nil { - return serverInfo, errors.New("模型信息解析失败") - } + //计算余额-先扣除指定金额 balance, err := Redis.IncrByFloat(context.Background(), "user:"+userInfo.UID+":balance", -(float64(modelInfo.ModelPrepayment*n) * modelInfo.ModelPrice)).Result() if err != nil { @@ -138,20 +145,11 @@ func checkBlanceForImages(key string, model string, n int) (ServerInfo, error) { // 预扣返还 func checkBlanceReturn(key string, model string) error { - var serverInfo ServerInfo //获取用户信息 userInfoStr, err := Redis.Get(context.Background(), "user:"+key).Result() var userInfo UserInfo err = json.Unmarshal([]byte(userInfoStr), &userInfo) - //获取服务器信息 - serverInfoStr, err := Redis.Get(context.Background(), "server:"+userInfo.SID).Result() - if err != nil { - return errors.New("服务器信息不存在") - } - err = json.Unmarshal([]byte(serverInfoStr), &serverInfo) - if err != nil { - return errors.New("服务器信息解析失败") - } + //获取模型价格 modelPriceStr, err := Redis.Get(context.Background(), "model:"+model).Result() if err != nil { @@ -169,20 +167,10 @@ func checkBlanceReturn(key string, model string) error { // 预扣返还 for images func checkBlanceReturnForImages(key string, model string, n int) error { - var serverInfo ServerInfo //获取用户信息 userInfoStr, err := Redis.Get(context.Background(), "user:"+key).Result() var userInfo UserInfo err = json.Unmarshal([]byte(userInfoStr), &userInfo) - //获取服务器信息 - serverInfoStr, err := Redis.Get(context.Background(), "server:"+userInfo.SID).Result() - if err != nil { - return errors.New("服务器信息不存在") - } - err = json.Unmarshal([]byte(serverInfoStr), &serverInfo) - if err != nil { - return errors.New("服务器信息解析失败") - } //获取模型价格 modelPriceStr, err := Redis.Get(context.Background(), "model:"+model).Result() if err != nil {