|
|
|
|
@ -10,7 +10,6 @@ import (
|
|
|
|
|
"fmt"
|
|
|
|
|
"github.com/gin-gonic/gin"
|
|
|
|
|
"io"
|
|
|
|
|
"io/ioutil"
|
|
|
|
|
"log"
|
|
|
|
|
"math/rand"
|
|
|
|
|
"net/http"
|
|
|
|
|
@ -66,12 +65,13 @@ func Images(c *gin.Context) {
|
|
|
|
|
keyList := strings.Split(serverInfo.AvailableKey, ",")
|
|
|
|
|
if len(keyList) > 1 {
|
|
|
|
|
// 随机数种子
|
|
|
|
|
rand.Seed(time.Now().UnixNano())
|
|
|
|
|
source := rand.NewSource(time.Now().UnixNano())
|
|
|
|
|
random := rand.New(source)
|
|
|
|
|
// 从数组中随机选择一个元素
|
|
|
|
|
serverKey = keyList[rand.Intn(len(keyList))]
|
|
|
|
|
serverKey = keyList[random.Intn(len(keyList))]
|
|
|
|
|
}
|
|
|
|
|
req.Header.Set("Authorization", "Bearer "+serverKey)
|
|
|
|
|
req.Body = ioutil.NopCloser(bytes.NewReader(newReqBody))
|
|
|
|
|
req.Body = io.NopCloser(bytes.NewReader(newReqBody))
|
|
|
|
|
}
|
|
|
|
|
sss, err := json.Marshal(imagesRequest)
|
|
|
|
|
log.Printf("开始处理返回逻辑 %d", string(sss))
|
|
|
|
|
@ -79,34 +79,42 @@ func Images(c *gin.Context) {
|
|
|
|
|
proxy.ModifyResponse = func(resp *http.Response) error {
|
|
|
|
|
if resp.StatusCode != http.StatusOK {
|
|
|
|
|
//退回预扣除的余额
|
|
|
|
|
service.CheckBlanceReturnForImages(key, modelStr, imagesRequest.N)
|
|
|
|
|
err := service.CheckBlanceReturnForImages(key, modelStr, imagesRequest.N)
|
|
|
|
|
if err != nil {
|
|
|
|
|
return err
|
|
|
|
|
}
|
|
|
|
|
return nil
|
|
|
|
|
}
|
|
|
|
|
resp.Header.Set("Openai-Organization", "api2gpt")
|
|
|
|
|
var imagesResponse model.ImagesResponse
|
|
|
|
|
body, err := ioutil.ReadAll(resp.Body)
|
|
|
|
|
body, err := io.ReadAll(resp.Body)
|
|
|
|
|
if err != nil {
|
|
|
|
|
log.Printf("读取返回数据出错: %v", err)
|
|
|
|
|
return err
|
|
|
|
|
}
|
|
|
|
|
json.Unmarshal(body, &imagesResponse)
|
|
|
|
|
resp.Body = ioutil.NopCloser(bytes.NewReader(body))
|
|
|
|
|
err = json.Unmarshal(body, &imagesResponse)
|
|
|
|
|
if err != nil {
|
|
|
|
|
log.Printf("json解析数据出错: %v", err)
|
|
|
|
|
return err
|
|
|
|
|
}
|
|
|
|
|
resp.Body = io.NopCloser(bytes.NewReader(body))
|
|
|
|
|
log.Printf("image size: %v", len(imagesResponse.Data))
|
|
|
|
|
timestamp := time.Now().Unix()
|
|
|
|
|
timestampID := "img-" + fmt.Sprintf("%d", timestamp)
|
|
|
|
|
//消费余额
|
|
|
|
|
service.ConsumptionForImages(key, modelStr, imagesRequest.N, len(imagesResponse.Data), timestampID)
|
|
|
|
|
_, err = service.ConsumptionForImages(key, modelStr, imagesRequest.N, len(imagesResponse.Data), timestampID)
|
|
|
|
|
if err != nil {
|
|
|
|
|
return err
|
|
|
|
|
}
|
|
|
|
|
return nil
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
proxy.ServeHTTP(c.Writer, c.Request)
|
|
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func Edit(c *gin.Context) {
|
|
|
|
|
var prompt_tokens int
|
|
|
|
|
var complate_tokens int
|
|
|
|
|
var total_tokens int
|
|
|
|
|
var promptTokens int
|
|
|
|
|
var complateTokens int
|
|
|
|
|
var totalTokens int
|
|
|
|
|
var chatRequest model.ChatRequest
|
|
|
|
|
if err := c.ShouldBindJSON(&chatRequest); err != nil {
|
|
|
|
|
c.AbortWithStatusJSON(400, gin.H{"error": err.Error()})
|
|
|
|
|
@ -116,7 +124,7 @@ func Edit(c *gin.Context) {
|
|
|
|
|
auth := c.Request.Header.Get("Authorization")
|
|
|
|
|
key := auth[7:]
|
|
|
|
|
//根据KEY调用用户余额接口,判断是否有足够的余额, 后期可考虑判断max_tokens参数来调整
|
|
|
|
|
serverInfo, err := service.CheckBlance(key, chatRequest.Model)
|
|
|
|
|
serverInfo, err := service.CheckBlance(key, chatRequest.Model, chatRequest.MaxTokens)
|
|
|
|
|
if err != nil {
|
|
|
|
|
c.AbortWithStatusJSON(403, gin.H{"error": err.Error()})
|
|
|
|
|
log.Printf("请求出错 KEY: %v Model: %v ERROR: %v", key, chatRequest.Model, err)
|
|
|
|
|
@ -152,12 +160,13 @@ func Edit(c *gin.Context) {
|
|
|
|
|
keyList := strings.Split(serverInfo.AvailableKey, ",")
|
|
|
|
|
if len(keyList) > 1 {
|
|
|
|
|
// 随机数种子
|
|
|
|
|
rand.Seed(time.Now().UnixNano())
|
|
|
|
|
source := rand.NewSource(time.Now().UnixNano())
|
|
|
|
|
random := rand.New(source)
|
|
|
|
|
// 从数组中随机选择一个元素
|
|
|
|
|
serverKey = keyList[rand.Intn(len(keyList))]
|
|
|
|
|
serverKey = keyList[random.Intn(len(keyList))]
|
|
|
|
|
}
|
|
|
|
|
req.Header.Set("Authorization", "Bearer "+serverKey)
|
|
|
|
|
req.Body = ioutil.NopCloser(bytes.NewReader(newReqBody))
|
|
|
|
|
req.Body = io.NopCloser(bytes.NewReader(newReqBody))
|
|
|
|
|
}
|
|
|
|
|
sss, err := json.Marshal(chatRequest)
|
|
|
|
|
log.Printf("开始处理返回逻辑 %d", string(sss))
|
|
|
|
|
@ -165,21 +174,28 @@ func Edit(c *gin.Context) {
|
|
|
|
|
proxy.ModifyResponse = func(resp *http.Response) error {
|
|
|
|
|
resp.Header.Set("Openai-Organization", "api2gpt")
|
|
|
|
|
var chatResponse model.ChatResponse
|
|
|
|
|
body, err := ioutil.ReadAll(resp.Body)
|
|
|
|
|
body, err := io.ReadAll(resp.Body)
|
|
|
|
|
if err != nil {
|
|
|
|
|
log.Printf("读取返回数据出错: %v", err)
|
|
|
|
|
return err
|
|
|
|
|
}
|
|
|
|
|
json.Unmarshal(body, &chatResponse)
|
|
|
|
|
prompt_tokens = chatResponse.Usage.PromptTokens
|
|
|
|
|
complate_tokens = chatResponse.Usage.CompletionTokens
|
|
|
|
|
total_tokens = chatResponse.Usage.TotalTokens
|
|
|
|
|
resp.Body = ioutil.NopCloser(bytes.NewReader(body))
|
|
|
|
|
log.Printf("prompt_tokens: %v complate_tokens: %v total_tokens: %v", prompt_tokens, complate_tokens, total_tokens)
|
|
|
|
|
err = json.Unmarshal(body, &chatResponse)
|
|
|
|
|
if err != nil {
|
|
|
|
|
log.Printf("json解析数据出错: %v", err)
|
|
|
|
|
return err
|
|
|
|
|
}
|
|
|
|
|
promptTokens = chatResponse.Usage.PromptTokens
|
|
|
|
|
complateTokens = chatResponse.Usage.CompletionTokens
|
|
|
|
|
totalTokens = chatResponse.Usage.TotalTokens
|
|
|
|
|
resp.Body = io.NopCloser(bytes.NewReader(body))
|
|
|
|
|
log.Printf("prompt_tokens: %v complate_tokens: %v total_tokens: %v", promptTokens, complateTokens, totalTokens)
|
|
|
|
|
timestamp := time.Now().Unix()
|
|
|
|
|
timestampID := "edit-" + fmt.Sprintf("%d", timestamp)
|
|
|
|
|
//消费余额
|
|
|
|
|
service.Consumption(key, chatRequest.Model, prompt_tokens, 0, total_tokens, timestampID)
|
|
|
|
|
_, err = service.Consumption(key, chatRequest.Model, promptTokens, 0, totalTokens, timestampID)
|
|
|
|
|
if err != nil {
|
|
|
|
|
return err
|
|
|
|
|
}
|
|
|
|
|
return nil
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
@ -188,19 +204,16 @@ func Edit(c *gin.Context) {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func Embeddings(c *gin.Context) {
|
|
|
|
|
var prompt_tokens int
|
|
|
|
|
var total_tokens int
|
|
|
|
|
var promptTokens int
|
|
|
|
|
var totalTokens int
|
|
|
|
|
var chatRequest model.ChatRequest
|
|
|
|
|
if err := c.ShouldBindJSON(&chatRequest); err != nil {
|
|
|
|
|
c.AbortWithStatusJSON(400, gin.H{"error": err.Error()})
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auth := c.Request.Header.Get("Authorization")
|
|
|
|
|
//key := strings.Trim(auth, "Bearer ")
|
|
|
|
|
key := auth[7:]
|
|
|
|
|
//根据KEY调用用户余额接口,判断是否有足够的余额, 后期可考虑判断max_tokens参数来调整
|
|
|
|
|
serverInfo, err := service.CheckBlance(key, chatRequest.Model)
|
|
|
|
|
serverInfo, err := service.CheckBlance(key, chatRequest.Model, chatRequest.MaxTokens)
|
|
|
|
|
if err != nil {
|
|
|
|
|
c.AbortWithStatusJSON(403, gin.H{"error": err.Error()})
|
|
|
|
|
log.Printf("请求出错 KEY: %v Model: %v ERROR: %v", key, chatRequest.Model, err)
|
|
|
|
|
@ -236,12 +249,13 @@ func Embeddings(c *gin.Context) {
|
|
|
|
|
keyList := strings.Split(serverInfo.AvailableKey, ",")
|
|
|
|
|
if len(keyList) > 1 {
|
|
|
|
|
// 随机数种子
|
|
|
|
|
rand.Seed(time.Now().UnixNano())
|
|
|
|
|
source := rand.NewSource(time.Now().UnixNano())
|
|
|
|
|
random := rand.New(source)
|
|
|
|
|
// 从数组中随机选择一个元素
|
|
|
|
|
serverKey = keyList[rand.Intn(len(keyList))]
|
|
|
|
|
serverKey = keyList[random.Intn(len(keyList))]
|
|
|
|
|
}
|
|
|
|
|
req.Header.Set("Authorization", "Bearer "+serverKey)
|
|
|
|
|
req.Body = ioutil.NopCloser(bytes.NewReader(newReqBody))
|
|
|
|
|
req.Body = io.NopCloser(bytes.NewReader(newReqBody))
|
|
|
|
|
}
|
|
|
|
|
sss, err := json.Marshal(chatRequest)
|
|
|
|
|
log.Printf("开始处理返回逻辑 %d", string(sss))
|
|
|
|
|
@ -249,20 +263,27 @@ func Embeddings(c *gin.Context) {
|
|
|
|
|
proxy.ModifyResponse = func(resp *http.Response) error {
|
|
|
|
|
resp.Header.Set("Openai-Organization", "api2gpt")
|
|
|
|
|
var chatResponse model.ChatResponse
|
|
|
|
|
body, err := ioutil.ReadAll(resp.Body)
|
|
|
|
|
body, err := io.ReadAll(resp.Body)
|
|
|
|
|
if err != nil {
|
|
|
|
|
log.Printf("读取返回数据出错: %v", err)
|
|
|
|
|
return err
|
|
|
|
|
}
|
|
|
|
|
json.Unmarshal(body, &chatResponse)
|
|
|
|
|
prompt_tokens = chatResponse.Usage.PromptTokens
|
|
|
|
|
total_tokens = chatResponse.Usage.TotalTokens
|
|
|
|
|
resp.Body = ioutil.NopCloser(bytes.NewReader(body))
|
|
|
|
|
log.Printf("prompt_tokens: %v total_tokens: %v", prompt_tokens, total_tokens)
|
|
|
|
|
err = json.Unmarshal(body, &chatResponse)
|
|
|
|
|
if err != nil {
|
|
|
|
|
log.Printf("json解析数据出错: %v", err)
|
|
|
|
|
return err
|
|
|
|
|
}
|
|
|
|
|
promptTokens = chatResponse.Usage.PromptTokens
|
|
|
|
|
totalTokens = chatResponse.Usage.TotalTokens
|
|
|
|
|
resp.Body = io.NopCloser(bytes.NewReader(body))
|
|
|
|
|
log.Printf("prompt_tokens: %v total_tokens: %v", promptTokens, totalTokens)
|
|
|
|
|
timestamp := time.Now().Unix()
|
|
|
|
|
timestampID := "emb-" + fmt.Sprintf("%d", timestamp)
|
|
|
|
|
//消费余额
|
|
|
|
|
service.Consumption(key, chatRequest.Model, prompt_tokens, 0, total_tokens, timestampID)
|
|
|
|
|
_, err = service.Consumption(key, chatRequest.Model, promptTokens, 0, totalTokens, timestampID)
|
|
|
|
|
if err != nil {
|
|
|
|
|
return err
|
|
|
|
|
}
|
|
|
|
|
return nil
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
@ -271,9 +292,9 @@ func Embeddings(c *gin.Context) {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func Completions(c *gin.Context) {
|
|
|
|
|
var prompt_tokens int
|
|
|
|
|
var completion_tokens int
|
|
|
|
|
var total_tokens int
|
|
|
|
|
var promptTokens int
|
|
|
|
|
var completionTokens int
|
|
|
|
|
var totalTokens int
|
|
|
|
|
var chatRequest model.ChatRequest
|
|
|
|
|
if err := c.ShouldBindJSON(&chatRequest); err != nil {
|
|
|
|
|
c.AbortWithStatusJSON(400, gin.H{"error": err.Error()})
|
|
|
|
|
@ -282,8 +303,7 @@ func Completions(c *gin.Context) {
|
|
|
|
|
|
|
|
|
|
auth := c.Request.Header.Get("Authorization")
|
|
|
|
|
key := auth[7:]
|
|
|
|
|
//根据KEY调用用户余额接口,判断是否有足够的余额, 后期可考虑判断max_tokens参数来调整
|
|
|
|
|
serverInfo, err := service.CheckBlance(key, chatRequest.Model)
|
|
|
|
|
serverInfo, err := service.CheckBlance(key, chatRequest.Model, chatRequest.MaxTokens)
|
|
|
|
|
if err != nil {
|
|
|
|
|
c.AbortWithStatusJSON(403, gin.H{"error": err.Error()})
|
|
|
|
|
log.Printf("请求出错 KEY: %v Model: %v ERROR: %v", key, chatRequest.Model, err)
|
|
|
|
|
@ -318,12 +338,13 @@ func Completions(c *gin.Context) {
|
|
|
|
|
keyList := strings.Split(serverInfo.AvailableKey, ",")
|
|
|
|
|
if len(keyList) > 1 {
|
|
|
|
|
// 随机数种子
|
|
|
|
|
rand.Seed(time.Now().UnixNano())
|
|
|
|
|
source := rand.NewSource(time.Now().UnixNano())
|
|
|
|
|
random := rand.New(source)
|
|
|
|
|
// 从数组中随机选择一个元素
|
|
|
|
|
serverKey = keyList[rand.Intn(len(keyList))]
|
|
|
|
|
serverKey = keyList[random.Intn(len(keyList))]
|
|
|
|
|
}
|
|
|
|
|
req.Header.Set("Authorization", "Bearer "+serverKey)
|
|
|
|
|
req.Body = ioutil.NopCloser(bytes.NewReader(newReqBody))
|
|
|
|
|
req.Body = io.NopCloser(bytes.NewReader(newReqBody))
|
|
|
|
|
}
|
|
|
|
|
sss, err := json.Marshal(chatRequest)
|
|
|
|
|
if err != nil {
|
|
|
|
|
@ -336,7 +357,10 @@ func Completions(c *gin.Context) {
|
|
|
|
|
log.Printf("流式回应 http status code: %v", resp.StatusCode)
|
|
|
|
|
if resp.StatusCode != http.StatusOK {
|
|
|
|
|
//退回预扣除的余额
|
|
|
|
|
service.CheckBlanceReturn(key, chatRequest.Model)
|
|
|
|
|
err = service.CheckBlanceReturn(key, chatRequest.Model, chatRequest.MaxTokens)
|
|
|
|
|
if err != nil {
|
|
|
|
|
return err
|
|
|
|
|
}
|
|
|
|
|
return nil
|
|
|
|
|
}
|
|
|
|
|
chatRequestId := ""
|
|
|
|
|
@ -361,7 +385,10 @@ func Completions(c *gin.Context) {
|
|
|
|
|
//去除回应中的data:前缀
|
|
|
|
|
var trimStr = strings.Trim(string(chunk), "data: ")
|
|
|
|
|
if trimStr != "\n" {
|
|
|
|
|
json.Unmarshal([]byte(trimStr), &chatResponse)
|
|
|
|
|
err := json.Unmarshal([]byte(trimStr), &chatResponse)
|
|
|
|
|
if err != nil {
|
|
|
|
|
return err
|
|
|
|
|
}
|
|
|
|
|
if chatResponse.Choices != nil {
|
|
|
|
|
if chatResponse.Choices[0].Text != "" {
|
|
|
|
|
reqContent += chatResponse.Choices[0].Text
|
|
|
|
|
@ -382,16 +409,19 @@ func Completions(c *gin.Context) {
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
if chatRequest.Model == "text-davinci-003" {
|
|
|
|
|
prompt_tokens = common.NumTokensFromString(chatRequest.Prompt, chatRequest.Model)
|
|
|
|
|
promptTokens = common.NumTokensFromString(chatRequest.Prompt, chatRequest.Model)
|
|
|
|
|
} else {
|
|
|
|
|
prompt_tokens = common.NumTokensFromMessages(chatRequest.Messages, chatRequest.Model)
|
|
|
|
|
promptTokens = common.NumTokensFromMessages(chatRequest.Messages, chatRequest.Model)
|
|
|
|
|
}
|
|
|
|
|
completion_tokens = common.NumTokensFromString(reqContent, chatRequest.Model)
|
|
|
|
|
completionTokens = common.NumTokensFromString(reqContent, chatRequest.Model)
|
|
|
|
|
log.Printf("返回内容:%v", reqContent)
|
|
|
|
|
total_tokens = prompt_tokens + completion_tokens
|
|
|
|
|
log.Printf("prompt_tokens: %v completion_tokens: %v total_tokens: %v", prompt_tokens, completion_tokens, total_tokens)
|
|
|
|
|
totalTokens = promptTokens + completionTokens
|
|
|
|
|
log.Printf("prompt_tokens: %v completion_tokens: %v total_tokens: %v", promptTokens, completionTokens, totalTokens)
|
|
|
|
|
//消费余额
|
|
|
|
|
service.Consumption(key, chatRequest.Model, prompt_tokens, completion_tokens, total_tokens, chatRequestId)
|
|
|
|
|
_, err := service.Consumption(key, chatRequest.Model, promptTokens, completionTokens, totalTokens, chatRequestId)
|
|
|
|
|
if err != nil {
|
|
|
|
|
return err
|
|
|
|
|
}
|
|
|
|
|
return nil
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
@ -399,19 +429,25 @@ func Completions(c *gin.Context) {
|
|
|
|
|
proxy.ModifyResponse = func(resp *http.Response) error {
|
|
|
|
|
resp.Header.Set("Openai-Organization", "api2gpt")
|
|
|
|
|
var chatResponse model.ChatResponse
|
|
|
|
|
body, err := ioutil.ReadAll(resp.Body)
|
|
|
|
|
body, err := io.ReadAll(resp.Body)
|
|
|
|
|
if err != nil {
|
|
|
|
|
log.Printf("非流式回应,处理 err: %v", err)
|
|
|
|
|
return err
|
|
|
|
|
}
|
|
|
|
|
json.Unmarshal(body, &chatResponse)
|
|
|
|
|
prompt_tokens = chatResponse.Usage.PromptTokens
|
|
|
|
|
completion_tokens = chatResponse.Usage.CompletionTokens
|
|
|
|
|
total_tokens = chatResponse.Usage.TotalTokens
|
|
|
|
|
resp.Body = ioutil.NopCloser(bytes.NewReader(body))
|
|
|
|
|
log.Printf("prompt_tokens: %v completion_tokens: %v total_tokens: %v", prompt_tokens, completion_tokens, total_tokens)
|
|
|
|
|
err = json.Unmarshal(body, &chatResponse)
|
|
|
|
|
if err != nil {
|
|
|
|
|
return err
|
|
|
|
|
}
|
|
|
|
|
promptTokens = chatResponse.Usage.PromptTokens
|
|
|
|
|
completionTokens = chatResponse.Usage.CompletionTokens
|
|
|
|
|
totalTokens = chatResponse.Usage.TotalTokens
|
|
|
|
|
resp.Body = io.NopCloser(bytes.NewReader(body))
|
|
|
|
|
log.Printf("prompt_tokens: %v completion_tokens: %v total_tokens: %v", promptTokens, completionTokens, totalTokens)
|
|
|
|
|
//消费余额
|
|
|
|
|
service.Consumption(key, chatRequest.Model, prompt_tokens, completion_tokens, total_tokens, chatResponse.Id)
|
|
|
|
|
_, err = service.Consumption(key, chatRequest.Model, promptTokens, completionTokens, totalTokens, chatResponse.Id)
|
|
|
|
|
if err != nil {
|
|
|
|
|
return err
|
|
|
|
|
}
|
|
|
|
|
return nil
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
@ -419,7 +455,7 @@ func Completions(c *gin.Context) {
|
|
|
|
|
proxy.ServeHTTP(c.Writer, c.Request)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// 针对接口预检OPTIONS的处理
|
|
|
|
|
// HandleOptions 针对接口预检OPTIONS的处理
|
|
|
|
|
func HandleOptions(c *gin.Context) {
|
|
|
|
|
// BUGFIX: fix options request, see https://github.com/diemus/azure-openai-proxy/issues/1
|
|
|
|
|
c.Header("Access-Control-Allow-Origin", "*")
|
|
|
|
|
|