You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
51 lines
1.1 KiB
51 lines
1.1 KiB
|
3 years ago
|
package common
|
||
|
|
|
||
|
|
import (
|
||
|
|
"api2gpt-mid/model"
|
||
|
|
"fmt"
|
||
|
|
"github.com/pkoukk/tiktoken-go"
|
||
|
|
"strings"
|
||
|
|
)
|
||
|
|
|
||
|
|
// 计算Messages中的token数量
|
||
|
|
func NumTokensFromMessages(messages []model.Message, model string) int {
|
||
|
|
if strings.Contains(model, "gpt-3.5") {
|
||
|
|
model = "gpt-3.5-turbo"
|
||
|
|
}
|
||
|
|
if strings.Contains(model, "gpt-4") {
|
||
|
|
model = "gpt-4"
|
||
|
|
}
|
||
|
|
tkm, err := tiktoken.EncodingForModel(model)
|
||
|
|
if err != nil {
|
||
|
|
err = fmt.Errorf("getEncoding: %v", err)
|
||
|
|
return 0
|
||
|
|
}
|
||
|
|
numTokens := 0
|
||
|
|
for _, message := range messages {
|
||
|
|
numTokens += len(tkm.Encode(message.Content, nil, nil))
|
||
|
|
numTokens += 6
|
||
|
|
}
|
||
|
|
numTokens += 3
|
||
|
|
return numTokens
|
||
|
|
}
|
||
|
|
|
||
|
|
// 计算String中的token数量
|
||
|
|
func NumTokensFromString(msg string, model string) int {
|
||
|
|
if strings.Contains(model, "gpt-3.5") {
|
||
|
|
model = "gpt-3.5-turbo"
|
||
|
|
}
|
||
|
|
if strings.Contains(model, "gpt-4") {
|
||
|
|
model = "gpt-4"
|
||
|
|
}
|
||
|
|
tkm, err := tiktoken.EncodingForModel(model)
|
||
|
|
if err != nil {
|
||
|
|
err = fmt.Errorf("getEncoding: %v", err)
|
||
|
|
return 0
|
||
|
|
}
|
||
|
|
if model == "text-davinci-003" {
|
||
|
|
return len(tkm.Encode(msg, nil, nil)) + 1
|
||
|
|
} else {
|
||
|
|
return len(tkm.Encode(msg, nil, nil)) + 9
|
||
|
|
}
|
||
|
|
}
|