From 23158a718529c604ee41399c3010b6e6f50b1f1e Mon Sep 17 00:00:00 2001 From: Kelvin Date: Sun, 3 Sep 2023 15:50:48 +0800 Subject: [PATCH] =?UTF-8?q?=E9=87=8D=E6=9E=84=E5=90=8E=E7=9A=84=E7=89=88?= =?UTF-8?q?=E6=9C=AC=EF=BC=8C=E7=94=A8=E6=88=B7=E6=B7=BB=E5=9B=9ETPM?= =?UTF-8?q?=E4=B8=8ERPM=EF=BC=8C=E6=8E=A5=E5=8F=A3=E8=B0=83=E7=94=A8?= =?UTF-8?q?=E6=B7=BB=E5=8A=A0MAX=5FTOKENS=E8=AE=BE=E7=BD=AE?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .drone.yml | 4 +- api/server.go | 19 +++-- common/redis.go | 6 +- common/utils.go | 4 +- controller/billing.go | 2 +- controller/relay.go | 170 +++++++++++++++++++++++--------------- deploy/docker-compose.yml | 7 +- main.go | 39 ++++++--- model/openai.go | 1 + model/user.go | 2 + service/business.go | 35 +++++--- 11 files changed, 187 insertions(+), 102 deletions(-) diff --git a/.drone.yml b/.drone.yml index 125912e..c4badb5 100644 --- a/.drone.yml +++ b/.drone.yml @@ -25,7 +25,7 @@ steps: # 来源目录 source: deploy/ # 目标服务器目录 - target: /data/wwwroot/api2gpt/mid + target: /data/wwwroot/api2gpt/mid2 script: - - cd /data/wwwroot/api2gpt/mid + - cd /data/wwwroot/api2gpt/mid2 - docker-compose up -d --build \ No newline at end of file diff --git a/api/server.go b/api/server.go index b7e92b1..abbaaaa 100644 --- a/api/server.go +++ b/api/server.go @@ -4,11 +4,12 @@ import ( "api2gpt-mid/model" "bytes" "encoding/json" - "io/ioutil" + "io" "net/http" + "os" ) -// 余额消费 +// BalanceConsumption 余额消费 func BalanceConsumption(key string, modelStr string, prompt_tokens int, completion_tokens int, total_tokens int, msg_id string) (string, error) { var data = model.Consumption{ SecretKey: key, @@ -23,7 +24,10 @@ func BalanceConsumption(key string, modelStr string, prompt_tokens int, completi // 构造post请求的body reqBody := bytes.NewBuffer(jsonData) - url := "http://172.17.0.1:8080/other/usageRecord" + url := os.Getenv("SERVER_API_USAGE_RECORD_STRING") + if url == "" { + url = "http://172.17.0.1:8080/other/usageRecord" + } req2, err := http.NewRequest("POST", url, reqBody) // 设置http请求的header @@ -34,9 +38,14 @@ func BalanceConsumption(key string, modelStr string, prompt_tokens int, completi if err != nil { return "", err } - defer resp.Body.Close() + defer func(Body io.ReadCloser) { + err := Body.Close() + if err != nil { + + } + }(resp.Body) - body, err := ioutil.ReadAll(resp.Body) + body, err := io.ReadAll(resp.Body) if err != nil { return "", err } diff --git a/common/redis.go b/common/redis.go index eedbfd9..04028a4 100644 --- a/common/redis.go +++ b/common/redis.go @@ -12,7 +12,11 @@ var RDB *redis.Client // InitRedisClient This function is called after init() func InitRedisClient() (err error) { SysLog("Redis start connection") - opt, err := redis.ParseURL("redis://@localhost:6379/0?dial_timeout=5s") + redisConnStr := os.Getenv("REDIS_CONN_STRING") + if redisConnStr == "" { + redisConnStr = "redis://@localhost:6379/0?dial_timeout=5s" + } + opt, err := redis.ParseURL(redisConnStr) //opt, err := redis.ParseURL(os.Getenv("REDIS_CONN_STRING")) if err != nil { FatalLog("failed to parse Redis connection string: " + err.Error()) diff --git a/common/utils.go b/common/utils.go index 31e4514..b717efd 100644 --- a/common/utils.go +++ b/common/utils.go @@ -7,7 +7,7 @@ import ( "strings" ) -// 计算Messages中的token数量 +// NumTokensFromMessages 计算Messages中的token数量 func NumTokensFromMessages(messages []model.Message, model string) int { if strings.Contains(model, "gpt-3.5") { model = "gpt-3.5-turbo" @@ -29,7 +29,7 @@ func NumTokensFromMessages(messages []model.Message, model string) int { return numTokens } -// 计算String中的token数量 +// NumTokensFromString 计算String中的token数量 func NumTokensFromString(msg string, model string) int { if strings.Contains(model, "gpt-3.5") { model = "gpt-3.5-turbo" diff --git a/controller/billing.go b/controller/billing.go index c80ef8a..21b2a6b 100644 --- a/controller/billing.go +++ b/controller/billing.go @@ -13,7 +13,7 @@ type CreditSummary struct { TotalRemaining float64 `json:"total_remaining"` } -// 余额查询 +// Balance 余额查询 func Balance(c *gin.Context) { auth := c.Request.Header.Get("Authorization") key := strings.Trim(auth, "Bearer ") diff --git a/controller/relay.go b/controller/relay.go index 0d01928..083db7b 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -10,7 +10,6 @@ import ( "fmt" "github.com/gin-gonic/gin" "io" - "io/ioutil" "log" "math/rand" "net/http" @@ -66,12 +65,13 @@ func Images(c *gin.Context) { keyList := strings.Split(serverInfo.AvailableKey, ",") if len(keyList) > 1 { // 随机数种子 - rand.Seed(time.Now().UnixNano()) + source := rand.NewSource(time.Now().UnixNano()) + random := rand.New(source) // 从数组中随机选择一个元素 - serverKey = keyList[rand.Intn(len(keyList))] + serverKey = keyList[random.Intn(len(keyList))] } req.Header.Set("Authorization", "Bearer "+serverKey) - req.Body = ioutil.NopCloser(bytes.NewReader(newReqBody)) + req.Body = io.NopCloser(bytes.NewReader(newReqBody)) } sss, err := json.Marshal(imagesRequest) log.Printf("开始处理返回逻辑 %d", string(sss)) @@ -79,34 +79,42 @@ func Images(c *gin.Context) { proxy.ModifyResponse = func(resp *http.Response) error { if resp.StatusCode != http.StatusOK { //退回预扣除的余额 - service.CheckBlanceReturnForImages(key, modelStr, imagesRequest.N) + err := service.CheckBlanceReturnForImages(key, modelStr, imagesRequest.N) + if err != nil { + return err + } return nil } resp.Header.Set("Openai-Organization", "api2gpt") var imagesResponse model.ImagesResponse - body, err := ioutil.ReadAll(resp.Body) + body, err := io.ReadAll(resp.Body) if err != nil { log.Printf("读取返回数据出错: %v", err) return err } - json.Unmarshal(body, &imagesResponse) - resp.Body = ioutil.NopCloser(bytes.NewReader(body)) + err = json.Unmarshal(body, &imagesResponse) + if err != nil { + log.Printf("json解析数据出错: %v", err) + return err + } + resp.Body = io.NopCloser(bytes.NewReader(body)) log.Printf("image size: %v", len(imagesResponse.Data)) timestamp := time.Now().Unix() timestampID := "img-" + fmt.Sprintf("%d", timestamp) //消费余额 - service.ConsumptionForImages(key, modelStr, imagesRequest.N, len(imagesResponse.Data), timestampID) + _, err = service.ConsumptionForImages(key, modelStr, imagesRequest.N, len(imagesResponse.Data), timestampID) + if err != nil { + return err + } return nil } - proxy.ServeHTTP(c.Writer, c.Request) - } func Edit(c *gin.Context) { - var prompt_tokens int - var complate_tokens int - var total_tokens int + var promptTokens int + var complateTokens int + var totalTokens int var chatRequest model.ChatRequest if err := c.ShouldBindJSON(&chatRequest); err != nil { c.AbortWithStatusJSON(400, gin.H{"error": err.Error()}) @@ -116,7 +124,7 @@ func Edit(c *gin.Context) { auth := c.Request.Header.Get("Authorization") key := auth[7:] //根据KEY调用用户余额接口,判断是否有足够的余额, 后期可考虑判断max_tokens参数来调整 - serverInfo, err := service.CheckBlance(key, chatRequest.Model) + serverInfo, err := service.CheckBlance(key, chatRequest.Model, chatRequest.MaxTokens) if err != nil { c.AbortWithStatusJSON(403, gin.H{"error": err.Error()}) log.Printf("请求出错 KEY: %v Model: %v ERROR: %v", key, chatRequest.Model, err) @@ -152,12 +160,13 @@ func Edit(c *gin.Context) { keyList := strings.Split(serverInfo.AvailableKey, ",") if len(keyList) > 1 { // 随机数种子 - rand.Seed(time.Now().UnixNano()) + source := rand.NewSource(time.Now().UnixNano()) + random := rand.New(source) // 从数组中随机选择一个元素 - serverKey = keyList[rand.Intn(len(keyList))] + serverKey = keyList[random.Intn(len(keyList))] } req.Header.Set("Authorization", "Bearer "+serverKey) - req.Body = ioutil.NopCloser(bytes.NewReader(newReqBody)) + req.Body = io.NopCloser(bytes.NewReader(newReqBody)) } sss, err := json.Marshal(chatRequest) log.Printf("开始处理返回逻辑 %d", string(sss)) @@ -165,21 +174,28 @@ func Edit(c *gin.Context) { proxy.ModifyResponse = func(resp *http.Response) error { resp.Header.Set("Openai-Organization", "api2gpt") var chatResponse model.ChatResponse - body, err := ioutil.ReadAll(resp.Body) + body, err := io.ReadAll(resp.Body) if err != nil { log.Printf("读取返回数据出错: %v", err) return err } - json.Unmarshal(body, &chatResponse) - prompt_tokens = chatResponse.Usage.PromptTokens - complate_tokens = chatResponse.Usage.CompletionTokens - total_tokens = chatResponse.Usage.TotalTokens - resp.Body = ioutil.NopCloser(bytes.NewReader(body)) - log.Printf("prompt_tokens: %v complate_tokens: %v total_tokens: %v", prompt_tokens, complate_tokens, total_tokens) + err = json.Unmarshal(body, &chatResponse) + if err != nil { + log.Printf("json解析数据出错: %v", err) + return err + } + promptTokens = chatResponse.Usage.PromptTokens + complateTokens = chatResponse.Usage.CompletionTokens + totalTokens = chatResponse.Usage.TotalTokens + resp.Body = io.NopCloser(bytes.NewReader(body)) + log.Printf("prompt_tokens: %v complate_tokens: %v total_tokens: %v", promptTokens, complateTokens, totalTokens) timestamp := time.Now().Unix() timestampID := "edit-" + fmt.Sprintf("%d", timestamp) //消费余额 - service.Consumption(key, chatRequest.Model, prompt_tokens, 0, total_tokens, timestampID) + _, err = service.Consumption(key, chatRequest.Model, promptTokens, 0, totalTokens, timestampID) + if err != nil { + return err + } return nil } @@ -188,19 +204,16 @@ func Edit(c *gin.Context) { } func Embeddings(c *gin.Context) { - var prompt_tokens int - var total_tokens int + var promptTokens int + var totalTokens int var chatRequest model.ChatRequest if err := c.ShouldBindJSON(&chatRequest); err != nil { c.AbortWithStatusJSON(400, gin.H{"error": err.Error()}) return } - auth := c.Request.Header.Get("Authorization") - //key := strings.Trim(auth, "Bearer ") key := auth[7:] - //根据KEY调用用户余额接口,判断是否有足够的余额, 后期可考虑判断max_tokens参数来调整 - serverInfo, err := service.CheckBlance(key, chatRequest.Model) + serverInfo, err := service.CheckBlance(key, chatRequest.Model, chatRequest.MaxTokens) if err != nil { c.AbortWithStatusJSON(403, gin.H{"error": err.Error()}) log.Printf("请求出错 KEY: %v Model: %v ERROR: %v", key, chatRequest.Model, err) @@ -236,12 +249,13 @@ func Embeddings(c *gin.Context) { keyList := strings.Split(serverInfo.AvailableKey, ",") if len(keyList) > 1 { // 随机数种子 - rand.Seed(time.Now().UnixNano()) + source := rand.NewSource(time.Now().UnixNano()) + random := rand.New(source) // 从数组中随机选择一个元素 - serverKey = keyList[rand.Intn(len(keyList))] + serverKey = keyList[random.Intn(len(keyList))] } req.Header.Set("Authorization", "Bearer "+serverKey) - req.Body = ioutil.NopCloser(bytes.NewReader(newReqBody)) + req.Body = io.NopCloser(bytes.NewReader(newReqBody)) } sss, err := json.Marshal(chatRequest) log.Printf("开始处理返回逻辑 %d", string(sss)) @@ -249,20 +263,27 @@ func Embeddings(c *gin.Context) { proxy.ModifyResponse = func(resp *http.Response) error { resp.Header.Set("Openai-Organization", "api2gpt") var chatResponse model.ChatResponse - body, err := ioutil.ReadAll(resp.Body) + body, err := io.ReadAll(resp.Body) if err != nil { log.Printf("读取返回数据出错: %v", err) return err } - json.Unmarshal(body, &chatResponse) - prompt_tokens = chatResponse.Usage.PromptTokens - total_tokens = chatResponse.Usage.TotalTokens - resp.Body = ioutil.NopCloser(bytes.NewReader(body)) - log.Printf("prompt_tokens: %v total_tokens: %v", prompt_tokens, total_tokens) + err = json.Unmarshal(body, &chatResponse) + if err != nil { + log.Printf("json解析数据出错: %v", err) + return err + } + promptTokens = chatResponse.Usage.PromptTokens + totalTokens = chatResponse.Usage.TotalTokens + resp.Body = io.NopCloser(bytes.NewReader(body)) + log.Printf("prompt_tokens: %v total_tokens: %v", promptTokens, totalTokens) timestamp := time.Now().Unix() timestampID := "emb-" + fmt.Sprintf("%d", timestamp) //消费余额 - service.Consumption(key, chatRequest.Model, prompt_tokens, 0, total_tokens, timestampID) + _, err = service.Consumption(key, chatRequest.Model, promptTokens, 0, totalTokens, timestampID) + if err != nil { + return err + } return nil } @@ -271,9 +292,9 @@ func Embeddings(c *gin.Context) { } func Completions(c *gin.Context) { - var prompt_tokens int - var completion_tokens int - var total_tokens int + var promptTokens int + var completionTokens int + var totalTokens int var chatRequest model.ChatRequest if err := c.ShouldBindJSON(&chatRequest); err != nil { c.AbortWithStatusJSON(400, gin.H{"error": err.Error()}) @@ -282,8 +303,7 @@ func Completions(c *gin.Context) { auth := c.Request.Header.Get("Authorization") key := auth[7:] - //根据KEY调用用户余额接口,判断是否有足够的余额, 后期可考虑判断max_tokens参数来调整 - serverInfo, err := service.CheckBlance(key, chatRequest.Model) + serverInfo, err := service.CheckBlance(key, chatRequest.Model, chatRequest.MaxTokens) if err != nil { c.AbortWithStatusJSON(403, gin.H{"error": err.Error()}) log.Printf("请求出错 KEY: %v Model: %v ERROR: %v", key, chatRequest.Model, err) @@ -318,12 +338,13 @@ func Completions(c *gin.Context) { keyList := strings.Split(serverInfo.AvailableKey, ",") if len(keyList) > 1 { // 随机数种子 - rand.Seed(time.Now().UnixNano()) + source := rand.NewSource(time.Now().UnixNano()) + random := rand.New(source) // 从数组中随机选择一个元素 - serverKey = keyList[rand.Intn(len(keyList))] + serverKey = keyList[random.Intn(len(keyList))] } req.Header.Set("Authorization", "Bearer "+serverKey) - req.Body = ioutil.NopCloser(bytes.NewReader(newReqBody)) + req.Body = io.NopCloser(bytes.NewReader(newReqBody)) } sss, err := json.Marshal(chatRequest) if err != nil { @@ -336,7 +357,10 @@ func Completions(c *gin.Context) { log.Printf("流式回应 http status code: %v", resp.StatusCode) if resp.StatusCode != http.StatusOK { //退回预扣除的余额 - service.CheckBlanceReturn(key, chatRequest.Model) + err = service.CheckBlanceReturn(key, chatRequest.Model, chatRequest.MaxTokens) + if err != nil { + return err + } return nil } chatRequestId := "" @@ -361,7 +385,10 @@ func Completions(c *gin.Context) { //去除回应中的data:前缀 var trimStr = strings.Trim(string(chunk), "data: ") if trimStr != "\n" { - json.Unmarshal([]byte(trimStr), &chatResponse) + err := json.Unmarshal([]byte(trimStr), &chatResponse) + if err != nil { + return err + } if chatResponse.Choices != nil { if chatResponse.Choices[0].Text != "" { reqContent += chatResponse.Choices[0].Text @@ -382,16 +409,19 @@ func Completions(c *gin.Context) { } } if chatRequest.Model == "text-davinci-003" { - prompt_tokens = common.NumTokensFromString(chatRequest.Prompt, chatRequest.Model) + promptTokens = common.NumTokensFromString(chatRequest.Prompt, chatRequest.Model) } else { - prompt_tokens = common.NumTokensFromMessages(chatRequest.Messages, chatRequest.Model) + promptTokens = common.NumTokensFromMessages(chatRequest.Messages, chatRequest.Model) } - completion_tokens = common.NumTokensFromString(reqContent, chatRequest.Model) + completionTokens = common.NumTokensFromString(reqContent, chatRequest.Model) log.Printf("返回内容:%v", reqContent) - total_tokens = prompt_tokens + completion_tokens - log.Printf("prompt_tokens: %v completion_tokens: %v total_tokens: %v", prompt_tokens, completion_tokens, total_tokens) + totalTokens = promptTokens + completionTokens + log.Printf("prompt_tokens: %v completion_tokens: %v total_tokens: %v", promptTokens, completionTokens, totalTokens) //消费余额 - service.Consumption(key, chatRequest.Model, prompt_tokens, completion_tokens, total_tokens, chatRequestId) + _, err := service.Consumption(key, chatRequest.Model, promptTokens, completionTokens, totalTokens, chatRequestId) + if err != nil { + return err + } return nil } } else { @@ -399,19 +429,25 @@ func Completions(c *gin.Context) { proxy.ModifyResponse = func(resp *http.Response) error { resp.Header.Set("Openai-Organization", "api2gpt") var chatResponse model.ChatResponse - body, err := ioutil.ReadAll(resp.Body) + body, err := io.ReadAll(resp.Body) if err != nil { log.Printf("非流式回应,处理 err: %v", err) return err } - json.Unmarshal(body, &chatResponse) - prompt_tokens = chatResponse.Usage.PromptTokens - completion_tokens = chatResponse.Usage.CompletionTokens - total_tokens = chatResponse.Usage.TotalTokens - resp.Body = ioutil.NopCloser(bytes.NewReader(body)) - log.Printf("prompt_tokens: %v completion_tokens: %v total_tokens: %v", prompt_tokens, completion_tokens, total_tokens) + err = json.Unmarshal(body, &chatResponse) + if err != nil { + return err + } + promptTokens = chatResponse.Usage.PromptTokens + completionTokens = chatResponse.Usage.CompletionTokens + totalTokens = chatResponse.Usage.TotalTokens + resp.Body = io.NopCloser(bytes.NewReader(body)) + log.Printf("prompt_tokens: %v completion_tokens: %v total_tokens: %v", promptTokens, completionTokens, totalTokens) //消费余额 - service.Consumption(key, chatRequest.Model, prompt_tokens, completion_tokens, total_tokens, chatResponse.Id) + _, err = service.Consumption(key, chatRequest.Model, promptTokens, completionTokens, totalTokens, chatResponse.Id) + if err != nil { + return err + } return nil } } @@ -419,7 +455,7 @@ func Completions(c *gin.Context) { proxy.ServeHTTP(c.Writer, c.Request) } -// 针对接口预检OPTIONS的处理 +// HandleOptions 针对接口预检OPTIONS的处理 func HandleOptions(c *gin.Context) { // BUGFIX: fix options request, see https://github.com/diemus/azure-openai-proxy/issues/1 c.Header("Access-Control-Allow-Origin", "*") diff --git a/deploy/docker-compose.yml b/deploy/docker-compose.yml index fa41857..fa8c569 100644 --- a/deploy/docker-compose.yml +++ b/deploy/docker-compose.yml @@ -1,6 +1,6 @@ version: "3" services: - server: + api2gpt-mid: build: context: . dockerfile: Dockerfile @@ -9,9 +9,12 @@ services: # 时区上海 TZ: Asia/Shanghai REDIS_ADDRESS: 172.17.0.1:6379 + REDIS_CONN_STRING: redis://@172.17.0.1:6379/0?dial_timeout=5s + SERVER_API_USAGE_RECORD_STRING: http://172.17.0.1:8080/other/usageRecord privileged: true restart: always + command: --port 8080 --log-dir /app/logs ports: - - 8082:8080 + - 8083:8080 volumes: - ./logs:/app/logs \ No newline at end of file diff --git a/main.go b/main.go index 07690dc..76ab1f8 100644 --- a/main.go +++ b/main.go @@ -23,17 +23,26 @@ func test_redis() { // } serverInfoStr, _ := json.Marshal(&serverInfo) // serverInfoStr2, _ := json.Marshal(&serverInfo2) - common.RedisSet("server:0", string(serverInfoStr), 0) - common.RedisSet("server:1", string(serverInfoStr), 0) + err := common.RedisSet("server:0", string(serverInfoStr), 0) + if err != nil { + return + } + err = common.RedisSet("server:1", string(serverInfoStr), 0) + if err != nil { + return + } // Redis.Set(context.Background(), "server:2", serverInfoStr2, 0) - // var modelInfo ModelInfo = ModelInfo{ - // ModelName: "gpt-3.5-turbo-0613", - // ModelPrice: 0.0001, - // ModelPrepayment: 4000, - // } - // modelInfoStr, _ := json.Marshal(&modelInfo) - // Redis.Set(context.Background(), "model:gpt-3.5-turbo-0613", modelInfoStr, 0) + var modelInfo model.ModelInfo = model.ModelInfo{ + ModelName: "images-generations", + ModelPrice: 0.0001, + ModelPrepayment: 4000, + } + modelInfoStr, _ := json.Marshal(&modelInfo) + err = common.RedisSet("model:images-generations", string(modelInfoStr), 0) + if err != nil { + return + } // var modelInfo1 ModelInfo = ModelInfo{ // ModelName: "gpt-3.5-turbo-16k", // ModelPrice: 0.0001, @@ -82,10 +91,16 @@ func test_redis() { // } userInfoStr, _ := json.Marshal(&userInfo) // userInfoStr2, _ := json.Marshal(&userInfo2) - common.RedisSet("user:key0", string(userInfoStr), 0) + err = common.RedisSet("user:key0", string(userInfoStr), 0) + if err != nil { + return + } //Redis.Set(context.Background(), "user:key0", userInfoStr, 0) // Redis.Set(context.Background(), "user:AK-7d8ab782-a152-4cc1-9972-568713465c96", userInfoStr2, 0) - common.RedisIncrByFloat("user:0:balance", 1000) + _, err = common.RedisIncrByFloat("user:0:balance", 1000) + if err != nil { + return + } //Redis.IncrByFloat(context.Background(), "user:0:balance", 1000).Result() // Redis.IncrByFloat(context.Background(), "user:2:balance", 1000).Result() } @@ -113,7 +128,7 @@ func main() { router.SetRouter(server) //添加测试数据 - test_redis() + //test_redis() var port = os.Getenv("PORT") if port == "" { diff --git a/model/openai.go b/model/openai.go index c831c0d..664240b 100644 --- a/model/openai.go +++ b/model/openai.go @@ -17,6 +17,7 @@ type ChatRequest struct { Instruction string `json:"instruction,omitempty"` N int `json:"n,omitempty"` Functions interface{} `json:"functions,omitempty"` + MaxTokens int `json:"max_tokens,omitempty"` } type ImagesRequest struct { diff --git a/model/user.go b/model/user.go index 8904af7..076bc15 100644 --- a/model/user.go +++ b/model/user.go @@ -3,4 +3,6 @@ package model type UserInfo struct { UID string `json:"uid"` SID string `json:"sid"` + RPM int `json:"rpm"` + TPM int `json:"tpm"` } diff --git a/service/business.go b/service/business.go index 03defdb..4e95565 100644 --- a/service/business.go +++ b/service/business.go @@ -11,10 +11,10 @@ import ( "time" ) -// 检测key是否存在,是否超出每分钟请求次数 +// CheckKeyAndTimeCount 检测key是否存在,是否超出每分钟请求次数 func CheckKeyAndTimeCount(key string) (int, error) { var timeOut = 60 * time.Second - var timeCount = 30 + var timeCount = 60 userInfoStr, err := common.RedisGet("user:" + key) log.Printf("用户信息 %s", userInfoStr) if err != nil { @@ -27,6 +27,9 @@ func CheckKeyAndTimeCount(key string) (int, error) { //用户状态异常 return 401, errors.New("40004") } + if userInfo.RPM > 0 { + timeCount = userInfo.RPM + } count, err := common.RedisIncr("user:count:" + key) log.Printf("用户请求次数 %d", count) if err != nil { @@ -60,7 +63,7 @@ func QueryBlance(key string) (float64, error) { } // 余额查询 -func CheckBlance(key string, modelStr string) (model.ServerInfo, error) { +func CheckBlance(key string, modelStr string, maxTokens int) (model.ServerInfo, error) { var serverInfo model.ServerInfo //获取模型价格 @@ -89,13 +92,19 @@ func CheckBlance(key string, modelStr string) (model.ServerInfo, error) { } //计算余额-先扣除指定金额 - balance, err := common.RedisIncrByFloat("user:"+userInfo.UID+":balance", -(float64(modelInfo.ModelPrepayment) * modelInfo.ModelPrice)) + if maxTokens == 0 { + maxTokens = modelInfo.ModelPrepayment + } + balance, err := common.RedisIncrByFloat("user:"+userInfo.UID+":balance", -(float64(maxTokens) * modelInfo.ModelPrice)) if err != nil { return serverInfo, errors.New("余额计算失败") } - log.Printf("用户余额 %f key: %v 预扣了:%f", balance, key, (float64(modelInfo.ModelPrepayment) * modelInfo.ModelPrice)) + log.Printf("用户余额 %f key: %v 预扣了:%f", balance, key, (float64(maxTokens) * modelInfo.ModelPrice)) if balance < 0 { - common.RedisIncrByFloat("user:"+userInfo.UID+":balance", float64(modelInfo.ModelPrepayment)*modelInfo.ModelPrice) + _, err := common.RedisIncrByFloat("user:"+userInfo.UID+":balance", float64(maxTokens)*modelInfo.ModelPrice) + if err != nil { + return serverInfo, errors.New("用户缓存出错") + } return serverInfo, errors.New("用户余额不足") } @@ -138,7 +147,10 @@ func CheckBlanceForImages(key string, modelStr string, n int) (model.ServerInfo, } log.Printf("用户余额 %f key: %v 预扣了:%f", balance, key, (float64(modelInfo.ModelPrepayment*n) * modelInfo.ModelPrice)) if balance < 0 { - common.RedisIncrByFloat("user:"+userInfo.UID+":balance", float64(modelInfo.ModelPrepayment*n)*modelInfo.ModelPrice) + _, err := common.RedisIncrByFloat("user:"+userInfo.UID+":balance", float64(modelInfo.ModelPrepayment*n)*modelInfo.ModelPrice) + if err != nil { + return serverInfo, errors.New("用户缓存出错") + } return serverInfo, errors.New("用户余额不足") } @@ -146,7 +158,7 @@ func CheckBlanceForImages(key string, modelStr string, n int) (model.ServerInfo, } // 预扣返还 -func CheckBlanceReturn(key string, modelStr string) error { +func CheckBlanceReturn(key string, modelStr string, maxTokens int) error { //获取用户信息 userInfoStr, err := common.RedisGet("user:" + key) var userInfo model.UserInfo @@ -162,8 +174,11 @@ func CheckBlanceReturn(key string, modelStr string) error { if err != nil { return errors.New("模型信息解析失败") } - balance, err := common.RedisIncrByFloat("user:"+userInfo.UID+":balance", (float64(modelInfo.ModelPrepayment) * modelInfo.ModelPrice)) - log.Printf("用户余额 %f key: %v 返还预扣:%f", balance, key, (float64(modelInfo.ModelPrepayment) * modelInfo.ModelPrice)) + if maxTokens == 0 { + maxTokens = modelInfo.ModelPrepayment + } + balance, err := common.RedisIncrByFloat("user:"+userInfo.UID+":balance", (float64(maxTokens) * modelInfo.ModelPrice)) + log.Printf("用户余额 %f key: %v 返还预扣:%f", balance, key, (float64(maxTokens) * modelInfo.ModelPrice)) return nil }