Fix issue with double newlines in codeblocks

This commit is contained in:
Hayden Johnson 2024-09-25 12:44:28 -07:00
parent b2f218639f
commit f454e9e17b

View file

@ -41,6 +41,7 @@ def highlight_code(language_name, code):
lexer = get_lexer_by_name('bash') lexer = get_lexer_by_name('bash')
else: else:
# If no language is specified, guess the lexer # If no language is specified, guess the lexer
print("LEXER NAME " + lexer_name)
lexer = guess_lexer(code.split('\n')[1:-1]) lexer = guess_lexer(code.split('\n')[1:-1])
if not lexer: if not lexer:
# If no lexer is guessed, default to bash # If no lexer is guessed, default to bash
@ -48,16 +49,15 @@ def highlight_code(language_name, code):
formatter = TerminalFormatter() formatter = TerminalFormatter()
just_code = '' just_code = code.split('\n')[0]
newlines = '\n'.join(code.split('\n')[1:])
# if code is a code block, strip surrounding block markers
lines = code.split('\n') lines = code.split('\n')
# Just a single line of code, without code blocks around it if (len(lines) > 2) and ('```' in lines[0]) and ('```' in lines[-1]):
if len(lines) == 2:
just_code = code
else:
just_code = '\n'.join(code.split('\n')[1:-1]) just_code = '\n'.join(code.split('\n')[1:-1])
highlighted_code = pygments.highlight(just_code, lexer, formatter) highlighted_code = pygments.highlight(just_code, lexer, formatter)
return highlighted_code return highlighted_code + newlines
def extract_code_block(markdown_text): def extract_code_block(markdown_text):
# Use the regular expression pattern to find all matches in the markdown text # Use the regular expression pattern to find all matches in the markdown text
@ -113,23 +113,22 @@ def chat(message, stream=True):
large_chunk.append(text) large_chunk.append(text)
large_text = ''.join(large_chunk) large_text = ''.join(large_chunk)
# Syntax highlight if possible # update language if entering or leaving code block
# check if highlighting can be done
if ('\n' in large_text) and ('```' in large_text): if ('\n' in large_text) and ('```' in large_text):
language = large_text.split('```')[1].split('\n')[0] language = large_text.split('```')[1].split('\n')[0]
if language == '':
language = None
print(large_text, end='', flush=True) print(large_text, end='', flush=True)
large_chunk = [] large_chunk = []
large_text = '' large_text = ''
if language == '':
language = None
# Only print full lines
if '\n' in large_text: if '\n' in large_text:
output = large_text output = large_text
if language != None: if language:
output = highlight_code(language, output) output = highlight_code(language, output)
print(output, end='', flush=True) print(output, end='', flush=True)
large_chunk = [] large_chunk = []
# print(highlighted_text, end='', flush=True)
result += text result += text
if not stream: if not stream:
result = completion['message']['content'] result = completion['message']['content']
@ -194,7 +193,6 @@ def handle_non_piped_input(args):
if args.follow_up: if args.follow_up:
user_input = arg_follow_up(args) user_input = arg_follow_up(args)
result = chat(user_input) result = chat(user_input)
code_blocks = extract_code_block(result)
exit() exit()
while True: while True:
try: try: