Stripped out ui/api

This commit is contained in:
Benjamin Sherriff
2024-09-03 22:32:43 -04:00
committed by Benjamin Sherriff
parent c83d398ce0
commit 96fe3fc0e5
152 changed files with 110 additions and 10056 deletions

3
src/bot/oai/mod.rs Normal file
View File

@@ -0,0 +1,3 @@
mod model;
pub use model::*;

160
src/bot/oai/model.rs Normal file
View File

@@ -0,0 +1,160 @@
use serde::{Serialize, Deserialize};
use serde_json::Value;
use siren::ServiceError;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum GPTRole {
#[serde(rename = "system")]
System,
#[serde(rename = "user")]
User,
#[serde(rename = "assistant")]
Assistant,
#[serde(rename = "function")]
Function,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChatCompletionRequest {
pub model: String,
pub messages: Vec<ChatCompletionMessage>,
/// Value between 0 and 2
#[serde(skip_serializing_if = "Option::is_none")]
pub temperature: Option<f64>,
/// Value between 0 and 1
#[serde(skip_serializing_if = "Option::is_none")]
pub top_p: Option<f64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub n: Option<f64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub max_tokens: Option<i64>,
/// Value between -2.0 and 2.0
#[serde(skip_serializing_if = "Option::is_none")]
pub presence_penalty: Option<f64>,
/// Value between -2.0 and 2.0
#[serde(skip_serializing_if = "Option::is_none")]
pub frequency_penalty: Option<f64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub user: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChatCompletionMessage {
pub role: GPTRole,
pub content: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChatCompletionResponse {
pub id: String,
pub object: String,
pub system_fingerprint: Option<String>,
pub created: i64,
pub model: String,
pub usage: Usage,
pub choices: Vec<Choice>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Usage {
pub prompt_tokens: i64,
pub completion_tokens: i64,
pub total_tokens: i64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Choice {
pub message: ChatCompletionMessage,
pub finish_reason: String,
pub index: i64,
pub logprobs: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(untagged)]
enum ResponseEvent {
ChatCompletionResponse(ChatCompletionResponse),
ResponseError(ResponseError),
// ChatCompletionResponse {
// id: String,
// object: String,
// system_fingerprint: Option<String>,
// created: i64,
// model: String,
// usage: Usage,
// choices: Vec<Choice>,
// },
// ResponseError {
// error: Option<ErrorDetails>,
// message: Option<String>,
// param: Option<String>,
// #[serde(rename = "type")]
// error_type: Option<String>,
// },
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct ResponseError {
error: Option<ErrorDetails>,
message: Option<String>,
param: Option<String>,
#[serde(rename = "type")]
error_type: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct ErrorDetails {
code: Option<String>,
message: Option<String>,
}
pub struct OAI {
pub client: reqwest::Client,
pub base_url: String,
pub service_url: String,
pub max_attempts: i64,
pub token: String,
pub max_tokens: i64,
pub default_model: String,
pub max_context_questions: i64,
}
impl OAI {
pub async fn chat_completion(
&self,
request: ChatCompletionRequest,
) -> Result<ChatCompletionResponse, ServiceError> {
let url = format!("{}/chat/completions", self.base_url);
let response = self
.client
.post(&url)
.bearer_auth(&self.token)
.header("Content-Type", "application/json".to_string())
.json(&request)
.send()
.await;
match response {
Ok(response) => {
let value = response.json::<Value>().await?;
let event: ResponseEvent = serde_json::from_value::<ResponseEvent>(value)?;
match event {
ResponseEvent::ChatCompletionResponse(response) => {
return Ok(response);
},
ResponseEvent::ResponseError(error) => {
return Err(ServiceError {
status: 500,
message: format!("Error: {}", error.message.unwrap()),
});
},
}
}
Err(err) => {
return Err(ServiceError {
status: 500,
message: format!("Error: {}", err),
})
}
}
}
}