Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 5 additions & 8 deletions niinii/src/app.rs
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ impl App {
}
}

fn transition(&mut self, ui: &Ui, state: State) {
fn transition(&mut self, _ui: &Ui, state: State) {
self.state = state;
}

Expand Down Expand Up @@ -213,14 +213,11 @@ impl App {
}
}

match &self.state {
State::Completed => {
if let Some(request_gloss_text) = self.request_gloss_text.clone() {
self.request_gloss_text = None;
self.request_parse(ui, &request_gloss_text);
}
if let State::Completed = &self.state {
if let Some(request_gloss_text) = self.request_gloss_text.clone() {
self.request_gloss_text = None;
self.request_parse(ui, &request_gloss_text);
}
_ => (),
};

if self.settings.watch_clipboard {
Expand Down
2 changes: 2 additions & 0 deletions niinii/src/settings.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ pub struct ChatSettings {
pub presence_penalty: Option<f32>,
pub connection_timeout: u64,
pub timeout: u64,
pub stream: bool,
}
impl Default for ChatSettings {
fn default() -> Self {
Expand All @@ -51,6 +52,7 @@ impl Default for ChatSettings {
presence_penalty: None,
connection_timeout: 3000,
timeout: 10000,
stream: true,
}
}
}
Expand Down
69 changes: 38 additions & 31 deletions niinii/src/translator/chat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ impl Translator for ChatTranslator {
let chatgpt = &settings.chat;

let permit = self.semaphore.clone().acquire_owned().await.unwrap();
let mut exchange = {
let exchange = {
let mut chat = self.chat.lock().await;
chat.start_exchange(
Message {
Expand All @@ -74,47 +74,54 @@ impl Translator for ChatTranslator {
messages: exchange.prompt(),
temperature: chatgpt.temperature,
top_p: chatgpt.top_p,
max_tokens: chatgpt.max_tokens,
max_completion_tokens: chatgpt.max_tokens,
presence_penalty: chatgpt.presence_penalty,
..Default::default()
};

let exchange = Arc::new(Mutex::new(exchange));
let mut stream = self.client.stream(chat_request).await?;
let token = CancellationToken::new();
let chat = &self.chat;
tokio::spawn(
enclose! { (chat, token, exchange, chatgpt.max_context_tokens => max_context_tokens) async move {
// Hold permit: We are not allowed to begin another translation
// request until this one is complete.
let _permit = permit;
loop {
tokio::select! {
msg = stream.next() => match msg {
Some(Ok(completion)) => {
let mut exchange = exchange.lock().await;
let message = &completion.choices.first().unwrap().delta;
exchange.append(message)
if chatgpt.stream {
let mut stream = self.client.stream(chat_request).await?;
let chat = &self.chat;
tokio::spawn(
enclose! { (chat, token, exchange, chatgpt.max_context_tokens => max_context_tokens) async move {
// Hold permit: We are not allowed to begin another translation
// request until this one is complete.
let _permit = permit;
loop {
tokio::select! {
msg = stream.next() => match msg {
Some(Ok(completion)) => {
let mut exchange = exchange.lock().await;
let message = &completion.choices.first().unwrap().delta;
exchange.append_partial(message)
},
Some(Err(err)) => {
tracing::error!(%err, "stream");
break
},
None => {
let mut chat = chat.lock().await;
let mut exchange = exchange.lock().await;
chat.commit(&mut exchange);
chat.enforce_context_limit(max_context_tokens);
break
}
},
Some(Err(err)) => {
tracing::error!(%err, "stream");
break
},
None => {
let mut chat = chat.lock().await;
let mut exchange = exchange.lock().await;
chat.commit(&mut exchange);
chat.enforce_context_limit(max_context_tokens);
_ = token.cancelled() => {
break
}
},
_ = token.cancelled() => {
break
}
}
}
}.instrument(tracing::Span::current())},
);
}.instrument(tracing::Span::current())},
);
} else {
let response = self.client.chat(chat_request).await?;
let mut e = exchange.lock().await;
e.set_complete(response.choices.first().unwrap().message.clone());
self.chat.lock().await.commit(&mut e);
}

Ok(Box::new(ChatTranslation {
model: chatgpt.model,
Expand Down
21 changes: 15 additions & 6 deletions niinii/src/view/translator/chat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use openai_chat::chat::{Model, Role};
use crate::{
settings::Settings,
translator::chat::{ChatTranslation, ChatTranslator},
view::mixins::drag_handle,
view::mixins::{drag_handle, help_marker},
};

use crate::view::{
Expand All @@ -30,7 +30,10 @@ impl View for ViewChatTranslator<'_> {
}
});
if ui.collapsing_header("Tuning", TreeNodeFlags::DEFAULT_OPEN) {
if let Some(_token) = ui.begin_table("##", 2) {
if let Some(_token) = ui.begin_table("##", 3) {
ui.table_next_column();
ui.set_next_item_width(ui.current_font_size() * -8.0);
combo_enum(ui, "Model", &mut chatgpt.model);
ui.table_next_column();
ui.set_next_item_width(ui.current_font_size() * -8.0);
ui.input_scalar("Max context tokens", &mut chatgpt.max_context_tokens)
Expand Down Expand Up @@ -75,8 +78,12 @@ impl View for ViewChatTranslator<'_> {
},
);
ui.table_next_column();
ui.set_next_item_width(ui.current_font_size() * -8.0);
combo_enum(ui, "Model", &mut chatgpt.model);
ui.checkbox("Stream", &mut chatgpt.stream);
ui.same_line();
help_marker(
ui,
"Use streaming API (may require ID verification for some models)",
);
}
}
ui.child_window("context_window").build(|| {
Expand Down Expand Up @@ -194,13 +201,15 @@ impl View for ViewChatTranslation<'_> {
let _wrap_token = ui.push_text_wrap_pos_with_pos(0.0);
ui.text(""); // anchor for line wrapping
ui.same_line();
let ChatTranslation { exchange, .. } = self.0;
let ChatTranslation {
model, exchange, ..
} = self.0;
let exchange = exchange.blocking_lock();
let draw_list = ui.get_window_draw_list();
stroke_text_with_highlight(
ui,
&draw_list,
"[ChatGPT]",
&format!("[{}]", std::convert::Into::<&'static str>::into(model)),
1.0,
Some(StyleColor::NavHighlight),
);
Expand Down
6 changes: 5 additions & 1 deletion openai-chat/src/chat/chat_buffer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ pub struct Exchange {
usage: Option<Usage>,
}
impl Exchange {
pub fn append(&mut self, partial: &PartialMessage) {
pub fn append_partial(&mut self, partial: &PartialMessage) {
if let Some(last) = &mut self.response {
if let Some(content) = &mut last.content {
content.push_str(&partial.content)
Expand All @@ -104,6 +104,10 @@ impl Exchange {
}
}

pub fn set_complete(&mut self, message: Message) {
self.response = Some(message);
}

pub fn prompt(&self) -> Vec<Message> {
let mut messages = vec![];
messages.push(self.system.clone());
Expand Down
91 changes: 60 additions & 31 deletions openai-chat/src/chat/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ pub use crate::protocol::chat::{Message, Model, PartialMessage, Role, Usage};
pub use chat_buffer::{ChatBuffer, Exchange};

use crate::{
protocol::chat::{self, StreamResponse},
protocol::chat::{self, ChatResponse, StreamOptions, StreamResponse},
Client, Error,
};

Expand All @@ -37,7 +37,7 @@ pub struct Request {
/// The maximum number of tokens allowed for the generated answer. By
/// default, the number of tokens the model can return will be (4096 - prompt
/// tokens).
pub max_tokens: Option<u32>,
pub max_completion_tokens: Option<u32>,
/// Number between -2.0 and 2.0. Positive values penalize new tokens based
/// on whether they appear in the text so far, increasing the model's
/// likelihood to talk about new topics.
Expand All @@ -51,7 +51,7 @@ impl From<Request> for chat::Request {
temperature,
top_p,
n,
max_tokens,
max_completion_tokens,
presence_penalty,
} = value;
chat::Request {
Expand All @@ -60,7 +60,7 @@ impl From<Request> for chat::Request {
temperature,
top_p,
n,
max_tokens,
max_completion_tokens,
presence_penalty,
..Default::default()
}
Expand All @@ -71,7 +71,7 @@ impl Client {
#[tracing::instrument(level = Level::DEBUG, skip_all, err)]
pub async fn chat(&self, request: Request) -> Result<chat::Completion, Error> {
let request: chat::Request = request.into();
tracing::trace!(?request);
tracing::debug!(?request);
let response: chat::ChatResponse = self
.shared
.request(Method::POST, "/v1/chat/completions")
Expand All @@ -80,7 +80,7 @@ impl Client {
.await?
.json()
.await?;
tracing::trace!(?response);
tracing::debug!(?response);
Ok(response.0?)
}

Expand All @@ -91,39 +91,68 @@ impl Client {
) -> Result<impl Stream<Item = Result<chat::PartialCompletion, Error>>, Error> {
let mut request: chat::Request = request.into();
request.stream = Some(true);
request.stream_options = Some(StreamOptions {
include_obfuscation: false,
include_usage: false,
});

tracing::trace!(?request);
let stream = self
tracing::debug!(?request);
let response = self
.shared
.request(Method::POST, "/v1/chat/completions")
.body(&request)
.send()
.await?
.bytes_stream()
.eventsource();
Ok(stream.map_while(|event| match event {
Ok(event) => {
if event.data == "[DONE]" {
None
} else {
let response = match serde_json::from_str::<StreamResponse>(&event.data) {
Ok(response) => {
tracing::trace!(?response);
Ok::<_, Error>(response.0)
}
Err(err) => {
tracing::error!(?err, ?event.data);
Err(err.into())
.await?;
let status = response.status();

if status.is_success() {
// HTTP success: Expect SSE response
let stream = response.bytes_stream().eventsource();
Ok(stream.map_while(|event| {
tracing::trace!(?event);
match event {
Ok(event) => {
if event.data == "[DONE]" {
None
} else {
let response = match serde_json::from_str::<StreamResponse>(&event.data)
{
Ok(response) => {
tracing::debug!(?response);
Ok::<_, Error>(response.0)
}
Err(err) => {
// Serde error
tracing::error!(?err, ?event.data);
Err(err.into())
}
};
Some(response)
}
};
Some(response)
}
Err(err) => {
// SSE error
tracing::error!(?err);
Some(Err(err.into()))
}
}
}))
} else {
// HTTP error: Expect JSON response
let response_err = response.error_for_status_ref().unwrap_err();
let chat_response = response.json::<ChatResponse>().await;
match chat_response {
Ok(err) => {
// OpenAI application error
Err(Error::Protocol(err.0.unwrap_err()))
}
Err(err) => {
// Not application error, return HTTP error
tracing::error!(?response_err, ?err, "unexpected stream response");
Err(response_err.into())
}
}
Err(err) => {
tracing::error!(?err);
Some(Err(err.into()))
}
}))
}
}
}

Expand Down
Loading
Loading