update new version

dev
Kelvin 3 years ago
parent 671bdb2168
commit 2725f4ab00

102
api.go

@ -1,102 +0,0 @@
package main
import (
"bytes"
"encoding/json"
"io/ioutil"
"net/http"
"github.com/gin-gonic/gin"
)
type BalanceInfo struct {
ServerAddress string `json:"server_address"`
AvailableKey string `json:"available_key"`
UserBalance float64 `json:"user_balance"`
TokenRatio float64 `json:"token_ratio"`
}
type Consumption struct {
SecretKey string `json:"secretKey"`
Model string `json:"model"`
MsgId string `json:"msgId"`
PromptTokens int `json:"promptTokens"`
CompletionTokens int `json:"completionTokens"`
TotalTokens int `json:"totalTokens"`
}
func mockBalanceInquiry(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{
"server_address": "https://gptp.any-door.cn",
"available_key": "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",
"user_balance": 10000,
"token_ratio": 1000,
})
}
func mockBalanceConsumption(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{
"success": "true",
})
}
// 余额查询api调用
func balanceInquiry(key string, model string) (*BalanceInfo, error) {
url := "http://localhost:8080/mock1?key=" + key + "&model=" + model
req, err := http.NewRequest("POST", url, nil)
req.Header.Set("Content-Type", "application/json")
client := &http.Client{}
resp, err := client.Do(req)
if err != nil {
panic(err)
}
defer resp.Body.Close()
body, err := ioutil.ReadAll(resp.Body)
if err != nil {
return nil, err
}
var balanceInfo BalanceInfo
if err := json.Unmarshal(body, &balanceInfo); err != nil {
return nil, err
}
return &balanceInfo, nil
}
// 余额消费
func balanceConsumption(key string, model string, prompt_tokens int, completion_tokens int, total_tokens int, msg_id string) (string, error) {
var data = Consumption{
SecretKey: key,
Model: model,
MsgId: msg_id,
PromptTokens: prompt_tokens,
CompletionTokens: completion_tokens,
TotalTokens: total_tokens,
}
jsonData, err := json.Marshal(data)
// 构造post请求的body
reqBody := bytes.NewBuffer(jsonData)
url := "http://172.17.0.1:8080/other/usageRecord"
req2, err := http.NewRequest("POST", url, reqBody)
// 设置http请求的header
req2.Header.Set("Content-Type", "application/json")
client := &http.Client{}
resp, err := client.Do(req2)
if err != nil {
return "", err
}
defer resp.Body.Close()
body, err := ioutil.ReadAll(resp.Body)
if err != nil {
return "", err
}
return string(body), nil
}

@ -0,0 +1,44 @@
package api
import (
"api2gpt-mid/model"
"bytes"
"encoding/json"
"io/ioutil"
"net/http"
)
// 余额消费
func BalanceConsumption(key string, modelStr string, prompt_tokens int, completion_tokens int, total_tokens int, msg_id string) (string, error) {
var data = model.Consumption{
SecretKey: key,
Model: modelStr,
MsgId: msg_id,
PromptTokens: prompt_tokens,
CompletionTokens: completion_tokens,
TotalTokens: total_tokens,
}
jsonData, err := json.Marshal(data)
// 构造post请求的body
reqBody := bytes.NewBuffer(jsonData)
url := "http://172.17.0.1:8080/other/usageRecord"
req2, err := http.NewRequest("POST", url, reqBody)
// 设置http请求的header
req2.Header.Set("Content-Type", "application/json")
client := &http.Client{}
resp, err := client.Do(req2)
if err != nil {
return "", err
}
defer resp.Body.Close()
body, err := ioutil.ReadAll(resp.Body)
if err != nil {
return "", err
}
return string(body), nil
}

@ -0,0 +1,50 @@
package common
import (
"api2gpt-mid/model"
"fmt"
"github.com/pkoukk/tiktoken-go"
"strings"
)
// 计算Messages中的token数量
func NumTokensFromMessages(messages []model.Message, model string) int {
if strings.Contains(model, "gpt-3.5") {
model = "gpt-3.5-turbo"
}
if strings.Contains(model, "gpt-4") {
model = "gpt-4"
}
tkm, err := tiktoken.EncodingForModel(model)
if err != nil {
err = fmt.Errorf("getEncoding: %v", err)
return 0
}
numTokens := 0
for _, message := range messages {
numTokens += len(tkm.Encode(message.Content, nil, nil))
numTokens += 6
}
numTokens += 3
return numTokens
}
// 计算String中的token数量
func NumTokensFromString(msg string, model string) int {
if strings.Contains(model, "gpt-3.5") {
model = "gpt-3.5-turbo"
}
if strings.Contains(model, "gpt-4") {
model = "gpt-4"
}
tkm, err := tiktoken.EncodingForModel(model)
if err != nil {
err = fmt.Errorf("getEncoding: %v", err)
return 0
}
if model == "text-davinci-003" {
return len(tkm.Encode(msg, nil, nil)) + 1
} else {
return len(tkm.Encode(msg, nil, nil)) + 9
}
}

@ -0,0 +1,30 @@
package controller
import (
"api2gpt-mid/service"
"github.com/gin-gonic/gin"
"strings"
)
type CreditSummary struct {
Object string `json:"object"`
TotalGranted float64 `json:"total_granted"`
TotalUsed float64 `json:"total_used"`
TotalRemaining float64 `json:"total_remaining"`
}
// 余额查询
func Balance(c *gin.Context) {
auth := c.Request.Header.Get("Authorization")
key := strings.Trim(auth, "Bearer ")
balance, err := service.QueryBlance(key)
if err != nil {
c.JSON(400, gin.H{"error": err.Error()})
}
var creditSummary CreditSummary
creditSummary.Object = "credit_grant"
creditSummary.TotalGranted = 999999
creditSummary.TotalUsed = 999999 - balance
creditSummary.TotalRemaining = balance
c.JSON(200, creditSummary)
}

