diff --git a/deploy/Dockerfile b/deploy/Dockerfile index 7a3fd2d..6f5e7c4 100644 --- a/deploy/Dockerfile +++ b/deploy/Dockerfile @@ -1,6 +1,10 @@ FROM alpine:latest WORKDIR /app +RUN apk add --no-cache tzdata +ENV TZ Asia/Shanghai +RUN ln -snf /usr/share/zoneinfo/$TZ /etc/localtime && echo $TZ > /etc/timezone + RUN adduser -S -D -H -h /app -s /sbin/nologin -u 1000 app ADD api2gpt-mid /app/ diff --git a/main.go b/main.go index c3ad534..8ddee31 100644 --- a/main.go +++ b/main.go @@ -321,6 +321,12 @@ func completions(c *gin.Context) { if chatRequest.Stream { // 流式回应,处理 proxy.ModifyResponse = func(resp *http.Response) error { + log.Printf("流式回应 http status code: %v", resp.StatusCode) + if resp.StatusCode != http.StatusOK { + //退回预扣除的余额 + checkBlanceReturn(key, chatRequest.Model) + return nil + } chatRequestId := "" reqContent := "" reader := bufio.NewReader(resp.Body) @@ -510,17 +516,17 @@ func main() { Redis = InitRedis() //添加reids测试数据 - // var serverInfo ServerInfo = ServerInfo{ - // ServerAddress: "https://gptp.any-door.cn", - // AvailableKey: "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,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,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", - // } + //var serverInfo ServerInfo = ServerInfo{ + // ServerAddress: "https://gptp.any-door.cn", + // AvailableKey: "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", + //} // var serverInfo2 ServerInfo = ServerInfo{ // ServerAddress: "https://azure.any-door.cn", // AvailableKey: "6c4d2c65970b40e482e7cd27adb0d119", // } - // serverInfoStr, _ := json.Marshal(&serverInfo) + //serverInfoStr, _ := json.Marshal(&serverInfo) // serverInfoStr2, _ := json.Marshal(&serverInfo2) - // Redis.Set(context.Background(), "server:1", serverInfoStr, 0) + //Redis.Set(context.Background(), "server:1", serverInfoStr, 0) // Redis.Set(context.Background(), "server:2", serverInfoStr2, 0) // var modelInfo ModelInfo = ModelInfo{ diff --git a/service.go b/service.go index e61f211..50432c6 100644 --- a/service.go +++ b/service.go @@ -87,7 +87,7 @@ func checkBlance(key string, model string) (ServerInfo, error) { if err != nil { return serverInfo, errors.New("余额计算失败") } - log.Printf("用户余额 %f key: %v", balance, key) + log.Printf("用户余额 %f key: %v 预扣了:%f", balance, key, (float64(modelInfo.ModelPrepayment) * modelInfo.ModelPrice)) if balance < 0 { Redis.IncrByFloat(context.Background(), "user:"+userInfo.UID+":balance", float64(modelInfo.ModelPrepayment)*modelInfo.ModelPrice).Result() return serverInfo, errors.New("用户余额不足") @@ -96,6 +96,36 @@ func checkBlance(key string, model string) (ServerInfo, error) { return serverInfo, nil } +func checkBlanceReturn(key string, model string) error { + var serverInfo ServerInfo + //获取用户信息 + userInfoStr, err := Redis.Get(context.Background(), "user:"+key).Result() + var userInfo UserInfo + err = json.Unmarshal([]byte(userInfoStr), &userInfo) + //获取服务器信息 + serverInfoStr, err := Redis.Get(context.Background(), "server:"+userInfo.SID).Result() + if err != nil { + return errors.New("服务器信息不存在") + } + err = json.Unmarshal([]byte(serverInfoStr), &serverInfo) + if err != nil { + return errors.New("服务器信息解析失败") + } + //获取模型价格 + modelPriceStr, err := Redis.Get(context.Background(), "model:"+model).Result() + if err != nil { + return errors.New("模型信息不存在") + } + var modelInfo ModelInfo + err = json.Unmarshal([]byte(modelPriceStr), &modelInfo) + if err != nil { + return errors.New("模型信息解析失败") + } + balance, err := Redis.IncrByFloat(context.Background(), "user:"+userInfo.UID+":balance", (float64(modelInfo.ModelPrepayment) * modelInfo.ModelPrice)).Result() + log.Printf("用户余额 %f key: %v 返还预扣:%f", balance, key, (float64(modelInfo.ModelPrepayment) * modelInfo.ModelPrice)) + return nil +} + // 余额消费 func consumption(key string, model string, prompt_tokens int, completion_tokens int, total_tokens int, msg_id string) (string, error) { //获取用户信息