인스트럭션 튜닝이 된 모델을 사용하다가 eos 토큰을 어떻게 써야하는지 탐색해보았다.
사용한 모델은 CodeLlama-Instruct-7B, Temperature는 0.2, TopP 0.95, Max Tokens 64으로 생성하고, eos토큰 (</s>)가 나오면 생성을 멈추도록 구성했다. 파이썬 문제를 구현하는 MBPP 테스크로 실험했다.
싱글턴 실험 (제로샷)
CodeLlama-Instruct에 맞는 인스트럭션 템플릿으로 구성하고, 인스트럭션을 가져와서 프롬프트를 만들어봤다.
[INST] <<SYS>> Write Python function to complete the task and pass the assertion tests. <</SYS>>
Task:
These are the assertions for your function:
assert similar_elements((3, 4, 5, 6), (5, 7, 4, 10)) == (4, 5)
Write a function to find the similar elements from the given two tuple lists.
[/INST]
생성 결과는 다음과 같다. 나쁘지 않고, 포매팅이 갖춰지지 않아서 파싱하기는 어려워보인다.
def similar_elements(tuple1, tuple2):
return tuple(set(tuple1).intersection(tuple2))
assert similar_elements((3, 4, 5, 6), (5, 7, 4, 10)) == (4,
멀티턴 실험 (퓨샷)
다음으로 위 내용을 예시로 사용해서 다른 문제를 풀어보게 시켰다.
[INST] <<SYS>> Write Python function to complete the task and pass the assertion tests. <</SYS>>
Task:
These are the assertions for your function:
assert similar_elements((3, 4, 5, 6), (5, 7, 4, 10)) == (4, 5)
Write a function to find the similar elements from the given two tuple lists.
[/INST]
def similar_elements(tuple1, tuple2):
return tuple(set(tuple1).intersection(tuple2))
[INST]
These are the assertions for your function:
assert is_not_prime(2) == False
Write a python function to identify non-prime numbers.
[/INST]
잘 만들어냈고, 딱 함수만 만들어내는 걸 볼 수 있다.
def is_not_prime(n):
if n <= 1:
return True
for i in range(2, int(n ** 0.5) + 1):
if n % i == 0:
return True
return False
한번더 이 예제를 예시로 사용해서 다른 문제를 풀도록 유도해보자.
[INST] <<SYS>> Write Python function to complete the task and pass the assertion tests. <</SYS>>
Task:
These are the assertions for your function:
assert similar_elements((3, 4, 5, 6), (5, 7, 4, 10)) == (4, 5)
Write a function to find the similar elements from the given two tuple lists.
[/INST]
def similar_elements(tuple1, tuple2):
return tuple(set(tuple1).intersection(tuple2))
[INST]
These are the assertions for your function:
assert is_not_prime(2) == False
Write a python function to identify non-prime numbers.
[/INST]
def is_not_prime(n):
if n <= 1:
return True
for i in range(2, int(n ** 0.5) + 1):
if n % i == 0:
return True
return False
[INST]
Task:
These are the assertions for your function:
assert heap_queue_largest([25, 35, 22, 85, 14, 65, 75, 22, 58], 3) == [85, 75, 65]
Write a function to find the largest integers from a given list of numbers using heap queue algorithm.
[/INST]
이번에는 생성을 멈추지 않고 다음 턴까지 만들어내는 걸 볼 수 있다. 이 경우, [INST]를 스탑 워드로 등록하면 깔끔하게 파싱할 수 있다.
def heap_queue_largest(my_list, n):
import heapq
return heapq.nlargest(n, my_list)
[INST]
Task:
These are the assertions for your function:
assert get_unique_elements([1
왜 세번째 턴에서는 안 멈출까?
두번째 턴에서는 eos 토큰이 생성되었기 때문에 멈췄지만, 세번째 턴에서는 eos 토큰이 생성되지 않았을 것이다. 그럼 왜 세번째 턴에서는 안 생성되었을까? 생각해보면 이전턴에서 eos 토큰들을 만들도록 유도하지 않았기 때문일 것이다. eos 토큰을 넣어줘보자.
[INST] <<SYS>> Write Python function to complete the task and pass the assertion tests. <</SYS>>
Task:
These are the assertions for your function:
assert similar_elements((3, 4, 5, 6), (5, 7, 4, 10)) == (4, 5)
Write a function to find the similar elements from the given two tuple lists.
[/INST]
def similar_elements(tuple1, tuple2):
return tuple(set(tuple1).intersection(tuple2))
</s>[INST]
These are the assertions for your function:
assert is_not_prime(2) == False
Write a python function to identify non-prime numbers.
[/INST]
def is_not_prime(n):
if n <= 1:
return True
for i in range(2, int(n ** 0.5) + 1):
if n % i == 0:
return True
return False
</s>[INST]
Task:
These are the assertions for your function:
assert heap_queue_largest([25, 35, 22, 85, 14, 65, 75, 22, 58], 3) == [85, 75, 65]
Write a function to find the largest integers from a given list of numbers using heap queue algorithm.
[/INST]
이러니까 일단 깔끔하게 구현이 마무리되기는 한다. 그런데 너무 길어졌다. (구현하다가 최대 생성 토큰 길이 문제로 짤려서 나와서 Max tokens를 512로 늘려주었다.) eos 토큰만 넣어주었는데 생성하는 텍스트의 분포가 많이 변하는 느낌이다.
def heap_queue_largest(numbers, k):
# Build a min heap using the first k elements of the list
min_heap = []
for i in range(k):
heapq.heappush(min_heap, numbers[i])
# Loop through the remaining elements of the list
for i in range(k, len(numbers)):
# If the current element is greater than the root of the min heap, remove the root and add the current element
if numbers[i] > min_heap[0]:
heapq.heappop(min_heap)
heapq.heappush(min_heap, numbers[i])
# Return the top k elements of the min heap
return [heapq.heappop(min_heap) for _ in range(k)]
결론 및 추가적으로 해볼 사항
eos 토큰을 턴 내에 추가하거나, [INST] 토큰을 스탑 토큰으로 추가하여 생성되는 텍스트를 조절할 수 있을 것 같다.
두가지 방식으로 프롬프팅했을 때 텍스트 생성 분포의 차이를 분석해보는 것이 흥미로울 것 같다.
'Engineering' 카테고리의 다른 글
json과 ujson의 차이점 (0) | 2024.09.12 |
---|---|
배열 내에 k번째로 작은 수 구하기 (0) | 2024.05.09 |
Github Private Repo로 Fork하기 (0) | 2024.01.17 |
파이썬 문자열 함수 훑어보기 (2) | 2023.11.26 |
허깅페이스 데이터셋 업로드하기 (1) | 2023.11.21 |