You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

467 lines
14 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

package controller
import (
"api2gpt-mid/common"
"api2gpt-mid/model"
"api2gpt-mid/service"
"bufio"
"bytes"
"encoding/json"
"fmt"
"github.com/gin-gonic/gin"
"io"
"log"
"math/rand"
"net/http"
"net/http/httputil"
"net/url"
"strings"
"time"
)
func Images(c *gin.Context) {
var imagesRequest model.ImagesRequest
var modelStr = "images-generations"
if err := c.ShouldBindJSON(&imagesRequest); err != nil {
c.AbortWithStatusJSON(400, gin.H{"error": err.Error()})
return
}
auth := c.Request.Header.Get("Authorization")
key := auth[7:]
serverInfo, err := service.CheckBlanceForImages(key, modelStr, imagesRequest.N)
if err != nil {
c.AbortWithStatusJSON(403, gin.H{"error": err.Error()})
log.Printf("请求出错 KEY: %v Model: %v ERROR: %v", key, modelStr, err)
return
}
log.Printf("请求的KEY: %v Model: %v", key, modelStr)
remote, err := url.Parse(serverInfo.ServerAddress)
if err != nil {
c.AbortWithStatusJSON(400, gin.H{"error": err.Error()})
return
}
proxy := httputil.NewSingleHostReverseProxy(remote)
newReqBody, err := json.Marshal(imagesRequest)
if err != nil {
log.Printf("http request err: %v", err)
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
proxy.Director = func(req *http.Request) {
req.Header = c.Request.Header
req.Host = remote.Host
req.URL.Scheme = remote.Scheme
req.URL.Host = remote.Host
req.URL.Path = c.Request.URL.Path
req.ContentLength = int64(len(newReqBody))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept-Encoding", "")
serverKey := serverInfo.AvailableKey
keyList := strings.Split(serverInfo.AvailableKey, ",")
if len(keyList) > 1 {
// 随机数种子
source := rand.NewSource(time.Now().UnixNano())
random := rand.New(source)
// 从数组中随机选择一个元素
serverKey = keyList[random.Intn(len(keyList))]
}
req.Header.Set("Authorization", "Bearer "+serverKey)
req.Body = io.NopCloser(bytes.NewReader(newReqBody))
}
sss, err := json.Marshal(imagesRequest)
log.Printf("开始处理返回逻辑 %d", string(sss))
proxy.ModifyResponse = func(resp *http.Response) error {
if resp.StatusCode != http.StatusOK {
//退回预扣除的余额
err := service.CheckBlanceReturnForImages(key, modelStr, imagesRequest.N)
if err != nil {
return err
}
return nil
}
resp.Header.Set("Openai-Organization", "api2gpt")
var imagesResponse model.ImagesResponse
body, err := io.ReadAll(resp.Body)
if err != nil {
log.Printf("读取返回数据出错: %v", err)
return err
}
err = json.Unmarshal(body, &imagesResponse)
if err != nil {
log.Printf("json解析数据出错: %v", err)
return err
}
resp.Body = io.NopCloser(bytes.NewReader(body))
log.Printf("image size: %v", len(imagesResponse.Data))
timestamp := time.Now().Unix()
timestampID := "img-" + fmt.Sprintf("%d", timestamp)
//消费余额
_, err = service.ConsumptionForImages(key, modelStr, imagesRequest.N, len(imagesResponse.Data), timestampID)
if err != nil {
return err
}
return nil
}
proxy.ServeHTTP(c.Writer, c.Request)
}
func Edit(c *gin.Context) {
var promptTokens int
var complateTokens int
var totalTokens int
var chatRequest model.ChatRequest
if err := c.ShouldBindJSON(&chatRequest); err != nil {
c.AbortWithStatusJSON(400, gin.H{"error": err.Error()})
return
}
auth := c.Request.Header.Get("Authorization")
key := auth[7:]
//根据KEY调用用户余额接口判断是否有足够的余额 后期可考虑判断max_tokens参数来调整
serverInfo, err := service.CheckBlance(key, chatRequest.Model, chatRequest.MaxTokens)
if err != nil {
c.AbortWithStatusJSON(403, gin.H{"error": err.Error()})
log.Printf("请求出错 KEY: %v Model: %v ERROR: %v", key, chatRequest.Model, err)
return
}
log.Printf("请求的KEY: %v Model: %v", key, chatRequest.Model)
remote, err := url.Parse(serverInfo.ServerAddress)
if err != nil {
c.AbortWithStatusJSON(400, gin.H{"error": err.Error()})
return
}
proxy := httputil.NewSingleHostReverseProxy(remote)
newReqBody, err := json.Marshal(chatRequest)
if err != nil {
log.Printf("http request err: %v", err)
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
proxy.Director = func(req *http.Request) {
req.Header = c.Request.Header
req.Host = remote.Host
req.URL.Scheme = remote.Scheme
req.URL.Host = remote.Host
req.URL.Path = c.Request.URL.Path
req.ContentLength = int64(len(newReqBody))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept-Encoding", "")
serverKey := serverInfo.AvailableKey
keyList := strings.Split(serverInfo.AvailableKey, ",")
if len(keyList) > 1 {
// 随机数种子
source := rand.NewSource(time.Now().UnixNano())
random := rand.New(source)
// 从数组中随机选择一个元素
serverKey = keyList[random.Intn(len(keyList))]
}
req.Header.Set("Authorization", "Bearer "+serverKey)
req.Body = io.NopCloser(bytes.NewReader(newReqBody))
}
sss, err := json.Marshal(chatRequest)
log.Printf("开始处理返回逻辑 %d", string(sss))
proxy.ModifyResponse = func(resp *http.Response) error {
resp.Header.Set("Openai-Organization", "api2gpt")
var chatResponse model.ChatResponse
body, err := io.ReadAll(resp.Body)
if err != nil {
log.Printf("读取返回数据出错: %v", err)
return err
}
err = json.Unmarshal(body, &chatResponse)
if err != nil {
log.Printf("json解析数据出错: %v", err)
return err
}
promptTokens = chatResponse.Usage.PromptTokens
complateTokens = chatResponse.Usage.CompletionTokens
totalTokens = chatResponse.Usage.TotalTokens
resp.Body = io.NopCloser(bytes.NewReader(body))
log.Printf("prompt_tokens: %v complate_tokens: %v total_tokens: %v", promptTokens, complateTokens, totalTokens)
timestamp := time.Now().Unix()
timestampID := "edit-" + fmt.Sprintf("%d", timestamp)
//消费余额
_, err = service.Consumption(key, chatRequest.Model, promptTokens, 0, totalTokens, timestampID)
if err != nil {
return err
}
return nil
}
proxy.ServeHTTP(c.Writer, c.Request)
}
func Embeddings(c *gin.Context) {
var promptTokens int
var totalTokens int
var chatRequest model.ChatRequest
if err := c.ShouldBindJSON(&chatRequest); err != nil {
c.AbortWithStatusJSON(400, gin.H{"error": err.Error()})
return
}
auth := c.Request.Header.Get("Authorization")
key := auth[7:]
serverInfo, err := service.CheckBlance(key, chatRequest.Model, chatRequest.MaxTokens)
if err != nil {
c.AbortWithStatusJSON(403, gin.H{"error": err.Error()})
log.Printf("请求出错 KEY: %v Model: %v ERROR: %v", key, chatRequest.Model, err)
return
}
log.Printf("请求的KEY: %v Model: %v", key, chatRequest.Model)
remote, err := url.Parse(serverInfo.ServerAddress)
if err != nil {
c.AbortWithStatusJSON(400, gin.H{"error": err.Error()})
return
}
proxy := httputil.NewSingleHostReverseProxy(remote)
newReqBody, err := json.Marshal(chatRequest)
if err != nil {
log.Printf("http request err: %v", err)
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
proxy.Director = func(req *http.Request) {
req.Header = c.Request.Header
req.Host = remote.Host
req.URL.Scheme = remote.Scheme
req.URL.Host = remote.Host
req.URL.Path = c.Request.URL.Path
req.ContentLength = int64(len(newReqBody))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept-Encoding", "")
serverKey := serverInfo.AvailableKey
keyList := strings.Split(serverInfo.AvailableKey, ",")
if len(keyList) > 1 {
// 随机数种子
source := rand.NewSource(time.Now().UnixNano())
random := rand.New(source)
// 从数组中随机选择一个元素
serverKey = keyList[random.Intn(len(keyList))]
}
req.Header.Set("Authorization", "Bearer "+serverKey)
req.Body = io.NopCloser(bytes.NewReader(newReqBody))
}
sss, err := json.Marshal(chatRequest)
log.Printf("开始处理返回逻辑 %d", string(sss))
proxy.ModifyResponse = func(resp *http.Response) error {
resp.Header.Set("Openai-Organization", "api2gpt")
var chatResponse model.ChatResponse
body, err := io.ReadAll(resp.Body)
if err != nil {
log.Printf("读取返回数据出错: %v", err)
return err
}
err = json.Unmarshal(body, &chatResponse)
if err != nil {
log.Printf("json解析数据出错: %v", err)
return err
}
promptTokens = chatResponse.Usage.PromptTokens
totalTokens = chatResponse.Usage.TotalTokens
resp.Body = io.NopCloser(bytes.NewReader(body))
log.Printf("prompt_tokens: %v total_tokens: %v", promptTokens, totalTokens)
timestamp := time.Now().Unix()
timestampID := "emb-" + fmt.Sprintf("%d", timestamp)
//消费余额
_, err = service.Consumption(key, chatRequest.Model, promptTokens, 0, totalTokens, timestampID)
if err != nil {
return err
}
return nil
}
proxy.ServeHTTP(c.Writer, c.Request)
}
func Completions(c *gin.Context) {
var promptTokens int
var completionTokens int
var totalTokens int
var chatRequest model.ChatRequest
if err := c.ShouldBindJSON(&chatRequest); err != nil {
c.AbortWithStatusJSON(400, gin.H{"error": err.Error()})
return
}
auth := c.Request.Header.Get("Authorization")
key := auth[7:]
serverInfo, err := service.CheckBlance(key, chatRequest.Model, chatRequest.MaxTokens)
if err != nil {
c.AbortWithStatusJSON(403, gin.H{"error": err.Error()})
log.Printf("请求出错 KEY: %v Model: %v ERROR: %v", key, chatRequest.Model, err)
return
}
log.Printf("请求的KEY: %v Model: %v", key, chatRequest.Model)
remote, err := url.Parse(serverInfo.ServerAddress)
if err != nil {
c.AbortWithStatusJSON(400, gin.H{"error": err.Error()})
return
}
proxy := httputil.NewSingleHostReverseProxy(remote)
newReqBody, err := json.Marshal(chatRequest)
if err != nil {
log.Printf("http request err: %v", err)
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
proxy.Director = func(req *http.Request) {
req.Header = c.Request.Header
req.Host = remote.Host
req.URL.Scheme = remote.Scheme
req.URL.Host = remote.Host
req.URL.Path = c.Request.URL.Path
req.ContentLength = int64(len(newReqBody))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept-Encoding", "")
serverKey := serverInfo.AvailableKey
keyList := strings.Split(serverInfo.AvailableKey, ",")
if len(keyList) > 1 {
// 随机数种子
source := rand.NewSource(time.Now().UnixNano())
random := rand.New(source)
// 从数组中随机选择一个元素
serverKey = keyList[random.Intn(len(keyList))]
}
req.Header.Set("Authorization", "Bearer "+serverKey)
req.Body = io.NopCloser(bytes.NewReader(newReqBody))
}
sss, err := json.Marshal(chatRequest)
if err != nil {
log.Printf("chatRequest 转化出错 %v", err)
}
log.Printf("开始处理返回逻辑: %v", string(sss))
if chatRequest.Stream {
// 流式回应,处理
proxy.ModifyResponse = func(resp *http.Response) error {
log.Printf("流式回应 http status code: %v", resp.StatusCode)
if resp.StatusCode != http.StatusOK {
//退回预扣除的余额
err = service.CheckBlanceReturn(key, chatRequest.Model, chatRequest.MaxTokens)
if err != nil {
return err
}
return nil
}
chatRequestId := ""
reqContent := ""
reader := bufio.NewReader(resp.Body)
headers := resp.Header
for k, v := range headers {
c.Writer.Header().Set(k, v[0])
}
c.Writer.Header().Set("Openai-Organization", "api2gpt")
for {
chunk, err := reader.ReadBytes('\n')
if err != nil {
if err == io.EOF {
break
}
log.Printf("流式回应,处理 err %v:", err.Error())
break
//return err
}
var chatResponse model.ChatResponse
//去除回应中的data:前缀
var trimStr = strings.Trim(string(chunk), "data: ")
if trimStr != "\n" {
err := json.Unmarshal([]byte(trimStr), &chatResponse)
if err != nil {
return err
}
if chatResponse.Choices != nil {
if chatResponse.Choices[0].Text != "" {
reqContent += chatResponse.Choices[0].Text
} else {
reqContent += chatResponse.Choices[0].Delta.Content
}
chatRequestId = chatResponse.Id
}
// 写回数据
_, err = c.Writer.Write([]byte(string(chunk) + "\n"))
if err != nil {
log.Printf("写回数据 err: %v", err.Error())
return err
}
c.Writer.(http.Flusher).Flush()
}
}
if chatRequest.Model == "text-davinci-003" {
promptTokens = common.NumTokensFromString(chatRequest.Prompt, chatRequest.Model)
} else {
promptTokens = common.NumTokensFromMessages(chatRequest.Messages, chatRequest.Model)
}
completionTokens = common.NumTokensFromString(reqContent, chatRequest.Model)
log.Printf("返回内容:%v", reqContent)
totalTokens = promptTokens + completionTokens
log.Printf("prompt_tokens: %v completion_tokens: %v total_tokens: %v", promptTokens, completionTokens, totalTokens)
//消费余额
_, err := service.Consumption(key, chatRequest.Model, promptTokens, completionTokens, totalTokens, chatRequestId)
if err != nil {
return err
}
return nil
}
} else {
// 非流式回应,处理
proxy.ModifyResponse = func(resp *http.Response) error {
resp.Header.Set("Openai-Organization", "api2gpt")
var chatResponse model.ChatResponse
body, err := io.ReadAll(resp.Body)
if err != nil {
log.Printf("非流式回应,处理 err: %v", err)
return err
}
err = json.Unmarshal(body, &chatResponse)
if err != nil {
return err
}
promptTokens = chatResponse.Usage.PromptTokens
completionTokens = chatResponse.Usage.CompletionTokens
totalTokens = chatResponse.Usage.TotalTokens
resp.Body = io.NopCloser(bytes.NewReader(body))
log.Printf("prompt_tokens: %v completion_tokens: %v total_tokens: %v", promptTokens, completionTokens, totalTokens)
//消费余额
_, err = service.Consumption(key, chatRequest.Model, promptTokens, completionTokens, totalTokens, chatResponse.Id)
if err != nil {
return err
}
return nil
}
}
proxy.ServeHTTP(c.Writer, c.Request)
}
// HandleOptions 针对接口预检OPTIONS的处理
func HandleOptions(c *gin.Context) {
// BUGFIX: fix options request, see https://github.com/diemus/azure-openai-proxy/issues/1
c.Header("Access-Control-Allow-Origin", "*")
c.Header("Access-Control-Allow-Methods", "POST, GET, OPTIONS, PUT, DELETE")
c.Header("Access-Control-Allow-Headers", "Content-Type, Authorization")
c.Status(200)
return
}