Compare commits

..

3 commits

2 changed files with 121 additions and 85 deletions

View file

@ -4,6 +4,9 @@ from ollama import Client
import re import re
import pyperclip import pyperclip
import sys import sys
import tty
import termios
import signal
import argparse import argparse
import pygments import pygments
from pygments.lexers import get_lexer_by_name, guess_lexer from pygments.lexers import get_lexer_by_name, guess_lexer
@ -11,6 +14,9 @@ from pygments.formatters import TerminalFormatter
import os import os
import json import json
from prompt_toolkit.key_binding import KeyBindings
from prompt_toolkit import PromptSession
server = 'localhost:11434' server = 'localhost:11434'
model = 'gemma3:12b' model = 'gemma3:12b'
reasoning_model='deepseek-r1:14b' reasoning_model='deepseek-r1:14b'
@ -344,37 +350,24 @@ def handle_piped_input(args):
if args.copy and len(blocks): if args.copy and len(blocks):
copy_string_to_clipboard(blocks[0]) copy_string_to_clipboard(blocks[0])
def improved_input(): kb = KeyBindings()
@kb.add('c-d')
def _(event):
event.current_buffer.validate_and_handle()
session = PromptSession(multiline=True, prompt_continuation='', key_bindings=kb)
def improved_input(prompt="> "):
""" """
Handles multi-line input and prevents accidental sending of pasted content. Returns the full text (including embedded newlines) when you press Ctrl-D.
Returns the complete input text when the user indicates they're done. Arrow keys edit within or across lines automatically.
""" """
lines = []
print("> ", end='', flush=True)
while True:
try: try:
line = sys.stdin.readline() text = session.prompt(prompt)
if not line: return text
print("", flush=True) # Some visual feedback that user input has ended
break # EOF
if line.strip() == '':
continue # ignore empty lines
lines.append(line)
# Check for termination signal (Ctrl+D or two consecutive empty lines)
if len(lines) > 1 and lines[-2] == '\n' and lines[-1] == '\n':
break
except KeyboardInterrupt: except KeyboardInterrupt:
print("\nUser aborted input") print("\nUser aborted input")
return None return None
except Exception as e:
print(f"Error reading input: {e}")
return None
# Join the lines and strip any trailing newlines
full_input = ''.join(lines).rstrip('\n')
return full_input
def handle_non_piped_input(args): def handle_non_piped_input(args):
if args.shell: if args.shell:

View file

@ -1,13 +1,20 @@
#/bin/python
# AI Assistant in the terminal
import argparse
import os
import sys
import json
from ollama import Client from ollama import Client
import re import re
import pyperclip import pyperclip
import sys
import argparse
import pygments import pygments
from pygments.lexers import get_lexer_by_name, guess_lexer from pygments.lexers import get_lexer_by_name, guess_lexer
from pygments.formatters import TerminalFormatter from pygments.formatters import TerminalFormatter
import os
import json from prompt_toolkit.key_binding import KeyBindings
from prompt_toolkit import PromptSession
class AIAssistant: class AIAssistant:
def __init__(self, server="http://localhost:11434", model="gemma3:12b"): def __init__(self, server="http://localhost:11434", model="gemma3:12b"):
@ -17,7 +24,6 @@ class AIAssistant:
self.temperature = 0.2 self.temperature = 0.2
self.num_ctx = 4096 self.num_ctx = 4096
self.history = [self.system_prompt()] self.history = [self.system_prompt()]
self.load_history()
def set_host(self, host): def set_host(self, host):
self.server = host self.server = host
@ -39,6 +45,33 @@ class AIAssistant:
with open(path, 'w+') as f: with open(path, 'w+') as f:
json.dump(self.history, f) json.dump(self.history, f)
def determine_lexer(self, code_block):
lexer_name = None
lines = code_block.split('\n')
for line in lines:
if line.strip().startswith('```'):
lexer_part = line.strip().split('```')[1].strip()
if lexer_part:
lexer_name = lexer_part
break
elif line.strip().startswith('lang:'):
lexer_part = line.strip().split(':')[1].strip()
if lexer_part:
lexer_name = lexer_part
break
return lexer_name
def highlight_code(self, lexer_name, code):
try:
lexer = get_lexer_by_name(lexer_name) if lexer_name else guess_lexer(code)
except ValueError:
lexer = guess_lexer('\n'.join(code.split('\n')[1:-1]))
if not lexer:
lexer = get_lexer_by_name('bash')
formatter = TerminalFormatter()
highlighted_code = pygments.highlight(code, lexer, formatter)
return highlighted_code
def chat(self, message, stream=True): def chat(self, message, stream=True):
self.history.append({"role": "user", "content": message}) self.history.append({"role": "user", "content": message})
completion = self.client.chat( completion = self.client.chat(
@ -49,11 +82,28 @@ class AIAssistant:
) )
result = '' result = ''
large_chunk = [] large_chunk = []
language = None
for chunk in completion: for chunk in completion:
text = chunk['message']['content'] text = chunk['message']['content']
large_chunk.append(text) large_chunk.append(text)
large_text = ''.join(large_chunk)
if ('```' in large_text) and ('\n' in large_text.split('```')[1]):
language = large_text.split('```')[1].split('\n')[0]
large_chunk = []
if language == '':
print(large_text, end='', flush=True)
language = None
if stream: if stream:
if language and ('\n' in large_text) and large_chunk:
output = self.highlight_code(language, large_text)
print(output, end='', flush=True)
large_chunk = []
elif not language or not large_chunk:
print(text, end='', flush=True) print(text, end='', flush=True)
if not stream: if not stream:
result = completion['message']['content'] result = completion['message']['content']
else: else:
@ -82,9 +132,17 @@ class CommandLineParser:
def parse(self): def parse(self):
return self.parser.parse_args() return self.parser.parse_args()
# Keybindings
bindings = KeyBindings()
@bindings.add('c-d')
def _(event):
event.current_buffer.validate_and_handle()
class InputHandler: class InputHandler:
def __init__(self, assistant): def __init__(self, assistant):
self.assistant = assistant self.assistant = assistant
self.session = PromptSession(multiline=True, prompt_continuation='', key_bindings=bindings)
def handle_input(self, args): def handle_input(self, args):
if not sys.stdin.isatty(): if not sys.stdin.isatty():
@ -98,14 +156,33 @@ class InputHandler:
if args.copy: if args.copy:
query += 'Answer the question using a codeblock for any code or shell scripts\n' query += 'Answer the question using a codeblock for any code or shell scripts\n'
if args.follow_up: if args.follow_up:
second_input = input('> ') second_input = improved_input()
query += f'\n{second_input}' query += f'\n{second_input}'
result = self.assistant.chat(query, stream=False) result = self.assistant.chat(query, stream=False)
blocks = self.extract_code_block(result) blocks = self.extract_code_block(result)
if args.copy and len(blocks): if args.copy and len(blocks):
pyperclip.copy(blocks[0]) pyperclip.copy(blocks[0])
def arg_shell(args):
query = '''
Form a shell command based on the following description. Only output a working shell command. Format the command like this: `command`
Description:
'''
if args.shell != True:
query += args.shell
else:
query += self.improved_input()
result = self.assistant.chat(query, stream=False)
result = blocks[0] if len(blocks := self.extract_code_block(result)) else result
print(result)
copy_string_to_clipboard(result)
def handle_interactive_input(self, args): def handle_interactive_input(self, args):
if args.shell:
self.arg_shell(args)
exit()
print("\033[91massistant\033[0m: Type your message (press Ctrl+D to send):") print("\033[91massistant\033[0m: Type your message (press Ctrl+D to send):")
while True: while True:
try: try:
@ -114,32 +191,23 @@ class InputHandler:
break break
if full_input.strip() == '': if full_input.strip() == '':
continue continue
print("\033[91massistant\033[0m: ", end='')
result = self.assistant.chat(full_input) result = self.assistant.chat(full_input)
print() print()
except (EOFError, KeyboardInterrupt): except (EOFError, KeyboardInterrupt):
print("\nExiting...") print("\nExiting...")
break break
def improved_input(self): def improved_input(self, prompt="> "):
lines = [] """
print("> ", end='', flush=True) Returns the full text (including embedded newlines) when you press Ctrl-D.
while True: Arrow keys edit within or across lines automatically.
"""
try: try:
line = sys.stdin.readline() text = self.session.prompt(prompt)
if not line: return text
print(flush=True)
break
if line.strip() == '':
continue
lines.append(line)
if len(lines) > 1 and lines[-2] == '\n' and lines[-1] == '\n':
break
except KeyboardInterrupt: except KeyboardInterrupt:
print("\nUser aborted input") print("\nUser aborted input")
return None return None
full_input = ''.join(lines).rstrip('\n')
return full_input
def extract_code_block(self, text): def extract_code_block(self, text):
pattern = r'```[a-z]*\n[\s\S]*?\n```' pattern = r'```[a-z]*\n[\s\S]*?\n```'
@ -147,8 +215,8 @@ class InputHandler:
matches = re.finditer(pattern, text) matches = re.finditer(pattern, text)
for match in matches: for match in matches:
code_block = match.group(0) code_block = match.group(0)
lexer_name = self.determine_lexer(code_block) lexer_name = self.assistant.determine_lexer(code_block)
highlighted_code = self.highlight_code(lexer_name, code_block) highlighted_code = self.assistant.highlight_code(lexer_name, code_block)
code_blocks.append(highlighted_code) code_blocks.append(highlighted_code)
if not code_blocks: if not code_blocks:
line_pattern = r'`[a-z]*[\s\S]*?`' line_pattern = r'`[a-z]*[\s\S]*?`'
@ -158,33 +226,6 @@ class InputHandler:
code_blocks.append(code_block[1:-1]) code_blocks.append(code_block[1:-1])
return code_blocks return code_blocks
def determine_lexer(self, code_block):
lexer_name = None
lines = code_block.split('\n')
for line in lines:
if line.strip().startswith('```'):
lexer_part = line.strip().split('```')[1].strip()
if lexer_part:
lexer_name = lexer_part
break
elif line.strip().startswith('lang:'):
lexer_part = line.strip().split(':')[1].strip()
if lexer_part:
lexer_name = lexer_part
break
return lexer_name
def highlight_code(self, lexer_name, code):
try:
lexer = get_lexer_by_name(lexer_name) if lexer_name else guess_lexer(code)
except ValueError:
lexer = guess_lexer('\n'.join(code.split('\n')[1:-1]))
if not lexer:
lexer = get_lexer_by_name('bash')
formatter = TerminalFormatter()
highlighted_code = pygments.highlight(code, lexer, formatter)
return highlighted_code
class CommandParser: class CommandParser:
def __init__(self): def __init__(self):
self.commands = { self.commands = {
@ -252,6 +293,8 @@ def main():
if args.new: if args.new:
assistant.history = [assistant.system_prompt()] assistant.history = [assistant.system_prompt()]
assistant.save_history() assistant.save_history()
else:
assistant.load_history()
input_handler = InputHandler(assistant) input_handler = InputHandler(assistant)
input_handler.handle_input(args) input_handler.handle_input(args)