You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

837 lines
26 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

package main
import (
"bufio"
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"io/ioutil"
"log"
"math/rand"
"net/http"
"net/http/httputil"
"net/url"
"os"
"strings"
"time"
"github.com/gin-gonic/gin"
"github.com/pkoukk/tiktoken-go"
"github.com/redis/go-redis/v9"
)
type Message struct {
Role string `json:"role,omitempty"`
Name string `json:"name,omitempty"`
Content string `json:"content,omitempty"`
}
type ChatRequest struct {
Stream bool `json:"stream,omitempty"`
Model string `json:"model,omitempty"`
PresencePenalty float64 `json:"presence_penalty,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
Messages []Message `json:"messages,omitempty"`
Prompt string `json:"prompt,omitempty"`
Input interface{} `json:"input,omitempty"`
Instruction string `json:"instruction,omitempty"`
N int `json:"n,omitempty"`
Functions interface{} `json:"functions,omitempty"`
}
type ImagesRequest struct {
Prompt string `json:"prompt,omitempty"`
N int `json:"n,omitempty"`
Size string `json:"size,omitempty"`
}
type ChatResponse struct {
Id string `json:"id,omitempty"`
Object string `json:"object,omitempty"`
Created int64 `json:"created,omitempty"`
Model string `json:"model,omitempty"`
Usage Usage `json:"usage,omitempty"`
Choices []Choice `json:"choices,omitempty"`
}
type ImagesResponse struct {
Created int64 `json:"created,omitempty"`
Data []ImagesData `json:"data,omitempty"`
}
type ImagesData struct {
Url string `json:"url,omitempty"`
}
type Usage struct {
PromptTokens int `json:"prompt_tokens,omitempty"`
CompletionTokens int `json:"completion_tokens,omitempty"`
TotalTokens int `json:"total_tokens,omitempty"`
}
type Choice struct {
Message Message `json:"message,omitempty"`
Delta Delta `json:"delta,omitempty"`
FinishReason string `json:"finish_reason,omitempty"`
Index int `json:"index,omitempty"`
Text string `json:"text,omitempty"`
}
type Delta struct {
Content string `json:"content,omitempty"`
}
type CreditSummary struct {
Object string `json:"object"`
TotalGranted float64 `json:"total_granted"`
TotalUsed float64 `json:"total_used"`
TotalRemaining float64 `json:"total_remaining"`
}
type ListModelResponse struct {
Object string `json:"object"`
Data []Model `json:"data"`
}
type Model struct {
ID string `json:"id"`
Object string `json:"object"`
Created int `json:"created"`
OwnedBy string `json:"owned_by"`
Permission []ModelPermission `json:"permission"`
Root string `json:"root"`
Parent any `json:"parent"`
}
type ModelPermission struct {
ID string `json:"id"`
Object string `json:"object"`
Created int `json:"created"`
AllowCreateEngine bool `json:"allow_create_engine"`
AllowSampling bool `json:"allow_sampling"`
AllowLogprobs bool `json:"allow_logprobs"`
AllowSearchIndices bool `json:"allow_search_indices"`
AllowView bool `json:"allow_view"`
AllowFineTuning bool `json:"allow_fine_tuning"`
Organization string `json:"organization"`
Group any `json:"group"`
IsBlocking bool `json:"is_blocking"`
}
type UserInfo struct {
UID string `json:"uid"`
SID string `json:"sid"`
}
type ServerInfo struct {
ServerAddress string `json:"server_address"`
AvailableKey string `json:"available_key"`
}
type ModelInfo struct {
ModelName string `json:"model_name"`
ModelPrice float64 `json:"model_price"`
ModelPrice2 float64 `json:"model_price2"`
ModelPrepayment int `json:"model_prepayment"`
ServerId int `json:"server_id"`
}
var (
Redis *redis.Client
RedisAddress = "localhost:6379"
)
func init() {
//gin.SetMode(gin.ReleaseMode)
if v := os.Getenv("REDIS_ADDRESS"); v != "" {
RedisAddress = v
}
log.Printf("loading redis address: %s", RedisAddress)
}
// redis 初始化
func InitRedis() *redis.Client {
rdb := redis.NewClient(&redis.Options{
Addr: RedisAddress,
Password: "", // no password set
DB: 0, // use default DB
PoolSize: 10,
})
result := rdb.Ping(context.Background())
fmt.Println("redis ping:", result.Val())
if result.Val() != "PONG" {
// 连接有问题
return nil
}
return rdb
}
// 计算Messages中的token数量
func numTokensFromMessages(messages []Message, model string) int {
if strings.Contains(model, "gpt-3.5") {
model = "gpt-3.5-turbo"
}
if strings.Contains(model, "gpt-4") {
model = "gpt-4"
}
tkm, err := tiktoken.EncodingForModel(model)
if err != nil {
err = fmt.Errorf("getEncoding: %v", err)
return 0
}
numTokens := 0
for _, message := range messages {
numTokens += len(tkm.Encode(message.Content, nil, nil))
numTokens += 6
}
numTokens += 3
return numTokens
}
// 计算String中的token数量
func numTokensFromString(msg string, model string) int {
if strings.Contains(model, "gpt-3.5") {
model = "gpt-3.5-turbo"
}
if strings.Contains(model, "gpt-4") {
model = "gpt-4"
}
tkm, err := tiktoken.EncodingForModel(model)
if err != nil {
err = fmt.Errorf("getEncoding: %v", err)
return 0
}
if model == "text-davinci-003" {
return len(tkm.Encode(msg, nil, nil)) + 1
} else {
return len(tkm.Encode(msg, nil, nil)) + 9
}
}
func images(c *gin.Context) {
var imagesRequest ImagesRequest
var model = "images-generations"
if err := c.ShouldBindJSON(&imagesRequest); err != nil {
c.AbortWithStatusJSON(400, gin.H{"error": err.Error()})
return
}
auth := c.Request.Header.Get("Authorization")
key := auth[7:]
serverInfo, err := checkBlanceForImages(key, model, imagesRequest.N)
if err != nil {
c.AbortWithStatusJSON(403, gin.H{"error": err.Error()})
log.Printf("请求出错 KEY: %v Model: %v ERROR: %v", key, model, err)
return
}
log.Printf("请求的KEY: %v Model: %v", key, model)
remote, err := url.Parse(serverInfo.ServerAddress)
if err != nil {
c.AbortWithStatusJSON(400, gin.H{"error": err.Error()})
return
}
proxy := httputil.NewSingleHostReverseProxy(remote)
newReqBody, err := json.Marshal(imagesRequest)
if err != nil {
log.Printf("http request err: %v", err)
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
proxy.Director = func(req *http.Request) {
req.Header = c.Request.Header
req.Host = remote.Host
req.URL.Scheme = remote.Scheme
req.URL.Host = remote.Host
req.URL.Path = c.Request.URL.Path
req.ContentLength = int64(len(newReqBody))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept-Encoding", "")
serverKey := serverInfo.AvailableKey
keyList := strings.Split(serverInfo.AvailableKey, ",")
if len(keyList) > 1 {
// 随机数种子
rand.Seed(time.Now().UnixNano())
// 从数组中随机选择一个元素
serverKey = keyList[rand.Intn(len(keyList))]
}
req.Header.Set("Authorization", "Bearer "+serverKey)
req.Body = ioutil.NopCloser(bytes.NewReader(newReqBody))
}
sss, err := json.Marshal(imagesRequest)
log.Printf("开始处理返回逻辑 %d", string(sss))
proxy.ModifyResponse = func(resp *http.Response) error {
if resp.StatusCode != http.StatusOK {
//退回预扣除的余额
checkBlanceReturnForImages(key, model, imagesRequest.N)
return nil
}
resp.Header.Set("Openai-Organization", "api2gpt")
var imagesResponse ImagesResponse
body, err := ioutil.ReadAll(resp.Body)
if err != nil {
log.Printf("读取返回数据出错: %v", err)
return err
}
json.Unmarshal(body, &imagesResponse)
resp.Body = ioutil.NopCloser(bytes.NewReader(body))
log.Printf("image size: %v", len(imagesResponse.Data))
timestamp := time.Now().Unix()
timestampID := "img-" + fmt.Sprintf("%d", timestamp)
//消费余额
consumptionForImages(key, model, imagesRequest.N, len(imagesResponse.Data), timestampID)
return nil
}
proxy.ServeHTTP(c.Writer, c.Request)
}
func edit(c *gin.Context) {
var prompt_tokens int
var complate_tokens int
var total_tokens int
var chatRequest ChatRequest
if err := c.ShouldBindJSON(&chatRequest); err != nil {
c.AbortWithStatusJSON(400, gin.H{"error": err.Error()})
return
}
auth := c.Request.Header.Get("Authorization")
key := auth[7:]
//根据KEY调用用户余额接口判断是否有足够的余额 后期可考虑判断max_tokens参数来调整
serverInfo, err := checkBlance(key, chatRequest.Model)
if err != nil {
c.AbortWithStatusJSON(403, gin.H{"error": err.Error()})
log.Printf("请求出错 KEY: %v Model: %v ERROR: %v", key, chatRequest.Model, err)
return
}
log.Printf("请求的KEY: %v Model: %v", key, chatRequest.Model)
remote, err := url.Parse(serverInfo.ServerAddress)
if err != nil {
c.AbortWithStatusJSON(400, gin.H{"error": err.Error()})
return
}
proxy := httputil.NewSingleHostReverseProxy(remote)
newReqBody, err := json.Marshal(chatRequest)
if err != nil {
log.Printf("http request err: %v", err)
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
proxy.Director = func(req *http.Request) {
req.Header = c.Request.Header
req.Host = remote.Host
req.URL.Scheme = remote.Scheme
req.URL.Host = remote.Host
req.URL.Path = c.Request.URL.Path
req.ContentLength = int64(len(newReqBody))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept-Encoding", "")
serverKey := serverInfo.AvailableKey
keyList := strings.Split(serverInfo.AvailableKey, ",")
if len(keyList) > 1 {
// 随机数种子
rand.Seed(time.Now().UnixNano())
// 从数组中随机选择一个元素
serverKey = keyList[rand.Intn(len(keyList))]
}
req.Header.Set("Authorization", "Bearer "+serverKey)
req.Body = ioutil.NopCloser(bytes.NewReader(newReqBody))
}
sss, err := json.Marshal(chatRequest)
log.Printf("开始处理返回逻辑 %d", string(sss))
proxy.ModifyResponse = func(resp *http.Response) error {
resp.Header.Set("Openai-Organization", "api2gpt")
var chatResponse ChatResponse
body, err := ioutil.ReadAll(resp.Body)
if err != nil {
log.Printf("读取返回数据出错: %v", err)
return err
}
json.Unmarshal(body, &chatResponse)
prompt_tokens = chatResponse.Usage.PromptTokens
complate_tokens = chatResponse.Usage.CompletionTokens
total_tokens = chatResponse.Usage.TotalTokens
resp.Body = ioutil.NopCloser(bytes.NewReader(body))
log.Printf("prompt_tokens: %v complate_tokens: %v total_tokens: %v", prompt_tokens, complate_tokens, total_tokens)
timestamp := time.Now().Unix()
timestampID := "edit-" + fmt.Sprintf("%d", timestamp)
//消费余额
consumption(key, chatRequest.Model, prompt_tokens, 0, total_tokens, timestampID)
return nil
}
proxy.ServeHTTP(c.Writer, c.Request)
}
func embeddings(c *gin.Context) {
var prompt_tokens int
var total_tokens int
var chatRequest ChatRequest
if err := c.ShouldBindJSON(&chatRequest); err != nil {
c.AbortWithStatusJSON(400, gin.H{"error": err.Error()})
return
}
auth := c.Request.Header.Get("Authorization")
//key := strings.Trim(auth, "Bearer ")
key := auth[7:]
//根据KEY调用用户余额接口判断是否有足够的余额 后期可考虑判断max_tokens参数来调整
serverInfo, err := checkBlance(key, chatRequest.Model)
if err != nil {
c.AbortWithStatusJSON(403, gin.H{"error": err.Error()})
log.Printf("请求出错 KEY: %v Model: %v ERROR: %v", key, chatRequest.Model, err)
return
}
log.Printf("请求的KEY: %v Model: %v", key, chatRequest.Model)
remote, err := url.Parse(serverInfo.ServerAddress)
if err != nil {
c.AbortWithStatusJSON(400, gin.H{"error": err.Error()})
return
}
proxy := httputil.NewSingleHostReverseProxy(remote)
newReqBody, err := json.Marshal(chatRequest)
if err != nil {
log.Printf("http request err: %v", err)
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
proxy.Director = func(req *http.Request) {
req.Header = c.Request.Header
req.Host = remote.Host
req.URL.Scheme = remote.Scheme
req.URL.Host = remote.Host
req.URL.Path = c.Request.URL.Path
req.ContentLength = int64(len(newReqBody))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept-Encoding", "")
serverKey := serverInfo.AvailableKey
keyList := strings.Split(serverInfo.AvailableKey, ",")
if len(keyList) > 1 {
// 随机数种子
rand.Seed(time.Now().UnixNano())
// 从数组中随机选择一个元素
serverKey = keyList[rand.Intn(len(keyList))]
}
req.Header.Set("Authorization", "Bearer "+serverKey)
req.Body = ioutil.NopCloser(bytes.NewReader(newReqBody))
}
sss, err := json.Marshal(chatRequest)
log.Printf("开始处理返回逻辑 %d", string(sss))
proxy.ModifyResponse = func(resp *http.Response) error {
resp.Header.Set("Openai-Organization", "api2gpt")
var chatResponse ChatResponse
body, err := ioutil.ReadAll(resp.Body)
if err != nil {
log.Printf("读取返回数据出错: %v", err)
return err
}
json.Unmarshal(body, &chatResponse)
prompt_tokens = chatResponse.Usage.PromptTokens
total_tokens = chatResponse.Usage.TotalTokens
resp.Body = ioutil.NopCloser(bytes.NewReader(body))
log.Printf("prompt_tokens: %v total_tokens: %v", prompt_tokens, total_tokens)
timestamp := time.Now().Unix()
timestampID := "emb-" + fmt.Sprintf("%d", timestamp)
//消费余额
consumption(key, chatRequest.Model, prompt_tokens, 0, total_tokens, timestampID)
return nil
}
proxy.ServeHTTP(c.Writer, c.Request)
}
func completions(c *gin.Context) {
var prompt_tokens int
var completion_tokens int
var total_tokens int
var chatRequest ChatRequest
if err := c.ShouldBindJSON(&chatRequest); err != nil {
c.AbortWithStatusJSON(400, gin.H{"error": err.Error()})
return
}
auth := c.Request.Header.Get("Authorization")
key := auth[7:]
//根据KEY调用用户余额接口判断是否有足够的余额 后期可考虑判断max_tokens参数来调整
serverInfo, err := checkBlance(key, chatRequest.Model)
if err != nil {
c.AbortWithStatusJSON(403, gin.H{"error": err.Error()})
log.Printf("请求出错 KEY: %v Model: %v ERROR: %v", key, chatRequest.Model, err)
return
}
log.Printf("请求的KEY: %v Model: %v", key, chatRequest.Model)
remote, err := url.Parse(serverInfo.ServerAddress)
if err != nil {
c.AbortWithStatusJSON(400, gin.H{"error": err.Error()})
return
}
proxy := httputil.NewSingleHostReverseProxy(remote)
newReqBody, err := json.Marshal(chatRequest)
if err != nil {
log.Printf("http request err: %v", err)
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
proxy.Director = func(req *http.Request) {
req.Header = c.Request.Header
req.Host = remote.Host
req.URL.Scheme = remote.Scheme
req.URL.Host = remote.Host
req.URL.Path = c.Request.URL.Path
req.ContentLength = int64(len(newReqBody))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept-Encoding", "")
serverKey := serverInfo.AvailableKey
keyList := strings.Split(serverInfo.AvailableKey, ",")
if len(keyList) > 1 {
// 随机数种子
rand.Seed(time.Now().UnixNano())
// 从数组中随机选择一个元素
serverKey = keyList[rand.Intn(len(keyList))]
}
req.Header.Set("Authorization", "Bearer "+serverKey)
req.Body = ioutil.NopCloser(bytes.NewReader(newReqBody))
}
sss, err := json.Marshal(chatRequest)
if err != nil {
log.Printf("chatRequest 转化出错 %v", err)
}
log.Printf("开始处理返回逻辑: %v", string(sss))
if chatRequest.Stream {
// 流式回应,处理
proxy.ModifyResponse = func(resp *http.Response) error {
log.Printf("流式回应 http status code: %v", resp.StatusCode)
if resp.StatusCode != http.StatusOK {
//退回预扣除的余额
checkBlanceReturn(key, chatRequest.Model)
return nil
}
chatRequestId := ""
reqContent := ""
reader := bufio.NewReader(resp.Body)
headers := resp.Header
for k, v := range headers {
c.Writer.Header().Set(k, v[0])
}
c.Writer.Header().Set("Openai-Organization", "api2gpt")
for {
chunk, err := reader.ReadBytes('\n')
if err != nil {
if err == io.EOF {
break
}
log.Printf("流式回应,处理 err %v:", err.Error())
break
//return err
}
var chatResponse ChatResponse
//去除回应中的data:前缀
var trimStr = strings.Trim(string(chunk), "data: ")
if trimStr != "\n" {
json.Unmarshal([]byte(trimStr), &chatResponse)
if chatResponse.Choices != nil {
if chatResponse.Choices[0].Text != "" {
reqContent += chatResponse.Choices[0].Text
} else {
reqContent += chatResponse.Choices[0].Delta.Content
}
chatRequestId = chatResponse.Id
}
// 写回数据
_, err = c.Writer.Write([]byte(string(chunk) + "\n"))
if err != nil {
log.Printf("写回数据 err: %v", err.Error())
return err
}
c.Writer.(http.Flusher).Flush()
}
}
if chatRequest.Model == "text-davinci-003" {
prompt_tokens = numTokensFromString(chatRequest.Prompt, chatRequest.Model)
} else {
prompt_tokens = numTokensFromMessages(chatRequest.Messages, chatRequest.Model)
}
completion_tokens = numTokensFromString(reqContent, chatRequest.Model)
log.Printf("返回内容:%v", reqContent)
total_tokens = prompt_tokens + completion_tokens
log.Printf("prompt_tokens: %v completion_tokens: %v total_tokens: %v", prompt_tokens, completion_tokens, total_tokens)
//消费余额
consumption(key, chatRequest.Model, prompt_tokens, completion_tokens, total_tokens, chatRequestId)
return nil
}
} else {
// 非流式回应,处理
proxy.ModifyResponse = func(resp *http.Response) error {
resp.Header.Set("Openai-Organization", "api2gpt")
var chatResponse ChatResponse
body, err := ioutil.ReadAll(resp.Body)
if err != nil {
log.Printf("非流式回应,处理 err: %v", err)
return err
}
json.Unmarshal(body, &chatResponse)
prompt_tokens = chatResponse.Usage.PromptTokens
completion_tokens = chatResponse.Usage.CompletionTokens
total_tokens = chatResponse.Usage.TotalTokens
resp.Body = ioutil.NopCloser(bytes.NewReader(body))
log.Printf("prompt_tokens: %v completion_tokens: %v total_tokens: %v", prompt_tokens, completion_tokens, total_tokens)
//消费余额
consumption(key, chatRequest.Model, prompt_tokens, completion_tokens, total_tokens, chatResponse.Id)
return nil
}
}
proxy.ServeHTTP(c.Writer, c.Request)
}
// model查询列表
func handleGetModels(c *gin.Context) {
// BUGFIX: fix options request, see https://github.com/diemus/azure-openai-proxy/issues/3
//models := []string{"gpt-4", "gpt-4-0314", "gpt-4-32k", "gpt-4-32k-0314", "gpt-3.5-turbo", "gpt-3.5-turbo-0301", "text-davinci-003", "text-embedding-ada-002"}
models := []string{"gpt-4", "gpt-4-0314", "gpt-4-0613", "gpt-3.5-turbo", "gpt-3.5-turbo-0301", "gpt-3.5-turbo-0613", "gpt-3.5-turbo-16k", "gpt-3.5-turbo-16k-0613", "text-davinci-003", "text-embedding-ada-002", "text-davinci-edit-001", "code-davinci-edit-001", "images-generations"}
result := ListModelResponse{
Object: "list",
}
for _, model := range models {
result.Data = append(result.Data, Model{
ID: model,
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: []ModelPermission{
{
ID: "",
Object: "model",
Created: 1679602087,
AllowCreateEngine: true,
AllowSampling: true,
AllowLogprobs: true,
AllowSearchIndices: true,
AllowView: true,
AllowFineTuning: true,
Organization: "*",
Group: nil,
IsBlocking: false,
},
},
Root: model,
Parent: nil,
})
}
c.JSON(200, result)
}
// 针对接口预检OPTIONS的处理
func handleOptions(c *gin.Context) {
// BUGFIX: fix options request, see https://github.com/diemus/azure-openai-proxy/issues/1
c.Header("Access-Control-Allow-Origin", "*")
c.Header("Access-Control-Allow-Methods", "POST, GET, OPTIONS, PUT, DELETE")
c.Header("Access-Control-Allow-Headers", "Content-Type, Authorization")
c.Status(200)
return
}
// 余额查询
func balance(c *gin.Context) {
auth := c.Request.Header.Get("Authorization")
key := strings.Trim(auth, "Bearer ")
balance, err := queryBlance(key)
if err != nil {
c.JSON(400, gin.H{"error": err.Error()})
}
var creditSummary CreditSummary
creditSummary.Object = "credit_grant"
creditSummary.TotalGranted = 999999
creditSummary.TotalUsed = 999999 - balance
creditSummary.TotalRemaining = balance
c.JSON(200, creditSummary)
}
// key 检测中间件
func checkKeyMid() gin.HandlerFunc {
return func(c *gin.Context) {
auth := c.Request.Header.Get("Authorization")
log.Printf("auth: %v", auth)
if auth == "" {
c.AbortWithStatusJSON(401, gin.H{"code": 40001})
} else {
//key := strings.Trim(auth, "Bearer ")
key := auth[7:]
log.Printf("key: %v", key)
msg, err := checkKeyAndTimeCount(key)
if err != nil {
log.Printf("checkKeyMid err: %v", err)
c.AbortWithStatusJSON(msg, gin.H{"code": err.Error()})
}
}
log.Printf("auth check end")
// 执行函数
c.Next()
}
}
func Cors() gin.HandlerFunc {
return func(c *gin.Context) {
method := c.Request.Method
origin := c.Request.Header.Get("Origin")
if origin != "" {
c.Header("Access-Control-Allow-Origin", "*") // 可将将 * 替换为指定的域名
c.Header("Access-Control-Allow-Methods", "POST, GET, OPTIONS, PUT, DELETE, UPDATE")
c.Header("Access-Control-Allow-Headers", "Origin, X-Requested-With, Content-Type, Accept, Authorization")
c.Header("Access-Control-Expose-Headers", "Content-Length, Access-Control-Allow-Origin, Access-Control-Allow-Headers, Cache-Control, Content-Language, Content-Type")
c.Header("Access-Control-Allow-Credentials", "true")
}
if method == "OPTIONS" {
c.AbortWithStatus(http.StatusNoContent)
}
c.Next()
}
}
func test_redis() {
//添加reids测试数据
//var serverInfo ServerInfo = ServerInfo{
// ServerAddress: "https://gptp.any-door.cn",
// AvailableKey: "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",
//}
// var serverInfo2 ServerInfo = ServerInfo{
// ServerAddress: "https://azure.any-door.cn",
// AvailableKey: "6c4d2c65970b40e482e7cd27adb0d119",
// }
//serverInfoStr, _ := json.Marshal(&serverInfo)
// serverInfoStr2, _ := json.Marshal(&serverInfo2)
//Redis.Set(context.Background(), "server:1", serverInfoStr, 0)
// Redis.Set(context.Background(), "server:2", serverInfoStr2, 0)
// var modelInfo ModelInfo = ModelInfo{
// ModelName: "gpt-3.5-turbo-0613",
// ModelPrice: 0.0001,
// ModelPrepayment: 4000,
// }
// modelInfoStr, _ := json.Marshal(&modelInfo)
// Redis.Set(context.Background(), "model:gpt-3.5-turbo-0613", modelInfoStr, 0)
// var modelInfo1 ModelInfo = ModelInfo{
// ModelName: "gpt-3.5-turbo-16k",
// ModelPrice: 0.0001,
// ModelPrepayment: 4000,
// }
// modelInfoStr1, _ := json.Marshal(&modelInfo1)
// Redis.Set(context.Background(), "model:gpt-3.5-turbo-16k", modelInfoStr1, 0)
// var modelInfo2 ModelInfo = ModelInfo{
// ModelName: "gpt-3.5-turbo-16k-061",
// ModelPrice: 0.0001,
// ModelPrepayment: 4000,
// }
// modelInfoStr2, _ := json.Marshal(&modelInfo2)
// Redis.Set(context.Background(), "model:gpt-3.5-turbo-16k-0613", modelInfoStr2, 0)
// var modelInfo2 ModelInfo = ModelInfo{
// ModelName: "text-davinci-003",
// ModelPrice: 0.001,
// ModelPrepayment: 4000,
// }
// modelInfoStr2, _ := json.Marshal(&modelInfo2)
// var modelInfo3 ModelInfo = ModelInfo{
// ModelName: "text-davinci-003",
// ModelPrice: 0.001,
// ModelPrepayment: 4000,
// }
// modelInfoStr3, _ := json.Marshal(&modelInfo3)
// Redis.Set(context.Background(), "model:gpt-3.5-turbo", modelInfoStr, 0)
// Redis.Set(context.Background(), "model:gpt-3.5-turbo-0301", modelInfoStr, 0)
// Redis.Set(context.Background(), "model:text-davinci-003", modelInfoStr2, 0)
// Redis.Set(context.Background(), "model:text-embedding-ada-002", modelInfoStr3, 0)
// var modelInfo2 ModelInfo = ModelInfo{
// ModelName: "images-generations",
// ModelPrice: 0.01,
// ModelPrepayment: 1000,
// }
// modelInfoStr2, _ := json.Marshal(&modelInfo2)
// Redis.Set(context.Background(), "model:images-generations", modelInfoStr2, 0)
var userInfo UserInfo = UserInfo{
UID: "0",
SID: "1",
}
// var userInfo2 UserInfo = UserInfo{
// UID: "2",
// SID: "2",
// }
userInfoStr, _ := json.Marshal(&userInfo)
// userInfoStr2, _ := json.Marshal(&userInfo2)
Redis.Set(context.Background(), "user:key0", userInfoStr, 0)
// Redis.Set(context.Background(), "user:AK-7d8ab782-a152-4cc1-9972-568713465c96", userInfoStr2, 0)
Redis.IncrByFloat(context.Background(), "user:0:balance", 1000).Result()
// Redis.IncrByFloat(context.Background(), "user:2:balance", 1000).Result()
}
func main() {
// 禁用控制台颜色,将日志写入文件时不需要控制台颜色。
gin.DisableConsoleColor()
// 记录到文件。
filename := time.Now().Format("20060102150405") + ".log"
f, _ := os.Create("logs/gin" + filename)
//gin.DefaultWriter = io.MultiWriter(f)
// 如果需要同时将日志写入文件和控制台,请使用以下代码。
gin.DefaultWriter = io.MultiWriter(f, os.Stdout)
log.SetOutput(gin.DefaultWriter)
r := gin.Default()
//添加跨域支持
r.Use(Cors())
r.GET("/dashboard/billing/credit_grants", checkKeyMid(), balance)
r.GET("/v1/models", handleGetModels)
r.OPTIONS("/v1/*path", handleOptions)
r.POST("/v1/chat/completions", checkKeyMid(), completions)
r.POST("/v1/completions", checkKeyMid(), completions)
r.POST("/v1/embeddings", checkKeyMid(), embeddings)
r.POST("/v1/edits", checkKeyMid(), edit)
r.POST("/v1/images/generations", checkKeyMid(), images)
//r.POST("/mock1", mockBalanceInquiry)
//r.POST("/mock2", mockBalanceConsumption)
// 定义一个GET请求测试接口
r.GET("/ping", func(c *gin.Context) {
c.JSON(200, gin.H{
"message": "pong from api2gpt",
})
})
Redis = InitRedis()
//添加测试数据
//test_redis()
r.Run("0.0.0.0:8080")
}