You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

559 lines
17 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

package main
import (
"bufio"
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"io/ioutil"
"log"
"math/rand"
"net/http"
"net/http/httputil"
"net/url"
"os"
"strings"
"time"
"github.com/gin-gonic/gin"
"github.com/pkoukk/tiktoken-go"
"github.com/redis/go-redis/v9"
)
type Message struct {
Role string `json:"role,omitempty"`
Name string `json:"name,omitempty"`
Content string `json:"content,omitempty"`
}
type ChatRequest struct {
Stream bool `json:"stream,omitempty"`
Model string `json:"model,omitempty"`
PresencePenalty float64 `json:"presence_penalty,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
Messages []Message `json:"messages,omitempty"`
Prompt string `json:"prompt,omitempty"`
Input string `json:"input,omitempty"`
}
type ChatResponse struct {
Id string `json:"id,omitempty"`
Object string `json:"object,omitempty"`
Created int64 `json:"created,omitempty"`
Model string `json:"model,omitempty"`
Usage Usage `json:"usage,omitempty"`
Choices []Choice `json:"choices,omitempty"`
}
type Usage struct {
PromptTokens int `json:"prompt_tokens,omitempty"`
CompletionTokens int `json:"completion_tokens,omitempty"`
TotalTokens int `json:"total_tokens,omitempty"`
}
type Choice struct {
Message Message `json:"message,omitempty"`
Delta Delta `json:"delta,omitempty"`
FinishReason string `json:"finish_reason,omitempty"`
Index int `json:"index,omitempty"`
Text string `json:"text,omitempty"`
}
type Delta struct {
Content string `json:"content,omitempty"`
}
type CreditSummary struct {
Object string `json:"object"`
TotalGranted float64 `json:"total_granted"`
TotalUsed float64 `json:"total_used"`
TotalRemaining float64 `json:"total_remaining"`
}
type ListModelResponse struct {
Object string `json:"object"`
Data []Model `json:"data"`
}
type Model struct {
ID string `json:"id"`
Object string `json:"object"`
Created int `json:"created"`
OwnedBy string `json:"owned_by"`
Permission []ModelPermission `json:"permission"`
Root string `json:"root"`
Parent any `json:"parent"`
}
type ModelPermission struct {
ID string `json:"id"`
Object string `json:"object"`
Created int `json:"created"`
AllowCreateEngine bool `json:"allow_create_engine"`
AllowSampling bool `json:"allow_sampling"`
AllowLogprobs bool `json:"allow_logprobs"`
AllowSearchIndices bool `json:"allow_search_indices"`
AllowView bool `json:"allow_view"`
AllowFineTuning bool `json:"allow_fine_tuning"`
Organization string `json:"organization"`
Group any `json:"group"`
IsBlocking bool `json:"is_blocking"`
}
type UserInfo struct {
UID string `json:"uid"`
SID string `json:"sid"`
}
type ServerInfo struct {
ServerAddress string `json:"server_address"`
AvailableKey string `json:"available_key"`
}
type ModelInfo struct {
ModelName string `json:"model_name"`
ModelPrice float64 `json:"model_price"`
ModelPrepayment int `json:"model_prepayment"`
}
var (
Redis *redis.Client
RedisAddress = "localhost:6379"
)
func init() {
//gin.SetMode(gin.ReleaseMode)
if v := os.Getenv("REDIS_ADDRESS"); v != "" {
RedisAddress = v
}
log.Printf("loading redis address: %s", RedisAddress)
}
// redis 初始化
func InitRedis() *redis.Client {
rdb := redis.NewClient(&redis.Options{
Addr: RedisAddress,
Password: "", // no password set
DB: 0, // use default DB
PoolSize: 10,
})
result := rdb.Ping(context.Background())
fmt.Println("redis ping:", result.Val())
if result.Val() != "PONG" {
// 连接有问题
return nil
}
return rdb
}
// 计算Messages中的token数量
func numTokensFromMessages(messages []Message, model string) int {
tkm, err := tiktoken.EncodingForModel(model)
if err != nil {
err = fmt.Errorf("getEncoding: %v", err)
panic(err)
}
numTokens := 0
for _, message := range messages {
numTokens += len(tkm.Encode(message.Content, nil, nil))
numTokens += 6
}
numTokens += 3
return numTokens
}
// 计算String中的token数量
func numTokensFromString(msg string, model string) int {
tkm, err := tiktoken.EncodingForModel(model)
if err != nil {
err = fmt.Errorf("getEncoding: %v", err)
panic(err)
}
if model == "text-davinci-003" {
return len(tkm.Encode(msg, nil, nil)) + 1
} else {
return len(tkm.Encode(msg, nil, nil)) + 9
}
}
func embeddings(c *gin.Context) {
var prompt_tokens int
var total_tokens int
var chatRequest ChatRequest
if err := c.ShouldBindJSON(&chatRequest); err != nil {
c.AbortWithStatusJSON(400, gin.H{"error": err.Error()})
return
}
auth := c.Request.Header.Get("Authorization")
key := strings.Trim(auth, "Bearer ")
//根据KEY调用用户余额接口判断是否有足够的余额 后期可考虑判断max_tokens参数来调整
serverInfo, err := checkBlance(key, chatRequest.Model)
if err != nil {
c.AbortWithStatusJSON(403, gin.H{"error": err.Error()})
log.Printf("请求出错 KEY: %v Model: %v ERROR: %v", key, chatRequest.Model, err)
return
}
log.Printf("请求的KEY: %v Model: %v", key, chatRequest.Model)
remote, err := url.Parse(serverInfo.ServerAddress)
if err != nil {
panic(err)
}
proxy := httputil.NewSingleHostReverseProxy(remote)
newReqBody, err := json.Marshal(chatRequest)
if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
proxy.Director = func(req *http.Request) {
req.Header = c.Request.Header
req.Host = remote.Host
req.URL.Scheme = remote.Scheme
req.URL.Host = remote.Host
req.URL.Path = c.Request.URL.Path
req.ContentLength = int64(len(newReqBody))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept-Encoding", "")
serverKey := serverInfo.AvailableKey
keyList := strings.Split(serverInfo.AvailableKey, ",")
if len(keyList) > 1 {
// 随机数种子
rand.Seed(time.Now().UnixNano())
// 从数组中随机选择一个元素
serverKey = keyList[rand.Intn(len(keyList))]
}
req.Header.Set("Authorization", "Bearer "+serverKey)
req.Body = ioutil.NopCloser(bytes.NewReader(newReqBody))
}
sss, err := json.Marshal(chatRequest)
log.Printf("开始处理返回逻辑 %d", string(sss))
proxy.ModifyResponse = func(resp *http.Response) error {
resp.Header.Set("Openai-Organization", "api2gpt")
var chatResponse ChatResponse
body, err := ioutil.ReadAll(resp.Body)
if err != nil {
return err
}
json.Unmarshal(body, &chatResponse)
prompt_tokens = chatResponse.Usage.PromptTokens
total_tokens = chatResponse.Usage.TotalTokens
resp.Body = ioutil.NopCloser(bytes.NewReader(body))
log.Printf("prompt_tokens: %v total_tokens: %v", prompt_tokens, total_tokens)
timestamp := time.Now().Unix()
timestampID := "emb-" + fmt.Sprintf("%d", timestamp)
//消费余额
consumption(key, chatRequest.Model, prompt_tokens, 0, total_tokens, timestampID)
return nil
}
proxy.ServeHTTP(c.Writer, c.Request)
}
func completions(c *gin.Context) {
var prompt_tokens int
var completion_tokens int
var total_tokens int
var chatRequest ChatRequest
if err := c.ShouldBindJSON(&chatRequest); err != nil {
c.AbortWithStatusJSON(400, gin.H{"error": err.Error()})
return
}
auth := c.Request.Header.Get("Authorization")
key := strings.Trim(auth, "Bearer ")
//根据KEY调用用户余额接口判断是否有足够的余额 后期可考虑判断max_tokens参数来调整
serverInfo, err := checkBlance(key, chatRequest.Model)
if err != nil {
c.AbortWithStatusJSON(403, gin.H{"error": err.Error()})
log.Printf("请求出错 KEY: %v Model: %v ERROR: %v", key, chatRequest.Model, err)
return
}
log.Printf("请求的KEY: %v Model: %v", key, chatRequest.Model)
remote, err := url.Parse(serverInfo.ServerAddress)
if err != nil {
panic(err)
}
proxy := httputil.NewSingleHostReverseProxy(remote)
newReqBody, err := json.Marshal(chatRequest)
if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
proxy.Director = func(req *http.Request) {
req.Header = c.Request.Header
req.Host = remote.Host
req.URL.Scheme = remote.Scheme
req.URL.Host = remote.Host
req.URL.Path = c.Request.URL.Path
req.ContentLength = int64(len(newReqBody))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept-Encoding", "")
serverKey := serverInfo.AvailableKey
keyList := strings.Split(serverInfo.AvailableKey, ",")
if len(keyList) > 1 {
// 随机数种子
rand.Seed(time.Now().UnixNano())
// 从数组中随机选择一个元素
serverKey = keyList[rand.Intn(len(keyList))]
}
req.Header.Set("Authorization", "Bearer "+serverKey)
req.Body = ioutil.NopCloser(bytes.NewReader(newReqBody))
}
sss, err := json.Marshal(chatRequest)
log.Printf("开始处理返回逻辑 %d", string(sss))
if chatRequest.Stream {
// 流式回应,处理
proxy.ModifyResponse = func(resp *http.Response) error {
chatRequestId := ""
reqContent := ""
reader := bufio.NewReader(resp.Body)
headers := resp.Header
for k, v := range headers {
c.Writer.Header().Set(k, v[0])
}
c.Writer.Header().Set("Openai-Organization", "api2gpt")
for {
chunk, err := reader.ReadBytes('\n')
if err != nil {
if err == io.EOF {
break
}
log.Printf("httpError1 %v", err.Error())
break
//return err
}
var chatResponse ChatResponse
//去除回应中的data:前缀
var trimStr = strings.Trim(string(chunk), "data: ")
if trimStr != "\n" {
json.Unmarshal([]byte(trimStr), &chatResponse)
//log.Printf("trimStr:" + trimStr)
if chatResponse.Choices != nil {
reqContent += chatResponse.Choices[0].Delta.Content
chatRequestId = chatResponse.Id
}
// 写回数据
_, err = c.Writer.Write([]byte(string(chunk) + "\n"))
if err != nil {
log.Printf("httpError2 %v", err.Error())
return err
}
c.Writer.(http.Flusher).Flush()
}
}
if chatRequest.Model == "text-davinci-003" {
prompt_tokens = numTokensFromString(chatRequest.Prompt, chatRequest.Model)
} else {
prompt_tokens = numTokensFromMessages(chatRequest.Messages, chatRequest.Model)
}
completion_tokens = numTokensFromString(reqContent, chatRequest.Model)
total_tokens = prompt_tokens + completion_tokens
log.Printf("prompt_tokens: %v completion_tokens: %v total_tokens: %v", prompt_tokens, completion_tokens, total_tokens)
//消费余额
consumption(key, chatRequest.Model, prompt_tokens, completion_tokens, total_tokens, chatRequestId)
return nil
}
} else {
// 非流式回应,处理
proxy.ModifyResponse = func(resp *http.Response) error {
resp.Header.Set("Openai-Organization", "api2gpt")
var chatResponse ChatResponse
body, err := ioutil.ReadAll(resp.Body)
if err != nil {
return err
}
json.Unmarshal(body, &chatResponse)
prompt_tokens = chatResponse.Usage.PromptTokens
completion_tokens = chatResponse.Usage.CompletionTokens
total_tokens = chatResponse.Usage.TotalTokens
resp.Body = ioutil.NopCloser(bytes.NewReader(body))
log.Printf("prompt_tokens: %v completion_tokens: %v total_tokens: %v", prompt_tokens, completion_tokens, total_tokens)
//消费余额
consumption(key, chatRequest.Model, prompt_tokens, completion_tokens, total_tokens, chatResponse.Id)
return nil
}
}
proxy.ServeHTTP(c.Writer, c.Request)
}
// model查询列表
func handleGetModels(c *gin.Context) {
// BUGFIX: fix options request, see https://github.com/diemus/azure-openai-proxy/issues/3
//models := []string{"gpt-4", "gpt-4-0314", "gpt-4-32k", "gpt-4-32k-0314", "gpt-3.5-turbo", "gpt-3.5-turbo-0301", "text-davinci-003", "text-embedding-ada-002"}
models := []string{"gpt-3.5-turbo", "gpt-3.5-turbo-0301", "text-davinci-003", "text-embedding-ada-002"}
result := ListModelResponse{
Object: "list",
}
for _, model := range models {
result.Data = append(result.Data, Model{
ID: model,
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: []ModelPermission{
{
ID: "",
Object: "model",
Created: 1679602087,
AllowCreateEngine: true,
AllowSampling: true,
AllowLogprobs: true,
AllowSearchIndices: true,
AllowView: true,
AllowFineTuning: true,
Organization: "*",
Group: nil,
IsBlocking: false,
},
},
Root: model,
Parent: nil,
})
}
c.JSON(200, result)
}
// 针对接口预检OPTIONS的处理
func handleOptions(c *gin.Context) {
// BUGFIX: fix options request, see https://github.com/diemus/azure-openai-proxy/issues/1
c.Header("Access-Control-Allow-Origin", "*")
c.Header("Access-Control-Allow-Methods", "POST, GET, OPTIONS, PUT, DELETE")
c.Header("Access-Control-Allow-Headers", "Content-Type, Authorization")
c.Status(200)
return
}
// 余额查询
func balance(c *gin.Context) {
auth := c.Request.Header.Get("Authorization")
key := strings.Trim(auth, "Bearer ")
balance, err := queryBlance(key)
if err != nil {
c.JSON(400, gin.H{"error": err.Error()})
}
var creditSummary CreditSummary
creditSummary.Object = "credit_grant"
creditSummary.TotalGranted = 999999
creditSummary.TotalUsed = 999999 - balance
creditSummary.TotalRemaining = balance
c.JSON(200, creditSummary)
}
// key 检测中间件
func checkKeyMid() gin.HandlerFunc {
return func(c *gin.Context) {
auth := c.Request.Header.Get("Authorization")
log.Printf("auth: %v", auth)
if auth == "" {
c.AbortWithStatusJSON(401, gin.H{"code": 40001})
} else {
key := strings.Trim(auth, "Bearer ")
msg, err := checkKeyAndTimeCount(key)
if err != nil {
c.AbortWithStatusJSON(msg, gin.H{"code": err.Error()})
}
}
// 执行函数
c.Next()
}
}
func main() {
// 禁用控制台颜色,将日志写入文件时不需要控制台颜色。
gin.DisableConsoleColor()
// 记录到文件。
f, _ := os.Create("gin.log")
//gin.DefaultWriter = io.MultiWriter(f)
// 如果需要同时将日志写入文件和控制台,请使用以下代码。
gin.DefaultWriter = io.MultiWriter(f, os.Stdout)
log.SetOutput(gin.DefaultWriter)
r := gin.Default()
//r.GET("/dashboard/billing/credit_grants", checkKeyMid(), balance)
r.GET("/v1/models", handleGetModels)
r.OPTIONS("/v1/*path", handleOptions)
r.POST("/v1/chat/completions", checkKeyMid(), completions)
r.POST("/v1/completions", checkKeyMid(), completions)
r.POST("/v1/embeddings", checkKeyMid(), embeddings)
r.POST("/mock1", mockBalanceInquiry)
r.POST("/mock2", mockBalanceConsumption)
Redis = InitRedis()
//添加reids测试数据
// var serverInfo ServerInfo = ServerInfo{
// ServerAddress: "https://gptp.any-door.cn",
// AvailableKey: "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,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,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",
// }
// var serverInfo2 ServerInfo = ServerInfo{
// ServerAddress: "https://azure.any-door.cn",
// AvailableKey: "6c4d2c65970b40e482e7cd27adb0d119",
// }
// serverInfoStr, _ := json.Marshal(&serverInfo)
// serverInfoStr2, _ := json.Marshal(&serverInfo2)
// Redis.Set(context.Background(), "server:1", serverInfoStr, 0)
// Redis.Set(context.Background(), "server:2", serverInfoStr2, 0)
// var modelInfo ModelInfo = ModelInfo{
// ModelName: "gpt-3.5-turbo",
// ModelPrice: 0.0001,
// ModelPrepayment: 4000,
// }
// modelInfoStr, _ := json.Marshal(&modelInfo)
// var modelInfo2 ModelInfo = ModelInfo{
// ModelName: "text-davinci-003",
// ModelPrice: 0.001,
// ModelPrepayment: 4000,
// }
// modelInfoStr2, _ := json.Marshal(&modelInfo2)
// var modelInfo3 ModelInfo = ModelInfo{
// ModelName: "text-davinci-003",
// ModelPrice: 0.001,
// ModelPrepayment: 4000,
// }
// modelInfoStr3, _ := json.Marshal(&modelInfo3)
// Redis.Set(context.Background(), "model:gpt-3.5-turbo", modelInfoStr, 0)
// Redis.Set(context.Background(), "model:gpt-3.5-turbo-0301", modelInfoStr, 0)
// Redis.Set(context.Background(), "model:text-davinci-003", modelInfoStr2, 0)
// Redis.Set(context.Background(), "model:text-embedding-ada-002", modelInfoStr3, 0)
// var userInfo UserInfo = UserInfo{
// UID: "1",
// SID: "1",
// }
// var userInfo2 UserInfo = UserInfo{
// UID: "2",
// SID: "2",
// }
// userInfoStr, _ := json.Marshal(&userInfo)
// userInfoStr2, _ := json.Marshal(&userInfo2)
// Redis.Set(context.Background(), "user:8aeb3747-715c-48e8-8b80-aec815949f22", userInfoStr, 0)
// Redis.Set(context.Background(), "user:AK-7d8ab782-a152-4cc1-9972-568713465c96", userInfoStr2, 0)
// Redis.IncrByFloat(context.Background(), "user:1:balance", 1000).Result()
// Redis.IncrByFloat(context.Background(), "user:2:balance", 1000).Result()
//r.Run("127.0.0.1:8080")
//docker下使用
r.Run("0.0.0.0:8080")
}