init project

main
lvxiu_ext 3 years ago
commit fd8cb3c976

1
.gitignore vendored

@ -0,0 +1 @@
gin.log

@ -0,0 +1,18 @@
# Build step
FROM golang:1.18 AS builder
ENV GOPROXY=https://goproxy.cn,direct
RUN mkdir -p /build
WORKDIR /build
COPY . .
RUN go build
# Final step
FROM debian:buster-slim
RUN set -x && apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y \
ca-certificates && \
rm -rf /var/lib/apt/lists/* \
EXPOSE 8080
WORKDIR /app
COPY --from=builder /build/main /app/api2u-go
ENTRYPOINT ["/app/api2u-go"]

@ -0,0 +1,22 @@
## docker 打包镜像
docker stop api2u-go
docker rm api2u-go
docker build -t api2u-go .
## docker 运行镜像
docker run -d -p 8081:8080 --name=api2u-go --env REDIS_ADDRESS=172.17.0.1:6379 api2u-go
docker run -p 8081:8080 --name=api2u-go --env REDIS_ADDRESS=172.17.0.1:6379 api2u-go
## nginx 配置
proxy_set_header Host $host;
proxy_set_header X-Real-IP $remote_addr;
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
proxy_cache off;
proxy_cache_bypass $http_pragam;
proxy_cache_revalidate on;
proxy_http_version 1.1;
proxy_buffering off;
proxy_pass http://localhost:8081/;

102
api.go

@ -0,0 +1,102 @@
package main
import (
"bytes"
"encoding/json"
"io/ioutil"
"net/http"
"github.com/gin-gonic/gin"
)
type BalanceInfo struct {
ServerAddress string `json:"server_address"`
AvailableKey string `json:"available_key"`
UserBalance float64 `json:"user_balance"`
TokenRatio float64 `json:"token_ratio"`
}
type Consumption struct {
SecretKey string `json:"secretKey"`
Model string `json:"model"`
MsgId string `json:"msgId"`
PromptTokens int `json:"promptTokens"`
CompletionTokens int `json:"completionTokens"`
TotalTokens int `json:"totalTokens"`
}
func mockBalanceInquiry(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{
"server_address": "https://gptp.any-door.cn",
"available_key": "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",
"user_balance": 10000,
"token_ratio": 1000,
})
}
func mockBalanceConsumption(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{
"success": "true",
})
}
// 余额查询api调用
func balanceInquiry(key string, model string) (*BalanceInfo, error) {
url := "http://localhost:8080/mock1?key=" + key + "&model=" + model
req, err := http.NewRequest("POST", url, nil)
req.Header.Set("Content-Type", "application/json")
client := &http.Client{}
resp, err := client.Do(req)
if err != nil {
panic(err)
}
defer resp.Body.Close()
body, err := ioutil.ReadAll(resp.Body)
if err != nil {
return nil, err
}
var balanceInfo BalanceInfo
if err := json.Unmarshal(body, &balanceInfo); err != nil {
return nil, err
}
return &balanceInfo, nil
}
// 余额消费
func balanceConsumption(key string, model string, prompt_tokens int, completion_tokens int, total_tokens int, msg_id string) (string, error) {
var data = Consumption{
SecretKey: key,
Model: model,
MsgId: msg_id,
PromptTokens: prompt_tokens,
CompletionTokens: completion_tokens,
TotalTokens: total_tokens,
}
jsonData, err := json.Marshal(data)
// 构造post请求的body
reqBody := bytes.NewBuffer(jsonData)
url := "http://121.4.100.155:8080/other/usageRecord"
req2, err := http.NewRequest("POST", url, reqBody)
// 设置http请求的header
req2.Header.Set("Content-Type", "application/json")
client := &http.Client{}
resp, err := client.Do(req2)
if err != nil {
panic(err)
}
defer resp.Body.Close()
body, err := ioutil.ReadAll(resp.Body)
if err != nil {
panic(err)
}
return string(body), nil
}

@ -0,0 +1,37 @@
module main
go 1.20
require (
github.com/bytedance/sonic v1.8.0 // indirect
github.com/cespare/xxhash/v2 v2.2.0 // indirect
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
github.com/dlclark/regexp2 v1.8.1 // indirect
github.com/gin-contrib/sse v0.1.0 // indirect
github.com/gin-gonic/gin v1.9.0 // indirect
github.com/go-playground/locales v0.14.1 // indirect
github.com/go-playground/universal-translator v0.18.1 // indirect
github.com/go-playground/validator/v10 v10.11.2 // indirect
github.com/go-redis/redis v6.15.9+incompatible // indirect
github.com/goccy/go-json v0.10.0 // indirect
github.com/google/uuid v1.3.0 // indirect
github.com/json-iterator/go v1.1.12 // indirect
github.com/klauspost/cpuid/v2 v2.0.9 // indirect
github.com/leodido/go-urn v1.2.1 // indirect
github.com/mattn/go-isatty v0.0.17 // indirect
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421 // indirect
github.com/modern-go/reflect2 v1.0.2 // indirect
github.com/pelletier/go-toml/v2 v2.0.6 // indirect
github.com/pkoukk/tiktoken-go v0.1.0 // indirect
github.com/redis/go-redis/v9 v9.0.3 // indirect
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
github.com/ugorji/go/codec v1.2.9 // indirect
golang.org/x/arch v0.0.0-20210923205945-b76863e36670 // indirect
golang.org/x/crypto v0.5.0 // indirect
golang.org/x/net v0.7.0 // indirect
golang.org/x/sys v0.5.0 // indirect
golang.org/x/text v0.7.0 // indirect
google.golang.org/protobuf v1.28.1 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)

@ -0,0 +1,87 @@
github.com/bytedance/sonic v1.5.0/go.mod h1:ED5hyg4y6t3/9Ku1R6dU/4KyJ48DZ4jPhfY1O2AihPM=
github.com/bytedance/sonic v1.8.0 h1:ea0Xadu+sHlu7x5O3gKhRpQ1IKiMrSiHttPF0ybECuA=
github.com/bytedance/sonic v1.8.0/go.mod h1:i736AoUSYt75HyZLoJW9ERYxcy6eaN6h4BZXU064P/U=
github.com/cespare/xxhash/v2 v2.2.0 h1:DC2CZ1Ep5Y4k3ZQ899DldepgrayRUGE6BBZ/cd9Cj44=
github.com/cespare/xxhash/v2 v2.2.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
github.com/chenzhuoyu/base64x v0.0.0-20211019084208-fb5309c8db06/go.mod h1:DH46F32mSOjUmXrMHnKwZdA8wcEefY7UVqBKYGjpdQY=
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 h1:qSGYFH7+jGhDF8vLC+iwCD4WpbV1EBDSzWkJODFLams=
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311/go.mod h1:b583jCggY9gE99b6G5LEC39OIiVsWj+R97kbl5odCEk=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78=
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
github.com/dlclark/regexp2 v1.8.1 h1:6Lcdwya6GjPUNsBct8Lg/yRPwMhABj269AAzdGSiR+0=
github.com/dlclark/regexp2 v1.8.1/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8=
github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE=
github.com/gin-contrib/sse v0.1.0/go.mod h1:RHrZQHXnP2xjPF+u1gW/2HnVO7nvIa9PG3Gm+fLHvGI=
github.com/gin-gonic/gin v1.9.0 h1:OjyFBKICoexlu99ctXNR2gg+c5pKrKMuyjgARg9qeY8=
github.com/gin-gonic/gin v1.9.0/go.mod h1:W1Me9+hsUSyj3CePGrd1/QrKJMSJ1Tu/0hFEH89961k=
github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA=
github.com/go-playground/locales v0.14.1/go.mod h1:hxrqLVvrK65+Rwrd5Fc6F2O76J/NuW9t0sjnWqG1slY=
github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJnYK9S473LQFuzCbDbfSFY=
github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY=
github.com/go-playground/validator/v10 v10.11.2 h1:q3SHpufmypg+erIExEKUmsgmhDTyhcJ38oeKGACXohU=
github.com/go-playground/validator/v10 v10.11.2/go.mod h1:NieE624vt4SCTJtD87arVLvdmjPAeV8BQlHtMnw9D7s=
github.com/go-redis/redis v6.15.9+incompatible h1:K0pv1D7EQUjfyoMql+r/jZqCLizCGKFlFgcHWWmHQjg=
github.com/go-redis/redis v6.15.9+incompatible/go.mod h1:NAIEuMOZ/fxfXJIrKDQDz8wamY7mA7PouImQ2Jvg6kA=
github.com/goccy/go-json v0.10.0 h1:mXKd9Qw4NuzShiRlOXKews24ufknHO7gx30lsDyokKA=
github.com/goccy/go-json v0.10.0/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I=
github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk=
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I=
github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo=
github.com/klauspost/cpuid/v2 v2.0.9 h1:lgaqFMSdTdQYdZ04uHyN2d/eKdOMyi2YLSvlQIBFYa4=
github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg=
github.com/leodido/go-urn v1.2.1 h1:BqpAaACuzVSgi/VLzGZIobT2z4v53pjosyNd9Yv6n/w=
github.com/leodido/go-urn v1.2.1/go.mod h1:zt4jvISO2HfUBqxjfIshjdMTYS56ZS/qv49ictyFfxY=
github.com/mattn/go-isatty v0.0.17 h1:BTarxUcIeDqL27Mc+vyvdWYSL28zpIhv3RoTdsLMPng=
github.com/mattn/go-isatty v0.0.17/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM=
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421 h1:ZqeYNhU3OHLH3mGKHDcjJRFFRrJa6eAM5H+CtDdOsPc=
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M=
github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk=
github.com/pelletier/go-toml/v2 v2.0.6 h1:nrzqCb7j9cDFj2coyLNLaZuJTLjWjlaz6nvTvIwycIU=
github.com/pelletier/go-toml/v2 v2.0.6/go.mod h1:eumQOmlWiOPt5WriQQqoM5y18pDHwha2N+QD+EUNTek=
github.com/pkoukk/tiktoken-go v0.1.0 h1:X1uP3+Nd8C3xe6AIGRWjchrylyaye0FDDTG22cxNQZs=
github.com/pkoukk/tiktoken-go v0.1.0/go.mod h1:BijIqAP84FMYC4XbdJgjyMpiSjusU8x0Y0W9K2t0QtU=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/redis/go-redis v6.15.9+incompatible h1:F+tnlesQSl3h9V8DdmtcYFdvkHLhbb7AgcLW6UJxnC4=
github.com/redis/go-redis v6.15.9+incompatible/go.mod h1:ic6dLmR0d9rkHSzaa0Ab3QVRZcjopJ9hSSPCrecj/+s=
github.com/redis/go-redis/v9 v9.0.3 h1:+7mmR26M0IvyLxGZUHxu4GiBkJkVDid0Un+j4ScYu4k=
github.com/redis/go-redis/v9 v9.0.3/go.mod h1:WqMKv5vnQbRuZstUwxQI195wHy+t4PuXDOjzMvcuQHk=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI=
github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08=
github.com/ugorji/go/codec v1.2.9 h1:rmenucSohSTiyL09Y+l2OCk+FrMxGMzho2+tjr5ticU=
github.com/ugorji/go/codec v1.2.9/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg=
golang.org/x/arch v0.0.0-20210923205945-b76863e36670 h1:18EFjUmQOcUvxNYSkA6jO9VAiXCnxFY6NyDX0bHDmkU=
golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
golang.org/x/crypto v0.5.0 h1:U/0M97KRkSFvyD/3FSmdP5W5swImpNgle/EHFhOsQPE=
golang.org/x/crypto v0.5.0/go.mod h1:NK/OQwhpMQP3MwtdjgLlYHnH9ebylxKWv3e0fK+mkQU=
golang.org/x/net v0.7.0 h1:rJrUqqhjsgNp7KqAIc25s9pZnjU7TUcSY7HcVZjdn1g=
golang.org/x/net v0.7.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs=
golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.5.0 h1:MUK/U/4lj1t1oPg0HfuXDN/Z1wv31ZJ/YcPiGccS4DU=
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/text v0.7.0 h1:4BRB4x83lYWy72KwLD/qYDuTu7q9PjSagHvijDw7cLo=
golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
google.golang.org/protobuf v1.28.1 h1:d0NfwRgPtno5B1Wa6L2DAG+KivqkdutMf1UhdNx175w=
google.golang.org/protobuf v1.28.1/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
rsc.io/pdf v0.1.1/go.mod h1:n8OzWcQ6Sp37PL01nO98y4iUCRdTGarVfzxY20ICaU4=

@ -0,0 +1,562 @@
package main
import (
"bufio"
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"io/ioutil"
"log"
"math/rand"
"net/http"
"net/http/httputil"
"net/url"
"os"
"strings"
"time"
"github.com/gin-gonic/gin"
"github.com/pkoukk/tiktoken-go"
"github.com/redis/go-redis/v9"
)
type Message struct {
Role string `json:"role,omitempty"`
Name string `json:"name,omitempty"`
Content string `json:"content,omitempty"`
}
type ChatRequest struct {
Stream bool `json:"stream,omitempty"`
Model string `json:"model,omitempty"`
PresencePenalty float64 `json:"presence_penalty,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
Messages []Message `json:"messages,omitempty"`
Prompt string `json:"prompt,omitempty"`
Input string `json:"input,omitempty"`
}
type ChatResponse struct {
Id string `json:"id,omitempty"`
Object string `json:"object,omitempty"`
Created int64 `json:"created,omitempty"`
Model string `json:"model,omitempty"`
Usage Usage `json:"usage,omitempty"`
Choices []Choice `json:"choices,omitempty"`
}
type Usage struct {
PromptTokens int `json:"prompt_tokens,omitempty"`
CompletionTokens int `json:"completion_tokens,omitempty"`
TotalTokens int `json:"total_tokens,omitempty"`
}
type Choice struct {
Message Message `json:"message,omitempty"`
Delta Delta `json:"delta,omitempty"`
FinishReason string `json:"finish_reason,omitempty"`
Index int `json:"index,omitempty"`
Text string `json:"text,omitempty"`
}
type Delta struct {
Content string `json:"content,omitempty"`
}
type CreditSummary struct {
Object string `json:"object"`
TotalGranted float64 `json:"total_granted"`
TotalUsed float64 `json:"total_used"`
TotalRemaining float64 `json:"total_remaining"`
}
type ListModelResponse struct {
Object string `json:"object"`
Data []Model `json:"data"`
}
type Model struct {
ID string `json:"id"`
Object string `json:"object"`
Created int `json:"created"`
OwnedBy string `json:"owned_by"`
Permission []ModelPermission `json:"permission"`
Root string `json:"root"`
Parent any `json:"parent"`
}
type ModelPermission struct {
ID string `json:"id"`
Object string `json:"object"`
Created int `json:"created"`
AllowCreateEngine bool `json:"allow_create_engine"`
AllowSampling bool `json:"allow_sampling"`
AllowLogprobs bool `json:"allow_logprobs"`
AllowSearchIndices bool `json:"allow_search_indices"`
AllowView bool `json:"allow_view"`
AllowFineTuning bool `json:"allow_fine_tuning"`
Organization string `json:"organization"`
Group any `json:"group"`
IsBlocking bool `json:"is_blocking"`
}
type UserInfo struct {
UID string `json:"uid"`
SID string `json:"sid"`
}
type ServerInfo struct {
ServerAddress string `json:"server_address"`
AvailableKey string `json:"available_key"`
}
type ModelInfo struct {
ModelName string `json:"model_name"`
ModelPrice float64 `json:"model_price"`
ModelPrepayment int `json:"model_prepayment"`
}
var (
Redis *redis.Client
RedisAddress = "localhost:6379"
)
func init() {
//gin.SetMode(gin.ReleaseMode)
if v := os.Getenv("REDIS_ADDRESS"); v != "" {
RedisAddress = v
}
log.Printf("loading redis address: %s", RedisAddress)
}
// redis 初始化
func InitRedis() *redis.Client {
rdb := redis.NewClient(&redis.Options{
Addr: RedisAddress,
Password: "", // no password set
DB: 0, // use default DB
PoolSize: 10,
})
result := rdb.Ping(context.Background())
fmt.Println("redis ping:", result.Val())
if result.Val() != "PONG" {
// 连接有问题
return nil
}
return rdb
}
// 计算Messages中的token数量
func numTokensFromMessages(messages []Message, model string) int {
tkm, err := tiktoken.EncodingForModel(model)
if err != nil {
err = fmt.Errorf("getEncoding: %v", err)
panic(err)
}
numTokens := 0
for _, message := range messages {
numTokens += len(tkm.Encode(message.Content, nil, nil))
numTokens += 6
}
numTokens += 3
return numTokens
}
// 计算String中的token数量
func numTokensFromString(msg string, model string) int {
tkm, err := tiktoken.EncodingForModel(model)
if err != nil {
err = fmt.Errorf("getEncoding: %v", err)
panic(err)
}
if model == "text-davinci-003" {
return len(tkm.Encode(msg, nil, nil)) + 1
} else {
return len(tkm.Encode(msg, nil, nil)) + 9
}
}
func embeddings(c *gin.Context) {
var prompt_tokens int
var total_tokens int
var chatRequest ChatRequest
if err := c.ShouldBindJSON(&chatRequest); err != nil {
c.AbortWithStatusJSON(400, gin.H{"error": err.Error()})
return
}
auth := c.Request.Header.Get("Authorization")
key := strings.Trim(auth, "Bearer ")
//根据KEY调用用户余额接口判断是否有足够的余额 后期可考虑判断max_tokens参数来调整
serverInfo, err := checkBlance(key, chatRequest.Model)
if err != nil {
c.AbortWithStatusJSON(403, gin.H{"error": err.Error()})
log.Printf("请求出错 KEY: %v Model: %v ERROR: %v", key, chatRequest.Model, err)
return
}
log.Printf("请求的KEY: %v Model: %v", key, chatRequest.Model)
remote, err := url.Parse(serverInfo.ServerAddress)
if err != nil {
panic(err)
}
proxy := httputil.NewSingleHostReverseProxy(remote)
newReqBody, err := json.Marshal(chatRequest)
if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
proxy.Director = func(req *http.Request) {
req.Header = c.Request.Header
req.Host = remote.Host
req.URL.Scheme = remote.Scheme
req.URL.Host = remote.Host
req.URL.Path = c.Request.URL.Path
req.ContentLength = int64(len(newReqBody))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept-Encoding", "")
serverKey := serverInfo.AvailableKey
keyList := strings.Split(serverInfo.AvailableKey, ",")
if len(keyList) > 1 {
// 随机数种子
rand.Seed(time.Now().UnixNano())
// 从数组中随机选择一个元素
serverKey = keyList[rand.Intn(len(keyList))]
}
req.Header.Set("Authorization", "Bearer "+serverKey)
req.Body = ioutil.NopCloser(bytes.NewReader(newReqBody))
}
sss, err := json.Marshal(chatRequest)
log.Printf("开始处理返回逻辑 %d", string(sss))
proxy.ModifyResponse = func(resp *http.Response) error {
var chatResponse ChatResponse
body, err := ioutil.ReadAll(resp.Body)
if err != nil {
return err
}
json.Unmarshal(body, &chatResponse)
prompt_tokens = chatResponse.Usage.PromptTokens
total_tokens = chatResponse.Usage.TotalTokens
resp.Body = ioutil.NopCloser(bytes.NewReader(body))
log.Printf("prompt_tokens: %v total_tokens: %v", prompt_tokens, total_tokens)
timestamp := time.Now().Unix()
timestampID := "emb-" + fmt.Sprintf("%d", timestamp)
//消费余额
consumption(key, chatRequest.Model, prompt_tokens, 0, total_tokens, timestampID)
return nil
}
proxy.ServeHTTP(c.Writer, c.Request)
}
func completions(c *gin.Context) {
var prompt_tokens int
var completion_tokens int
var total_tokens int
var chatRequest ChatRequest
if err := c.ShouldBindJSON(&chatRequest); err != nil {
c.AbortWithStatusJSON(400, gin.H{"error": err.Error()})
return
}
auth := c.Request.Header.Get("Authorization")
key := strings.Trim(auth, "Bearer ")
//根据KEY调用用户余额接口判断是否有足够的余额 后期可考虑判断max_tokens参数来调整
serverInfo, err := checkBlance(key, chatRequest.Model)
if err != nil {
c.AbortWithStatusJSON(403, gin.H{"error": err.Error()})
log.Printf("请求出错 KEY: %v Model: %v ERROR: %v", key, chatRequest.Model, err)
return
}
log.Printf("请求的KEY: %v Model: %v", key, chatRequest.Model)
remote, err := url.Parse(serverInfo.ServerAddress)
if err != nil {
panic(err)
}
proxy := httputil.NewSingleHostReverseProxy(remote)
newReqBody, err := json.Marshal(chatRequest)
if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
proxy.Director = func(req *http.Request) {
req.Header = c.Request.Header
req.Host = remote.Host
req.URL.Scheme = remote.Scheme
req.URL.Host = remote.Host
req.URL.Path = c.Request.URL.Path
req.ContentLength = int64(len(newReqBody))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept-Encoding", "")
serverKey := serverInfo.AvailableKey
keyList := strings.Split(serverInfo.AvailableKey, ",")
if len(keyList) > 1 {
// 随机数种子
rand.Seed(time.Now().UnixNano())
// 从数组中随机选择一个元素
serverKey = keyList[rand.Intn(len(keyList))]
}
req.Header.Set("Authorization", "Bearer "+serverKey)
req.Body = ioutil.NopCloser(bytes.NewReader(newReqBody))
}
sss, err := json.Marshal(chatRequest)
log.Printf("开始处理返回逻辑 %d", string(sss))
if chatRequest.Stream {
// 流式回应,处理
proxy.ModifyResponse = func(resp *http.Response) error {
chatRequestId := ""
reqContent := ""
reader := bufio.NewReader(resp.Body)
headers := resp.Header
for k, v := range headers {
c.Writer.Header().Set(k, v[0])
}
for {
chunk, err := reader.ReadBytes('\n')
if err != nil {
if err == io.EOF {
break
}
return err
}
var chatResponse ChatResponse
//去除回应中的data:前缀
var trimStr = strings.Trim(string(chunk), "data: ")
if trimStr != "\n" {
json.Unmarshal([]byte(trimStr), &chatResponse)
//log.Printf("trimStr:" + trimStr)
if chatResponse.Choices != nil {
reqContent += chatResponse.Choices[0].Delta.Content
chatRequestId = chatResponse.Id
}
// 写回数据
_, err = c.Writer.Write([]byte(string(chunk) + "\n"))
if err != nil {
return err
}
c.Writer.(http.Flusher).Flush()
}
}
if chatRequest.Model == "text-davinci-003" {
prompt_tokens = numTokensFromString(chatRequest.Prompt, chatRequest.Model)
} else {
prompt_tokens = numTokensFromMessages(chatRequest.Messages, chatRequest.Model)
}
completion_tokens = numTokensFromString(reqContent, chatRequest.Model)
total_tokens = prompt_tokens + completion_tokens
log.Printf("prompt_tokens: %v completion_tokens: %v total_tokens: %v", prompt_tokens, completion_tokens, total_tokens)
//消费余额
consumption(key, chatRequest.Model, prompt_tokens, completion_tokens, total_tokens, chatRequestId)
return nil
}
} else {
// 非流式回应,处理
proxy.ModifyResponse = func(resp *http.Response) error {
var chatResponse ChatResponse
body, err := ioutil.ReadAll(resp.Body)
if err != nil {
return err
}
json.Unmarshal(body, &chatResponse)
prompt_tokens = chatResponse.Usage.PromptTokens
completion_tokens = chatResponse.Usage.CompletionTokens
total_tokens = chatResponse.Usage.TotalTokens
resp.Body = ioutil.NopCloser(bytes.NewReader(body))
log.Printf("prompt_tokens: %v completion_tokens: %v total_tokens: %v", prompt_tokens, completion_tokens, total_tokens)
//消费余额
consumption(key, chatRequest.Model, prompt_tokens, completion_tokens, total_tokens, chatResponse.Id)
return nil
}
}
proxy.ServeHTTP(c.Writer, c.Request)
}
// model查询列表
func handleGetModels(c *gin.Context) {
// BUGFIX: fix options request, see https://github.com/diemus/azure-openai-proxy/issues/3
//models := []string{"gpt-4", "gpt-4-0314", "gpt-4-32k", "gpt-4-32k-0314", "gpt-3.5-turbo", "gpt-3.5-turbo-0301", "text-davinci-003", "text-embedding-ada-002"}
models := []string{"gpt-3.5-turbo", "gpt-3.5-turbo-0301"}
result := ListModelResponse{
Object: "list",
}
for _, model := range models {
result.Data = append(result.Data, Model{
ID: model,
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: []ModelPermission{
{
ID: "",
Object: "model",
Created: 1679602087,
AllowCreateEngine: true,
AllowSampling: true,
AllowLogprobs: true,
AllowSearchIndices: true,
AllowView: true,
AllowFineTuning: true,
Organization: "*",
Group: nil,
IsBlocking: false,
},
},
Root: model,
Parent: nil,
})
}
c.JSON(200, result)
}
// 针对接口预检OPTIONS的处理
func handleOptions(c *gin.Context) {
// BUGFIX: fix options request, see https://github.com/diemus/azure-openai-proxy/issues/1
c.Header("Access-Control-Allow-Origin", "*")
c.Header("Access-Control-Allow-Methods", "POST, GET, OPTIONS, PUT, DELETE")
c.Header("Access-Control-Allow-Headers", "Content-Type, Authorization")
c.Status(200)
return
}
// 余额查询
func balance(c *gin.Context) {
auth := c.Request.Header.Get("Authorization")
key := strings.Trim(auth, "Bearer ")
balance, err := queryBlance(key)
if err != nil {
c.JSON(400, gin.H{"error": err.Error()})
}
var creditSummary CreditSummary
creditSummary.Object = "credit_grant"
creditSummary.TotalGranted = 999999
creditSummary.TotalUsed = 999999 - balance
creditSummary.TotalRemaining = balance
c.JSON(200, creditSummary)
}
// key 检测中间件
func checkKeyMid() gin.HandlerFunc {
return func(c *gin.Context) {
auth := c.Request.Header.Get("Authorization")
if auth == "" {
c.AbortWithStatusJSON(401, gin.H{"code": 40001})
} else {
key := strings.Trim(auth, "Bearer ")
msg, err := checkKeyAndTimeCount(key)
if err != nil {
c.AbortWithStatusJSON(msg, gin.H{"code": err.Error()})
}
}
// 执行函数
c.Next()
}
}
func main() {
// 禁用控制台颜色,将日志写入文件时不需要控制台颜色。
gin.DisableConsoleColor()
// 记录到文件。
f, _ := os.Create("gin.log")
//gin.DefaultWriter = io.MultiWriter(f)
// 如果需要同时将日志写入文件和控制台,请使用以下代码。
gin.DefaultWriter = io.MultiWriter(f, os.Stdout)
log.SetOutput(gin.DefaultWriter)
r := gin.Default()
r.GET("/v1/models", handleGetModels)
r.GET("/proxy/v1/models", handleGetModels)
r.GET("/dashboard/billing/credit_grants", checkKeyMid(), balance)
r.GET("/proxy/dashboard/billing/credit_grants", checkKeyMid(), balance)
r.OPTIONS("/v1/*path", handleOptions)
r.OPTIONS("/proxy/v1/*path", handleOptions)
r.POST("/v1/chat/completions", checkKeyMid(), completions)
r.POST("/proxy/v1/chat/completions", checkKeyMid(), completions)
r.POST("/v1/completions", checkKeyMid(), completions)
r.POST("/proxy/v1/completions", checkKeyMid(), completions)
r.POST("/v1/embeddings", checkKeyMid(), embeddings)
r.POST("/proxy/v1/embeddings", checkKeyMid(), embeddings)
r.POST("/mock1", mockBalanceInquiry)
r.POST("/mock2", mockBalanceConsumption)
Redis = InitRedis()
//添加reids测试数据
var serverInfo ServerInfo = ServerInfo{
ServerAddress: "https://gptp.any-door.cn",
AvailableKey: "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,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,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",
}
var serverInfo2 ServerInfo = ServerInfo{
ServerAddress: "https://azure.any-door.cn",
AvailableKey: "6c4d2c65970b40e482e7cd27adb0d119",
}
serverInfoStr, _ := json.Marshal(&serverInfo)
serverInfoStr2, _ := json.Marshal(&serverInfo2)
Redis.Set(context.Background(), "server:1", serverInfoStr, 0)
Redis.Set(context.Background(), "server:2", serverInfoStr2, 0)
var modelInfo ModelInfo = ModelInfo{
ModelName: "gpt-3.5-turbo",
ModelPrice: 0.0001,
ModelPrepayment: 4000,
}
modelInfoStr, _ := json.Marshal(&modelInfo)
var modelInfo2 ModelInfo = ModelInfo{
ModelName: "text-davinci-003",
ModelPrice: 0.001,
ModelPrepayment: 4000,
}
modelInfoStr2, _ := json.Marshal(&modelInfo2)
var modelInfo3 ModelInfo = ModelInfo{
ModelName: "text-davinci-003",
ModelPrice: 0.001,
ModelPrepayment: 4000,
}
modelInfoStr3, _ := json.Marshal(&modelInfo3)
Redis.Set(context.Background(), "model:gpt-3.5-turbo", modelInfoStr, 0)
Redis.Set(context.Background(), "model:gpt-3.5-turbo-0301", modelInfoStr, 0)
Redis.Set(context.Background(), "model:text-davinci-003", modelInfoStr2, 0)
Redis.Set(context.Background(), "model:text-embedding-ada-002", modelInfoStr3, 0)
var userInfo UserInfo = UserInfo{
UID: "1",
SID: "1",
}
var userInfo2 UserInfo = UserInfo{
UID: "2",
SID: "2",
}
userInfoStr, _ := json.Marshal(&userInfo)
userInfoStr2, _ := json.Marshal(&userInfo2)
Redis.Set(context.Background(), "user:8aeb3747-715c-48e8-8b80-aec815949f22", userInfoStr, 0)
Redis.Set(context.Background(), "user:key2", userInfoStr2, 0)
Redis.IncrByFloat(context.Background(), "user:1:balance", 1000).Result()
Redis.IncrByFloat(context.Background(), "user:2:balance", 1000).Result()
//r.Run("127.0.0.1:8080")
//docker下使用
r.Run("0.0.0.0:8080")
}

@ -0,0 +1,125 @@
package main
import (
"context"
"encoding/json"
"errors"
"log"
"time"
)
// 检测key是否存在是否超出每分钟请求次数
func checkKeyAndTimeCount(key string) (int, error) {
var timeOut = 60 * time.Second
var timeCount = 30
userInfoStr, err := Redis.Get(context.Background(), "user:"+key).Result()
if err != nil {
//用户不存在
return 401, errors.New("40003")
}
var userInfo UserInfo
err = json.Unmarshal([]byte(userInfoStr), &userInfo)
if err != nil {
//用户状态异常
return 401, errors.New("40004")
}
count, err := Redis.Incr(context.Background(), "user:count:"+key).Result()
if err != nil {
return 500, errors.New("系统计数器设置异常")
}
if count == 1 {
_, err := Redis.Expire(context.Background(), "user:count:"+key, timeOut).Result()
if err != nil {
return 500, errors.New("系统计数器异常")
}
}
// 如果请求次数超出限制,则中断请求
if count > int64(timeCount) {
//您的账户请求次数过多,超过分钟配额
return 429, errors.New("42901")
}
return 200, nil
}
func queryBlance(key string) (float64, error) {
userInfoStr, err := Redis.Get(context.Background(), "user:"+key).Result()
var userInfo UserInfo
err = json.Unmarshal([]byte(userInfoStr), &userInfo)
balance, err := Redis.IncrByFloat(context.Background(), "user:"+userInfo.SID+":balance", 0).Result()
if err != nil {
return 0, errors.New("余额计算失败")
}
return balance, nil
}
// 余额查询
func checkBlance(key string, model string) (ServerInfo, error) {
var serverInfo ServerInfo
//获取用户信息
userInfoStr, err := Redis.Get(context.Background(), "user:"+key).Result()
var userInfo UserInfo
err = json.Unmarshal([]byte(userInfoStr), &userInfo)
//获取服务器信息
serverInfoStr, err := Redis.Get(context.Background(), "server:"+userInfo.SID).Result()
if err != nil {
return serverInfo, errors.New("服务器信息不存在")
}
err = json.Unmarshal([]byte(serverInfoStr), &serverInfo)
if err != nil {
return serverInfo, errors.New("服务器信息解析失败")
}
//获取模型价格
modelPriceStr, err := Redis.Get(context.Background(), "model:"+model).Result()
if err != nil {
return serverInfo, errors.New("模型信息不存在")
}
var modelInfo ModelInfo
err = json.Unmarshal([]byte(modelPriceStr), &modelInfo)
if err != nil {
return serverInfo, errors.New("模型信息解析失败")
}
//计算余额-先扣除指定金额
balance, err := Redis.IncrByFloat(context.Background(), "user:"+userInfo.SID+":balance", -(float64(modelInfo.ModelPrepayment) * modelInfo.ModelPrice)).Result()
if err != nil {
return serverInfo, errors.New("余额计算失败")
}
if balance < 0 {
Redis.IncrByFloat(context.Background(), "user:"+userInfo.SID+":balance", float64(modelInfo.ModelPrepayment)*modelInfo.ModelPrice).Result()
return serverInfo, errors.New("用户余额不足")
}
return serverInfo, nil
}
// 余额消费
func consumption(key string, model string, prompt_tokens int, completion_tokens int, total_tokens int, msg_id string) (string, error) {
//获取用户信息
userInfoStr, err := Redis.Get(context.Background(), "user:"+key).Result()
if err != nil {
return "", errors.New("用户信息不存在")
}
var userInfo UserInfo
err = json.Unmarshal([]byte(userInfoStr), &userInfo)
if err != nil {
return "", errors.New("用户信息解析失败")
}
//获取模型价格
modelPriceStr, err := Redis.Get(context.Background(), "model:"+model).Result()
if err != nil {
return "", errors.New("模型信息不存在")
}
var modelInfo ModelInfo
err = json.Unmarshal([]byte(modelPriceStr), &modelInfo)
if err != nil {
return "", errors.New("模型信息解析失败")
}
Redis.IncrByFloat(context.Background(), "user:"+userInfo.SID+":balance", float64(modelInfo.ModelPrepayment)*modelInfo.ModelPrice-(float64(total_tokens)*modelInfo.ModelPrice)).Result()
// 余额消费日志请求
result, err := balanceConsumption(key, model, prompt_tokens, completion_tokens, total_tokens, msg_id)
if err != nil {
return "", err
}
log.Printf("扣费KEY: %s 扣费token数: %d 扣费结果 %s", key, total_tokens, result)
return result, nil
}
Loading…
Cancel
Save