diff --git a/refactored.py b/refactored.py index a607b1f..8d176e1 100644 --- a/refactored.py +++ b/refactored.py @@ -1,13 +1,20 @@ +#/bin/python +# AI Assistant in the terminal +import argparse +import os +import sys + +import json from ollama import Client import re import pyperclip -import sys -import argparse + import pygments from pygments.lexers import get_lexer_by_name, guess_lexer from pygments.formatters import TerminalFormatter -import os -import json + +from prompt_toolkit.key_binding import KeyBindings +from prompt_toolkit import PromptSession class AIAssistant: def __init__(self, server="http://localhost:11434", model="gemma3:12b"): @@ -17,7 +24,6 @@ class AIAssistant: self.temperature = 0.2 self.num_ctx = 4096 self.history = [self.system_prompt()] - self.load_history() def set_host(self, host): self.server = host @@ -39,6 +45,33 @@ class AIAssistant: with open(path, 'w+') as f: json.dump(self.history, f) + def determine_lexer(self, code_block): + lexer_name = None + lines = code_block.split('\n') + for line in lines: + if line.strip().startswith('```'): + lexer_part = line.strip().split('```')[1].strip() + if lexer_part: + lexer_name = lexer_part + break + elif line.strip().startswith('lang:'): + lexer_part = line.strip().split(':')[1].strip() + if lexer_part: + lexer_name = lexer_part + break + return lexer_name + + def highlight_code(self, lexer_name, code): + try: + lexer = get_lexer_by_name(lexer_name) if lexer_name else guess_lexer(code) + except ValueError: + lexer = guess_lexer('\n'.join(code.split('\n')[1:-1])) + if not lexer: + lexer = get_lexer_by_name('bash') + formatter = TerminalFormatter() + highlighted_code = pygments.highlight(code, lexer, formatter) + return highlighted_code + def chat(self, message, stream=True): self.history.append({"role": "user", "content": message}) completion = self.client.chat( @@ -49,11 +82,28 @@ class AIAssistant: ) result = '' large_chunk = [] + language = None + for chunk in completion: text = chunk['message']['content'] large_chunk.append(text) + large_text = ''.join(large_chunk) + + if ('```' in large_text) and ('\n' in large_text.split('```')[1]): + language = large_text.split('```')[1].split('\n')[0] + large_chunk = [] + if language == '': + print(large_text, end='', flush=True) + language = None + if stream: - print(text, end='', flush=True) + if language and ('\n' in large_text) and large_chunk: + output = self.highlight_code(language, large_text) + print(output, end='', flush=True) + large_chunk = [] + elif not language or not large_chunk: + print(text, end='', flush=True) + if not stream: result = completion['message']['content'] else: @@ -82,9 +132,17 @@ class CommandLineParser: def parse(self): return self.parser.parse_args() +# Keybindings +bindings = KeyBindings() + +@bindings.add('c-d') +def _(event): + event.current_buffer.validate_and_handle() + class InputHandler: def __init__(self, assistant): self.assistant = assistant + self.session = PromptSession(multiline=True, prompt_continuation='', key_bindings=bindings) def handle_input(self, args): if not sys.stdin.isatty(): @@ -98,14 +156,33 @@ class InputHandler: if args.copy: query += 'Answer the question using a codeblock for any code or shell scripts\n' if args.follow_up: - second_input = input('> ') + second_input = improved_input() query += f'\n{second_input}' result = self.assistant.chat(query, stream=False) blocks = self.extract_code_block(result) if args.copy and len(blocks): pyperclip.copy(blocks[0]) + def arg_shell(args): + query = ''' +Form a shell command based on the following description. Only output a working shell command. Format the command like this: `command` + +Description: + ''' + if args.shell != True: + query += args.shell + else: + query += self.improved_input() + result = self.assistant.chat(query, stream=False) + result = blocks[0] if len(blocks := self.extract_code_block(result)) else result + print(result) + copy_string_to_clipboard(result) + def handle_interactive_input(self, args): + if args.shell: + self.arg_shell(args) + exit() + print("\033[91massistant\033[0m: Type your message (press Ctrl+D to send):") while True: try: @@ -114,32 +191,23 @@ class InputHandler: break if full_input.strip() == '': continue - print("\033[91massistant\033[0m: ", end='') result = self.assistant.chat(full_input) print() except (EOFError, KeyboardInterrupt): print("\nExiting...") break - def improved_input(self): - lines = [] - print("> ", end='', flush=True) - while True: - try: - line = sys.stdin.readline() - if not line: - print(flush=True) - break - if line.strip() == '': - continue - lines.append(line) - if len(lines) > 1 and lines[-2] == '\n' and lines[-1] == '\n': - break - except KeyboardInterrupt: - print("\nUser aborted input") - return None - full_input = ''.join(lines).rstrip('\n') - return full_input + def improved_input(self, prompt="> "): + """ + Returns the full text (including embedded newlines) when you press Ctrl-D. + Arrow keys edit within or across lines automatically. + """ + try: + text = self.session.prompt(prompt) + return text + except KeyboardInterrupt: + print("\nUser aborted input") + return None def extract_code_block(self, text): pattern = r'```[a-z]*\n[\s\S]*?\n```' @@ -147,8 +215,8 @@ class InputHandler: matches = re.finditer(pattern, text) for match in matches: code_block = match.group(0) - lexer_name = self.determine_lexer(code_block) - highlighted_code = self.highlight_code(lexer_name, code_block) + lexer_name = self.assistant.determine_lexer(code_block) + highlighted_code = self.assistant.highlight_code(lexer_name, code_block) code_blocks.append(highlighted_code) if not code_blocks: line_pattern = r'`[a-z]*[\s\S]*?`' @@ -158,33 +226,6 @@ class InputHandler: code_blocks.append(code_block[1:-1]) return code_blocks - def determine_lexer(self, code_block): - lexer_name = None - lines = code_block.split('\n') - for line in lines: - if line.strip().startswith('```'): - lexer_part = line.strip().split('```')[1].strip() - if lexer_part: - lexer_name = lexer_part - break - elif line.strip().startswith('lang:'): - lexer_part = line.strip().split(':')[1].strip() - if lexer_part: - lexer_name = lexer_part - break - return lexer_name - - def highlight_code(self, lexer_name, code): - try: - lexer = get_lexer_by_name(lexer_name) if lexer_name else guess_lexer(code) - except ValueError: - lexer = guess_lexer('\n'.join(code.split('\n')[1:-1])) - if not lexer: - lexer = get_lexer_by_name('bash') - formatter = TerminalFormatter() - highlighted_code = pygments.highlight(code, lexer, formatter) - return highlighted_code - class CommandParser: def __init__(self): self.commands = { @@ -252,6 +293,8 @@ def main(): if args.new: assistant.history = [assistant.system_prompt()] assistant.save_history() + else: + assistant.load_history() input_handler = InputHandler(assistant) input_handler.handle_input(args)