重构后的版本,用户添回TPM与RPM,接口调用添加MAX_TOKENS设置

dev
Kelvin 3 years ago
parent 2725f4ab00
commit 23158a7185

@ -25,7 +25,7 @@ steps:
# 来源目录
source: deploy/
# 目标服务器目录
target: /data/wwwroot/api2gpt/mid
target: /data/wwwroot/api2gpt/mid2
script:
- cd /data/wwwroot/api2gpt/mid
- cd /data/wwwroot/api2gpt/mid2
- docker-compose up -d --build

@ -4,11 +4,12 @@ import (
"api2gpt-mid/model"
"bytes"
"encoding/json"
"io/ioutil"
"io"
"net/http"
"os"
)
// 余额消费
// BalanceConsumption 余额消费
func BalanceConsumption(key string, modelStr string, prompt_tokens int, completion_tokens int, total_tokens int, msg_id string) (string, error) {
var data = model.Consumption{
SecretKey: key,
@ -23,7 +24,10 @@ func BalanceConsumption(key string, modelStr string, prompt_tokens int, completi
// 构造post请求的body
reqBody := bytes.NewBuffer(jsonData)
url := "http://172.17.0.1:8080/other/usageRecord"
url := os.Getenv("SERVER_API_USAGE_RECORD_STRING")
if url == "" {
url = "http://172.17.0.1:8080/other/usageRecord"
}
req2, err := http.NewRequest("POST", url, reqBody)
// 设置http请求的header
@ -34,9 +38,14 @@ func BalanceConsumption(key string, modelStr string, prompt_tokens int, completi
if err != nil {
return "", err
}
defer resp.Body.Close()
defer func(Body io.ReadCloser) {
err := Body.Close()
if err != nil {
}
}(resp.Body)
body, err := ioutil.ReadAll(resp.Body)
body, err := io.ReadAll(resp.Body)
if err != nil {
return "", err
}

@ -12,7 +12,11 @@ var RDB *redis.Client
// InitRedisClient This function is called after init()
func InitRedisClient() (err error) {
SysLog("Redis start connection")
opt, err := redis.ParseURL("redis://@localhost:6379/0?dial_timeout=5s")
redisConnStr := os.Getenv("REDIS_CONN_STRING")
if redisConnStr == "" {
redisConnStr = "redis://@localhost:6379/0?dial_timeout=5s"
}
opt, err := redis.ParseURL(redisConnStr)
//opt, err := redis.ParseURL(os.Getenv("REDIS_CONN_STRING"))
if err != nil {
FatalLog("failed to parse Redis connection string: " + err.Error())

@ -7,7 +7,7 @@ import (
"strings"
)
// 计算Messages中的token数量
// NumTokensFromMessages 计算Messages中的token数量
func NumTokensFromMessages(messages []model.Message, model string) int {
if strings.Contains(model, "gpt-3.5") {
model = "gpt-3.5-turbo"
@ -29,7 +29,7 @@ func NumTokensFromMessages(messages []model.Message, model string) int {
return numTokens
}
// 计算String中的token数量
// NumTokensFromString 计算String中的token数量
func NumTokensFromString(msg string, model string) int {
if strings.Contains(model, "gpt-3.5") {
model = "gpt-3.5-turbo"

@ -13,7 +13,7 @@ type CreditSummary struct {
TotalRemaining float64 `json:"total_remaining"`
}
// 余额查询
// Balance 余额查询
func Balance(c *gin.Context) {
auth := c.Request.Header.Get("Authorization")
key := strings.Trim(auth, "Bearer ")

@ -10,7 +10,6 @@ import (
"fmt"
"github.com/gin-gonic/gin"
"io"
"io/ioutil"
"log"
"math/rand"
"net/http"
@ -66,12 +65,13 @@ func Images(c *gin.Context) {
keyList := strings.Split(serverInfo.AvailableKey, ",")
if len(keyList) > 1 {
// 随机数种子
rand.Seed(time.Now().UnixNano())
source := rand.NewSource(time.Now().UnixNano())
random := rand.New(source)
// 从数组中随机选择一个元素
serverKey = keyList[rand.Intn(len(keyList))]
serverKey = keyList[random.Intn(len(keyList))]
}
req.Header.Set("Authorization", "Bearer "+serverKey)
req.Body = ioutil.NopCloser(bytes.NewReader(newReqBody))
req.Body = io.NopCloser(bytes.NewReader(newReqBody))
}
sss, err := json.Marshal(imagesRequest)
log.Printf("开始处理返回逻辑 %d", string(sss))
@ -79,34 +79,42 @@ func Images(c *gin.Context) {
proxy.ModifyResponse = func(resp *http.Response) error {
if resp.StatusCode != http.StatusOK {
//退回预扣除的余额
service.CheckBlanceReturnForImages(key, modelStr, imagesRequest.N)
err := service.CheckBlanceReturnForImages(key, modelStr, imagesRequest.N)
if err != nil {
return err
}
return nil
}
resp.Header.Set("Openai-Organization", "api2gpt")
var imagesResponse model.ImagesResponse
body, err := ioutil.ReadAll(resp.Body)
body, err := io.ReadAll(resp.Body)
if err != nil {
log.Printf("读取返回数据出错: %v", err)
return err
}
json.Unmarshal(body, &imagesResponse)
resp.Body = ioutil.NopCloser(bytes.NewReader(body))
err = json.Unmarshal(body, &imagesResponse)
if err != nil {
log.Printf("json解析数据出错: %v", err)
return err
}
resp.Body = io.NopCloser(bytes.NewReader(body))
log.Printf("image size: %v", len(imagesResponse.Data))
timestamp := time.Now().Unix()
timestampID := "img-" + fmt.Sprintf("%d", timestamp)
//消费余额
service.ConsumptionForImages(key, modelStr, imagesRequest.N, len(imagesResponse.Data), timestampID)
_, err = service.ConsumptionForImages(key, modelStr, imagesRequest.N, len(imagesResponse.Data), timestampID)
if err != nil {
return err
}
return nil
}
proxy.ServeHTTP(c.Writer, c.Request)
}
func Edit(c *gin.Context) {
var prompt_tokens int
var complate_tokens int
var total_tokens int
var promptTokens int
var complateTokens int
var totalTokens int
var chatRequest model.ChatRequest
if err := c.ShouldBindJSON(&chatRequest); err != nil {
c.AbortWithStatusJSON(400, gin.H{"error": err.Error()})
@ -116,7 +124,7 @@ func Edit(c *gin.Context) {
auth := c.Request.Header.Get("Authorization")
key := auth[7:]
//根据KEY调用用户余额接口判断是否有足够的余额 后期可考虑判断max_tokens参数来调整
serverInfo, err := service.CheckBlance(key, chatRequest.Model)
serverInfo, err := service.CheckBlance(key, chatRequest.Model, chatRequest.MaxTokens)
if err != nil {
c.AbortWithStatusJSON(403, gin.H{"error": err.Error()})
log.Printf("请求出错 KEY: %v Model: %v ERROR: %v", key, chatRequest.Model, err)
@ -152,12 +160,13 @@ func Edit(c *gin.Context) {
keyList := strings.Split(serverInfo.AvailableKey, ",")
if len(keyList) > 1 {
// 随机数种子
rand.Seed(time.Now().UnixNano())
source := rand.NewSource(time.Now().UnixNano())
random := rand.New(source)
// 从数组中随机选择一个元素
serverKey = keyList[rand.Intn(len(keyList))]
serverKey = keyList[random.Intn(len(keyList))]
}
req.Header.Set("Authorization", "Bearer "+serverKey)
req.Body = ioutil.NopCloser(bytes.NewReader(newReqBody))
req.Body = io.NopCloser(bytes.NewReader(newReqBody))
}
sss, err := json.Marshal(chatRequest)
log.Printf("开始处理返回逻辑 %d", string(sss))
@ -165,21 +174,28 @@ func Edit(c *gin.Context) {
proxy.ModifyResponse = func(resp *http.Response) error {
resp.Header.Set("Openai-Organization", "api2gpt")
var chatResponse model.ChatResponse
body, err := ioutil.ReadAll(resp.Body)
body, err := io.ReadAll(resp.Body)
if err != nil {
log.Printf("读取返回数据出错: %v", err)
return err
}
json.Unmarshal(body, &chatResponse)
prompt_tokens = chatResponse.Usage.PromptTokens
complate_tokens = chatResponse.Usage.CompletionTokens
total_tokens = chatResponse.Usage.TotalTokens
resp.Body = ioutil.NopCloser(bytes.NewReader(body))
log.Printf("prompt_tokens: %v complate_tokens: %v total_tokens: %v", prompt_tokens, complate_tokens, total_tokens)
err = json.Unmarshal(body, &chatResponse)
if err != nil {
log.Printf("json解析数据出错: %v", err)
return err
}
promptTokens = chatResponse.Usage.PromptTokens
complateTokens = chatResponse.Usage.CompletionTokens
totalTokens = chatResponse.Usage.TotalTokens
resp.Body = io.NopCloser(bytes.NewReader(body))
log.Printf("prompt_tokens: %v complate_tokens: %v total_tokens: %v", promptTokens, complateTokens, totalTokens)
timestamp := time.Now().Unix()
timestampID := "edit-" + fmt.Sprintf("%d", timestamp)
//消费余额
service.Consumption(key, chatRequest.Model, prompt_tokens, 0, total_tokens, timestampID)
_, err = service.Consumption(key, chatRequest.Model, promptTokens, 0, totalTokens, timestampID)
if err != nil {
return err
}
return nil
}
@ -188,19 +204,16 @@ func Edit(c *gin.Context) {
}
func Embeddings(c *gin.Context) {
var prompt_tokens int
var total_tokens int
var promptTokens int
var totalTokens int
var chatRequest model.ChatRequest
if err := c.ShouldBindJSON(&chatRequest); err != nil {
c.AbortWithStatusJSON(400, gin.H{"error": err.Error()})
return
}
auth := c.Request.Header.Get("Authorization")
//key := strings.Trim(auth, "Bearer ")
key := auth[7:]
//根据KEY调用用户余额接口判断是否有足够的余额 后期可考虑判断max_tokens参数来调整
serverInfo, err := service.CheckBlance(key, chatRequest.Model)
serverInfo, err := service.CheckBlance(key, chatRequest.Model, chatRequest.MaxTokens)
if err != nil {
c.AbortWithStatusJSON(403, gin.H{"error": err.Error()})
log.Printf("请求出错 KEY: %v Model: %v ERROR: %v", key, chatRequest.Model, err)
@ -236,12 +249,13 @@ func Embeddings(c *gin.Context) {
keyList := strings.Split(serverInfo.AvailableKey, ",")
if len(keyList) > 1 {
// 随机数种子
rand.Seed(time.Now().UnixNano())
source := rand.NewSource(time.Now().UnixNano())
random := rand.New(source)
// 从数组中随机选择一个元素
serverKey = keyList[rand.Intn(len(keyList))]
serverKey = keyList[random.Intn(len(keyList))]
}
req.Header.Set("Authorization", "Bearer "+serverKey)
req.Body = ioutil.NopCloser(bytes.NewReader(newReqBody))
req.Body = io.NopCloser(bytes.NewReader(newReqBody))
}
sss, err := json.Marshal(chatRequest)
log.Printf("开始处理返回逻辑 %d", string(sss))
@ -249,20 +263,27 @@ func Embeddings(c *gin.Context) {
proxy.ModifyResponse = func(resp *http.Response) error {
resp.Header.Set("Openai-Organization", "api2gpt")
var chatResponse model.ChatResponse
body, err := ioutil.ReadAll(resp.Body)
body, err := io.ReadAll(resp.Body)
if err != nil {
log.Printf("读取返回数据出错: %v", err)
return err
}
json.Unmarshal(body, &chatResponse)
prompt_tokens = chatResponse.Usage.PromptTokens
total_tokens = chatResponse.Usage.TotalTokens
resp.Body = ioutil.NopCloser(bytes.NewReader(body))
log.Printf("prompt_tokens: %v total_tokens: %v", prompt_tokens, total_tokens)
err = json.Unmarshal(body, &chatResponse)
if err != nil {
log.Printf("json解析数据出错: %v", err)
return err
}
promptTokens = chatResponse.Usage.PromptTokens
totalTokens = chatResponse.Usage.TotalTokens
resp.Body = io.NopCloser(bytes.NewReader(body))
log.Printf("prompt_tokens: %v total_tokens: %v", promptTokens, totalTokens)
timestamp := time.Now().Unix()
timestampID := "emb-" + fmt.Sprintf("%d", timestamp)
//消费余额
service.Consumption(key, chatRequest.Model, prompt_tokens, 0, total_tokens, timestampID)
_, err = service.Consumption(key, chatRequest.Model, promptTokens, 0, totalTokens, timestampID)
if err != nil {
return err
}
return nil
}
@ -271,9 +292,9 @@ func Embeddings(c *gin.Context) {
}
func Completions(c *gin.Context) {
var prompt_tokens int
var completion_tokens int
var total_tokens int
var promptTokens int
var completionTokens int
var totalTokens int
var chatRequest model.ChatRequest
if err := c.ShouldBindJSON(&chatRequest); err != nil {
c.AbortWithStatusJSON(400, gin.H{"error": err.Error()})
@ -282,8 +303,7 @@ func Completions(c *gin.Context) {
auth := c.Request.Header.Get("Authorization")
key := auth[7:]
//根据KEY调用用户余额接口判断是否有足够的余额 后期可考虑判断max_tokens参数来调整
serverInfo, err := service.CheckBlance(key, chatRequest.Model)
serverInfo, err := service.CheckBlance(key, chatRequest.Model, chatRequest.MaxTokens)
if err != nil {
c.AbortWithStatusJSON(403, gin.H{"error": err.Error()})
log.Printf("请求出错 KEY: %v Model: %v ERROR: %v", key, chatRequest.Model, err)
@ -318,12 +338,13 @@ func Completions(c *gin.Context) {
keyList := strings.Split(serverInfo.AvailableKey, ",")
if len(keyList) > 1 {
// 随机数种子
rand.Seed(time.Now().UnixNano())
source := rand.NewSource(time.Now().UnixNano())
random := rand.New(source)
// 从数组中随机选择一个元素
serverKey = keyList[rand.Intn(len(keyList))]
serverKey = keyList[random.Intn(len(keyList))]
}
req.Header.Set("Authorization", "Bearer "+serverKey)
req.Body = ioutil.NopCloser(bytes.NewReader(newReqBody))
req.Body = io.NopCloser(bytes.NewReader(newReqBody))
}
sss, err := json.Marshal(chatRequest)
if err != nil {
@ -336,7 +357,10 @@ func Completions(c *gin.Context) {
log.Printf("流式回应 http status code: %v", resp.StatusCode)
if resp.StatusCode != http.StatusOK {
//退回预扣除的余额
service.CheckBlanceReturn(key, chatRequest.Model)
err = service.CheckBlanceReturn(key, chatRequest.Model, chatRequest.MaxTokens)
if err != nil {
return err
}
return nil
}
chatRequestId := ""
@ -361,7 +385,10 @@ func Completions(c *gin.Context) {
//去除回应中的data:前缀
var trimStr = strings.Trim(string(chunk), "data: ")
if trimStr != "\n" {
json.Unmarshal([]byte(trimStr), &chatResponse)
err := json.Unmarshal([]byte(trimStr), &chatResponse)
if err != nil {
return err
}
if chatResponse.Choices != nil {
if chatResponse.Choices[0].Text != "" {
reqContent += chatResponse.Choices[0].Text
@ -382,16 +409,19 @@ func Completions(c *gin.Context) {
}
}
if chatRequest.Model == "text-davinci-003" {
prompt_tokens = common.NumTokensFromString(chatRequest.Prompt, chatRequest.Model)
promptTokens = common.NumTokensFromString(chatRequest.Prompt, chatRequest.Model)
} else {
prompt_tokens = common.NumTokensFromMessages(chatRequest.Messages, chatRequest.Model)
promptTokens = common.NumTokensFromMessages(chatRequest.Messages, chatRequest.Model)
}
completion_tokens = common.NumTokensFromString(reqContent, chatRequest.Model)
completionTokens = common.NumTokensFromString(reqContent, chatRequest.Model)
log.Printf("返回内容:%v", reqContent)
total_tokens = prompt_tokens + completion_tokens
log.Printf("prompt_tokens: %v completion_tokens: %v total_tokens: %v", prompt_tokens, completion_tokens, total_tokens)
totalTokens = promptTokens + completionTokens
log.Printf("prompt_tokens: %v completion_tokens: %v total_tokens: %v", promptTokens, completionTokens, totalTokens)
//消费余额
service.Consumption(key, chatRequest.Model, prompt_tokens, completion_tokens, total_tokens, chatRequestId)
_, err := service.Consumption(key, chatRequest.Model, promptTokens, completionTokens, totalTokens, chatRequestId)
if err != nil {
return err
}
return nil
}
} else {
@ -399,19 +429,25 @@ func Completions(c *gin.Context) {
proxy.ModifyResponse = func(resp *http.Response) error {
resp.Header.Set("Openai-Organization", "api2gpt")
var chatResponse model.ChatResponse
body, err := ioutil.ReadAll(resp.Body)
body, err := io.ReadAll(resp.Body)
if err != nil {
log.Printf("非流式回应,处理 err: %v", err)
return err
}
json.Unmarshal(body, &chatResponse)
prompt_tokens = chatResponse.Usage.PromptTokens
completion_tokens = chatResponse.Usage.CompletionTokens
total_tokens = chatResponse.Usage.TotalTokens
resp.Body = ioutil.NopCloser(bytes.NewReader(body))
log.Printf("prompt_tokens: %v completion_tokens: %v total_tokens: %v", prompt_tokens, completion_tokens, total_tokens)
err = json.Unmarshal(body, &chatResponse)
if err != nil {
return err
}
promptTokens = chatResponse.Usage.PromptTokens
completionTokens = chatResponse.Usage.CompletionTokens
totalTokens = chatResponse.Usage.TotalTokens
resp.Body = io.NopCloser(bytes.NewReader(body))
log.Printf("prompt_tokens: %v completion_tokens: %v total_tokens: %v", promptTokens, completionTokens, totalTokens)
//消费余额
service.Consumption(key, chatRequest.Model, prompt_tokens, completion_tokens, total_tokens, chatResponse.Id)
_, err = service.Consumption(key, chatRequest.Model, promptTokens, completionTokens, totalTokens, chatResponse.Id)
if err != nil {
return err
}
return nil
}
}
@ -419,7 +455,7 @@ func Completions(c *gin.Context) {
proxy.ServeHTTP(c.Writer, c.Request)
}
// 针对接口预检OPTIONS的处理
// HandleOptions 针对接口预检OPTIONS的处理
func HandleOptions(c *gin.Context) {
// BUGFIX: fix options request, see https://github.com/diemus/azure-openai-proxy/issues/1
c.Header("Access-Control-Allow-Origin", "*")

@ -1,6 +1,6 @@
version: "3"
services:
server:
api2gpt-mid:
build:
context: .
dockerfile: Dockerfile
@ -9,9 +9,12 @@ services:
# 时区上海
TZ: Asia/Shanghai
REDIS_ADDRESS: 172.17.0.1:6379
REDIS_CONN_STRING: redis://@172.17.0.1:6379/0?dial_timeout=5s
SERVER_API_USAGE_RECORD_STRING: http://172.17.0.1:8080/other/usageRecord
privileged: true
restart: always
command: --port 8080 --log-dir /app/logs
ports:
- 8082:8080
- 8083:8080
volumes:
- ./logs:/app/logs

@ -23,17 +23,26 @@ func test_redis() {
// }
serverInfoStr, _ := json.Marshal(&serverInfo)
// serverInfoStr2, _ := json.Marshal(&serverInfo2)
common.RedisSet("server:0", string(serverInfoStr), 0)
common.RedisSet("server:1", string(serverInfoStr), 0)
err := common.RedisSet("server:0", string(serverInfoStr), 0)
if err != nil {
return
}
err = common.RedisSet("server:1", string(serverInfoStr), 0)
if err != nil {
return
}
// Redis.Set(context.Background(), "server:2", serverInfoStr2, 0)
// var modelInfo ModelInfo = ModelInfo{
// ModelName: "gpt-3.5-turbo-0613",
// ModelPrice: 0.0001,
// ModelPrepayment: 4000,
// }
// modelInfoStr, _ := json.Marshal(&modelInfo)
// Redis.Set(context.Background(), "model:gpt-3.5-turbo-0613", modelInfoStr, 0)
var modelInfo model.ModelInfo = model.ModelInfo{
ModelName: "images-generations",
ModelPrice: 0.0001,
ModelPrepayment: 4000,
}
modelInfoStr, _ := json.Marshal(&modelInfo)
err = common.RedisSet("model:images-generations", string(modelInfoStr), 0)
if err != nil {
return
}
// var modelInfo1 ModelInfo = ModelInfo{
// ModelName: "gpt-3.5-turbo-16k",
// ModelPrice: 0.0001,
@ -82,10 +91,16 @@ func test_redis() {
// }
userInfoStr, _ := json.Marshal(&userInfo)
// userInfoStr2, _ := json.Marshal(&userInfo2)
common.RedisSet("user:key0", string(userInfoStr), 0)
err = common.RedisSet("user:key0", string(userInfoStr), 0)
if err != nil {
return
}
//Redis.Set(context.Background(), "user:key0", userInfoStr, 0)
// Redis.Set(context.Background(), "user:AK-7d8ab782-a152-4cc1-9972-568713465c96", userInfoStr2, 0)
common.RedisIncrByFloat("user:0:balance", 1000)
_, err = common.RedisIncrByFloat("user:0:balance", 1000)
if err != nil {
return
}
//Redis.IncrByFloat(context.Background(), "user:0:balance", 1000).Result()
// Redis.IncrByFloat(context.Background(), "user:2:balance", 1000).Result()
}
@ -113,7 +128,7 @@ func main() {
router.SetRouter(server)
//添加测试数据
test_redis()
//test_redis()
var port = os.Getenv("PORT")
if port == "" {

@ -17,6 +17,7 @@ type ChatRequest struct {
Instruction string `json:"instruction,omitempty"`
N int `json:"n,omitempty"`
Functions interface{} `json:"functions,omitempty"`
MaxTokens int `json:"max_tokens,omitempty"`
}
type ImagesRequest struct {

@ -3,4 +3,6 @@ package model
type UserInfo struct {
UID string `json:"uid"`
SID string `json:"sid"`
RPM int `json:"rpm"`
TPM int `json:"tpm"`
}

@ -11,10 +11,10 @@ import (
"time"
)
// 检测key是否存在是否超出每分钟请求次数
// CheckKeyAndTimeCount 检测key是否存在是否超出每分钟请求次数
func CheckKeyAndTimeCount(key string) (int, error) {
var timeOut = 60 * time.Second
var timeCount = 30
var timeCount = 60
userInfoStr, err := common.RedisGet("user:" + key)
log.Printf("用户信息 %s", userInfoStr)
if err != nil {
@ -27,6 +27,9 @@ func CheckKeyAndTimeCount(key string) (int, error) {
//用户状态异常
return 401, errors.New("40004")
}
if userInfo.RPM > 0 {
timeCount = userInfo.RPM
}
count, err := common.RedisIncr("user:count:" + key)
log.Printf("用户请求次数 %d", count)
if err != nil {
@ -60,7 +63,7 @@ func QueryBlance(key string) (float64, error) {
}
// 余额查询
func CheckBlance(key string, modelStr string) (model.ServerInfo, error) {
func CheckBlance(key string, modelStr string, maxTokens int) (model.ServerInfo, error) {
var serverInfo model.ServerInfo
//获取模型价格
@ -89,13 +92,19 @@ func CheckBlance(key string, modelStr string) (model.ServerInfo, error) {
}
//计算余额-先扣除指定金额
balance, err := common.RedisIncrByFloat("user:"+userInfo.UID+":balance", -(float64(modelInfo.ModelPrepayment) * modelInfo.ModelPrice))
if maxTokens == 0 {
maxTokens = modelInfo.ModelPrepayment
}
balance, err := common.RedisIncrByFloat("user:"+userInfo.UID+":balance", -(float64(maxTokens) * modelInfo.ModelPrice))
if err != nil {
return serverInfo, errors.New("余额计算失败")
}
log.Printf("用户余额 %f key: %v 预扣了:%f", balance, key, (float64(modelInfo.ModelPrepayment) * modelInfo.ModelPrice))
log.Printf("用户余额 %f key: %v 预扣了:%f", balance, key, (float64(maxTokens) * modelInfo.ModelPrice))
if balance < 0 {
common.RedisIncrByFloat("user:"+userInfo.UID+":balance", float64(modelInfo.ModelPrepayment)*modelInfo.ModelPrice)
_, err := common.RedisIncrByFloat("user:"+userInfo.UID+":balance", float64(maxTokens)*modelInfo.ModelPrice)
if err != nil {
return serverInfo, errors.New("用户缓存出错")
}
return serverInfo, errors.New("用户余额不足")
}
@ -138,7 +147,10 @@ func CheckBlanceForImages(key string, modelStr string, n int) (model.ServerInfo,
}
log.Printf("用户余额 %f key: %v 预扣了:%f", balance, key, (float64(modelInfo.ModelPrepayment*n) * modelInfo.ModelPrice))
if balance < 0 {
common.RedisIncrByFloat("user:"+userInfo.UID+":balance", float64(modelInfo.ModelPrepayment*n)*modelInfo.ModelPrice)
_, err := common.RedisIncrByFloat("user:"+userInfo.UID+":balance", float64(modelInfo.ModelPrepayment*n)*modelInfo.ModelPrice)
if err != nil {
return serverInfo, errors.New("用户缓存出错")
}
return serverInfo, errors.New("用户余额不足")
}
@ -146,7 +158,7 @@ func CheckBlanceForImages(key string, modelStr string, n int) (model.ServerInfo,
}
// 预扣返还
func CheckBlanceReturn(key string, modelStr string) error {
func CheckBlanceReturn(key string, modelStr string, maxTokens int) error {
//获取用户信息
userInfoStr, err := common.RedisGet("user:" + key)
var userInfo model.UserInfo
@ -162,8 +174,11 @@ func CheckBlanceReturn(key string, modelStr string) error {
if err != nil {
return errors.New("模型信息解析失败")
}
balance, err := common.RedisIncrByFloat("user:"+userInfo.UID+":balance", (float64(modelInfo.ModelPrepayment) * modelInfo.ModelPrice))
log.Printf("用户余额 %f key: %v 返还预扣:%f", balance, key, (float64(modelInfo.ModelPrepayment) * modelInfo.ModelPrice))
if maxTokens == 0 {
maxTokens = modelInfo.ModelPrepayment
}
balance, err := common.RedisIncrByFloat("user:"+userInfo.UID+":balance", (float64(maxTokens) * modelInfo.ModelPrice))
log.Printf("用户余额 %f key: %v 返还预扣:%f", balance, key, (float64(maxTokens) * modelInfo.ModelPrice))
return nil
}

Loading…
Cancel
Save