# 현재 작업 디렉토리 변경 및 LLaVA 소스 코드 클론
%cd /content
!git clone -b v1.0 https://github.com/camenduru/LLaVA
%cd /content/LLaVA
# 필요한 라이브러리 설치
!pip install -q transformers==4.36.2
!pip install -q gradio .
# 필요한 패키지 및 모듈 임포트
from transformers import AutoTokenizer, BitsAndBytesConfig
from llava.model import LlavaLlamaForCausalLM
import torch
# 모델 및 토크나이저 초기화
model_path = "4bit/llava-v1.5-13b-3GB"
kwargs = {"device_map": "auto"}
kwargs['load_in_4bit'] = True
kwargs['quantization_config'] = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type='nf4'
)
model = LlavaLlamaForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
# 모델에서 비전 타워 및 관련 구성 요소 초기화
vision_tower = model.get_vision_tower()
if not vision_tower.is_loaded:
vision_tower.load_model()
vision_tower.to(device='cuda')
image_processor = vision_tower.image_processor
# 필요한 라이브러리 및 모듈 임포트
import os
import requests
from PIL import Image
from io import BytesIO
from llava.conversation import conv_templates, SeparatorStyle
from llava.utils import disable_torch_init
from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
from llava.mm_utils import tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria
from transformers import TextStreamer
# 이미지 캡션 생성 함수 정의
def caption_image(image_file, prompt):
# 이미지 파일 불러오기
if image_file.startswith('http') or image_file.startswith('https'):
response = requests.get(image_file)
image = Image.open(BytesIO(response.content)).convert('RGB')
else:
image = Image.open(image_file).convert('RGB')
# Torch 초기화 비활성화
disable_torch_init()
# 대화 템플릿 및 역할 설정
conv_mode = "llava_v0"
conv = conv_templates[conv_mode].copy()
roles = conv.roles
# 이미지 전처리 및 텐서로 변환
image_tensor = image_processor.preprocess(image, return_tensors='pt')['pixel_values'].half().cuda()
# 입력 생성 및 대화에 추가
inp = f"{roles[0]}: {prompt}"
inp = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + inp
conv.append_message(conv.roles[0], inp)
conv.append_message(conv.roles[1], None)
# 토큰화 및 생성을 위한 입력 데이터 준비
raw_prompt = conv.get_prompt()
input_ids = tokenizer_image_token(raw_prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda()
# 생성 중단 기준 설정
stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
keywords = [stop_str]
stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
# Torch 추론 모드에서 생성
with torch.inference_mode():
output_ids = model.generate(input_ids, images=image_tensor, do_sample=True, temperature=0.2,
max_new_tokens=1024, use_cache=True, stopping_criteria=[stopping_criteria])
# 출력 디코딩 및 정리
outputs = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip()
conv.messages[-1][-1] = outputs
output = outputs.rsplit('</s>', 1)[0]
return image, o