diff --git a/api.go b/api.go deleted file mode 100644 index 75fa846..0000000 --- a/api.go +++ /dev/null @@ -1,102 +0,0 @@ -package main - -import ( - "bytes" - "encoding/json" - "io/ioutil" - "net/http" - - "github.com/gin-gonic/gin" -) - -type BalanceInfo struct { - ServerAddress string `json:"server_address"` - AvailableKey string `json:"available_key"` - UserBalance float64 `json:"user_balance"` - TokenRatio float64 `json:"token_ratio"` -} - -type Consumption struct { - SecretKey string `json:"secretKey"` - Model string `json:"model"` - MsgId string `json:"msgId"` - PromptTokens int `json:"promptTokens"` - CompletionTokens int `json:"completionTokens"` - TotalTokens int `json:"totalTokens"` -} - -func mockBalanceInquiry(c *gin.Context) { - c.JSON(http.StatusOK, gin.H{ - "server_address": "https://gptp.any-door.cn", - "available_key": "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", - "user_balance": 10000, - "token_ratio": 1000, - }) -} - -func mockBalanceConsumption(c *gin.Context) { - c.JSON(http.StatusOK, gin.H{ - "success": "true", - }) -} - -// 余额查询api调用 -func balanceInquiry(key string, model string) (*BalanceInfo, error) { - url := "http://localhost:8080/mock1?key=" + key + "&model=" + model - req, err := http.NewRequest("POST", url, nil) - req.Header.Set("Content-Type", "application/json") - - client := &http.Client{} - resp, err := client.Do(req) - if err != nil { - panic(err) - } - defer resp.Body.Close() - - body, err := ioutil.ReadAll(resp.Body) - if err != nil { - return nil, err - } - - var balanceInfo BalanceInfo - if err := json.Unmarshal(body, &balanceInfo); err != nil { - return nil, err - } - - return &balanceInfo, nil -} - -// 余额消费 -func balanceConsumption(key string, model string, prompt_tokens int, completion_tokens int, total_tokens int, msg_id string) (string, error) { - var data = Consumption{ - SecretKey: key, - Model: model, - MsgId: msg_id, - PromptTokens: prompt_tokens, - CompletionTokens: completion_tokens, - TotalTokens: total_tokens, - } - - jsonData, err := json.Marshal(data) - // 构造post请求的body - reqBody := bytes.NewBuffer(jsonData) - - url := "http://172.17.0.1:8080/other/usageRecord" - req2, err := http.NewRequest("POST", url, reqBody) - - // 设置http请求的header - req2.Header.Set("Content-Type", "application/json") - - client := &http.Client{} - resp, err := client.Do(req2) - if err != nil { - return "", err - } - defer resp.Body.Close() - - body, err := ioutil.ReadAll(resp.Body) - if err != nil { - return "", err - } - return string(body), nil -} diff --git a/api/server.go b/api/server.go new file mode 100644 index 0000000..b7e92b1 --- /dev/null +++ b/api/server.go @@ -0,0 +1,44 @@ +package api + +import ( + "api2gpt-mid/model" + "bytes" + "encoding/json" + "io/ioutil" + "net/http" +) + +// 余额消费 +func BalanceConsumption(key string, modelStr string, prompt_tokens int, completion_tokens int, total_tokens int, msg_id string) (string, error) { + var data = model.Consumption{ + SecretKey: key, + Model: modelStr, + MsgId: msg_id, + PromptTokens: prompt_tokens, + CompletionTokens: completion_tokens, + TotalTokens: total_tokens, + } + + jsonData, err := json.Marshal(data) + // 构造post请求的body + reqBody := bytes.NewBuffer(jsonData) + + url := "http://172.17.0.1:8080/other/usageRecord" + req2, err := http.NewRequest("POST", url, reqBody) + + // 设置http请求的header + req2.Header.Set("Content-Type", "application/json") + + client := &http.Client{} + resp, err := client.Do(req2) + if err != nil { + return "", err + } + defer resp.Body.Close() + + body, err := ioutil.ReadAll(resp.Body) + if err != nil { + return "", err + } + return string(body), nil +} diff --git a/common/utils.go b/common/utils.go new file mode 100644 index 0000000..31e4514 --- /dev/null +++ b/common/utils.go @@ -0,0 +1,50 @@ +package common + +import ( + "api2gpt-mid/model" + "fmt" + "github.com/pkoukk/tiktoken-go" + "strings" +) + +// 计算Messages中的token数量 +func NumTokensFromMessages(messages []model.Message, model string) int { + if strings.Contains(model, "gpt-3.5") { + model = "gpt-3.5-turbo" + } + if strings.Contains(model, "gpt-4") { + model = "gpt-4" + } + tkm, err := tiktoken.EncodingForModel(model) + if err != nil { + err = fmt.Errorf("getEncoding: %v", err) + return 0 + } + numTokens := 0 + for _, message := range messages { + numTokens += len(tkm.Encode(message.Content, nil, nil)) + numTokens += 6 + } + numTokens += 3 + return numTokens +} + +// 计算String中的token数量 +func NumTokensFromString(msg string, model string) int { + if strings.Contains(model, "gpt-3.5") { + model = "gpt-3.5-turbo" + } + if strings.Contains(model, "gpt-4") { + model = "gpt-4" + } + tkm, err := tiktoken.EncodingForModel(model) + if err != nil { + err = fmt.Errorf("getEncoding: %v", err) + return 0 + } + if model == "text-davinci-003" { + return len(tkm.Encode(msg, nil, nil)) + 1 + } else { + return len(tkm.Encode(msg, nil, nil)) + 9 + } +} diff --git a/controller/billing.go b/controller/billing.go new file mode 100644 index 0000000..c80ef8a --- /dev/null +++ b/controller/billing.go @@ -0,0 +1,30 @@ +package controller + +import ( + "api2gpt-mid/service" + "github.com/gin-gonic/gin" + "strings" +) + +type CreditSummary struct { + Object string `json:"object"` + TotalGranted float64 `json:"total_granted"` + TotalUsed float64 `json:"total_used"` + TotalRemaining float64 `json:"total_remaining"` +} + +// 余额查询 +func Balance(c *gin.Context) { + auth := c.Request.Header.Get("Authorization") + key := strings.Trim(auth, "Bearer ") + balance, err := service.QueryBlance(key) + if err != nil { + c.JSON(400, gin.H{"error": err.Error()}) + } + var creditSummary CreditSummary + creditSummary.Object = "credit_grant" + creditSummary.TotalGranted = 999999 + creditSummary.TotalUsed = 999999 - balance + creditSummary.TotalRemaining = balance + c.JSON(200, creditSummary) +} diff --git a/controller/model.go b/controller/model.go new file mode 100644 index 0000000..8f2fd04 --- /dev/null +++ b/controller/model.go @@ -0,0 +1,458 @@ +package controller + +import ( + "fmt" + + "github.com/gin-gonic/gin" +) + +// https://platform.openai.com/docs/api-reference/models/list + +type OpenAIModelPermission struct { + Id string `json:"id"` + Object string `json:"object"` + Created int `json:"created"` + AllowCreateEngine bool `json:"allow_create_engine"` + AllowSampling bool `json:"allow_sampling"` + AllowLogprobs bool `json:"allow_logprobs"` + AllowSearchIndices bool `json:"allow_search_indices"` + AllowView bool `json:"allow_view"` + AllowFineTuning bool `json:"allow_fine_tuning"` + Organization string `json:"organization"` + Group *string `json:"group"` + IsBlocking bool `json:"is_blocking"` +} + +type OpenAIModels struct { + Id string `json:"id"` + Object string `json:"object"` + Created int `json:"created"` + OwnedBy string `json:"owned_by"` + Permission []OpenAIModelPermission `json:"permission"` + Root string `json:"root"` + Parent *string `json:"parent"` +} + +type OpenAIError struct { + Message string `json:"message"` + Type string `json:"type"` + Param string `json:"param"` + Code any `json:"code"` +} + +type OpenAIErrorWithStatusCode struct { + OpenAIError + StatusCode int `json:"status_code"` +} + +var openAIModels []OpenAIModels +var openAIModelsMap map[string]OpenAIModels + +func init() { + var permission []OpenAIModelPermission + permission = append(permission, OpenAIModelPermission{ + Id: "modelperm-LwHkVFn8AcMItP432fKKDIKJ", + Object: "model_permission", + Created: 1626777600, + AllowCreateEngine: true, + AllowSampling: true, + AllowLogprobs: true, + AllowSearchIndices: false, + AllowView: true, + AllowFineTuning: false, + Organization: "*", + Group: nil, + IsBlocking: false, + }) + // https://platform.openai.com/docs/models/model-endpoint-compatibility + openAIModels = []OpenAIModels{ + { + Id: "dall-e", + Object: "model", + Created: 1677649963, + OwnedBy: "openai", + Permission: permission, + Root: "dall-e", + Parent: nil, + }, + { + Id: "whisper-1", + Object: "model", + Created: 1677649963, + OwnedBy: "openai", + Permission: permission, + Root: "whisper-1", + Parent: nil, + }, + { + Id: "gpt-3.5-turbo", + Object: "model", + Created: 1677649963, + OwnedBy: "openai", + Permission: permission, + Root: "gpt-3.5-turbo", + Parent: nil, + }, + { + Id: "gpt-3.5-turbo-0301", + Object: "model", + Created: 1677649963, + OwnedBy: "openai", + Permission: permission, + Root: "gpt-3.5-turbo-0301", + Parent: nil, + }, + { + Id: "gpt-3.5-turbo-0613", + Object: "model", + Created: 1677649963, + OwnedBy: "openai", + Permission: permission, + Root: "gpt-3.5-turbo-0613", + Parent: nil, + }, + { + Id: "gpt-3.5-turbo-16k", + Object: "model", + Created: 1677649963, + OwnedBy: "openai", + Permission: permission, + Root: "gpt-3.5-turbo-16k", + Parent: nil, + }, + { + Id: "gpt-3.5-turbo-16k-0613", + Object: "model", + Created: 1677649963, + OwnedBy: "openai", + Permission: permission, + Root: "gpt-3.5-turbo-16k-0613", + Parent: nil, + }, + { + Id: "gpt-4", + Object: "model", + Created: 1677649963, + OwnedBy: "openai", + Permission: permission, + Root: "gpt-4", + Parent: nil, + }, + { + Id: "gpt-4-0314", + Object: "model", + Created: 1677649963, + OwnedBy: "openai", + Permission: permission, + Root: "gpt-4-0314", + Parent: nil, + }, + { + Id: "gpt-4-0613", + Object: "model", + Created: 1677649963, + OwnedBy: "openai", + Permission: permission, + Root: "gpt-4-0613", + Parent: nil, + }, + { + Id: "gpt-4-32k", + Object: "model", + Created: 1677649963, + OwnedBy: "openai", + Permission: permission, + Root: "gpt-4-32k", + Parent: nil, + }, + { + Id: "gpt-4-32k-0314", + Object: "model", + Created: 1677649963, + OwnedBy: "openai", + Permission: permission, + Root: "gpt-4-32k-0314", + Parent: nil, + }, + { + Id: "gpt-4-32k-0613", + Object: "model", + Created: 1677649963, + OwnedBy: "openai", + Permission: permission, + Root: "gpt-4-32k-0613", + Parent: nil, + }, + { + Id: "text-embedding-ada-002", + Object: "model", + Created: 1677649963, + OwnedBy: "openai", + Permission: permission, + Root: "text-embedding-ada-002", + Parent: nil, + }, + { + Id: "text-davinci-003", + Object: "model", + Created: 1677649963, + OwnedBy: "openai", + Permission: permission, + Root: "text-davinci-003", + Parent: nil, + }, + { + Id: "text-davinci-002", + Object: "model", + Created: 1677649963, + OwnedBy: "openai", + Permission: permission, + Root: "text-davinci-002", + Parent: nil, + }, + { + Id: "text-curie-001", + Object: "model", + Created: 1677649963, + OwnedBy: "openai", + Permission: permission, + Root: "text-curie-001", + Parent: nil, + }, + { + Id: "text-babbage-001", + Object: "model", + Created: 1677649963, + OwnedBy: "openai", + Permission: permission, + Root: "text-babbage-001", + Parent: nil, + }, + { + Id: "text-ada-001", + Object: "model", + Created: 1677649963, + OwnedBy: "openai", + Permission: permission, + Root: "text-ada-001", + Parent: nil, + }, + { + Id: "text-moderation-latest", + Object: "model", + Created: 1677649963, + OwnedBy: "openai", + Permission: permission, + Root: "text-moderation-latest", + Parent: nil, + }, + { + Id: "text-moderation-stable", + Object: "model", + Created: 1677649963, + OwnedBy: "openai", + Permission: permission, + Root: "text-moderation-stable", + Parent: nil, + }, + { + Id: "text-davinci-edit-001", + Object: "model", + Created: 1677649963, + OwnedBy: "openai", + Permission: permission, + Root: "text-davinci-edit-001", + Parent: nil, + }, + { + Id: "code-davinci-edit-001", + Object: "model", + Created: 1677649963, + OwnedBy: "openai", + Permission: permission, + Root: "code-davinci-edit-001", + Parent: nil, + }, + { + Id: "claude-instant-1", + Object: "model", + Created: 1677649963, + OwnedBy: "anturopic", + Permission: permission, + Root: "claude-instant-1", + Parent: nil, + }, + { + Id: "claude-2", + Object: "model", + Created: 1677649963, + OwnedBy: "anturopic", + Permission: permission, + Root: "claude-2", + Parent: nil, + }, + { + Id: "ERNIE-Bot", + Object: "model", + Created: 1677649963, + OwnedBy: "baidu", + Permission: permission, + Root: "ERNIE-Bot", + Parent: nil, + }, + { + Id: "ERNIE-Bot-turbo", + Object: "model", + Created: 1677649963, + OwnedBy: "baidu", + Permission: permission, + Root: "ERNIE-Bot-turbo", + Parent: nil, + }, + { + Id: "Embedding-V1", + Object: "model", + Created: 1677649963, + OwnedBy: "baidu", + Permission: permission, + Root: "Embedding-V1", + Parent: nil, + }, + { + Id: "PaLM-2", + Object: "model", + Created: 1677649963, + OwnedBy: "google", + Permission: permission, + Root: "PaLM-2", + Parent: nil, + }, + { + Id: "chatglm_pro", + Object: "model", + Created: 1677649963, + OwnedBy: "zhipu", + Permission: permission, + Root: "chatglm_pro", + Parent: nil, + }, + { + Id: "chatglm_std", + Object: "model", + Created: 1677649963, + OwnedBy: "zhipu", + Permission: permission, + Root: "chatglm_std", + Parent: nil, + }, + { + Id: "chatglm_lite", + Object: "model", + Created: 1677649963, + OwnedBy: "zhipu", + Permission: permission, + Root: "chatglm_lite", + Parent: nil, + }, + { + Id: "qwen-v1", + Object: "model", + Created: 1677649963, + OwnedBy: "ali", + Permission: permission, + Root: "qwen-v1", + Parent: nil, + }, + { + Id: "qwen-plus-v1", + Object: "model", + Created: 1677649963, + OwnedBy: "ali", + Permission: permission, + Root: "qwen-plus-v1", + Parent: nil, + }, + { + Id: "SparkDesk", + Object: "model", + Created: 1677649963, + OwnedBy: "xunfei", + Permission: permission, + Root: "SparkDesk", + Parent: nil, + }, + { + Id: "360GPT_S2_V9", + Object: "model", + Created: 1677649963, + OwnedBy: "360", + Permission: permission, + Root: "360GPT_S2_V9", + Parent: nil, + }, + { + Id: "embedding-bert-512-v1", + Object: "model", + Created: 1677649963, + OwnedBy: "360", + Permission: permission, + Root: "embedding-bert-512-v1", + Parent: nil, + }, + { + Id: "embedding_s1_v1", + Object: "model", + Created: 1677649963, + OwnedBy: "360", + Permission: permission, + Root: "embedding_s1_v1", + Parent: nil, + }, + { + Id: "semantic_similarity_s1_v1", + Object: "model", + Created: 1677649963, + OwnedBy: "360", + Permission: permission, + Root: "semantic_similarity_s1_v1", + Parent: nil, + }, + { + Id: "360GPT_S2_V9.4", + Object: "model", + Created: 1677649963, + OwnedBy: "360", + Permission: permission, + Root: "360GPT_S2_V9.4", + Parent: nil, + }, + } + openAIModelsMap = make(map[string]OpenAIModels) + for _, model := range openAIModels { + openAIModelsMap[model.Id] = model + } +} + +func ListModels(c *gin.Context) { + c.JSON(200, gin.H{ + "object": "list", + "data": openAIModels, + }) +} + +func RetrieveModel(c *gin.Context) { + modelId := c.Param("model") + if model, ok := openAIModelsMap[modelId]; ok { + c.JSON(200, model) + } else { + openAIError := OpenAIError{ + Message: fmt.Sprintf("The model '%s' does not exist", modelId), + Type: "invalid_request_error", + Param: "model", + Code: "model_not_found", + } + c.JSON(200, gin.H{ + "error": openAIError, + }) + } +} diff --git a/controller/relay.go b/controller/relay.go new file mode 100644 index 0000000..0d01928 --- /dev/null +++ b/controller/relay.go @@ -0,0 +1,430 @@ +package controller + +import ( + "api2gpt-mid/common" + "api2gpt-mid/model" + "api2gpt-mid/service" + "bufio" + "bytes" + "encoding/json" + "fmt" + "github.com/gin-gonic/gin" + "io" + "io/ioutil" + "log" + "math/rand" + "net/http" + "net/http/httputil" + "net/url" + "strings" + "time" +) + +func Images(c *gin.Context) { + var imagesRequest model.ImagesRequest + var modelStr = "images-generations" + if err := c.ShouldBindJSON(&imagesRequest); err != nil { + c.AbortWithStatusJSON(400, gin.H{"error": err.Error()}) + return + } + auth := c.Request.Header.Get("Authorization") + key := auth[7:] + + serverInfo, err := service.CheckBlanceForImages(key, modelStr, imagesRequest.N) + if err != nil { + c.AbortWithStatusJSON(403, gin.H{"error": err.Error()}) + log.Printf("请求出错 KEY: %v Model: %v ERROR: %v", key, modelStr, err) + return + } + + log.Printf("请求的KEY: %v Model: %v", key, modelStr) + + remote, err := url.Parse(serverInfo.ServerAddress) + if err != nil { + c.AbortWithStatusJSON(400, gin.H{"error": err.Error()}) + return + } + + proxy := httputil.NewSingleHostReverseProxy(remote) + newReqBody, err := json.Marshal(imagesRequest) + if err != nil { + log.Printf("http request err: %v", err) + c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + proxy.Director = func(req *http.Request) { + req.Header = c.Request.Header + req.Host = remote.Host + req.URL.Scheme = remote.Scheme + req.URL.Host = remote.Host + req.URL.Path = c.Request.URL.Path + req.ContentLength = int64(len(newReqBody)) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept-Encoding", "") + serverKey := serverInfo.AvailableKey + keyList := strings.Split(serverInfo.AvailableKey, ",") + if len(keyList) > 1 { + // 随机数种子 + rand.Seed(time.Now().UnixNano()) + // 从数组中随机选择一个元素 + serverKey = keyList[rand.Intn(len(keyList))] + } + req.Header.Set("Authorization", "Bearer "+serverKey) + req.Body = ioutil.NopCloser(bytes.NewReader(newReqBody)) + } + sss, err := json.Marshal(imagesRequest) + log.Printf("开始处理返回逻辑 %d", string(sss)) + + proxy.ModifyResponse = func(resp *http.Response) error { + if resp.StatusCode != http.StatusOK { + //退回预扣除的余额 + service.CheckBlanceReturnForImages(key, modelStr, imagesRequest.N) + return nil + } + resp.Header.Set("Openai-Organization", "api2gpt") + var imagesResponse model.ImagesResponse + body, err := ioutil.ReadAll(resp.Body) + if err != nil { + log.Printf("读取返回数据出错: %v", err) + return err + } + json.Unmarshal(body, &imagesResponse) + resp.Body = ioutil.NopCloser(bytes.NewReader(body)) + log.Printf("image size: %v", len(imagesResponse.Data)) + timestamp := time.Now().Unix() + timestampID := "img-" + fmt.Sprintf("%d", timestamp) + //消费余额 + service.ConsumptionForImages(key, modelStr, imagesRequest.N, len(imagesResponse.Data), timestampID) + return nil + } + + proxy.ServeHTTP(c.Writer, c.Request) + +} + +func Edit(c *gin.Context) { + var prompt_tokens int + var complate_tokens int + var total_tokens int + var chatRequest model.ChatRequest + if err := c.ShouldBindJSON(&chatRequest); err != nil { + c.AbortWithStatusJSON(400, gin.H{"error": err.Error()}) + return + } + + auth := c.Request.Header.Get("Authorization") + key := auth[7:] + //根据KEY调用用户余额接口,判断是否有足够的余额, 后期可考虑判断max_tokens参数来调整 + serverInfo, err := service.CheckBlance(key, chatRequest.Model) + if err != nil { + c.AbortWithStatusJSON(403, gin.H{"error": err.Error()}) + log.Printf("请求出错 KEY: %v Model: %v ERROR: %v", key, chatRequest.Model, err) + return + } + + log.Printf("请求的KEY: %v Model: %v", key, chatRequest.Model) + + remote, err := url.Parse(serverInfo.ServerAddress) + if err != nil { + c.AbortWithStatusJSON(400, gin.H{"error": err.Error()}) + return + } + + proxy := httputil.NewSingleHostReverseProxy(remote) + newReqBody, err := json.Marshal(chatRequest) + if err != nil { + log.Printf("http request err: %v", err) + c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + proxy.Director = func(req *http.Request) { + req.Header = c.Request.Header + req.Host = remote.Host + req.URL.Scheme = remote.Scheme + req.URL.Host = remote.Host + req.URL.Path = c.Request.URL.Path + req.ContentLength = int64(len(newReqBody)) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept-Encoding", "") + serverKey := serverInfo.AvailableKey + keyList := strings.Split(serverInfo.AvailableKey, ",") + if len(keyList) > 1 { + // 随机数种子 + rand.Seed(time.Now().UnixNano()) + // 从数组中随机选择一个元素 + serverKey = keyList[rand.Intn(len(keyList))] + } + req.Header.Set("Authorization", "Bearer "+serverKey) + req.Body = ioutil.NopCloser(bytes.NewReader(newReqBody)) + } + sss, err := json.Marshal(chatRequest) + log.Printf("开始处理返回逻辑 %d", string(sss)) + + proxy.ModifyResponse = func(resp *http.Response) error { + resp.Header.Set("Openai-Organization", "api2gpt") + var chatResponse model.ChatResponse + body, err := ioutil.ReadAll(resp.Body) + if err != nil { + log.Printf("读取返回数据出错: %v", err) + return err + } + json.Unmarshal(body, &chatResponse) + prompt_tokens = chatResponse.Usage.PromptTokens + complate_tokens = chatResponse.Usage.CompletionTokens + total_tokens = chatResponse.Usage.TotalTokens + resp.Body = ioutil.NopCloser(bytes.NewReader(body)) + log.Printf("prompt_tokens: %v complate_tokens: %v total_tokens: %v", prompt_tokens, complate_tokens, total_tokens) + timestamp := time.Now().Unix() + timestampID := "edit-" + fmt.Sprintf("%d", timestamp) + //消费余额 + service.Consumption(key, chatRequest.Model, prompt_tokens, 0, total_tokens, timestampID) + return nil + } + + proxy.ServeHTTP(c.Writer, c.Request) + +} + +func Embeddings(c *gin.Context) { + var prompt_tokens int + var total_tokens int + var chatRequest model.ChatRequest + if err := c.ShouldBindJSON(&chatRequest); err != nil { + c.AbortWithStatusJSON(400, gin.H{"error": err.Error()}) + return + } + + auth := c.Request.Header.Get("Authorization") + //key := strings.Trim(auth, "Bearer ") + key := auth[7:] + //根据KEY调用用户余额接口,判断是否有足够的余额, 后期可考虑判断max_tokens参数来调整 + serverInfo, err := service.CheckBlance(key, chatRequest.Model) + if err != nil { + c.AbortWithStatusJSON(403, gin.H{"error": err.Error()}) + log.Printf("请求出错 KEY: %v Model: %v ERROR: %v", key, chatRequest.Model, err) + return + } + + log.Printf("请求的KEY: %v Model: %v", key, chatRequest.Model) + + remote, err := url.Parse(serverInfo.ServerAddress) + if err != nil { + c.AbortWithStatusJSON(400, gin.H{"error": err.Error()}) + return + } + + proxy := httputil.NewSingleHostReverseProxy(remote) + newReqBody, err := json.Marshal(chatRequest) + if err != nil { + log.Printf("http request err: %v", err) + c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + proxy.Director = func(req *http.Request) { + req.Header = c.Request.Header + req.Host = remote.Host + req.URL.Scheme = remote.Scheme + req.URL.Host = remote.Host + req.URL.Path = c.Request.URL.Path + req.ContentLength = int64(len(newReqBody)) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept-Encoding", "") + serverKey := serverInfo.AvailableKey + keyList := strings.Split(serverInfo.AvailableKey, ",") + if len(keyList) > 1 { + // 随机数种子 + rand.Seed(time.Now().UnixNano()) + // 从数组中随机选择一个元素 + serverKey = keyList[rand.Intn(len(keyList))] + } + req.Header.Set("Authorization", "Bearer "+serverKey) + req.Body = ioutil.NopCloser(bytes.NewReader(newReqBody)) + } + sss, err := json.Marshal(chatRequest) + log.Printf("开始处理返回逻辑 %d", string(sss)) + + proxy.ModifyResponse = func(resp *http.Response) error { + resp.Header.Set("Openai-Organization", "api2gpt") + var chatResponse model.ChatResponse + body, err := ioutil.ReadAll(resp.Body) + if err != nil { + log.Printf("读取返回数据出错: %v", err) + return err + } + json.Unmarshal(body, &chatResponse) + prompt_tokens = chatResponse.Usage.PromptTokens + total_tokens = chatResponse.Usage.TotalTokens + resp.Body = ioutil.NopCloser(bytes.NewReader(body)) + log.Printf("prompt_tokens: %v total_tokens: %v", prompt_tokens, total_tokens) + timestamp := time.Now().Unix() + timestampID := "emb-" + fmt.Sprintf("%d", timestamp) + //消费余额 + service.Consumption(key, chatRequest.Model, prompt_tokens, 0, total_tokens, timestampID) + return nil + } + + proxy.ServeHTTP(c.Writer, c.Request) + +} + +func Completions(c *gin.Context) { + var prompt_tokens int + var completion_tokens int + var total_tokens int + var chatRequest model.ChatRequest + if err := c.ShouldBindJSON(&chatRequest); err != nil { + c.AbortWithStatusJSON(400, gin.H{"error": err.Error()}) + return + } + + auth := c.Request.Header.Get("Authorization") + key := auth[7:] + //根据KEY调用用户余额接口,判断是否有足够的余额, 后期可考虑判断max_tokens参数来调整 + serverInfo, err := service.CheckBlance(key, chatRequest.Model) + if err != nil { + c.AbortWithStatusJSON(403, gin.H{"error": err.Error()}) + log.Printf("请求出错 KEY: %v Model: %v ERROR: %v", key, chatRequest.Model, err) + return + } + + log.Printf("请求的KEY: %v Model: %v", key, chatRequest.Model) + + remote, err := url.Parse(serverInfo.ServerAddress) + if err != nil { + c.AbortWithStatusJSON(400, gin.H{"error": err.Error()}) + return + } + + proxy := httputil.NewSingleHostReverseProxy(remote) + newReqBody, err := json.Marshal(chatRequest) + if err != nil { + log.Printf("http request err: %v", err) + c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + proxy.Director = func(req *http.Request) { + req.Header = c.Request.Header + req.Host = remote.Host + req.URL.Scheme = remote.Scheme + req.URL.Host = remote.Host + req.URL.Path = c.Request.URL.Path + req.ContentLength = int64(len(newReqBody)) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept-Encoding", "") + serverKey := serverInfo.AvailableKey + keyList := strings.Split(serverInfo.AvailableKey, ",") + if len(keyList) > 1 { + // 随机数种子 + rand.Seed(time.Now().UnixNano()) + // 从数组中随机选择一个元素 + serverKey = keyList[rand.Intn(len(keyList))] + } + req.Header.Set("Authorization", "Bearer "+serverKey) + req.Body = ioutil.NopCloser(bytes.NewReader(newReqBody)) + } + sss, err := json.Marshal(chatRequest) + if err != nil { + log.Printf("chatRequest 转化出错 %v", err) + } + log.Printf("开始处理返回逻辑: %v", string(sss)) + if chatRequest.Stream { + // 流式回应,处理 + proxy.ModifyResponse = func(resp *http.Response) error { + log.Printf("流式回应 http status code: %v", resp.StatusCode) + if resp.StatusCode != http.StatusOK { + //退回预扣除的余额 + service.CheckBlanceReturn(key, chatRequest.Model) + return nil + } + chatRequestId := "" + reqContent := "" + reader := bufio.NewReader(resp.Body) + headers := resp.Header + for k, v := range headers { + c.Writer.Header().Set(k, v[0]) + } + c.Writer.Header().Set("Openai-Organization", "api2gpt") + for { + chunk, err := reader.ReadBytes('\n') + if err != nil { + if err == io.EOF { + break + } + log.Printf("流式回应,处理 err %v:", err.Error()) + break + //return err + } + var chatResponse model.ChatResponse + //去除回应中的data:前缀 + var trimStr = strings.Trim(string(chunk), "data: ") + if trimStr != "\n" { + json.Unmarshal([]byte(trimStr), &chatResponse) + if chatResponse.Choices != nil { + if chatResponse.Choices[0].Text != "" { + reqContent += chatResponse.Choices[0].Text + } else { + reqContent += chatResponse.Choices[0].Delta.Content + } + + chatRequestId = chatResponse.Id + } + + // 写回数据 + _, err = c.Writer.Write([]byte(string(chunk) + "\n")) + if err != nil { + log.Printf("写回数据 err: %v", err.Error()) + return err + } + c.Writer.(http.Flusher).Flush() + } + } + if chatRequest.Model == "text-davinci-003" { + prompt_tokens = common.NumTokensFromString(chatRequest.Prompt, chatRequest.Model) + } else { + prompt_tokens = common.NumTokensFromMessages(chatRequest.Messages, chatRequest.Model) + } + completion_tokens = common.NumTokensFromString(reqContent, chatRequest.Model) + log.Printf("返回内容:%v", reqContent) + total_tokens = prompt_tokens + completion_tokens + log.Printf("prompt_tokens: %v completion_tokens: %v total_tokens: %v", prompt_tokens, completion_tokens, total_tokens) + //消费余额 + service.Consumption(key, chatRequest.Model, prompt_tokens, completion_tokens, total_tokens, chatRequestId) + return nil + } + } else { + // 非流式回应,处理 + proxy.ModifyResponse = func(resp *http.Response) error { + resp.Header.Set("Openai-Organization", "api2gpt") + var chatResponse model.ChatResponse + body, err := ioutil.ReadAll(resp.Body) + if err != nil { + log.Printf("非流式回应,处理 err: %v", err) + return err + } + json.Unmarshal(body, &chatResponse) + prompt_tokens = chatResponse.Usage.PromptTokens + completion_tokens = chatResponse.Usage.CompletionTokens + total_tokens = chatResponse.Usage.TotalTokens + resp.Body = ioutil.NopCloser(bytes.NewReader(body)) + log.Printf("prompt_tokens: %v completion_tokens: %v total_tokens: %v", prompt_tokens, completion_tokens, total_tokens) + //消费余额 + service.Consumption(key, chatRequest.Model, prompt_tokens, completion_tokens, total_tokens, chatResponse.Id) + return nil + } + } + + proxy.ServeHTTP(c.Writer, c.Request) +} + +// 针对接口预检OPTIONS的处理 +func HandleOptions(c *gin.Context) { + // BUGFIX: fix options request, see https://github.com/diemus/azure-openai-proxy/issues/1 + c.Header("Access-Control-Allow-Origin", "*") + c.Header("Access-Control-Allow-Methods", "POST, GET, OPTIONS, PUT, DELETE") + c.Header("Access-Control-Allow-Headers", "Content-Type, Authorization") + c.Status(200) + return +} diff --git a/main.go b/main.go index d0b600e..07690dc 100644 --- a/main.go +++ b/main.go @@ -3,681 +3,28 @@ package main import ( "api2gpt-mid/common" "api2gpt-mid/middleware" - "bufio" - "bytes" + "api2gpt-mid/model" + "api2gpt-mid/router" "encoding/json" - "fmt" - "io" - "io/ioutil" - "log" - "math/rand" - "net/http" - "net/http/httputil" - "net/url" + "github.com/gin-gonic/gin" "os" "strconv" - "strings" - "time" - - "github.com/gin-gonic/gin" - "github.com/pkoukk/tiktoken-go" ) -type Message struct { - Role string `json:"role,omitempty"` - Name string `json:"name,omitempty"` - Content string `json:"content,omitempty"` -} - -type ChatRequest struct { - Stream bool `json:"stream,omitempty"` - Model string `json:"model,omitempty"` - PresencePenalty float64 `json:"presence_penalty,omitempty"` - Temperature float64 `json:"temperature,omitempty"` - Messages []Message `json:"messages,omitempty"` - Prompt string `json:"prompt,omitempty"` - Input interface{} `json:"input,omitempty"` - Instruction string `json:"instruction,omitempty"` - N int `json:"n,omitempty"` - Functions interface{} `json:"functions,omitempty"` -} - -type ImagesRequest struct { - Prompt string `json:"prompt,omitempty"` - N int `json:"n,omitempty"` - Size string `json:"size,omitempty"` -} - -type ChatResponse struct { - Id string `json:"id,omitempty"` - Object string `json:"object,omitempty"` - Created int64 `json:"created,omitempty"` - Model string `json:"model,omitempty"` - Usage Usage `json:"usage,omitempty"` - Choices []Choice `json:"choices,omitempty"` -} -type ImagesResponse struct { - Created int64 `json:"created,omitempty"` - Data []ImagesData `json:"data,omitempty"` -} - -type ImagesData struct { - Url string `json:"url,omitempty"` -} - -type Usage struct { - PromptTokens int `json:"prompt_tokens,omitempty"` - CompletionTokens int `json:"completion_tokens,omitempty"` - TotalTokens int `json:"total_tokens,omitempty"` -} - -type Choice struct { - Message Message `json:"message,omitempty"` - Delta Delta `json:"delta,omitempty"` - FinishReason string `json:"finish_reason,omitempty"` - Index int `json:"index,omitempty"` - Text string `json:"text,omitempty"` -} - -type Delta struct { - Content string `json:"content,omitempty"` -} - -type CreditSummary struct { - Object string `json:"object"` - TotalGranted float64 `json:"total_granted"` - TotalUsed float64 `json:"total_used"` - TotalRemaining float64 `json:"total_remaining"` -} - -type ListModelResponse struct { - Object string `json:"object"` - Data []Model `json:"data"` -} - -type Model struct { - ID string `json:"id"` - Object string `json:"object"` - Created int `json:"created"` - OwnedBy string `json:"owned_by"` - Permission []ModelPermission `json:"permission"` - Root string `json:"root"` - Parent any `json:"parent"` -} - -type ModelPermission struct { - ID string `json:"id"` - Object string `json:"object"` - Created int `json:"created"` - AllowCreateEngine bool `json:"allow_create_engine"` - AllowSampling bool `json:"allow_sampling"` - AllowLogprobs bool `json:"allow_logprobs"` - AllowSearchIndices bool `json:"allow_search_indices"` - AllowView bool `json:"allow_view"` - AllowFineTuning bool `json:"allow_fine_tuning"` - Organization string `json:"organization"` - Group any `json:"group"` - IsBlocking bool `json:"is_blocking"` -} - -type UserInfo struct { - UID string `json:"uid"` - SID string `json:"sid"` -} - -type ServerInfo struct { - ServerAddress string `json:"server_address"` - AvailableKey string `json:"available_key"` -} - -type ModelInfo struct { - ModelName string `json:"model_name"` - ModelPrice float64 `json:"model_price"` - ModelPrice2 float64 `json:"model_price2"` - ModelPrepayment int `json:"model_prepayment"` - ServerId int `json:"server_id"` -} - -// 计算Messages中的token数量 -func numTokensFromMessages(messages []Message, model string) int { - if strings.Contains(model, "gpt-3.5") { - model = "gpt-3.5-turbo" - } - if strings.Contains(model, "gpt-4") { - model = "gpt-4" - } - tkm, err := tiktoken.EncodingForModel(model) - if err != nil { - err = fmt.Errorf("getEncoding: %v", err) - return 0 - } - numTokens := 0 - for _, message := range messages { - numTokens += len(tkm.Encode(message.Content, nil, nil)) - numTokens += 6 - } - numTokens += 3 - return numTokens -} - -// 计算String中的token数量 -func numTokensFromString(msg string, model string) int { - if strings.Contains(model, "gpt-3.5") { - model = "gpt-3.5-turbo" - } - if strings.Contains(model, "gpt-4") { - model = "gpt-4" - } - tkm, err := tiktoken.EncodingForModel(model) - if err != nil { - err = fmt.Errorf("getEncoding: %v", err) - return 0 - } - if model == "text-davinci-003" { - return len(tkm.Encode(msg, nil, nil)) + 1 - } else { - return len(tkm.Encode(msg, nil, nil)) + 9 - } -} - -func images(c *gin.Context) { - var imagesRequest ImagesRequest - var model = "images-generations" - if err := c.ShouldBindJSON(&imagesRequest); err != nil { - c.AbortWithStatusJSON(400, gin.H{"error": err.Error()}) - return - } - auth := c.Request.Header.Get("Authorization") - key := auth[7:] - - serverInfo, err := checkBlanceForImages(key, model, imagesRequest.N) - if err != nil { - c.AbortWithStatusJSON(403, gin.H{"error": err.Error()}) - log.Printf("请求出错 KEY: %v Model: %v ERROR: %v", key, model, err) - return - } - - log.Printf("请求的KEY: %v Model: %v", key, model) - - remote, err := url.Parse(serverInfo.ServerAddress) - if err != nil { - c.AbortWithStatusJSON(400, gin.H{"error": err.Error()}) - return - } - - proxy := httputil.NewSingleHostReverseProxy(remote) - newReqBody, err := json.Marshal(imagesRequest) - if err != nil { - log.Printf("http request err: %v", err) - c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()}) - return - } - - proxy.Director = func(req *http.Request) { - req.Header = c.Request.Header - req.Host = remote.Host - req.URL.Scheme = remote.Scheme - req.URL.Host = remote.Host - req.URL.Path = c.Request.URL.Path - req.ContentLength = int64(len(newReqBody)) - req.Header.Set("Content-Type", "application/json") - req.Header.Set("Accept-Encoding", "") - serverKey := serverInfo.AvailableKey - keyList := strings.Split(serverInfo.AvailableKey, ",") - if len(keyList) > 1 { - // 随机数种子 - rand.Seed(time.Now().UnixNano()) - // 从数组中随机选择一个元素 - serverKey = keyList[rand.Intn(len(keyList))] - } - req.Header.Set("Authorization", "Bearer "+serverKey) - req.Body = ioutil.NopCloser(bytes.NewReader(newReqBody)) - } - sss, err := json.Marshal(imagesRequest) - log.Printf("开始处理返回逻辑 %d", string(sss)) - - proxy.ModifyResponse = func(resp *http.Response) error { - if resp.StatusCode != http.StatusOK { - //退回预扣除的余额 - checkBlanceReturnForImages(key, model, imagesRequest.N) - return nil - } - resp.Header.Set("Openai-Organization", "api2gpt") - var imagesResponse ImagesResponse - body, err := ioutil.ReadAll(resp.Body) - if err != nil { - log.Printf("读取返回数据出错: %v", err) - return err - } - json.Unmarshal(body, &imagesResponse) - resp.Body = ioutil.NopCloser(bytes.NewReader(body)) - log.Printf("image size: %v", len(imagesResponse.Data)) - timestamp := time.Now().Unix() - timestampID := "img-" + fmt.Sprintf("%d", timestamp) - //消费余额 - consumptionForImages(key, model, imagesRequest.N, len(imagesResponse.Data), timestampID) - return nil - } - - proxy.ServeHTTP(c.Writer, c.Request) - -} - -func edit(c *gin.Context) { - var prompt_tokens int - var complate_tokens int - var total_tokens int - var chatRequest ChatRequest - if err := c.ShouldBindJSON(&chatRequest); err != nil { - c.AbortWithStatusJSON(400, gin.H{"error": err.Error()}) - return - } - - auth := c.Request.Header.Get("Authorization") - key := auth[7:] - //根据KEY调用用户余额接口,判断是否有足够的余额, 后期可考虑判断max_tokens参数来调整 - serverInfo, err := checkBlance(key, chatRequest.Model) - if err != nil { - c.AbortWithStatusJSON(403, gin.H{"error": err.Error()}) - log.Printf("请求出错 KEY: %v Model: %v ERROR: %v", key, chatRequest.Model, err) - return - } - - log.Printf("请求的KEY: %v Model: %v", key, chatRequest.Model) - - remote, err := url.Parse(serverInfo.ServerAddress) - if err != nil { - c.AbortWithStatusJSON(400, gin.H{"error": err.Error()}) - return - } - - proxy := httputil.NewSingleHostReverseProxy(remote) - newReqBody, err := json.Marshal(chatRequest) - if err != nil { - log.Printf("http request err: %v", err) - c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()}) - return - } - - proxy.Director = func(req *http.Request) { - req.Header = c.Request.Header - req.Host = remote.Host - req.URL.Scheme = remote.Scheme - req.URL.Host = remote.Host - req.URL.Path = c.Request.URL.Path - req.ContentLength = int64(len(newReqBody)) - req.Header.Set("Content-Type", "application/json") - req.Header.Set("Accept-Encoding", "") - serverKey := serverInfo.AvailableKey - keyList := strings.Split(serverInfo.AvailableKey, ",") - if len(keyList) > 1 { - // 随机数种子 - rand.Seed(time.Now().UnixNano()) - // 从数组中随机选择一个元素 - serverKey = keyList[rand.Intn(len(keyList))] - } - req.Header.Set("Authorization", "Bearer "+serverKey) - req.Body = ioutil.NopCloser(bytes.NewReader(newReqBody)) - } - sss, err := json.Marshal(chatRequest) - log.Printf("开始处理返回逻辑 %d", string(sss)) - - proxy.ModifyResponse = func(resp *http.Response) error { - resp.Header.Set("Openai-Organization", "api2gpt") - var chatResponse ChatResponse - body, err := ioutil.ReadAll(resp.Body) - if err != nil { - log.Printf("读取返回数据出错: %v", err) - return err - } - json.Unmarshal(body, &chatResponse) - prompt_tokens = chatResponse.Usage.PromptTokens - complate_tokens = chatResponse.Usage.CompletionTokens - total_tokens = chatResponse.Usage.TotalTokens - resp.Body = ioutil.NopCloser(bytes.NewReader(body)) - log.Printf("prompt_tokens: %v complate_tokens: %v total_tokens: %v", prompt_tokens, complate_tokens, total_tokens) - timestamp := time.Now().Unix() - timestampID := "edit-" + fmt.Sprintf("%d", timestamp) - //消费余额 - consumption(key, chatRequest.Model, prompt_tokens, 0, total_tokens, timestampID) - return nil - } - - proxy.ServeHTTP(c.Writer, c.Request) - -} - -func embeddings(c *gin.Context) { - var prompt_tokens int - var total_tokens int - var chatRequest ChatRequest - if err := c.ShouldBindJSON(&chatRequest); err != nil { - c.AbortWithStatusJSON(400, gin.H{"error": err.Error()}) - return - } - - auth := c.Request.Header.Get("Authorization") - //key := strings.Trim(auth, "Bearer ") - key := auth[7:] - //根据KEY调用用户余额接口,判断是否有足够的余额, 后期可考虑判断max_tokens参数来调整 - serverInfo, err := checkBlance(key, chatRequest.Model) - if err != nil { - c.AbortWithStatusJSON(403, gin.H{"error": err.Error()}) - log.Printf("请求出错 KEY: %v Model: %v ERROR: %v", key, chatRequest.Model, err) - return - } - - log.Printf("请求的KEY: %v Model: %v", key, chatRequest.Model) - - remote, err := url.Parse(serverInfo.ServerAddress) - if err != nil { - c.AbortWithStatusJSON(400, gin.H{"error": err.Error()}) - return - } - - proxy := httputil.NewSingleHostReverseProxy(remote) - newReqBody, err := json.Marshal(chatRequest) - if err != nil { - log.Printf("http request err: %v", err) - c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()}) - return - } - - proxy.Director = func(req *http.Request) { - req.Header = c.Request.Header - req.Host = remote.Host - req.URL.Scheme = remote.Scheme - req.URL.Host = remote.Host - req.URL.Path = c.Request.URL.Path - req.ContentLength = int64(len(newReqBody)) - req.Header.Set("Content-Type", "application/json") - req.Header.Set("Accept-Encoding", "") - serverKey := serverInfo.AvailableKey - keyList := strings.Split(serverInfo.AvailableKey, ",") - if len(keyList) > 1 { - // 随机数种子 - rand.Seed(time.Now().UnixNano()) - // 从数组中随机选择一个元素 - serverKey = keyList[rand.Intn(len(keyList))] - } - req.Header.Set("Authorization", "Bearer "+serverKey) - req.Body = ioutil.NopCloser(bytes.NewReader(newReqBody)) - } - sss, err := json.Marshal(chatRequest) - log.Printf("开始处理返回逻辑 %d", string(sss)) - - proxy.ModifyResponse = func(resp *http.Response) error { - resp.Header.Set("Openai-Organization", "api2gpt") - var chatResponse ChatResponse - body, err := ioutil.ReadAll(resp.Body) - if err != nil { - log.Printf("读取返回数据出错: %v", err) - return err - } - json.Unmarshal(body, &chatResponse) - prompt_tokens = chatResponse.Usage.PromptTokens - total_tokens = chatResponse.Usage.TotalTokens - resp.Body = ioutil.NopCloser(bytes.NewReader(body)) - log.Printf("prompt_tokens: %v total_tokens: %v", prompt_tokens, total_tokens) - timestamp := time.Now().Unix() - timestampID := "emb-" + fmt.Sprintf("%d", timestamp) - //消费余额 - consumption(key, chatRequest.Model, prompt_tokens, 0, total_tokens, timestampID) - return nil - } - - proxy.ServeHTTP(c.Writer, c.Request) - -} - -func completions(c *gin.Context) { - var prompt_tokens int - var completion_tokens int - var total_tokens int - var chatRequest ChatRequest - if err := c.ShouldBindJSON(&chatRequest); err != nil { - c.AbortWithStatusJSON(400, gin.H{"error": err.Error()}) - return - } - - auth := c.Request.Header.Get("Authorization") - key := auth[7:] - //根据KEY调用用户余额接口,判断是否有足够的余额, 后期可考虑判断max_tokens参数来调整 - serverInfo, err := checkBlance(key, chatRequest.Model) - if err != nil { - c.AbortWithStatusJSON(403, gin.H{"error": err.Error()}) - log.Printf("请求出错 KEY: %v Model: %v ERROR: %v", key, chatRequest.Model, err) - return - } - - log.Printf("请求的KEY: %v Model: %v", key, chatRequest.Model) - - remote, err := url.Parse(serverInfo.ServerAddress) - if err != nil { - c.AbortWithStatusJSON(400, gin.H{"error": err.Error()}) - return - } - - proxy := httputil.NewSingleHostReverseProxy(remote) - newReqBody, err := json.Marshal(chatRequest) - if err != nil { - log.Printf("http request err: %v", err) - c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()}) - return - } - proxy.Director = func(req *http.Request) { - req.Header = c.Request.Header - req.Host = remote.Host - req.URL.Scheme = remote.Scheme - req.URL.Host = remote.Host - req.URL.Path = c.Request.URL.Path - req.ContentLength = int64(len(newReqBody)) - req.Header.Set("Content-Type", "application/json") - req.Header.Set("Accept-Encoding", "") - serverKey := serverInfo.AvailableKey - keyList := strings.Split(serverInfo.AvailableKey, ",") - if len(keyList) > 1 { - // 随机数种子 - rand.Seed(time.Now().UnixNano()) - // 从数组中随机选择一个元素 - serverKey = keyList[rand.Intn(len(keyList))] - } - req.Header.Set("Authorization", "Bearer "+serverKey) - req.Body = ioutil.NopCloser(bytes.NewReader(newReqBody)) - } - sss, err := json.Marshal(chatRequest) - if err != nil { - log.Printf("chatRequest 转化出错 %v", err) - } - log.Printf("开始处理返回逻辑: %v", string(sss)) - if chatRequest.Stream { - // 流式回应,处理 - proxy.ModifyResponse = func(resp *http.Response) error { - log.Printf("流式回应 http status code: %v", resp.StatusCode) - if resp.StatusCode != http.StatusOK { - //退回预扣除的余额 - checkBlanceReturn(key, chatRequest.Model) - return nil - } - chatRequestId := "" - reqContent := "" - reader := bufio.NewReader(resp.Body) - headers := resp.Header - for k, v := range headers { - c.Writer.Header().Set(k, v[0]) - } - c.Writer.Header().Set("Openai-Organization", "api2gpt") - for { - chunk, err := reader.ReadBytes('\n') - if err != nil { - if err == io.EOF { - break - } - log.Printf("流式回应,处理 err %v:", err.Error()) - break - //return err - } - var chatResponse ChatResponse - //去除回应中的data:前缀 - var trimStr = strings.Trim(string(chunk), "data: ") - if trimStr != "\n" { - json.Unmarshal([]byte(trimStr), &chatResponse) - if chatResponse.Choices != nil { - if chatResponse.Choices[0].Text != "" { - reqContent += chatResponse.Choices[0].Text - } else { - reqContent += chatResponse.Choices[0].Delta.Content - } - - chatRequestId = chatResponse.Id - } - - // 写回数据 - _, err = c.Writer.Write([]byte(string(chunk) + "\n")) - if err != nil { - log.Printf("写回数据 err: %v", err.Error()) - return err - } - c.Writer.(http.Flusher).Flush() - } - } - if chatRequest.Model == "text-davinci-003" { - prompt_tokens = numTokensFromString(chatRequest.Prompt, chatRequest.Model) - } else { - prompt_tokens = numTokensFromMessages(chatRequest.Messages, chatRequest.Model) - } - completion_tokens = numTokensFromString(reqContent, chatRequest.Model) - log.Printf("返回内容:%v", reqContent) - total_tokens = prompt_tokens + completion_tokens - log.Printf("prompt_tokens: %v completion_tokens: %v total_tokens: %v", prompt_tokens, completion_tokens, total_tokens) - //消费余额 - consumption(key, chatRequest.Model, prompt_tokens, completion_tokens, total_tokens, chatRequestId) - return nil - } - } else { - // 非流式回应,处理 - proxy.ModifyResponse = func(resp *http.Response) error { - resp.Header.Set("Openai-Organization", "api2gpt") - var chatResponse ChatResponse - body, err := ioutil.ReadAll(resp.Body) - if err != nil { - log.Printf("非流式回应,处理 err: %v", err) - return err - } - json.Unmarshal(body, &chatResponse) - prompt_tokens = chatResponse.Usage.PromptTokens - completion_tokens = chatResponse.Usage.CompletionTokens - total_tokens = chatResponse.Usage.TotalTokens - resp.Body = ioutil.NopCloser(bytes.NewReader(body)) - log.Printf("prompt_tokens: %v completion_tokens: %v total_tokens: %v", prompt_tokens, completion_tokens, total_tokens) - //消费余额 - consumption(key, chatRequest.Model, prompt_tokens, completion_tokens, total_tokens, chatResponse.Id) - return nil - } - } - - proxy.ServeHTTP(c.Writer, c.Request) -} - -// model查询列表 -func handleGetModels(c *gin.Context) { - // BUGFIX: fix options request, see https://github.com/diemus/azure-openai-proxy/issues/3 - //models := []string{"gpt-4", "gpt-4-0314", "gpt-4-32k", "gpt-4-32k-0314", "gpt-3.5-turbo", "gpt-3.5-turbo-0301", "text-davinci-003", "text-embedding-ada-002"} - models := []string{"gpt-4", "gpt-4-0314", "gpt-4-0613", "gpt-3.5-turbo", "gpt-3.5-turbo-0301", "gpt-3.5-turbo-0613", "gpt-3.5-turbo-16k", "gpt-3.5-turbo-16k-0613", "text-davinci-003", "text-embedding-ada-002", "text-davinci-edit-001", "code-davinci-edit-001", "images-generations"} - result := ListModelResponse{ - Object: "list", - } - for _, model := range models { - result.Data = append(result.Data, Model{ - ID: model, - Object: "model", - Created: 1677649963, - OwnedBy: "openai", - Permission: []ModelPermission{ - { - ID: "", - Object: "model", - Created: 1679602087, - AllowCreateEngine: true, - AllowSampling: true, - AllowLogprobs: true, - AllowSearchIndices: true, - AllowView: true, - AllowFineTuning: true, - Organization: "*", - Group: nil, - IsBlocking: false, - }, - }, - Root: model, - Parent: nil, - }) - } - c.JSON(200, result) -} - -// 针对接口预检OPTIONS的处理 -func handleOptions(c *gin.Context) { - // BUGFIX: fix options request, see https://github.com/diemus/azure-openai-proxy/issues/1 - c.Header("Access-Control-Allow-Origin", "*") - c.Header("Access-Control-Allow-Methods", "POST, GET, OPTIONS, PUT, DELETE") - c.Header("Access-Control-Allow-Headers", "Content-Type, Authorization") - c.Status(200) - return -} - -// 余额查询 -func balance(c *gin.Context) { - auth := c.Request.Header.Get("Authorization") - key := strings.Trim(auth, "Bearer ") - balance, err := queryBlance(key) - if err != nil { - c.JSON(400, gin.H{"error": err.Error()}) - } - var creditSummary CreditSummary - creditSummary.Object = "credit_grant" - creditSummary.TotalGranted = 999999 - creditSummary.TotalUsed = 999999 - balance - creditSummary.TotalRemaining = balance - c.JSON(200, creditSummary) -} - -// key 检测中间件 -func checkKeyMid() gin.HandlerFunc { - return func(c *gin.Context) { - auth := c.Request.Header.Get("Authorization") - log.Printf("auth: %v", auth) - if auth == "" { - c.AbortWithStatusJSON(401, gin.H{"code": 40001}) - } else { - //key := strings.Trim(auth, "Bearer ") - key := auth[7:] - log.Printf("key: %v", key) - msg, err := checkKeyAndTimeCount(key) - if err != nil { - log.Printf("checkKeyMid err: %v", err) - c.AbortWithStatusJSON(msg, gin.H{"code": err.Error()}) - } - } - log.Printf("auth check end") - // 执行函数 - c.Next() - } -} - func test_redis() { //添加reids测试数据 - //var serverInfo ServerInfo = ServerInfo{ - // ServerAddress: "https://gptp.any-door.cn", - // AvailableKey: "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", - //} + var serverInfo model.ServerInfo = model.ServerInfo{ + ServerAddress: "https://gptp.any-door.cn", + AvailableKey: "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", + } // var serverInfo2 ServerInfo = ServerInfo{ // ServerAddress: "https://azure.any-door.cn", // AvailableKey: "6c4d2c65970b40e482e7cd27adb0d119", // } - //serverInfoStr, _ := json.Marshal(&serverInfo) + serverInfoStr, _ := json.Marshal(&serverInfo) // serverInfoStr2, _ := json.Marshal(&serverInfo2) - //Redis.Set(context.Background(), "server:1", serverInfoStr, 0) + common.RedisSet("server:0", string(serverInfoStr), 0) + common.RedisSet("server:1", string(serverInfoStr), 0) // Redis.Set(context.Background(), "server:2", serverInfoStr2, 0) // var modelInfo ModelInfo = ModelInfo{ @@ -725,7 +72,7 @@ func test_redis() { // modelInfoStr2, _ := json.Marshal(&modelInfo2) // Redis.Set(context.Background(), "model:images-generations", modelInfoStr2, 0) - var userInfo UserInfo = UserInfo{ + var userInfo model.UserInfo = model.UserInfo{ UID: "0", SID: "1", } @@ -763,25 +110,10 @@ func main() { server := gin.Default() server.Use(middleware.CORS()) - server.GET("/dashboard/billing/credit_grants", checkKeyMid(), balance) - - server.GET("/v1/models", handleGetModels) - server.OPTIONS("/v1/*path", handleOptions) - server.POST("/v1/chat/completions", checkKeyMid(), completions) - server.POST("/v1/completions", checkKeyMid(), completions) - server.POST("/v1/embeddings", checkKeyMid(), embeddings) - server.POST("/v1/edits", checkKeyMid(), edit) - server.POST("/v1/images/generations", checkKeyMid(), images) - - // 定义一个GET请求测试接口 - server.GET("/ping", func(c *gin.Context) { - c.JSON(200, gin.H{ - "message": "pong from api2gpt", - }) - }) + router.SetRouter(server) //添加测试数据 - //test_redis() + test_redis() var port = os.Getenv("PORT") if port == "" { diff --git a/middleware/auth.go b/middleware/auth.go new file mode 100644 index 0000000..3aa2c92 --- /dev/null +++ b/middleware/auth.go @@ -0,0 +1,29 @@ +package middleware + +import ( + "api2gpt-mid/service" + "github.com/gin-gonic/gin" + "log" +) + +func TokenAuth() func(c *gin.Context) { + return func(c *gin.Context) { + auth := c.Request.Header.Get("Authorization") + log.Printf("auth: %v", auth) + if auth == "" { + c.AbortWithStatusJSON(401, gin.H{"code": 40001}) + } else { + //key := strings.Trim(auth, "Bearer ") + key := auth[7:] + log.Printf("key: %v", key) + msg, err := service.CheckKeyAndTimeCount(key) + if err != nil { + log.Printf("checkKeyMid err: %v", err) + c.AbortWithStatusJSON(msg, gin.H{"code": err.Error()}) + } + } + log.Printf("auth check end") + // 执行函数 + c.Next() + } +} diff --git a/model/consumption.go b/model/consumption.go new file mode 100644 index 0000000..6eda1d9 --- /dev/null +++ b/model/consumption.go @@ -0,0 +1,10 @@ +package model + +type Consumption struct { + SecretKey string `json:"secretKey"` + Model string `json:"model"` + MsgId string `json:"msgId"` + PromptTokens int `json:"promptTokens"` + CompletionTokens int `json:"completionTokens"` + TotalTokens int `json:"totalTokens"` +} diff --git a/model/model.go b/model/model.go new file mode 100644 index 0000000..f342e35 --- /dev/null +++ b/model/model.go @@ -0,0 +1,9 @@ +package model + +type ModelInfo struct { + ModelName string `json:"model_name"` + ModelPrice float64 `json:"model_price"` + ModelPrice2 float64 `json:"model_price2"` + ModelPrepayment int `json:"model_prepayment"` + ServerId int `json:"server_id"` +} diff --git a/model/openai.go b/model/openai.go new file mode 100644 index 0000000..c831c0d --- /dev/null +++ b/model/openai.go @@ -0,0 +1,98 @@ +package model + +type Message struct { + Role string `json:"role,omitempty"` + Name string `json:"name,omitempty"` + Content string `json:"content,omitempty"` +} + +type ChatRequest struct { + Stream bool `json:"stream,omitempty"` + Model string `json:"model,omitempty"` + PresencePenalty float64 `json:"presence_penalty,omitempty"` + Temperature float64 `json:"temperature,omitempty"` + Messages []Message `json:"messages,omitempty"` + Prompt string `json:"prompt,omitempty"` + Input interface{} `json:"input,omitempty"` + Instruction string `json:"instruction,omitempty"` + N int `json:"n,omitempty"` + Functions interface{} `json:"functions,omitempty"` +} + +type ImagesRequest struct { + Prompt string `json:"prompt,omitempty"` + N int `json:"n,omitempty"` + Size string `json:"size,omitempty"` +} + +type ChatResponse struct { + Id string `json:"id,omitempty"` + Object string `json:"object,omitempty"` + Created int64 `json:"created,omitempty"` + Model string `json:"model,omitempty"` + Usage Usage `json:"usage,omitempty"` + Choices []Choice `json:"choices,omitempty"` +} +type ImagesResponse struct { + Created int64 `json:"created,omitempty"` + Data []ImagesData `json:"data,omitempty"` +} + +type ImagesData struct { + Url string `json:"url,omitempty"` +} + +type Usage struct { + PromptTokens int `json:"prompt_tokens,omitempty"` + CompletionTokens int `json:"completion_tokens,omitempty"` + TotalTokens int `json:"total_tokens,omitempty"` +} + +type Choice struct { + Message Message `json:"message,omitempty"` + Delta Delta `json:"delta,omitempty"` + FinishReason string `json:"finish_reason,omitempty"` + Index int `json:"index,omitempty"` + Text string `json:"text,omitempty"` +} + +type Delta struct { + Content string `json:"content,omitempty"` +} + +type CreditSummary struct { + Object string `json:"object"` + TotalGranted float64 `json:"total_granted"` + TotalUsed float64 `json:"total_used"` + TotalRemaining float64 `json:"total_remaining"` +} + +type ListModelResponse struct { + Object string `json:"object"` + Data []Model `json:"data"` +} + +type Model struct { + ID string `json:"id"` + Object string `json:"object"` + Created int `json:"created"` + OwnedBy string `json:"owned_by"` + Permission []ModelPermission `json:"permission"` + Root string `json:"root"` + Parent any `json:"parent"` +} + +type ModelPermission struct { + ID string `json:"id"` + Object string `json:"object"` + Created int `json:"created"` + AllowCreateEngine bool `json:"allow_create_engine"` + AllowSampling bool `json:"allow_sampling"` + AllowLogprobs bool `json:"allow_logprobs"` + AllowSearchIndices bool `json:"allow_search_indices"` + AllowView bool `json:"allow_view"` + AllowFineTuning bool `json:"allow_fine_tuning"` + Organization string `json:"organization"` + Group any `json:"group"` + IsBlocking bool `json:"is_blocking"` +} diff --git a/model/server.go b/model/server.go new file mode 100644 index 0000000..efc6cb0 --- /dev/null +++ b/model/server.go @@ -0,0 +1,6 @@ +package model + +type ServerInfo struct { + ServerAddress string `json:"server_address"` + AvailableKey string `json:"available_key"` +} diff --git a/model/user.go b/model/user.go new file mode 100644 index 0000000..8904af7 --- /dev/null +++ b/model/user.go @@ -0,0 +1,6 @@ +package model + +type UserInfo struct { + UID string `json:"uid"` + SID string `json:"sid"` +} diff --git a/router/api-router.go b/router/api-router.go index 21d76a8..0fc5375 100644 --- a/router/api-router.go +++ b/router/api-router.go @@ -1,7 +1,36 @@ package router -import "github.com/gin-gonic/gin" +import ( + "api2gpt-mid/controller" + "api2gpt-mid/middleware" + "github.com/gin-gonic/gin" +) func SetApiRouter(router *gin.Engine) { - + modelsRouter := router.Group("/v1/models") + modelsRouter.Use(middleware.TokenAuth()) + { + modelsRouter.GET("", controller.ListModels) + modelsRouter.GET("/:model", controller.RetrieveModel) + } + relayV1Router := router.Group("/v1") + relayV1Router.Use(middleware.TokenAuth()) + { + relayV1Router.POST("/completions", controller.Completions) + relayV1Router.POST("/chat/completions", controller.Completions) + relayV1Router.POST("/embeddings", controller.Embeddings) + relayV1Router.POST("/edits", controller.Edit) + relayV1Router.POST("/images/generations", controller.Images) + } + dashboardRouter := router.Group("/dashboard") + dashboardRouter.Use(middleware.TokenAuth()) + { + dashboardRouter.GET("/dashboard/billing/credit_grants", controller.Balance) + } + router.OPTIONS("/v1/*path", controller.HandleOptions) + router.GET("/ping", func(c *gin.Context) { + c.JSON(200, gin.H{ + "message": "pong from api2gpt", + }) + }) } diff --git a/service.go b/service/business.go similarity index 81% rename from service.go rename to service/business.go index dae2595..03defdb 100644 --- a/service.go +++ b/service/business.go @@ -1,7 +1,9 @@ -package main +package service import ( + "api2gpt-mid/api" "api2gpt-mid/common" + "api2gpt-mid/model" "encoding/json" "errors" "log" @@ -10,7 +12,7 @@ import ( ) // 检测key是否存在,是否超出每分钟请求次数 -func checkKeyAndTimeCount(key string) (int, error) { +func CheckKeyAndTimeCount(key string) (int, error) { var timeOut = 60 * time.Second var timeCount = 30 userInfoStr, err := common.RedisGet("user:" + key) @@ -19,7 +21,7 @@ func checkKeyAndTimeCount(key string) (int, error) { //用户不存在 return 401, errors.New("40003") } - var userInfo UserInfo + var userInfo model.UserInfo err = json.Unmarshal([]byte(userInfoStr), &userInfo) if err != nil { //用户状态异常 @@ -46,9 +48,9 @@ func checkKeyAndTimeCount(key string) (int, error) { return 200, nil } -func queryBlance(key string) (float64, error) { +func QueryBlance(key string) (float64, error) { userInfoStr, err := common.RedisGet("user:" + key) - var userInfo UserInfo + var userInfo model.UserInfo err = json.Unmarshal([]byte(userInfoStr), &userInfo) balance, err := common.RedisIncrByFloat("user:"+userInfo.UID+":balance", 0) if err != nil { @@ -58,15 +60,15 @@ func queryBlance(key string) (float64, error) { } // 余额查询 -func checkBlance(key string, model string) (ServerInfo, error) { - var serverInfo ServerInfo +func CheckBlance(key string, modelStr string) (model.ServerInfo, error) { + var serverInfo model.ServerInfo //获取模型价格 - modelPriceStr, err := common.RedisGet("model:" + model) + modelPriceStr, err := common.RedisGet("model:" + modelStr) if err != nil { return serverInfo, errors.New("模型信息不存在") } - var modelInfo ModelInfo + var modelInfo model.ModelInfo err = json.Unmarshal([]byte(modelPriceStr), &modelInfo) if err != nil { return serverInfo, errors.New("模型信息解析失败") @@ -74,7 +76,7 @@ func checkBlance(key string, model string) (ServerInfo, error) { //获取用户信息 userInfoStr, err := common.RedisGet("user:" + key) - var userInfo UserInfo + var userInfo model.UserInfo err = json.Unmarshal([]byte(userInfoStr), &userInfo) //获取服务器信息 serverInfoStr, err := common.RedisGet("server:" + strconv.Itoa(modelInfo.ServerId)) @@ -101,15 +103,15 @@ func checkBlance(key string, model string) (ServerInfo, error) { } // 余额查询 for images -func checkBlanceForImages(key string, model string, n int) (ServerInfo, error) { - var serverInfo ServerInfo +func CheckBlanceForImages(key string, modelStr string, n int) (model.ServerInfo, error) { + var serverInfo model.ServerInfo //获取模型价格 - modelPriceStr, err := common.RedisGet("model:" + model) + modelPriceStr, err := common.RedisGet("model:" + modelStr) if err != nil { return serverInfo, errors.New("模型信息不存在") } - var modelInfo ModelInfo + var modelInfo model.ModelInfo err = json.Unmarshal([]byte(modelPriceStr), &modelInfo) if err != nil { return serverInfo, errors.New("模型信息解析失败") @@ -117,7 +119,7 @@ func checkBlanceForImages(key string, model string, n int) (ServerInfo, error) { //获取用户信息 userInfoStr, err := common.RedisGet("user:" + key) - var userInfo UserInfo + var userInfo model.UserInfo err = json.Unmarshal([]byte(userInfoStr), &userInfo) //获取服务器信息 serverInfoStr, err := common.RedisGet("server:" + strconv.Itoa(modelInfo.ServerId)) @@ -144,18 +146,18 @@ func checkBlanceForImages(key string, model string, n int) (ServerInfo, error) { } // 预扣返还 -func checkBlanceReturn(key string, model string) error { +func CheckBlanceReturn(key string, modelStr string) error { //获取用户信息 userInfoStr, err := common.RedisGet("user:" + key) - var userInfo UserInfo + var userInfo model.UserInfo err = json.Unmarshal([]byte(userInfoStr), &userInfo) //获取模型价格 - modelPriceStr, err := common.RedisGet("model:" + model) + modelPriceStr, err := common.RedisGet("model:" + modelStr) if err != nil { return errors.New("模型信息不存在") } - var modelInfo ModelInfo + var modelInfo model.ModelInfo err = json.Unmarshal([]byte(modelPriceStr), &modelInfo) if err != nil { return errors.New("模型信息解析失败") @@ -166,17 +168,17 @@ func checkBlanceReturn(key string, model string) error { } // 预扣返还 for images -func checkBlanceReturnForImages(key string, model string, n int) error { +func CheckBlanceReturnForImages(key string, modelStr string, n int) error { //获取用户信息 userInfoStr, err := common.RedisGet("user:" + key) - var userInfo UserInfo + var userInfo model.UserInfo err = json.Unmarshal([]byte(userInfoStr), &userInfo) //获取模型价格 - modelPriceStr, err := common.RedisGet("model:" + model) + modelPriceStr, err := common.RedisGet("model:" + modelStr) if err != nil { return errors.New("模型信息不存在") } - var modelInfo ModelInfo + var modelInfo model.ModelInfo err = json.Unmarshal([]byte(modelPriceStr), &modelInfo) if err != nil { return errors.New("模型信息解析失败") @@ -187,23 +189,23 @@ func checkBlanceReturnForImages(key string, model string, n int) error { } // 余额消费 -func consumption(key string, model string, prompt_tokens int, completion_tokens int, total_tokens int, msg_id string) (string, error) { +func Consumption(key string, modelStr string, prompt_tokens int, completion_tokens int, total_tokens int, msg_id string) (string, error) { //获取用户信息 userInfoStr, err := common.RedisGet("user:" + key) if err != nil { return "", errors.New("用户信息不存在") } - var userInfo UserInfo + var userInfo model.UserInfo err = json.Unmarshal([]byte(userInfoStr), &userInfo) if err != nil { return "", errors.New("用户信息解析失败") } //获取模型价格 - modelPriceStr, err := common.RedisGet("model:" + model) + modelPriceStr, err := common.RedisGet("model:" + modelStr) if err != nil { return "", errors.New("模型信息不存在") } - var modelInfo ModelInfo + var modelInfo model.ModelInfo err = json.Unmarshal([]byte(modelPriceStr), &modelInfo) if err != nil { return "", errors.New("模型信息解析失败") @@ -211,7 +213,7 @@ func consumption(key string, model string, prompt_tokens int, completion_tokens balance, err := common.RedisIncrByFloat("user:"+userInfo.UID+":balance", float64(modelInfo.ModelPrepayment)*modelInfo.ModelPrice-(float64(prompt_tokens)*modelInfo.ModelPrice+float64(completion_tokens)*modelInfo.ModelPrice2)) // 余额消费日志请求 - result, err := balanceConsumption(key, model, prompt_tokens, completion_tokens, total_tokens, msg_id) + result, err := api.BalanceConsumption(key, modelStr, prompt_tokens, completion_tokens, total_tokens, msg_id) log.Printf("用户余额:%f 扣费KEY: %s 扣费token数: %d 扣费:%f 扣费日志发送结果 %s", balance, key, total_tokens, float64(modelInfo.ModelPrepayment)*modelInfo.ModelPrice-(float64(prompt_tokens)*modelInfo.ModelPrice+float64(completion_tokens)*modelInfo.ModelPrice2), result) if err != nil { log.Printf("%s 余额消费日志请求失败 %v", key, err) @@ -221,23 +223,23 @@ func consumption(key string, model string, prompt_tokens int, completion_tokens } // 余额消费 for images -func consumptionForImages(key string, model string, n int, dataNum int, msg_id string) (string, error) { +func ConsumptionForImages(key string, modelStr string, n int, dataNum int, msg_id string) (string, error) { //获取用户信息 userInfoStr, err := common.RedisGet("user:" + key) if err != nil { return "", errors.New("用户信息不存在") } - var userInfo UserInfo + var userInfo model.UserInfo err = json.Unmarshal([]byte(userInfoStr), &userInfo) if err != nil { return "", errors.New("用户信息解析失败") } //获取模型价格 - modelPriceStr, err := common.RedisGet("model:" + model) + modelPriceStr, err := common.RedisGet("model:" + modelStr) if err != nil { return "", errors.New("模型信息不存在") } - var modelInfo ModelInfo + var modelInfo model.ModelInfo err = json.Unmarshal([]byte(modelPriceStr), &modelInfo) if err != nil { return "", errors.New("模型信息解析失败") @@ -245,7 +247,7 @@ func consumptionForImages(key string, model string, n int, dataNum int, msg_id s balance, err := common.RedisIncrByFloat("user:"+userInfo.UID+":balance", float64(modelInfo.ModelPrepayment*n)*modelInfo.ModelPrice-(float64(1000*dataNum)*modelInfo.ModelPrice)) // 余额消费日志请求 - result, err := balanceConsumption(key, model, 0, 1000*dataNum, 1000*dataNum, msg_id) + result, err := api.BalanceConsumption(key, modelStr, 0, 1000*dataNum, 1000*dataNum, msg_id) log.Printf("用户余额:%f 扣费KEY: %s 扣费token数: %d 扣费:%f 扣费日志发送结果 %s", balance, key, 1000*dataNum, float64(1000*dataNum)*modelInfo.ModelPrice, result) if err != nil { log.Printf("%s 余额消费日志请求失败 %v", key, err)