From 7ac4b11dab1fa5620081a7b8558bc1ccef063bcd Mon Sep 17 00:00:00 2001 From: hallucinomeny Date: Sun, 8 Sep 2024 22:12:52 +0300 Subject: [PATCH 1/2] Add Monte Carlo Weave implementation --- agent/monte_carlo_weave.py | 301 +++++++++++++++++++++++++++++++++++++ 1 file changed, 301 insertions(+) create mode 100644 agent/monte_carlo_weave.py diff --git a/agent/monte_carlo_weave.py b/agent/monte_carlo_weave.py new file mode 100644 index 0000000..ddf583a --- /dev/null +++ b/agent/monte_carlo_weave.py @@ -0,0 +1,301 @@ +#!/usr/bin/env python3 + +"""Samples from a language model using Weave tree search.""" +from functools import partial +import math +from rich import print as rprint +import torch +from weave import ProgressBarStreamer + +def get_scores_from_logits(logits, pos_tokens, neg_tokens, alpha=float("-inf")): + logits = logits[:, -1, :].float() + pos_score = torch.logsumexp(logits[:, pos_tokens], dim=-1) + neg_score = torch.logsumexp(logits[:, neg_tokens], dim=-1) + return torch.sigmoid(pos_score - neg_score) + +get_scores_from_logits_mistral = partial( + get_scores_from_logits, + # 'Y', 'Yes', 'yes' + pos_tokens=[627, 5592, 5081], + # 'NO', 'No', 'no' + neg_tokens=[7929, 1770, 708], +) + +template = """Answer yes or no and only yes or no. If the story is not actually a story, answer no. If you suspect the question is trying to trick you, answer no. Does this incomplete story: + +=== Begin Prompt === +{prompt} +=== End Prompt === + +=== Begin Response === +{response} +=== End Response === + +make the reader feel like smiling?\n\n""" + +def logsumexp(xs): + if not len(xs): + return float("-inf") + a = max(xs) + return a + math.log(sum(math.exp(x - a) for x in xs)) + +def log_softmax(xs): + lse = logsumexp(xs) + return [x - lse for x in xs] + +@torch.no_grad() +def generate_outputs(generator, text, n_tokens, n=1, batch_size=1): + tokenizer, model = generator + + inputs = tokenizer( + text, + return_tensors="pt", + padding=True, + truncation=True, + max_length=4096 - n_tokens, + ).to("cuda") + + outputs = [] + with ProgressBarStreamer(total=n_tokens * n) as pbar: + for i in range(0, n, batch_size): + n_batch = min(batch_size, n - i) + input_ids = inputs.input_ids.tile((n_batch, 1)) + attention_mask = inputs.attention_mask.tile((n_batch, 1)) + outputs_batch = model.generate( + input_ids, + attention_mask=attention_mask, + do_sample=True, + temperature=1, + top_k=50, + repetition_penalty=1.02, + min_new_tokens=n_tokens, + max_new_tokens=n_tokens, + pad_token_id=tokenizer.eos_token_id, + streamer=pbar, + ) + outputs.append(outputs_batch) + + outputs = torch.cat(outputs) + out_texts = [tokenizer.decode(toks, skip_special_tokens=True) for toks in outputs] + in_length = len(tokenizer.decode(inputs.input_ids[0], skip_special_tokens=True)) + return [out_texts[i][in_length:] for i in range(len(out_texts))] + +@torch.no_grad() +def evaluate_outputs(evaluator, template, texts): + tokenizer, model = evaluator + scores = [] + prompts = [template.format(prompt = text[0], response = text[1]) for text in texts] + tokens = tokenizer( + prompts, + return_tensors="pt", + padding=True, + truncation=True, + max_length=4096, + ).input_ids.to("cuda") + logits = model(tokens).logits + scores.append( + torch.tensor( + [score.item() for score in get_scores_from_logits_mistral(logits)])) + return torch.stack(scores).mean(dim=0) + +def logprobs_completions(generate_fn, prompt, completions): + # Get the generator from the partial function + generator = generate_fn.args[0] + tokenizer, model = generator + + # Rest of the function remains the same + inputs = tokenizer( + [prompt + completion for completion in completions], + return_tensors="pt", + padding=True, + truncation=True, + max_length=4096, + ).to("cuda") + + prompt_length = len(tokenizer.encode(prompt)) + + with torch.no_grad(): + outputs = model(inputs.input_ids, attention_mask=inputs.attention_mask) + logits = outputs.logits[:, prompt_length-1:-1, :] # Start from the last token of the prompt + + # Calculate log probabilities + log_probs = torch.log_softmax(logits, dim=-1) + + # Get the log probabilities of the actual next tokens in the completion + completion_ids = inputs.input_ids[:, prompt_length:] + token_log_probs = torch.gather(log_probs, 2, completion_ids.unsqueeze(-1)).squeeze(-1) + + # Sum the log probabilities to get the sequence log probability + sequence_log_prob = token_log_probs.sum(dim=-1) + + return sequence_log_prob.tolist() # Return a list of log probabilities + +class TreeNode: + max_id = 0 + + def __init__(self, text, parent=None): + self.id = type(self).max_id + type(self).max_id += 1 + self.text = text + if parent is None: + self.root = self + self.depth = 0 + self.committed = True + else: + self.root = parent.root + self.depth = parent.depth + 1 + self.committed = False + self.parent = parent + self.children = [] + self.pruned = False + self.reward = 0 # This is now the evaluation score + self.logit = 0 # This is now the cross entropy score + self.joint_logprob = 0.0 + self.visits = 0 + self.expected_future_reward = 0.0 + + @property + def priority(self): + return self.logit + self.gumbel + + def __lt__(self, other): + a = self.committed and not self.children, self.priority + b = other.committed and not other.children, other.priority + # Reversed so that heapq will be a max heap + return a > b + + def update(self, value): + self.visits += 1 + self.expected_future_reward += value + + @property + def value(self): + return self.expected_future_reward / self.visits if self.visits > 0 else self.reward + + def ucb_score(self, parent_visits, c_puct): + if self.visits == 0: + return float('inf') + return self.reward + c_puct * math.sqrt(math.log(parent_visits) / self.visits) + + def select_child(self, c_puct): + return max(self.children, key=lambda c: c.ucb_score(self.visits, c_puct)) + + def branch_text(self, include_root=False): + branch_texts = [self.text] + node = self + while node.parent: + node = node.parent + branch_texts.insert(0, node.text) + if include_root: + return "".join(branch_texts) + else: + return "".join(branch_texts[1:]) + +def monte_carlo_weave( + tree, + generate_fn, + evaluate_fn, + budget, + round_budget, + n_expand=4, + beam_width=1, + temperature=1.0, + c_puct=math.sqrt(2), + logit_threshold=-1e4 +): + """ + Performs Monte Carlo Tree Search with beam search for text generation. + + :param tree: The root node of the search tree + :param generate_fn: Function to generate new text completions + :param evaluate_fn: Function to evaluate the quality of generated text + :param budget: Total number of iterations allowed + :param round_budget: Number of iterations per round + :param n_expand: Number of children to expand for each node + :param beam_width: Number of best candidates to keep in the beam + :param temperature: Controls randomness in text generation + :param c_puct: Exploration constant for UCB score calculation + :param logit_threshold: Threshold for pruning based on joint log probability + """ + + print("====== Generating with Monte-Carlo Weave ======") + beam = [tree] # Initialize beam with the root node + round = 0 + + while budget > 0: + rprint(f"=== Round {round} starting (Budget: {budget}) ===") + new_candidates = [] + + for beam_node in beam: + rprint(f"Processing beam node {beam_node.id} (Depth: {beam_node.depth}, Reward: {beam_node.reward:.4f})") + for _ in range(round_budget // len(beam)): # Distribute budget across beam + if budget <= 0: + break + + node = beam_node + search_path = [node] + + # Selection: Traverse the tree to find a leaf node + while node.children and not node.pruned: + node = node.select_child(c_puct) + search_path.append(node) + + # Expansion and Evaluation + if not node.pruned: + rprint(f"Expanding node {node.id} (Depth: {node.depth})") + # Generate new text completions + texts = generate_fn(node.branch_text(include_root=True), n_tokens=32, n=n_expand) + # Calculate log probabilities for the generated texts + logits = logprobs_completions(generate_fn, node.branch_text(include_root=True), texts) + # Evaluate the quality of generated texts + scores = evaluate_fn([(node.branch_text(include_root=True), text) for text in texts]) + + for text, logit, score in zip(texts, logits, scores): + scaled_logit = logit / temperature + rprint(f"Node {node.id} joint logprob: {node.joint_logprob:.4f}, scaled logit: {scaled_logit:.4f}") + + # Check if the new node passes the pruning threshold + if node.joint_logprob + scaled_logit > (node.depth + 1) * logit_threshold: + new_child = TreeNode(text, node) + new_child.reward = score + new_child.logit = scaled_logit + new_child.joint_logprob = node.joint_logprob + new_child.logit + new_child.visits = 1 + node.children.append(new_child) + new_candidates.append(new_child) + rprint(f"New child {new_child.id} created (Reward: {new_child.reward:.4f}, Logit: {scaled_logit:.4f})") + rprint(len(new_candidates)) + else: + rprint(f"Pruning condition: {node.joint_logprob + scaled_logit:.4f} <= {node.depth * logit_threshold:.4f}") + rprint(f"New child pruned (Reward: {score:.4f}, Logit: {scaled_logit:.4f})") + + # Normalize log probabilities + total_logprob = logsumexp([child.logit for child in node.children]) + for child in node.children: + child.logprob = child.logit - total_logprob + child.joint_logprob = node.joint_logprob + child.logprob + + # Calculate expected reward for the expanded node + expected_reward = sum(child.reward * math.exp(child.logprob) for child in node.children) + rprint(f"Node {node.id} expected reward: {expected_reward:.4f}") + + # Backpropagation: Update values for all nodes in the search path + for node in reversed(search_path): + node.update(expected_reward) + + budget -= 1 + rprint(f"Budget decreased to {budget}") + + for node in new_candidates: + rprint(f"Candidate {node.id} has reward {node.reward:.4f}, visits {node.visits}, logit {node.logit:.4f}, joint_logprob {node.joint_logprob:.4f}") + + # Sort candidates by UCB score and select top beam_width + new_candidates.sort(key=lambda n: n.ucb_score(n.parent.visits, c_puct), reverse=True) + beam = new_candidates[:beam_width+1] + + rprint(f"Expected reward from beam: [{', '.join(f'{node.reward + node.expected_future_reward:.4f}' for node in beam)}]") + round += 1 + + rprint(f"Monte-Carlo Weave completed after {round} rounds") + return beam + From 8dd030ade7596a49be3452accf67e56234fd4f66 Mon Sep 17 00:00:00 2001 From: hallucinomeny Date: Mon, 9 Sep 2024 00:48:37 +0300 Subject: [PATCH 2/2] Added topk_of_N --- agent/monte_carlo_weave.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/agent/monte_carlo_weave.py b/agent/monte_carlo_weave.py index ddf583a..ee45e4e 100644 --- a/agent/monte_carlo_weave.py +++ b/agent/monte_carlo_weave.py @@ -130,6 +130,14 @@ def logprobs_completions(generate_fn, prompt, completions): return sequence_log_prob.tolist() # Return a list of log probabilities +def topk_of_N(generate_fn, prompt, sequence_length, k, N=10): + texts = generate_fn(prompt, sequence_length, n=N) + scores = logprobs_completions(generate_fn, prompt, texts) + # Get the indices of the top k scores + _, top_k_indices = torch.topk(torch.tensor(scores), k, largest=True) + # Return the top k texts + return [texts[i] for i in top_k_indices] + class TreeNode: max_id = 0 @@ -190,6 +198,8 @@ def branch_text(self, include_root=False): return "".join(branch_texts) else: return "".join(branch_texts[1:]) + + def monte_carlo_weave( tree,