add back highlighting into refactored version

This commit is contained in:
Hayden Johnson 2025-04-26 14:20:58 -07:00
parent 4b38d5cb6a
commit c0d83584c7

View file

@ -1,13 +1,20 @@
#/bin/python
# AI Assistant in the terminal
import argparse
import os
import sys
import json
from ollama import Client
import re
import pyperclip
import sys
import argparse
import pygments
from pygments.lexers import get_lexer_by_name, guess_lexer
from pygments.formatters import TerminalFormatter
import os
import json
from prompt_toolkit.key_binding import KeyBindings
from prompt_toolkit import PromptSession
class AIAssistant:
def __init__(self, server="http://localhost:11434", model="gemma3:12b"):
@ -17,7 +24,6 @@ class AIAssistant:
self.temperature = 0.2
self.num_ctx = 4096
self.history = [self.system_prompt()]
self.load_history()
def set_host(self, host):
self.server = host
@ -39,6 +45,33 @@ class AIAssistant:
with open(path, 'w+') as f:
json.dump(self.history, f)
def determine_lexer(self, code_block):
lexer_name = None
lines = code_block.split('\n')
for line in lines:
if line.strip().startswith('```'):
lexer_part = line.strip().split('```')[1].strip()
if lexer_part:
lexer_name = lexer_part
break
elif line.strip().startswith('lang:'):
lexer_part = line.strip().split(':')[1].strip()
if lexer_part:
lexer_name = lexer_part
break
return lexer_name
def highlight_code(self, lexer_name, code):
try:
lexer = get_lexer_by_name(lexer_name) if lexer_name else guess_lexer(code)
except ValueError:
lexer = guess_lexer('\n'.join(code.split('\n')[1:-1]))
if not lexer:
lexer = get_lexer_by_name('bash')
formatter = TerminalFormatter()
highlighted_code = pygments.highlight(code, lexer, formatter)
return highlighted_code
def chat(self, message, stream=True):
self.history.append({"role": "user", "content": message})
completion = self.client.chat(
@ -49,11 +82,28 @@ class AIAssistant:
)
result = ''
large_chunk = []
language = None
for chunk in completion:
text = chunk['message']['content']
large_chunk.append(text)
large_text = ''.join(large_chunk)
if ('```' in large_text) and ('\n' in large_text.split('```')[1]):
language = large_text.split('```')[1].split('\n')[0]
large_chunk = []
if language == '':
print(large_text, end='', flush=True)
language = None
if stream:
if language and ('\n' in large_text) and large_chunk:
output = self.highlight_code(language, large_text)
print(output, end='', flush=True)
large_chunk = []
elif not language or not large_chunk:
print(text, end='', flush=True)
if not stream:
result = completion['message']['content']
else:
@ -82,9 +132,17 @@ class CommandLineParser:
def parse(self):
return self.parser.parse_args()
# Keybindings
bindings = KeyBindings()
@bindings.add('c-d')
def _(event):
event.current_buffer.validate_and_handle()
class InputHandler:
def __init__(self, assistant):
self.assistant = assistant
self.session = PromptSession(multiline=True, prompt_continuation='', key_bindings=bindings)
def handle_input(self, args):
if not sys.stdin.isatty():
@ -98,14 +156,33 @@ class InputHandler:
if args.copy:
query += 'Answer the question using a codeblock for any code or shell scripts\n'
if args.follow_up:
second_input = input('> ')
second_input = improved_input()
query += f'\n{second_input}'
result = self.assistant.chat(query, stream=False)
blocks = self.extract_code_block(result)
if args.copy and len(blocks):
pyperclip.copy(blocks[0])
def arg_shell(args):
query = '''
Form a shell command based on the following description. Only output a working shell command. Format the command like this: `command`
Description:
'''
if args.shell != True:
query += args.shell
else:
query += self.improved_input()
result = self.assistant.chat(query, stream=False)
result = blocks[0] if len(blocks := self.extract_code_block(result)) else result
print(result)
copy_string_to_clipboard(result)
def handle_interactive_input(self, args):
if args.shell:
self.arg_shell(args)
exit()
print("\033[91massistant\033[0m: Type your message (press Ctrl+D to send):")
while True:
try:
@ -114,32 +191,23 @@ class InputHandler:
break
if full_input.strip() == '':
continue
print("\033[91massistant\033[0m: ", end='')
result = self.assistant.chat(full_input)
print()
except (EOFError, KeyboardInterrupt):
print("\nExiting...")
break
def improved_input(self):
lines = []
print("> ", end='', flush=True)
while True:
def improved_input(self, prompt="> "):
"""
Returns the full text (including embedded newlines) when you press Ctrl-D.
Arrow keys edit within or across lines automatically.
"""
try:
line = sys.stdin.readline()
if not line:
print(flush=True)
break
if line.strip() == '':
continue
lines.append(line)
if len(lines) > 1 and lines[-2] == '\n' and lines[-1] == '\n':
break
text = self.session.prompt(prompt)
return text
except KeyboardInterrupt:
print("\nUser aborted input")
return None
full_input = ''.join(lines).rstrip('\n')
return full_input
def extract_code_block(self, text):
pattern = r'```[a-z]*\n[\s\S]*?\n```'
@ -147,8 +215,8 @@ class InputHandler:
matches = re.finditer(pattern, text)
for match in matches:
code_block = match.group(0)
lexer_name = self.determine_lexer(code_block)
highlighted_code = self.highlight_code(lexer_name, code_block)
lexer_name = self.assistant.determine_lexer(code_block)
highlighted_code = self.assistant.highlight_code(lexer_name, code_block)
code_blocks.append(highlighted_code)
if not code_blocks:
line_pattern = r'`[a-z]*[\s\S]*?`'
@ -158,33 +226,6 @@ class InputHandler:
code_blocks.append(code_block[1:-1])
return code_blocks
def determine_lexer(self, code_block):
lexer_name = None
lines = code_block.split('\n')
for line in lines:
if line.strip().startswith('```'):
lexer_part = line.strip().split('```')[1].strip()
if lexer_part:
lexer_name = lexer_part
break
elif line.strip().startswith('lang:'):
lexer_part = line.strip().split(':')[1].strip()
if lexer_part:
lexer_name = lexer_part
break
return lexer_name
def highlight_code(self, lexer_name, code):
try:
lexer = get_lexer_by_name(lexer_name) if lexer_name else guess_lexer(code)
except ValueError:
lexer = guess_lexer('\n'.join(code.split('\n')[1:-1]))
if not lexer:
lexer = get_lexer_by_name('bash')
formatter = TerminalFormatter()
highlighted_code = pygments.highlight(code, lexer, formatter)
return highlighted_code
class CommandParser:
def __init__(self):
self.commands = {
@ -252,6 +293,8 @@ def main():
if args.new:
assistant.history = [assistant.system_prompt()]
assistant.save_history()
else:
assistant.load_history()
input_handler = InputHandler(assistant)
input_handler.handle_input(args)