@ -0,0 +1,458 @@
package controller
import (
"fmt"
"github.com/gin-gonic/gin"
)
// https://platform.openai.com/docs/api-reference/models/list
type OpenAIModelPermission struct {
Id string `json:"id"`
Object string `json:"object"`
Created int `json:"created"`
AllowCreateEngine bool `json:"allow_create_engine"`
AllowSampling bool `json:"allow_sampling"`
AllowLogprobs bool `json:"allow_logprobs"`
AllowSearchIndices bool `json:"allow_search_indices"`
AllowView bool `json:"allow_view"`
AllowFineTuning bool `json:"allow_fine_tuning"`
Organization string `json:"organization"`
Group *string `json:"group"`
IsBlocking bool `json:"is_blocking"`
}
type OpenAIModels struct {
Id string `json:"id"`
Object string `json:"object"`
Created int `json:"created"`
OwnedBy string `json:"owned_by"`
Permission []OpenAIModelPermission `json:"permission"`
Root string `json:"root"`
Parent *string `json:"parent"`
}
type OpenAIError struct {
Message string `json:"message"`
Type string `json:"type"`
Param string `json:"param"`
Code any `json:"code"`
}
type OpenAIErrorWithStatusCode struct {
OpenAIError
StatusCode int `json:"status_code"`
}
var openAIModels []OpenAIModels
var openAIModelsMap map[string]OpenAIModels
func init() {
var permission []OpenAIModelPermission
permission = append(permission, OpenAIModelPermission{
Id: "modelperm-LwHkVFn8AcMItP432fKKDIKJ",
Object: "model_permission",
Created: 1626777600,
AllowCreateEngine: true,
AllowSampling: true,
AllowLogprobs: true,
AllowSearchIndices: false,
AllowView: true,
AllowFineTuning: false,
Organization: "*",
Group: nil,
IsBlocking: false,
})
// https://platform.openai.com/docs/models/model-endpoint-compatibility
openAIModels = []OpenAIModels{
{
Id: "dall-e",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "dall-e",
Parent: nil,
},
{
Id: "whisper-1",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "whisper-1",
Parent: nil,
},
{
Id: "gpt-3.5-turbo",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "gpt-3.5-turbo",
Parent: nil,
},
{
Id: "gpt-3.5-turbo-0301",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "gpt-3.5-turbo-0301",
Parent: nil,
},
{
Id: "gpt-3.5-turbo-0613",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "gpt-3.5-turbo-0613",
Parent: nil,
},
{
Id: "gpt-3.5-turbo-16k",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "gpt-3.5-turbo-16k",
Parent: nil,
},
{
Id: "gpt-3.5-turbo-16k-0613",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "gpt-3.5-turbo-16k-0613",
Parent: nil,
},
{
Id: "gpt-4",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "gpt-4",
Parent: nil,
},
{
Id: "gpt-4-0314",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "gpt-4-0314",
Parent: nil,
},
{
Id: "gpt-4-0613",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "gpt-4-0613",
Parent: nil,
},
{
Id: "gpt-4-32k",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "gpt-4-32k",
Parent: nil,
},
{
Id: "gpt-4-32k-0314",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "gpt-4-32k-0314",
Parent: nil,
},
{
Id: "gpt-4-32k-0613",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "gpt-4-32k-0613",
Parent: nil,
},
{
Id: "text-embedding-ada-002",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "text-embedding-ada-002",
Parent: nil,
},
{
Id: "text-davinci-003",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "text-davinci-003",
Parent: nil,
},
{
Id: "text-davinci-002",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "text-davinci-002",
Parent: nil,
},
{
Id: "text-curie-001",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "text-curie-001",
Parent: nil,
},
{
Id: "text-babbage-001",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "text-babbage-001",
Parent: nil,
},
{
Id: "text-ada-001",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "text-ada-001",
Parent: nil,
},
{
Id: "text-moderation-latest",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "text-moderation-latest",
Parent: nil,
},
{
Id: "text-moderation-stable",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "text-moderation-stable",
Parent: nil,
},
{
Id: "text-davinci-edit-001",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "text-davinci-edit-001",
Parent: nil,
},
{
Id: "code-davinci-edit-001",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "code-davinci-edit-001",
Parent: nil,
},
{
Id: "claude-instant-1",
Object: "model",
Created: 1677649963,
OwnedBy: "anturopic",
Permission: permission,
Root: "claude-instant-1",
Parent: nil,
},
{
Id: "claude-2",
Object: "model",
Created: 1677649963,
OwnedBy: "anturopic",
Permission: permission,
Root: "claude-2",
Parent: nil,
},
{
Id: "ERNIE-Bot",
Object: "model",
Created: 1677649963,
OwnedBy: "baidu",
Permission: permission,
Root: "ERNIE-Bot",
Parent: nil,
},
{
Id: "ERNIE-Bot-turbo",
Object: "model",
Created: 1677649963,
OwnedBy: "baidu",
Permission: permission,
Root: "ERNIE-Bot-turbo",
Parent: nil,
},
{
Id: "Embedding-V1",
Object: "model",
Created: 1677649963,
OwnedBy: "baidu",
Permission: permission,
Root: "Embedding-V1",
Parent: nil,
},
{
Id: "PaLM-2",
Object: "model",
Created: 1677649963,
OwnedBy: "google",
Permission: permission,
Root: "PaLM-2",
Parent: nil,
},
{
Id: "chatglm_pro",
Object: "model",
Created: 1677649963,
OwnedBy: "zhipu",
Permission: permission,
Root: "chatglm_pro",
Parent: nil,
},
{
Id: "chatglm_std",
Object: "model",
Created: 1677649963,
OwnedBy: "zhipu",
Permission: permission,
Root: "chatglm_std",
Parent: nil,
},
{
Id: "chatglm_lite",
Object: "model",
Created: 1677649963,
OwnedBy: "zhipu",
Permission: permission,
Root: "chatglm_lite",
Parent: nil,
},
{
Id: "qwen-v1",
Object: "model",
Created: 1677649963,
OwnedBy: "ali",
Permission: permission,
Root: "qwen-v1",
Parent: nil,
},
{
Id: "qwen-plus-v1",
Object: "model",
Created: 1677649963,
OwnedBy: "ali",
Permission: permission,
Root: "qwen-plus-v1",
Parent: nil,
},
{
Id: "SparkDesk",
Object: "model",
Created: 1677649963,
OwnedBy: "xunfei",
Permission: permission,
Root: "SparkDesk",
Parent: nil,
},
{
Id: "360GPT_S2_V9",
Object: "model",
Created: 1677649963,
OwnedBy: "360",
Permission: permission,
Root: "360GPT_S2_V9",
Parent: nil,
},
{
Id: "embedding-bert-512-v1",
Object: "model",
Created: 1677649963,
OwnedBy: "360",
Permission: permission,
Root: "embedding-bert-512-v1",
Parent: nil,
},
{
Id: "embedding_s1_v1",
Object: "model",
Created: 1677649963,
OwnedBy: "360",
Permission: permission,
Root: "embedding_s1_v1",
Parent: nil,
},
{
Id: "semantic_similarity_s1_v1",
Object: "model",
Created: 1677649963,
OwnedBy: "360",
Permission: permission,
Root: "semantic_similarity_s1_v1",
Parent: nil,
},
{
Id: "360GPT_S2_V9.4",
Object: "model",
Created: 1677649963,
OwnedBy: "360",
Permission: permission,
Root: "360GPT_S2_V9.4",
Parent: nil,
},
}
openAIModelsMap = make(map[string]OpenAIModels)
for _, model := range openAIModels {
openAIModelsMap[model.Id] = model
}
}
func ListModels(c *gin.Context) {
c.JSON(200, gin.H{
"object": "list",
"data": openAIModels,
})
}
func RetrieveModel(c *gin.Context) {
modelId := c.Param("model")
if model, ok := openAIModelsMap[modelId]; ok {
c.JSON(200, model)
} else {
openAIError := OpenAIError{
Message: fmt.Sprintf("The model '%s' does not exist", modelId),
Type: "invalid_request_error",
Param: "model",
Code: "model_not_found",
}
c.JSON(200, gin.H{
"error": openAIError,
})
}
}

@ -0,0 +1,430 @@
package controller
import (
"api2gpt-mid/common"
"api2gpt-mid/model"
"api2gpt-mid/service"
"bufio"
"bytes"
"encoding/json"
"fmt"
"github.com/gin-gonic/gin"
"io"
"io/ioutil"
"log"
"math/rand"
"net/http"
"net/http/httputil"
"net/url"
"strings"
"time"
)
func Images(c *gin.Context) {
var imagesRequest model.ImagesRequest
var modelStr = "images-generations"
if err := c.ShouldBindJSON(&imagesRequest); err != nil {
c.AbortWithStatusJSON(400, gin.H{"error": err.Error()})
return
}
auth := c.Request.Header.Get("Authorization")
key := auth[7:]
serverInfo, err := service.CheckBlanceForImages(key, modelStr, imagesRequest.N)
if err != nil {
c.AbortWithStatusJSON(403, gin.H{"error": err.Error()})
log.Printf("请求出错 KEY: %v Model: %v ERROR: %v", key, modelStr, err)
return
}
log.Printf("请求的KEY: %v Model: %v", key, modelStr)
remote, err := url.Parse(serverInfo.ServerAddress)
if err != nil {
c.AbortWithStatusJSON(400, gin.H{"error": err.Error()})
return
}
proxy := httputil.NewSingleHostReverseProxy(remote)
newReqBody, err := json.Marshal(imagesRequest)
if err != nil {
log.Printf("http request err: %v", err)
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
proxy.Director = func(req *http.Request) {
req.Header = c.Request.Header
req.Host = remote.Host
req.URL.Scheme = remote.Scheme
req.URL.Host = remote.Host
req.URL.Path = c.Request.URL.Path
req.ContentLength = int64(len(newReqBody))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept-Encoding", "")
serverKey := serverInfo.AvailableKey
keyList := strings.Split(serverInfo.AvailableKey, ",")
if len(keyList) > 1 {
// 随机数种子
rand.Seed(time.Now().UnixNano())
// 从数组中随机选择一个元素
serverKey = keyList[rand.Intn(len(keyList))]
}
req.Header.Set("Authorization", "Bearer "+serverKey)
req.Body = ioutil.NopCloser(bytes.NewReader(newReqBody))
}
sss, err := json.Marshal(imagesRequest)
log.Printf("开始处理返回逻辑 %d", string(sss))
proxy.ModifyResponse = func(resp *http.Response) error {
if resp.StatusCode != http.StatusOK {
//退回预扣除的余额
service.CheckBlanceReturnForImages(key, modelStr, imagesRequest.N)
return nil
}
resp.Header.Set("Openai-Organization", "api2gpt")
var imagesResponse model.ImagesResponse
body, err := ioutil.ReadAll(resp.Body)
if err != nil {
log.Printf("读取返回数据出错: %v", err)
return err
}
json.Unmarshal(body, &imagesResponse)
resp.Body = ioutil.NopCloser(bytes.NewReader(body))
log.Printf("image size: %v", len(imagesResponse.Data))
timestamp := time.Now().Unix()
timestampID := "img-" + fmt.Sprintf("%d", timestamp)
//消费余额
service.ConsumptionForImages(key, modelStr, imagesRequest.N, len(imagesResponse.Data), timestampID)
return nil
}
proxy.ServeHTTP(c.Writer, c.Request)
}
func Edit(c *gin.Context) {
var prompt_tokens int
var complate_tokens int
var total_tokens int
var chatRequest model.ChatRequest
if err := c.ShouldBindJSON(&chatRequest); err != nil {
c.AbortWithStatusJSON(400, gin.H{"error": err.Error()})
return
}
auth := c.Request.Header.Get("Authorization")
key := auth[7:]
//根据KEY调用用户余额接口判断是否有足够的余额 后期可考虑判断max_tokens参数来调整
serverInfo, err := service.CheckBlance(key, chatRequest.Model)
if err != nil {
c.AbortWithStatusJSON(403, gin.H{"error": err.Error()})
log.Printf("请求出错 KEY: %v Model: %v ERROR: %v", key, chatRequest.Model, err)
return
}
log.Printf("请求的KEY: %v Model: %v", key, chatRequest.Model)
remote, err := url.Parse(serverInfo.ServerAddress)
if err != nil {
c.AbortWithStatusJSON(400, gin.H{"error": err.Error()})
return
}
proxy := httputil.NewSingleHostReverseProxy(remote)
newReqBody, err := json.Marshal(chatRequest)
if err != nil {
log.Printf("http request err: %v", err)
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
proxy.Director = func(req *http.Request) {
req.Header = c.Request.Header
req.Host = remote.Host
req.URL.Scheme = remote.Scheme
req.URL.Host = remote.Host
req.URL.Path = c.Request.URL.Path
req.ContentLength = int64(len(newReqBody))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept-Encoding", "")
serverKey := serverInfo.AvailableKey
keyList := strings.Split(serverInfo.AvailableKey, ",")
if len(keyList) > 1 {
// 随机数种子
rand.Seed(time.Now().UnixNano())
// 从数组中随机选择一个元素
serverKey = keyList[rand.Intn(len(keyList))]
}
req.Header.Set("Authorization", "Bearer "+serverKey)
req.Body = ioutil.NopCloser(bytes.NewReader(newReqBody))
}
sss, err := json.Marshal(chatRequest)
log.Printf("开始处理返回逻辑 %d", string(sss))
proxy.ModifyResponse = func(resp *http.Response) error {
resp.Header.Set("Openai-Organization", "api2gpt")
var chatResponse model.ChatResponse
body, err := ioutil.ReadAll(resp.Body)
if err != nil {
log.Printf("读取返回数据出错: %v", err)
return err
}
json.Unmarshal(body, &chatResponse)
prompt_tokens = chatResponse.Usage.PromptTokens
complate_tokens = chatResponse.Usage.CompletionTokens
total_tokens = chatResponse.Usage.TotalTokens
resp.Body = ioutil.NopCloser(bytes.NewReader(body))
log.Printf("prompt_tokens: %v complate_tokens: %v total_tokens: %v", prompt_tokens, complate_tokens, total_tokens)
timestamp := time.Now().Unix()
timestampID := "edit-" + fmt.Sprintf("%d", timestamp)
//消费余额
service.Consumption(key, chatRequest.Model, prompt_tokens, 0, total_tokens, timestampID)
return nil
}
proxy.ServeHTTP(c.Writer, c.Request)
}
func Embeddings(c *gin.Context) {
var prompt_tokens int
var total_tokens int
var chatRequest model.ChatRequest
if err := c.ShouldBindJSON(&chatRequest); err != nil {
c.AbortWithStatusJSON(400, gin.H{"error": err.Error()})
return
}
auth := c.Request.Header.Get("Authorization")
//key := strings.Trim(auth, "Bearer ")
key := auth[7:]
//根据KEY调用用户余额接口判断是否有足够的余额 后期可考虑判断max_tokens参数来调整
serverInfo, err := service.CheckBlance(key, chatRequest.Model)
if err != nil {
c.AbortWithStatusJSON(403, gin.H{"error": err.Error()})
log.Printf("请求出错 KEY: %v Model: %v ERROR: %v", key, chatRequest.Model, err)
return
}
log.Printf("请求的KEY: %v Model: %v", key, chatRequest.Model)
remote, err := url.Parse(serverInfo.ServerAddress)
if err != nil {
c.AbortWithStatusJSON(400, gin.H{"error": err.Error()})
return
}
proxy := httputil.NewSingleHostReverseProxy(remote)
newReqBody, err := json.Marshal(chatRequest)
if err != nil {
log.Printf("http request err: %v", err)
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
proxy.Director = func(req *http.Request) {
req.Header = c.Request.Header
req.Host = remote.Host
req.URL.Scheme = remote.Scheme
req.URL.Host = remote.Host
req.URL.Path = c.Request.URL.Path
req.ContentLength = int64(len(newReqBody))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept-Encoding", "")
serverKey := serverInfo.AvailableKey
keyList := strings.Split(serverInfo.AvailableKey, ",")
if len(keyList) > 1 {
// 随机数种子
rand.Seed(time.Now().UnixNano())
// 从数组中随机选择一个元素
serverKey = keyList[rand.Intn(len(keyList))]
}
req.Header.Set("Authorization", "Bearer "+serverKey)
req.Body = ioutil.NopCloser(bytes.NewReader(newReqBody))
}
sss, err := json.Marshal(chatRequest)
log.Printf("开始处理返回逻辑 %d", string(sss))
proxy.ModifyResponse = func(resp *http.Response) error {
resp.Header.Set("Openai-Organization", "api2gpt")
var chatResponse model.ChatResponse
body, err := ioutil.ReadAll(resp.Body)
if err != nil {
log.Printf("读取返回数据出错: %v", err)
return err
}
json.Unmarshal(body, &chatResponse)
prompt_tokens = chatResponse.Usage.PromptTokens
total_tokens = chatResponse.Usage.TotalTokens
resp.Body = ioutil.NopCloser(bytes.NewReader(body))
log.Printf("prompt_tokens: %v total_tokens: %v", prompt_tokens, total_tokens)
timestamp := time.Now().Unix()
timestampID := "emb-" + fmt.Sprintf("%d", timestamp)
//消费余额
service.Consumption(key, chatRequest.Model, prompt_tokens, 0, total_tokens, timestampID)
return nil
}
proxy.ServeHTTP(c.Writer, c.Request)
}
func Completions(c *gin.Context) {
var prompt_tokens int
var completion_tokens int
var total_tokens int
var chatRequest model.ChatRequest
if err := c.ShouldBindJSON(&chatRequest); err != nil {
c.AbortWithStatusJSON(400, gin.H{"error": err.Error()})
return
}
auth := c.Request.Header.Get("Authorization")
key := auth[7:]
//根据KEY调用用户余额接口判断是否有足够的余额 后期可考虑判断max_tokens参数来调整
serverInfo, err := service.CheckBlance(key, chatRequest.Model)
if err != nil {
c.AbortWithStatusJSON(403, gin.H{"error": err.Error()})
log.Printf("请求出错 KEY: %v Model: %v ERROR: %v", key, chatRequest.Model, err)
return
}
log.Printf("请求的KEY: %v Model: %v", key, chatRequest.Model)
remote, err := url.Parse(serverInfo.ServerAddress)
if err != nil {
c.AbortWithStatusJSON(400, gin.H{"error": err.Error()})
return
}
proxy := httputil.NewSingleHostReverseProxy(remote)
newReqBody, err := json.Marshal(chatRequest)
if err != nil {
log.Printf("http request err: %v", err)
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
proxy.Director = func(req *http.Request) {
req.Header = c.Request.Header
req.Host = remote.Host
req.URL.Scheme = remote.Scheme
req.URL.Host = remote.Host
req.URL.Path = c.Request.URL.Path
req.ContentLength = int64(len(newReqBody))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept-Encoding", "")
serverKey := serverInfo.AvailableKey
keyList := strings.Split(serverInfo.AvailableKey, ",")
if len(keyList) > 1 {
// 随机数种子
rand.Seed(time.Now().UnixNano())
// 从数组中随机选择一个元素
serverKey = keyList[rand.Intn(len(keyList))]
}
req.Header.Set("Authorization", "Bearer "+serverKey)
req.Body = ioutil.NopCloser(bytes.NewReader(newReqBody))
}
sss, err := json.Marshal(chatRequest)
if err != nil {
log.Printf("chatRequest 转化出错 %v", err)
}
log.Printf("开始处理返回逻辑: %v", string(sss))
if chatRequest.Stream {
// 流式回应,处理
proxy.ModifyResponse = func(resp *http.Response) error {
log.Printf("流式回应 http status code: %v", resp.StatusCode)
if resp.StatusCode != http.StatusOK {
//退回预扣除的余额
service.CheckBlanceReturn(key, chatRequest.Model)
return nil
}
chatRequestId := ""
reqContent := ""
reader := bufio.NewReader(resp.Body)
headers := resp.Header
for k, v := range headers {
c.Writer.Header().Set(k, v[0])
}
c.Writer.Header().Set("Openai-Organization", "api2gpt")
for {
chunk, err := reader.ReadBytes('\n')
if err != nil {
if err == io.EOF {
break
}
log.Printf("流式回应,处理 err %v:", err.Error())
break
//return err
}
var chatResponse model.ChatResponse
//去除回应中的data:前缀
var trimStr = strings.Trim(string(chunk), "data: ")
if trimStr != "\n" {
json.Unmarshal([]byte(trimStr), &chatResponse)
if chatResponse.Choices != nil {
if chatResponse.Choices[0].Text != "" {
reqContent += chatResponse.Choices[0].Text
} else {
reqContent += chatResponse.Choices[0].Delta.Content
}
chatRequestId = chatResponse.Id
}
// 写回数据
_, err = c.Writer.Write([]byte(string(chunk) + "\n"))
if err != nil {
log.Printf("写回数据 err: %v", err.Error())
return err
}
c.Writer.(http.Flusher).Flush()
}
}
if chatRequest.Model == "text-davinci-003" {
prompt_tokens = common.NumTokensFromString(chatRequest.Prompt, chatRequest.Model)
} else {
prompt_tokens = common.NumTokensFromMessages(chatRequest.Messages, chatRequest.Model)
}
completion_tokens = common.NumTokensFromString(reqContent, chatRequest.Model)
log.Printf("返回内容:%v", reqContent)
total_tokens = prompt_tokens + completion_tokens
log.Printf("prompt_tokens: %v completion_tokens: %v total_tokens: %v", prompt_tokens, completion_tokens, total_tokens)
//消费余额
service.Consumption(key, chatRequest.Model, prompt_tokens, completion_tokens, total_tokens, chatRequestId)
return nil
}
} else {
// 非流式回应,处理
proxy.ModifyResponse = func(resp *http.Response) error {
resp.Header.Set("Openai-Organization", "api2gpt")
var chatResponse model.ChatResponse
body, err := ioutil.ReadAll(resp.Body)
if err != nil {
log.Printf("非流式回应,处理 err: %v", err)
return err
}
json.Unmarshal(body, &chatResponse)
prompt_tokens = chatResponse.Usage.PromptTokens
completion_tokens = chatResponse.Usage.CompletionTokens
total_tokens = chatResponse.Usage.TotalTokens
resp.Body = ioutil.NopCloser(bytes.NewReader(body))
log.Printf("prompt_tokens: %v completion_tokens: %v total_tokens: %v", prompt_tokens, completion_tokens, total_tokens)
//消费余额
service.Consumption(key, chatRequest.Model, prompt_tokens, completion_tokens, total_tokens, chatResponse.Id)
return nil
}
}
proxy.ServeHTTP(c.Writer, c.Request)
}
// 针对接口预检OPTIONS的处理
func HandleOptions(c *gin.Context) {
// BUGFIX: fix options request, see https://github.com/diemus/azure-openai-proxy/issues/1
c.Header("Access-Control-Allow-Origin", "*")
c.Header("Access-Control-Allow-Methods", "POST, GET, OPTIONS, PUT, DELETE")
c.Header("Access-Control-Allow-Headers", "Content-Type, Authorization")
c.Status(200)
return
}

@ -3,681 +3,28 @@ package main
import (
"api2gpt-mid/common"
"api2gpt-mid/middleware"
"bufio"
"bytes"
"api2gpt-mid/model"
"api2gpt-mid/router"
"encoding/json"
"fmt"
"io"
"io/ioutil"
"log"
"math/rand"
"net/http"
"net/http/httputil"
"net/url"
"github.com/gin-gonic/gin"
"os"
"strconv"
"strings"
"time"
"github.com/gin-gonic/gin"
"github.com/pkoukk/tiktoken-go"
)
type Message struct {
Role string `json:"role,omitempty"`
Name string `json:"name,omitempty"`
Content string `json:"content,omitempty"`
}
type ChatRequest struct {
Stream bool `json:"stream,omitempty"`
Model string `json:"model,omitempty"`
PresencePenalty float64 `json:"presence_penalty,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
Messages []Message `json:"messages,omitempty"`
Prompt string `json:"prompt,omitempty"`
Input interface{} `json:"input,omitempty"`
Instruction string `json:"instruction,omitempty"`
N int `json:"n,omitempty"`
Functions interface{} `json:"functions,omitempty"`
}
type ImagesRequest struct {
Prompt string `json:"prompt,omitempty"`
N int `json:"n,omitempty"`
Size string `json:"size,omitempty"`
}
type ChatResponse struct {
Id string `json:"id,omitempty"`
Object string `json:"object,omitempty"`
Created int64 `json:"created,omitempty"`
Model string `json:"model,omitempty"`
Usage Usage `json:"usage,omitempty"`
Choices []Choice `json:"choices,omitempty"`
}
type ImagesResponse struct {
Created int64 `json:"created,omitempty"`
Data []ImagesData `json:"data,omitempty"`
}
type ImagesData struct {
Url string `json:"url,omitempty"`
}
type Usage struct {
PromptTokens int `json:"prompt_tokens,omitempty"`
CompletionTokens int `json:"completion_tokens,omitempty"`
TotalTokens int `json:"total_tokens,omitempty"`
}
type Choice struct {
Message Message `json:"message,omitempty"`
Delta Delta `json:"delta,omitempty"`
FinishReason string `json:"finish_reason,omitempty"`
Index int `json:"index,omitempty"`
Text string `json:"text,omitempty"`
}
type Delta struct {
Content string `json:"content,omitempty"`
}
type CreditSummary struct {
Object string `json:"object"`
TotalGranted float64 `json:"total_granted"`
TotalUsed float64 `json:"total_used"`
TotalRemaining float64 `json:"total_remaining"`
}
type ListModelResponse struct {
Object string `json:"object"`
Data []Model `json:"data"`
}
type Model struct {
ID string `json:"id"`
Object string `json:"object"`
Created int `json:"created"`
OwnedBy string `json:"owned_by"`
Permission []ModelPermission `json:"permission"`
Root string `json:"root"`
Parent any `json:"parent"`
}
type ModelPermission struct {
ID string `json:"id"`
Object string `json:"object"`
Created int `json:"created"`
AllowCreateEngine bool `json:"allow_create_engine"`
AllowSampling bool `json:"allow_sampling"`
AllowLogprobs bool `json:"allow_logprobs"`
AllowSearchIndices bool `json:"allow_search_indices"`
AllowView bool `json:"allow_view"`
AllowFineTuning bool `json:"allow_fine_tuning"`
Organization string `json:"organization"`
Group any `json:"group"`
IsBlocking bool `json:"is_blocking"`
}
type UserInfo struct {
UID string `json:"uid"`
SID string `json:"sid"`
}
type ServerInfo struct {
ServerAddress string `json:"server_address"`
AvailableKey string `json:"available_key"`
}
type ModelInfo struct {
ModelName string `json:"model_name"`
ModelPrice float64 `json:"model_price"`
ModelPrice2 float64 `json:"model_price2"`
ModelPrepayment int `json:"model_prepayment"`
ServerId int `json:"server_id"`
}
// 计算Messages中的token数量
func numTokensFromMessages(messages []Message, model string) int {
if strings.Contains(model, "gpt-3.5") {
model = "gpt-3.5-turbo"
}
if strings.Contains(model, "gpt-4") {
model = "gpt-4"
}
tkm, err := tiktoken.EncodingForModel(model)
if err != nil {
err = fmt.Errorf("getEncoding: %v", err)
return 0
}
numTokens := 0
for _, message := range messages {
numTokens += len(tkm.Encode(message.Content, nil, nil))
numTokens += 6
}
numTokens += 3
return numTokens
}
// 计算String中的token数量
func numTokensFromString(msg string, model string) int {
if strings.Contains(model, "gpt-3.5") {
model = "gpt-3.5-turbo"
}
if strings.Contains(model, "gpt-4") {
model = "gpt-4"
}
tkm, err := tiktoken.EncodingForModel(model)
if err != nil {
err = fmt.Errorf("getEncoding: %v", err)
return 0
}
if model == "text-davinci-003" {
return len(tkm.Encode(msg, nil, nil)) + 1
} else {
return len(tkm.Encode(msg, nil, nil)) + 9
}
}
func images(c *gin.Context) {
var imagesRequest ImagesRequest
var model = "images-generations"
if err := c.ShouldBindJSON(&imagesRequest); err != nil {
c.AbortWithStatusJSON(400, gin.H{"error": err.Error()})
return
}
auth := c.Request.Header.Get("Authorization")
key := auth[7:]
serverInfo, err := checkBlanceForImages(key, model, imagesRequest.N)
if err != nil {
c.AbortWithStatusJSON(403, gin.H{"error": err.Error()})
log.Printf("请求出错 KEY: %v Model: %v ERROR: %v", key, model, err)
return
}
log.Printf("请求的KEY: %v Model: %v", key, model)
remote, err := url.Parse(serverInfo.ServerAddress)
if err != nil {
c.AbortWithStatusJSON(400, gin.H{"error": err.Error()})
return
}
proxy := httputil.NewSingleHostReverseProxy(remote)
newReqBody, err := json.Marshal(imagesRequest)
if err != nil {
log.Printf("http request err: %v", err)
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
proxy.Director = func(req *http.Request) {
req.Header = c.Request.Header
req.Host = remote.Host
req.URL.Scheme = remote.Scheme
req.URL.Host = remote.Host
req.URL.Path = c.Request.URL.Path
req.ContentLength = int64(len(newReqBody))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept-Encoding", "")
serverKey := serverInfo.AvailableKey
keyList := strings.Split(serverInfo.AvailableKey, ",")
if len(keyList) > 1 {
// 随机数种子
rand.Seed(time.Now().UnixNano())
// 从数组中随机选择一个元素
serverKey = keyList[rand.Intn(len(keyList))]
}
req.Header.Set("Authorization", "Bearer "+serverKey)
req.Body = ioutil.NopCloser(bytes.NewReader(newReqBody))
}
sss, err := json.Marshal(imagesRequest)
log.Printf("开始处理返回逻辑 %d", string(sss))
proxy.ModifyResponse = func(resp *http.Response) error {
if resp.StatusCode != http.StatusOK {
//退回预扣除的余额
checkBlanceReturnForImages(key, model, imagesRequest.N)
return nil
}
resp.Header.Set("Openai-Organization", "api2gpt")
var imagesResponse ImagesResponse
body, err := ioutil.ReadAll(resp.Body)
if err != nil {
log.Printf("读取返回数据出错: %v", err)
return err
}
json.Unmarshal(body, &imagesResponse)
resp.Body = ioutil.NopCloser(bytes.NewReader(body))
log.Printf("image size: %v", len(imagesResponse.Data))
timestamp := time.Now().Unix()
timestampID := "img-" + fmt.Sprintf("%d", timestamp)
//消费余额
consumptionForImages(key, model, imagesRequest.N, len(imagesResponse.Data), timestampID)
return nil
}
proxy.ServeHTTP(c.Writer, c.Request)
}
func edit(c *gin.Context) {
var prompt_tokens int
var complate_tokens int
var total_tokens int
var chatRequest ChatRequest
if err := c.ShouldBindJSON(&chatRequest); err != nil {
c.AbortWithStatusJSON(400, gin.H{"error": err.Error()})
return
}
auth := c.Request.Header.Get("Authorization")
key := auth[7:]
//根据KEY调用用户余额接口判断是否有足够的余额 后期可考虑判断max_tokens参数来调整
serverInfo, err := checkBlance(key, chatRequest.Model)
if err != nil {
c.AbortWithStatusJSON(403, gin.H{"error": err.Error()})
log.Printf("请求出错 KEY: %v Model: %v ERROR: %v", key, chatRequest.Model, err)
return
}
log.Printf("请求的KEY: %v Model: %v", key, chatRequest.Model)
remote, err := url.Parse(serverInfo.ServerAddress)
if err != nil {
c.AbortWithStatusJSON(400, gin.H{"error": err.Error()})
return
}
proxy := httputil.NewSingleHostReverseProxy(remote)
newReqBody, err := json.Marshal(chatRequest)
if err != nil {
log.Printf("http request err: %v", err)
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
proxy.Director = func(req *http.Request) {
req.Header = c.Request.Header
req.Host = remote.Host
req.URL.Scheme = remote.Scheme
req.URL.Host = remote.Host
req.URL.Path = c.Request.URL.Path
req.ContentLength = int64(len(newReqBody))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept-Encoding", "")
serverKey := serverInfo.AvailableKey
keyList := strings.Split(serverInfo.AvailableKey, ",")
if len(keyList) > 1 {
// 随机数种子
rand.Seed(time.Now().UnixNano())
// 从数组中随机选择一个元素
serverKey = keyList[rand.Intn(len(keyList))]
}
req.Header.Set("Authorization", "Bearer "+serverKey)
req.Body = ioutil.NopCloser(bytes.NewReader(newReqBody))
}
sss, err := json.Marshal(chatRequest)
log.Printf("开始处理返回逻辑 %d", string(sss))
proxy.ModifyResponse = func(resp *http.Response) error {
resp.Header.Set("Openai-Organization", "api2gpt")
var chatResponse ChatResponse
body, err := ioutil.ReadAll(resp.Body)
if err != nil {
log.Printf("读取返回数据出错: %v", err)
return err
}
json.Unmarshal(body, &chatResponse)
prompt_tokens = chatResponse.Usage.PromptTokens
complate_tokens = chatResponse.Usage.CompletionTokens
total_tokens = chatResponse.Usage.TotalTokens
resp.Body = ioutil.NopCloser(bytes.NewReader(body))
log.Printf("prompt_tokens: %v complate_tokens: %v total_tokens: %v", prompt_tokens, complate_tokens, total_tokens)
timestamp := time.Now().Unix()
timestampID := "edit-" + fmt.Sprintf("%d", timestamp)
//消费余额
consumption(key, chatRequest.Model, prompt_tokens, 0, total_tokens, timestampID)
return nil
}
proxy.ServeHTTP(c.Writer, c.Request)
}
func embeddings(c *gin.Context) {
var prompt_tokens int
var total_tokens int
var chatRequest ChatRequest
if err := c.ShouldBindJSON(&chatRequest); err != nil {
c.AbortWithStatusJSON(400, gin.H{"error": err.Error()})
return
}
auth := c.Request.Header.Get("Authorization")
//key := strings.Trim(auth, "Bearer ")
key := auth[7:]
//根据KEY调用用户余额接口判断是否有足够的余额 后期可考虑判断max_tokens参数来调整
serverInfo, err := checkBlance(key, chatRequest.Model)
if err != nil {
c.AbortWithStatusJSON(403, gin.H{"error": err.Error()})
log.Printf("请求出错 KEY: %v Model: %v ERROR: %v", key, chatRequest.Model, err)
return
}
log.Printf("请求的KEY: %v Model: %v", key, chatRequest.Model)
remote, err := url.Parse(serverInfo.ServerAddress)
if err != nil {
c.AbortWithStatusJSON(400, gin.H{"error": err.Error()})
return
}
proxy := httputil.NewSingleHostReverseProxy(remote)
newReqBody, err := json.Marshal(chatRequest)
if err != nil {
log.Printf("http request err: %v", err)
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
proxy.Director = func(req *http.Request) {
req.Header = c.Request.Header
req.Host = remote.Host
req.URL.Scheme = remote.Scheme
req.URL.Host = remote.Host
req.URL.Path = c.Request.URL.Path
req.ContentLength = int64(len(newReqBody))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept-Encoding", "")
serverKey := serverInfo.AvailableKey
keyList := strings.Split(serverInfo.AvailableKey, ",")
if len(keyList) > 1 {
// 随机数种子
rand.Seed(time.Now().UnixNano())
// 从数组中随机选择一个元素
serverKey = keyList[rand.Intn(len(keyList))]
}
req.Header.Set("Authorization", "Bearer "+serverKey)
req.Body = ioutil.NopCloser(bytes.NewReader(newReqBody))
}
sss, err := json.Marshal(chatRequest)
log.Printf("开始处理返回逻辑 %d", string(sss))
proxy.ModifyResponse = func(resp *http.Response) error {
resp.Header.Set("Openai-Organization", "api2gpt")
var chatResponse ChatResponse
body, err := ioutil.ReadAll(resp.Body)
if err != nil {
log.Printf("读取返回数据出错: %v", err)
return err
}
json.Unmarshal(body, &chatResponse)
prompt_tokens = chatResponse.Usage.PromptTokens
total_tokens = chatResponse.Usage.TotalTokens
resp.Body = ioutil.NopCloser(bytes.NewReader(body))
log.Printf("prompt_tokens: %v total_tokens: %v", prompt_tokens, total_tokens)
timestamp := time.Now().Unix()
timestampID := "emb-" + fmt.Sprintf("%d", timestamp)
//消费余额
consumption(key, chatRequest.Model, prompt_tokens, 0, total_tokens, timestampID)
return nil
}
proxy.ServeHTTP(c.Writer, c.Request)
}
func completions(c *gin.Context) {
var prompt_tokens int
var completion_tokens int
var total_tokens int
var chatRequest ChatRequest
if err := c.ShouldBindJSON(&chatRequest); err != nil {
c.AbortWithStatusJSON(400, gin.H{"error": err.Error()})
return
}
auth := c.Request.Header.Get("Authorization")
key := auth[7:]
//根据KEY调用用户余额接口判断是否有足够的余额 后期可考虑判断max_tokens参数来调整
serverInfo, err := checkBlance(key, chatRequest.Model)
if err != nil {
c.AbortWithStatusJSON(403, gin.H{"error": err.Error()})
log.Printf("请求出错 KEY: %v Model: %v ERROR: %v", key, chatRequest.Model, err)
return
}
log.Printf("请求的KEY: %v Model: %v", key, chatRequest.Model)
remote, err := url.Parse(serverInfo.ServerAddress)
if err != nil {
c.AbortWithStatusJSON(400, gin.H{"error": err.Error()})
return
}
proxy := httputil.NewSingleHostReverseProxy(remote)
newReqBody, err := json.Marshal(chatRequest)
if err != nil {
log.Printf("http request err: %v", err)
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
proxy.Director = func(req *http.Request) {
req.Header = c.Request.Header
req.Host = remote.Host
req.URL.Scheme = remote.Scheme
req.URL.Host = remote.Host
req.URL.Path = c.Request.URL.Path
req.ContentLength = int64(len(newReqBody))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept-Encoding", "")
serverKey := serverInfo.AvailableKey
keyList := strings.Split(serverInfo.AvailableKey, ",")
if len(keyList) > 1 {
// 随机数种子
rand.Seed(time.Now().UnixNano())
// 从数组中随机选择一个元素
serverKey = keyList[rand.Intn(len(keyList))]
}
req.Header.Set("Authorization", "Bearer "+serverKey)
req.Body = ioutil.NopCloser(bytes.NewReader(newReqBody))
}
sss, err := json.Marshal(chatRequest)
if err != nil {
log.Printf("chatRequest 转化出错 %v", err)
}
log.Printf("开始处理返回逻辑: %v", string(sss))
if chatRequest.Stream {
// 流式回应,处理
proxy.ModifyResponse = func(resp *http.Response) error {
log.Printf("流式回应 http status code: %v", resp.StatusCode)
if resp.StatusCode != http.StatusOK {
//退回预扣除的余额
checkBlanceReturn(key, chatRequest.Model)
return nil
}
chatRequestId := ""
reqContent := ""
reader := bufio.NewReader(resp.Body)
headers := resp.Header
for k, v := range headers {
c.Writer.Header().Set(k, v[0])
}
c.Writer.Header().Set("Openai-Organization", "api2gpt")
for {
chunk, err := reader.ReadBytes('\n')
if err != nil {
if err == io.EOF {
break
}
log.Printf("流式回应,处理 err %v:", err.Error())
break
//return err
}
var chatResponse ChatResponse
//去除回应中的data:前缀
var trimStr = strings.Trim(string(chunk), "data: ")
if trimStr != "\n" {
json.Unmarshal([]byte(trimStr), &chatResponse)
if chatResponse.Choices != nil {
if chatResponse.Choices[0].Text != "" {
reqContent += chatResponse.Choices[0].Text
} else {
reqContent += chatResponse.Choices[0].Delta.Content
}
chatRequestId = chatResponse.Id
}
// 写回数据
_, err = c.Writer.Write([]byte(string(chunk) + "\n"))
if err != nil {
log.Printf("写回数据 err: %v", err.Error())
return err
}
c.Writer.(http.Flusher).Flush()
}
}
if chatRequest.Model == "text-davinci-003" {
prompt_tokens = numTokensFromString(chatRequest.Prompt, chatRequest.Model)
} else {
prompt_tokens = numTokensFromMessages(chatRequest.Messages, chatRequest.Model)
}
completion_tokens = numTokensFromString(reqContent, chatRequest.Model)
log.Printf("返回内容:%v", reqContent)
total_tokens = prompt_tokens + completion_tokens
log.Printf("prompt_tokens: %v completion_tokens: %v total_tokens: %v", prompt_tokens, completion_tokens, total_tokens)
//消费余额
consumption(key, chatRequest.Model, prompt_tokens, completion_tokens, total_tokens, chatRequestId)
return nil
}
} else {
// 非流式回应,处理
proxy.ModifyResponse = func(resp *http.Response) error {
resp.Header.Set("Openai-Organization", "api2gpt")
var chatResponse ChatResponse
body, err := ioutil.ReadAll(resp.Body)
if err != nil {
log.Printf("非流式回应,处理 err: %v", err)
return err
}
json.Unmarshal(body, &chatResponse)
prompt_tokens = chatResponse.Usage.PromptTokens
completion_tokens = chatResponse.Usage.CompletionTokens
total_tokens = chatResponse.Usage.TotalTokens
resp.Body = ioutil.NopCloser(bytes.NewReader(body))
log.Printf("prompt_tokens: %v completion_tokens: %v total_tokens: %v", prompt_tokens, completion_tokens, total_tokens)
//消费余额
consumption(key, chatRequest.Model, prompt_tokens, completion_tokens, total_tokens, chatResponse.Id)
return nil
}
}
proxy.ServeHTTP(c.Writer, c.Request)
}
// model查询列表
func handleGetModels(c *gin.Context) {
// BUGFIX: fix options request, see https://github.com/diemus/azure-openai-proxy/issues/3
//models := []string{"gpt-4", "gpt-4-0314", "gpt-4-32k", "gpt-4-32k-0314", "gpt-3.5-turbo", "gpt-3.5-turbo-0301", "text-davinci-003", "text-embedding-ada-002"}
models := []string{"gpt-4", "gpt-4-0314", "gpt-4-0613", "gpt-3.5-turbo", "gpt-3.5-turbo-0301", "gpt-3.5-turbo-0613", "gpt-3.5-turbo-16k", "gpt-3.5-turbo-16k-0613", "text-davinci-003", "text-embedding-ada-002", "text-davinci-edit-001", "code-davinci-edit-001", "images-generations"}
result := ListModelResponse{
Object: "list",
}
for _, model := range models {
result.Data = append(result.Data, Model{
ID: model,
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: []ModelPermission{
{
ID: "",
Object: "model",
Created: 1679602087,
AllowCreateEngine: true,
AllowSampling: true,
AllowLogprobs: true,
AllowSearchIndices: true,
AllowView: true,
AllowFineTuning: true,
Organization: "*",
Group: nil,
IsBlocking: false,
},
},
Root: model,
Parent: nil,
})
}
c.JSON(200, result)
}
// 针对接口预检OPTIONS的处理
func handleOptions(c *gin.Context) {
// BUGFIX: fix options request, see https://github.com/diemus/azure-openai-proxy/issues/1
c.Header("Access-Control-Allow-Origin", "*")
c.Header("Access-Control-Allow-Methods", "POST, GET, OPTIONS, PUT, DELETE")
c.Header("Access-Control-Allow-Headers", "Content-Type, Authorization")
c.Status(200)
return
}
// 余额查询
func balance(c *gin.Context) {
auth := c.Request.Header.Get("Authorization")
key := strings.Trim(auth, "Bearer ")
balance, err := queryBlance(key)
if err != nil {
c.JSON(400, gin.H{"error": err.Error()})
}
var creditSummary CreditSummary
creditSummary.Object = "credit_grant"
creditSummary.TotalGranted = 999999
creditSummary.TotalUsed = 999999 - balance
creditSummary.TotalRemaining = balance
c.JSON(200, creditSummary)
}
// key 检测中间件
func checkKeyMid() gin.HandlerFunc {
return func(c *gin.Context) {
auth := c.Request.Header.Get("Authorization")
log.Printf("auth: %v", auth)
if auth == "" {
c.AbortWithStatusJSON(401, gin.H{"code": 40001})
} else {
//key := strings.Trim(auth, "Bearer ")
key := auth[7:]
log.Printf("key: %v", key)
msg, err := checkKeyAndTimeCount(key)
if err != nil {
log.Printf("checkKeyMid err: %v", err)
c.AbortWithStatusJSON(msg, gin.H{"code": err.Error()})
}
}
log.Printf("auth check end")
// 执行函数
c.Next()
}
}
func test_redis() {
//添加reids测试数据
//var serverInfo ServerInfo = ServerInfo{
// ServerAddress: "https://gptp.any-door.cn",
// AvailableKey: "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",
//}
var serverInfo model.ServerInfo = model.ServerInfo{
ServerAddress: "https://gptp.any-door.cn",
AvailableKey: "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",
}
// var serverInfo2 ServerInfo = ServerInfo{
// ServerAddress: "https://azure.any-door.cn",
// AvailableKey: "6c4d2c65970b40e482e7cd27adb0d119",
// }
//serverInfoStr, _ := json.Marshal(&serverInfo)
serverInfoStr, _ := json.Marshal(&serverInfo)
// serverInfoStr2, _ := json.Marshal(&serverInfo2)
//Redis.Set(context.Background(), "server:1", serverInfoStr, 0)
common.RedisSet("server:0", string(serverInfoStr), 0)
common.RedisSet("server:1", string(serverInfoStr), 0)
// Redis.Set(context.Background(), "server:2", serverInfoStr2, 0)
// var modelInfo ModelInfo = ModelInfo{
@ -725,7 +72,7 @@ func test_redis() {
// modelInfoStr2, _ := json.Marshal(&modelInfo2)
// Redis.Set(context.Background(), "model:images-generations", modelInfoStr2, 0)
var userInfo UserInfo = UserInfo{
var userInfo model.UserInfo = model.UserInfo{
UID: "0",
SID: "1",
}
@ -763,25 +110,10 @@ func main() {
server := gin.Default()
server.Use(middleware.CORS())
server.GET("/dashboard/billing/credit_grants", checkKeyMid(), balance)
server.GET("/v1/models", handleGetModels)
server.OPTIONS("/v1/*path", handleOptions)
server.POST("/v1/chat/completions", checkKeyMid(), completions)
server.POST("/v1/completions", checkKeyMid(), completions)
server.POST("/v1/embeddings", checkKeyMid(), embeddings)
server.POST("/v1/edits", checkKeyMid(), edit)
server.POST("/v1/images/generations", checkKeyMid(), images)
// 定义一个GET请求测试接口
server.GET("/ping", func(c *gin.Context) {
c.JSON(200, gin.H{
"message": "pong from api2gpt",
})
})
router.SetRouter(server)
//添加测试数据
//test_redis()
test_redis()
var port = os.Getenv("PORT")
if port == "" {

@ -0,0 +1,29 @@
package middleware
import (
"api2gpt-mid/service"
"github.com/gin-gonic/gin"
"log"
)
func TokenAuth() func(c *gin.Context) {
return func(c *gin.Context) {
auth := c.Request.Header.Get("Authorization")
log.Printf("auth: %v", auth)
if auth == "" {
c.AbortWithStatusJSON(401, gin.H{"code": 40001})
} else {
//key := strings.Trim(auth, "Bearer ")
key := auth[7:]
log.Printf("key: %v", key)
msg, err := service.CheckKeyAndTimeCount(key)
if err != nil {
log.Printf("checkKeyMid err: %v", err)
c.AbortWithStatusJSON(msg, gin.H{"code": err.Error()})
}
}
log.Printf("auth check end")
// 执行函数
c.Next()
}
}

@ -0,0 +1,10 @@
package model
type Consumption struct {
SecretKey string `json:"secretKey"`
Model string `json:"model"`
MsgId string `json:"msgId"`
PromptTokens int `json:"promptTokens"`
CompletionTokens int `json:"completionTokens"`
TotalTokens int `json:"totalTokens"`
}

@ -0,0 +1,9 @@
package model
type ModelInfo struct {
ModelName string `json:"model_name"`
ModelPrice float64 `json:"model_price"`
ModelPrice2 float64 `json:"model_price2"`
ModelPrepayment int `json:"model_prepayment"`
ServerId int `json:"server_id"`
}

@ -0,0 +1,98 @@
package model
type Message struct {
Role string `json:"role,omitempty"`
Name string `json:"name,omitempty"`
Content string `json:"content,omitempty"`
}
type ChatRequest struct {
Stream bool `json:"stream,omitempty"`
Model string `json:"model,omitempty"`
PresencePenalty float64 `json:"presence_penalty,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
Messages []Message `json:"messages,omitempty"`
Prompt string `json:"prompt,omitempty"`
Input interface{} `json:"input,omitempty"`
Instruction string `json:"instruction,omitempty"`
N int `json:"n,omitempty"`
Functions interface{} `json:"functions,omitempty"`
}
type ImagesRequest struct {
Prompt string `json:"prompt,omitempty"`
N int `json:"n,omitempty"`
Size string `json:"size,omitempty"`
}
type ChatResponse struct {
Id string `json:"id,omitempty"`
Object string `json:"object,omitempty"`
Created int64 `json:"created,omitempty"`
Model string `json:"model,omitempty"`
Usage Usage `json:"usage,omitempty"`
Choices []Choice `json:"choices,omitempty"`
}
type ImagesResponse struct {
Created int64 `json:"created,omitempty"`
Data []ImagesData `json:"data,omitempty"`
}
type ImagesData struct {
Url string `json:"url,omitempty"`
}
type Usage struct {
PromptTokens int `json:"prompt_tokens,omitempty"`
CompletionTokens int `json:"completion_tokens,omitempty"`
TotalTokens int `json:"total_tokens,omitempty"`
}
type Choice struct {
Message Message `json:"message,omitempty"`
Delta Delta `json:"delta,omitempty"`
FinishReason string `json:"finish_reason,omitempty"`
Index int `json:"index,omitempty"`
Text string `json:"text,omitempty"`
}
type Delta struct {
Content string `json:"content,omitempty"`
}
type CreditSummary struct {
Object string `json:"object"`
TotalGranted float64 `json:"total_granted"`
TotalUsed float64 `json:"total_used"`
TotalRemaining float64 `json:"total_remaining"`
}
type ListModelResponse struct {
Object string `json:"object"`
Data []Model `json:"data"`
}
type Model struct {
ID string `json:"id"`
Object string `json:"object"`
Created int `json:"created"`
OwnedBy string `json:"owned_by"`
Permission []ModelPermission `json:"permission"`
Root string `json:"root"`
Parent any `json:"parent"`
}
type ModelPermission struct {
ID string `json:"id"`
Object string `json:"object"`
Created int `json:"created"`
AllowCreateEngine bool `json:"allow_create_engine"`
AllowSampling bool `json:"allow_sampling"`
AllowLogprobs bool `json:"allow_logprobs"`
AllowSearchIndices bool `json:"allow_search_indices"`
AllowView bool `json:"allow_view"`
AllowFineTuning bool `json:"allow_fine_tuning"`
Organization string `json:"organization"`
Group any `json:"group"`
IsBlocking bool `json:"is_blocking"`
}

@ -0,0 +1,6 @@
package model
type ServerInfo struct {
ServerAddress string `json:"server_address"`
AvailableKey string `json:"available_key"`
}

@ -0,0 +1,6 @@
package model
type UserInfo struct {
UID string `json:"uid"`
SID string `json:"sid"`
}

@ -1,7 +1,36 @@
package router
import "github.com/gin-gonic/gin"
import (
"api2gpt-mid/controller"
"api2gpt-mid/middleware"
"github.com/gin-gonic/gin"
)
func SetApiRouter(router *gin.Engine) {
modelsRouter := router.Group("/v1/models")
modelsRouter.Use(middleware.TokenAuth())
{
modelsRouter.GET("", controller.ListModels)
modelsRouter.GET("/:model", controller.RetrieveModel)
}
relayV1Router := router.Group("/v1")
relayV1Router.Use(middleware.TokenAuth())
{
relayV1Router.POST("/completions", controller.Completions)
relayV1Router.POST("/chat/completions", controller.Completions)
relayV1Router.POST("/embeddings", controller.Embeddings)
relayV1Router.POST("/edits", controller.Edit)
relayV1Router.POST("/images/generations", controller.Images)
}
dashboardRouter := router.Group("/dashboard")
dashboardRouter.Use(middleware.TokenAuth())
{
dashboardRouter.GET("/dashboard/billing/credit_grants", controller.Balance)
}
router.OPTIONS("/v1/*path", controller.HandleOptions)
router.GET("/ping", func(c *gin.Context) {
c.JSON(200, gin.H{
"message": "pong from api2gpt",
})
})
}

@ -1,7 +1,9 @@
package main
package service
import (
"api2gpt-mid/api"
"api2gpt-mid/common"
"api2gpt-mid/model"
"encoding/json"
"errors"
"log"
@ -10,7 +12,7 @@ import (
)
// 检测key是否存在是否超出每分钟请求次数
func checkKeyAndTimeCount(key string) (int, error) {
func CheckKeyAndTimeCount(key string) (int, error) {
var timeOut = 60 * time.Second
var timeCount = 30
userInfoStr, err := common.RedisGet("user:" + key)
@ -19,7 +21,7 @@ func checkKeyAndTimeCount(key string) (int, error) {
//用户不存在
return 401, errors.New("40003")
}
var userInfo UserInfo
var userInfo model.UserInfo
err = json.Unmarshal([]byte(userInfoStr), &userInfo)
if err != nil {
//用户状态异常
@ -46,9 +48,9 @@ func checkKeyAndTimeCount(key string) (int, error) {
return 200, nil
}
func queryBlance(key string) (float64, error) {
func QueryBlance(key string) (float64, error) {
userInfoStr, err := common.RedisGet("user:" + key)
var userInfo UserInfo
var userInfo model.UserInfo
err = json.Unmarshal([]byte(userInfoStr), &userInfo)
balance, err := common.RedisIncrByFloat("user:"+userInfo.UID+":balance", 0)
if err != nil {
@ -58,15 +60,15 @@ func queryBlance(key string) (float64, error) {
}
// 余额查询
func checkBlance(key string, model string) (ServerInfo, error) {
var serverInfo ServerInfo
func CheckBlance(key string, modelStr string) (model.ServerInfo, error) {
var serverInfo model.ServerInfo
//获取模型价格
modelPriceStr, err := common.RedisGet("model:" + model)
modelPriceStr, err := common.RedisGet("model:" + modelStr)
if err != nil {
return serverInfo, errors.New("模型信息不存在")
}
var modelInfo ModelInfo
var modelInfo model.ModelInfo
err = json.Unmarshal([]byte(modelPriceStr), &modelInfo)
if err != nil {
return serverInfo, errors.New("模型信息解析失败")
@ -74,7 +76,7 @@ func checkBlance(key string, model string) (ServerInfo, error) {
//获取用户信息
userInfoStr, err := common.RedisGet("user:" + key)
var userInfo UserInfo
var userInfo model.UserInfo
err = json.Unmarshal([]byte(userInfoStr), &userInfo)
//获取服务器信息
serverInfoStr, err := common.RedisGet("server:" + strconv.Itoa(modelInfo.ServerId))
@ -101,15 +103,15 @@ func checkBlance(key string, model string) (ServerInfo, error) {
}
// 余额查询 for images
func checkBlanceForImages(key string, model string, n int) (ServerInfo, error) {
var serverInfo ServerInfo
func CheckBlanceForImages(key string, modelStr string, n int) (model.ServerInfo, error) {
var serverInfo model.ServerInfo
//获取模型价格
modelPriceStr, err := common.RedisGet("model:" + model)
modelPriceStr, err := common.RedisGet("model:" + modelStr)
if err != nil {
return serverInfo, errors.New("模型信息不存在")
}
var modelInfo ModelInfo
var modelInfo model.ModelInfo
err = json.Unmarshal([]byte(modelPriceStr), &modelInfo)
if err != nil {
return serverInfo, errors.New("模型信息解析失败")
@ -117,7 +119,7 @@ func checkBlanceForImages(key string, model string, n int) (ServerInfo, error) {
//获取用户信息
userInfoStr, err := common.RedisGet("user:" + key)
var userInfo UserInfo
var userInfo model.UserInfo
err = json.Unmarshal([]byte(userInfoStr), &userInfo)
//获取服务器信息
serverInfoStr, err := common.RedisGet("server:" + strconv.Itoa(modelInfo.ServerId))
@ -144,18 +146,18 @@ func checkBlanceForImages(key string, model string, n int) (ServerInfo, error) {
}
// 预扣返还
func checkBlanceReturn(key string, model string) error {
func CheckBlanceReturn(key string, modelStr string) error {
//获取用户信息
userInfoStr, err := common.RedisGet("user:" + key)
var userInfo UserInfo
var userInfo model.UserInfo
err = json.Unmarshal([]byte(userInfoStr), &userInfo)
//获取模型价格
modelPriceStr, err := common.RedisGet("model:" + model)
modelPriceStr, err := common.RedisGet("model:" + modelStr)
if err != nil {
return errors.New("模型信息不存在")
}
var modelInfo ModelInfo
var modelInfo model.ModelInfo
err = json.Unmarshal([]byte(modelPriceStr), &modelInfo)
if err != nil {
return errors.New("模型信息解析失败")
@ -166,17 +168,17 @@ func checkBlanceReturn(key string, model string) error {
}
// 预扣返还 for images
func checkBlanceReturnForImages(key string, model string, n int) error {
func CheckBlanceReturnForImages(key string, modelStr string, n int) error {
//获取用户信息
userInfoStr, err := common.RedisGet("user:" + key)
var userInfo UserInfo
var userInfo model.UserInfo
err = json.Unmarshal([]byte(userInfoStr), &userInfo)
//获取模型价格
modelPriceStr, err := common.RedisGet("model:" + model)
modelPriceStr, err := common.RedisGet("model:" + modelStr)
if err != nil {
return errors.New("模型信息不存在")
}
var modelInfo ModelInfo
var modelInfo model.ModelInfo
err = json.Unmarshal([]byte(modelPriceStr), &modelInfo)
if err != nil {
return errors.New("模型信息解析失败")
@ -187,23 +189,23 @@ func checkBlanceReturnForImages(key string, model string, n int) error {
}
// 余额消费
func consumption(key string, model string, prompt_tokens int, completion_tokens int, total_tokens int, msg_id string) (string, error) {
func Consumption(key string, modelStr string, prompt_tokens int, completion_tokens int, total_tokens int, msg_id string) (string, error) {
//获取用户信息
userInfoStr, err := common.RedisGet("user:" + key)
if err != nil {
return "", errors.New("用户信息不存在")
}
var userInfo UserInfo
var userInfo model.UserInfo
err = json.Unmarshal([]byte(userInfoStr), &userInfo)
if err != nil {
return "", errors.New("用户信息解析失败")
}
//获取模型价格
modelPriceStr, err := common.RedisGet("model:" + model)
modelPriceStr, err := common.RedisGet("model:" + modelStr)
if err != nil {
return "", errors.New("模型信息不存在")
}
var modelInfo ModelInfo
var modelInfo model.ModelInfo
err = json.Unmarshal([]byte(modelPriceStr), &modelInfo)
if err != nil {
return "", errors.New("模型信息解析失败")
@ -211,7 +213,7 @@ func consumption(key string, model string, prompt_tokens int, completion_tokens
balance, err := common.RedisIncrByFloat("user:"+userInfo.UID+":balance", float64(modelInfo.ModelPrepayment)*modelInfo.ModelPrice-(float64(prompt_tokens)*modelInfo.ModelPrice+float64(completion_tokens)*modelInfo.ModelPrice2))
// 余额消费日志请求
result, err := balanceConsumption(key, model, prompt_tokens, completion_tokens, total_tokens, msg_id)
result, err := api.BalanceConsumption(key, modelStr, prompt_tokens, completion_tokens, total_tokens, msg_id)
log.Printf("用户余额:%f 扣费KEY: %s 扣费token数: %d 扣费:%f 扣费日志发送结果 %s", balance, key, total_tokens, float64(modelInfo.ModelPrepayment)*modelInfo.ModelPrice-(float64(prompt_tokens)*modelInfo.ModelPrice+float64(completion_tokens)*modelInfo.ModelPrice2), result)
if err != nil {
log.Printf("%s 余额消费日志请求失败 %v", key, err)
@ -221,23 +223,23 @@ func consumption(key string, model string, prompt_tokens int, completion_tokens
}
// 余额消费 for images
func consumptionForImages(key string, model string, n int, dataNum int, msg_id string) (string, error) {
func ConsumptionForImages(key string, modelStr string, n int, dataNum int, msg_id string) (string, error) {
//获取用户信息
userInfoStr, err := common.RedisGet("user:" + key)
if err != nil {
return "", errors.New("用户信息不存在")
}
var userInfo UserInfo
var userInfo model.UserInfo
err = json.Unmarshal([]byte(userInfoStr), &userInfo)
if err != nil {
return "", errors.New("用户信息解析失败")
}
//获取模型价格
modelPriceStr, err := common.RedisGet("model:" + model)
modelPriceStr, err := common.RedisGet("model:" + modelStr)
if err != nil {
return "", errors.New("模型信息不存在")
}
var modelInfo ModelInfo
var modelInfo model.ModelInfo
err = json.Unmarshal([]byte(modelPriceStr), &modelInfo)
if err != nil {
return "", errors.New("模型信息解析失败")
@ -245,7 +247,7 @@ func consumptionForImages(key string, model string, n int, dataNum int, msg_id s
balance, err := common.RedisIncrByFloat("user:"+userInfo.UID+":balance", float64(modelInfo.ModelPrepayment*n)*modelInfo.ModelPrice-(float64(1000*dataNum)*modelInfo.ModelPrice))
// 余额消费日志请求
result, err := balanceConsumption(key, model, 0, 1000*dataNum, 1000*dataNum, msg_id)
result, err := api.BalanceConsumption(key, modelStr, 0, 1000*dataNum, 1000*dataNum, msg_id)
log.Printf("用户余额:%f 扣费KEY: %s 扣费token数: %d 扣费:%f 扣费日志发送结果 %s", balance, key, 1000*dataNum, float64(1000*dataNum)*modelInfo.ModelPrice, result)
if err != nil {
log.Printf("%s 余额消费日志请求失败 %v", key, err)
Loading…
Cancel
Save