From 19f430e4086fc8c2ab022af86cc538b451849dc8 Mon Sep 17 00:00:00 2001 From: Hayden Johnson Date: Wed, 25 Sep 2024 11:56:48 -0700 Subject: [PATCH] Add syntax highlighting to output --- assistant.py | 85 ++++++++++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 79 insertions(+), 6 deletions(-) diff --git a/assistant.py b/assistant.py index 823a5f6..506d87c 100755 --- a/assistant.py +++ b/assistant.py @@ -5,6 +5,10 @@ import re import pyperclip import sys import argparse +import pygments +from pygments.lexers import get_lexer_by_name +from pygments.formatters import TerminalFormatter + model = 'llama3.1:8b-instruct-q8_0' temp = 0.2 @@ -12,6 +16,49 @@ temp = 0.2 pattern = r'```[a-z]*\n[\s\S]*?\n```' line_pattern = r'`[a-z]*[\s\S]*?`' +def highlight_code(language_name, code): + # Check if the language is specified in the first line + lexer_name = language_name + if lexer_name == None: + lines = code.split('\n') + for line in lines: + if line.strip().startswith('```'): + lexer_name = line.strip().split('```')[1].strip() + break + elif line.strip().startswith('lang:'): + lexer_name = line.strip().split(':')[1].strip() + break + + if lexer_name: + try: + # Try to get the lexer by name + lexer = get_lexer_by_name(lexer_name) + except ValueError: + # If the lexer is not found, guess it + lexer = guess_lexer(code.split('\n')[1:-1]) + if not lexer: + # If no lexer is guessed, default to bash + lexer = get_lexer_by_name('bash') + else: + # If no language is specified, guess the lexer + lexer = guess_lexer(code.split('\n')[1:-1]) + if not lexer: + # If no lexer is guessed, default to bash + lexer = get_lexer_by_name('bash') + + formatter = TerminalFormatter() + + just_code = '' + lines = code.split('\n') + # Just a single line of code, without code blocks around it + if len(lines) == 2: + just_code = code + else: + just_code = '\n'.join(code.split('\n')[1:-1]) + + highlighted_code = pygments.highlight(just_code, lexer, formatter) + return highlighted_code + def extract_code_block(markdown_text): # Use the regular expression pattern to find all matches in the markdown text matches = re.finditer(pattern, markdown_text) @@ -21,8 +68,11 @@ def extract_code_block(markdown_text): for match in matches: code_block = match.group(0) - # Add the code block to the list of code blocks - code_blocks.append('\n'.join(code_block.split('\n')[1:-1])) + highlighted_code = highlight_code(None, code_block) + + # Add the highlighted code block to the list of code blocks + code_blocks.append(highlighted_code) + if len(code_blocks) == 0: line_matches = re.finditer(line_pattern, markdown_text) @@ -55,10 +105,32 @@ def chat(message, stream=True): stream=stream ) result = '' + language = '' + large_chunk = [] for chunk in completion: if stream: - print(chunk['message']['content'], end='', flush=True) - result += chunk['message']['content'] + text = chunk['message']['content'] + large_chunk.append(text) + large_text = ''.join(large_chunk) + + # Syntax highlight if possible + # check if highlighting can be done + if ('\n' in large_text) and ('```' in large_text): + language = large_text.split('```')[1].split('\n')[0] + print(large_text, end='', flush=True) + large_chunk = [] + large_text = '' + if language == '': + language = None + + if '\n' in large_text: + output = large_text + if language != None: + output = highlight_code(language, output) + print(output, end='', flush=True) + large_chunk = [] + # print(highlighted_text, end='', flush=True) + result += text if not stream: result = completion['message']['content'] if stream: @@ -98,7 +170,7 @@ def arg_shell(args): query += input('> ') result = chat(query, False) result = blocks[0] if len(blocks := extract_code_block(result)) else result - print(result) + print(blocks) copy_string_to_clipboard(result) def handle_piped_input(args): @@ -121,7 +193,8 @@ def handle_non_piped_input(args): exit() if args.follow_up: user_input = arg_follow_up(args) - chat(user_input) + result = chat(user_input) + code_blocks = extract_code_block(result) exit() while True: try: