Add syntax highlighting to output

This commit is contained in:
Hayden Johnson 2024-09-25 11:56:48 -07:00
parent 63c38a371f
commit 19f430e408

View file

@ -5,6 +5,10 @@ import re
import pyperclip import pyperclip
import sys import sys
import argparse import argparse
import pygments
from pygments.lexers import get_lexer_by_name
from pygments.formatters import TerminalFormatter
model = 'llama3.1:8b-instruct-q8_0' model = 'llama3.1:8b-instruct-q8_0'
temp = 0.2 temp = 0.2
@ -12,6 +16,49 @@ temp = 0.2
pattern = r'```[a-z]*\n[\s\S]*?\n```' pattern = r'```[a-z]*\n[\s\S]*?\n```'
line_pattern = r'`[a-z]*[\s\S]*?`' line_pattern = r'`[a-z]*[\s\S]*?`'
def highlight_code(language_name, code):
# Check if the language is specified in the first line
lexer_name = language_name
if lexer_name == None:
lines = code.split('\n')
for line in lines:
if line.strip().startswith('```'):
lexer_name = line.strip().split('```')[1].strip()
break
elif line.strip().startswith('lang:'):
lexer_name = line.strip().split(':')[1].strip()
break
if lexer_name:
try:
# Try to get the lexer by name
lexer = get_lexer_by_name(lexer_name)
except ValueError:
# If the lexer is not found, guess it
lexer = guess_lexer(code.split('\n')[1:-1])
if not lexer:
# If no lexer is guessed, default to bash
lexer = get_lexer_by_name('bash')
else:
# If no language is specified, guess the lexer
lexer = guess_lexer(code.split('\n')[1:-1])
if not lexer:
# If no lexer is guessed, default to bash
lexer = get_lexer_by_name('bash')
formatter = TerminalFormatter()
just_code = ''
lines = code.split('\n')
# Just a single line of code, without code blocks around it
if len(lines) == 2:
just_code = code
else:
just_code = '\n'.join(code.split('\n')[1:-1])
highlighted_code = pygments.highlight(just_code, lexer, formatter)
return highlighted_code
def extract_code_block(markdown_text): def extract_code_block(markdown_text):
# Use the regular expression pattern to find all matches in the markdown text # Use the regular expression pattern to find all matches in the markdown text
matches = re.finditer(pattern, markdown_text) matches = re.finditer(pattern, markdown_text)
@ -21,8 +68,11 @@ def extract_code_block(markdown_text):
for match in matches: for match in matches:
code_block = match.group(0) code_block = match.group(0)
# Add the code block to the list of code blocks highlighted_code = highlight_code(None, code_block)
code_blocks.append('\n'.join(code_block.split('\n')[1:-1]))
# Add the highlighted code block to the list of code blocks
code_blocks.append(highlighted_code)
if len(code_blocks) == 0: if len(code_blocks) == 0:
line_matches = re.finditer(line_pattern, markdown_text) line_matches = re.finditer(line_pattern, markdown_text)
@ -55,10 +105,32 @@ def chat(message, stream=True):
stream=stream stream=stream
) )
result = '' result = ''
language = ''
large_chunk = []
for chunk in completion: for chunk in completion:
if stream: if stream:
print(chunk['message']['content'], end='', flush=True) text = chunk['message']['content']
result += chunk['message']['content'] large_chunk.append(text)
large_text = ''.join(large_chunk)
# Syntax highlight if possible
# check if highlighting can be done
if ('\n' in large_text) and ('```' in large_text):
language = large_text.split('```')[1].split('\n')[0]
print(large_text, end='', flush=True)
large_chunk = []
large_text = ''
if language == '':
language = None
if '\n' in large_text:
output = large_text
if language != None:
output = highlight_code(language, output)
print(output, end='', flush=True)
large_chunk = []
# print(highlighted_text, end='', flush=True)
result += text
if not stream: if not stream:
result = completion['message']['content'] result = completion['message']['content']
if stream: if stream:
@ -98,7 +170,7 @@ def arg_shell(args):
query += input('> ') query += input('> ')
result = chat(query, False) result = chat(query, False)
result = blocks[0] if len(blocks := extract_code_block(result)) else result result = blocks[0] if len(blocks := extract_code_block(result)) else result
print(result) print(blocks)
copy_string_to_clipboard(result) copy_string_to_clipboard(result)
def handle_piped_input(args): def handle_piped_input(args):
@ -121,7 +193,8 @@ def handle_non_piped_input(args):
exit() exit()
if args.follow_up: if args.follow_up:
user_input = arg_follow_up(args) user_input = arg_follow_up(args)
chat(user_input) result = chat(user_input)
code_blocks = extract_code_block(result)
exit() exit()
while True: while True:
try: try: