빅데이타 & 머신러닝/생성형 AI (ChatGPT etc)

생성형 AI로 코드 품질을 높이는 방법

Terry Cho 2025. 1. 4. 13:54

쓰레드에서 @choi.openai 라는 분이 LLM 모델로 코드를 생성할때, 

"LLM에게 "더 나은 코드를 작성해달라"고 반복적으로 요구하면 실제로 더 나은 코드를 생성할 수 있다는 결과가 나왔습니다."

라는 이야기가 있어서 테스트를 해봤다. 이 내용의 원본은 https://minimaxir.com/2025/01/write-better-code/ 이다.

 

 

Gemini를 이용해서 다음과 같이 간단한 REST API를 FastAPI로 만드는 프롬프트를 작성하였다.

Create the REST API code with fastAPI.
- get request with HTTP POST. Input body has username,address,gender fields.
- have proper error handling logic

 

다음은 첫번째 생성된 코드이다.

from fastapi import FastAPI, HTTPException, Request
from pydantic import BaseModel, field_validator

app = FastAPI()


class User(BaseModel):
    username: str
    address: str
    gender: str

    @field_validator("gender")
    def validate_gender(cls, value):
        if value.lower() not in ["male", "female", "other"]:
            raise ValueError("Gender must be 'male', 'female', or 'other'")
        return value


@app.post("/users/", status_code=201)
async def create_user(request: Request, user: User):
    """
    Creates a new user with the provided details.

    **Input Body:**

    - **username** (str): The username of the user.
    - **address** (str): The address of the user.
    - **gender** (str): The gender of the user (male, female, other).

    **Returns:**

    - **JSON Response:** A dictionary containing the created user details or an error message.

    **Error Handling:**

    - **400 Bad Request:** If the input body is invalid or missing fields.
    - **500 Internal Server Error:** For any unexpected server-side errors.
    """
    try:
        # Accessing the raw request body as JSON (optional, for debugging)
        # raw_body = await request.json()
        # print(f"Raw request body: {raw_body}")

        # Process the user data (e.g., save to a database)
        print(f"Creating user: {user.dict()}")

        # For demonstration purposes, return the user details
        return {"message": "User created successfully", "user": user.dict()}

    except ValueError as e:
        raise HTTPException(status_code=400, detail=str(e))
    except Exception as e:
        # Log the error for debugging purposes
        print(f"An error occurred: {e}")
        raise HTTPException(status_code=500, detail="Internal server error")

 

이후에, "make the code to be better." 프롬프트를 두번 더 실행하여, 코드를 두번 더 개선하게한 결과 이다. 

FastAPI 코드내에 예제도 추가하고, 에러 처리 로직도 훨씬 정교해졌다. 

결론적으로, 코드 생성은 한번 프롬프트를 주고 끝나는 것이 아니라, 2~4번 정도 추가로 코드 개선을 요청하는 것이 좋은 품질의 코드를 얻을 수 있는 것으로 나타난다. 

from fastapi import FastAPI, HTTPException, Request, status
from pydantic import BaseModel, Field, field_validator, ValidationError
from typing import Dict, List
import logging
import uuid

# Configure logging
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)

app = FastAPI(
    title="User Management API",
    description="An API for managing user data.",
    version="0.1.0",
)

# In-memory database (replace with a real database in production)
DATABASE: Dict[str, "User"] = {}


class User(BaseModel):
    user_id: str = Field(default_factory=lambda: str(uuid.uuid4()), description="Unique identifier for the user")
    username: str = Field(..., min_length=3, max_length=50, example="johndoe", description="The username of the user")
    address: str = Field(..., example="123 Main St", description="The address of the user")
    gender: str = Field(..., example="male", description="The gender of the user")

    @field_validator("gender")
    def validate_gender(cls, value):
        allowed_genders = ["male", "female", "other"]
        if value.lower() not in allowed_genders:
            raise ValueError(f"Gender must be one of: {', '.join(allowed_genders)}")
        return value.lower()

    class Config:
        json_schema_extra = {
            "example": {
                "user_id": "f47ac10b-58cc-4372-a567-0e02b2c3d479",
                "username": "janedoe",
                "address": "456 Oak Ave",
                "gender": "female",
            }
        }


@app.post(
    "/users/",
    status_code=status.HTTP_201_CREATED,
    response_model=User,
    summary="Create a new user",
    description="Creates a new user with the provided details.",
    tags=["users"],
)
async def create_user(user: User):
    """
    Creates a new user with the provided details.

    **Input Body:**

    - **username** (str): The username of the user (3-50 characters).
    - **address** (str): The address of the user.
    - **gender** (str): The gender of the user (male, female, other).

    **Returns:**

    - **JSON Response:** The created user details, including the generated `user_id`.

    **Error Handling:**

    - **400 Bad Request:** If the input body is invalid or missing fields.
    - **409 Conflict:** If a user with the same username already exists.
    - **500 Internal Server Error:** For any unexpected server-side errors.
    """
    try:
        # Check for username conflict
        if any(u.username == user.username for u in DATABASE.values()):
            logger.warning(f"Username conflict: {user.username}")
            raise HTTPException(
                status_code=status.HTTP_409_CONFLICT,
                detail=f"User with username '{user.username}' already exists",
            )

        # Store the user in the in-memory database
        DATABASE[user.user_id] = user
        logger.info(f"User created: {user.user_id} - {user.username}")

        return user

    except ValidationError as e:
        logger.warning(f"Validation error: {e}")
        raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=e.errors())

    except Exception as e:
        logger.error(f"An unexpected error occurred: {e}")
        raise HTTPException(
            status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
            detail="Internal server error",
        )


@app.get(
    "/users/",
    response_model=List[User],
    summary="Get all users",
    description="Retrieves a list of all users.",
    tags=["users"],
)
async def get_all_users():
    """
    Retrieves a list of all users.

    **Returns:**

    - **JSON Response:** A list of all users.

    **Error Handling:**

    - **500 Internal Server Error:** For any unexpected server-side errors.
    """
    try:
        logger.info("Retrieving all users")
        return list(DATABASE.values())
    except Exception as e:
        logger.error(f"An unexpected error occurred: {e}")
        raise HTTPException(
            status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
            detail="Internal server error",
        )


@app.get(
    "/users/{user_id}",
    response_model=User,
    summary="Get user by ID",
    description="Retrieves a user by their ID.",
    tags=["users"],
)
async def get_user(user_id: str):
    """
    Retrieves a user by their ID.

    **Path Parameters:**

    - **user_id** (str): The ID of the user to retrieve.

    **Returns:**

    - **JSON Response:** The user details if found.

    **Error Handling:**

    - **400 Bad Request:** If user_id is not a valid UUID format.
    - **404 Not Found:** If no user is found with the given ID.
    - **500 Internal Server Error:** For any unexpected server-side errors.
    """
    try:
        # Check for valid user_id format
        uuid.UUID(user_id)
    except ValueError:
        logger.warning(f"Invalid user_id format: {user_id}")
        raise HTTPException(
            status_code=status.HTTP_400_BAD_REQUEST,
            detail="Invalid user_id format. Must be a valid UUID."
        )

    try:
        user = DATABASE[user_id]
        logger.info(f"Retrieving user: {user_id}")
        return user
    except KeyError:
        logger.warning(f"User not found: {user_id}")
        raise HTTPException(
            status_code=status.HTTP_404_NOT_FOUND,
            detail=f"User with ID '{user_id}' not found",
        )
    except Exception as e:
        logger.error(f"An unexpected error occurred: {e}")
        raise HTTPException(
            status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
            detail="Internal server error",
        )

 

아래는 원문에서 실험한 반복 개선 요청을 한 실험 내용입니다.

 

LLM을 활용한 코드 개선 실험

실험 설계

  • 대상 모델: Claude 3.5 Sonnet
  • 프로그래밍 언어: Python
  • 테스트 프롬프트:
    • "1과 100,000 사이의 100만 개의 임의 정수 목록에서 각 자릿수의 합이 30이 되는 가장 작은 수와 가장 큰 수의 차이를 찾는 Python 코드를 작성하세요."
    • 이 프롬프트는 초보적인 수준의 개발자도 이해하고 구현할 수 있을 만큼 간단하지만, 최적화의 여지가 많은 문제입니다. 또한, LeetCode나 HackerRank와 같은 코딩 테스트에서 가져온 것이 아닌 완전히 독창적인 프롬프트여야 합니다. (LLM이 해당 문제들을 학습했을 가능성이 높아, 암기한 답을 그대로 출력하여 "cheating"할 가능성이 있기 때문입니다.)
  • 반복 프롬프트: "더 나은 코드를 작성해 줘" (write better code)
  • 평가 지표:
    • 코드 실행 시간 (M3 Pro Macbook Pro 기준)
    • 코드의 정확성
    • 코드의 가독성 및 복잡성

실험 과정 및 결과 (캐주얼 프롬프팅)

  1. 초기 요청: Claude는 기본적인 Python 코드를 생성했습니다. 이 코드는 정확했지만, digit_sum() 함수에서 str과 int 사이의 형 변환으로 인한 불필요한 오버헤드가 발생하는 등 최적화되지 않았습니다. (평균 실행 시간: 657ms)
  2. 반복 #1: Claude는 코드를 클래스로 리팩토링하고, 정수 연산을 사용하여 자릿수 합계를 계산하고, 모든 자릿수 합계를 미리 계산하여 바이트 배열에 저장하는 등 두 가지 중요한 알고리즘 개선을 수행했습니다. (평균 실행 시간: 243ms, 초기 구현 대비 2.7배 빠름)
  3. 반복 #2: Claude는 멀티스레딩과 numpy 벡터 연산을 도입하여 추가적인 최적화를 수행했습니다. 그러나, 이 구현은 병렬 처리 부분에서 버그가 있었고(수정 필요), generator 사용 문제와 numpy 배열 타입 혼용 문제도 발생했습니다. (수정 후 평균 실행 시간: 129ms, 초기 구현 대비 5.1배 빠름)
  4. 반복 #3: Claude는 큰 알고리즘 개선 없이 코드를 더 복잡하게 만들었고, 성능이 약간 저하되었습니다. (평균 실행 시간: 160ms, 초기 구현 대비 4.1배 빠름)
  5. 반복 #4: Claude는 "엔터프라이즈 수준 기능"을 추가하면서 코드를 더욱 복잡하게 만들었습니다. Numba를 사용한 JIT 컴파일과 asyncio를 사용한 병렬 처리를 도입했고, Prometheus를 사용한 구조화된 메트릭 로깅, 강제 종료 시 코드를 안전하게 종료하기 위한 시그널 핸들러, rich 테이블을 사용한 벤치마킹 결과 표시 등의 기능을 추가했습니다. 코드는 버그 없이 실행되었고, 성능이 크게 향상되었습니다. (평균 실행 시간: 6ms, 초기 구현 대비 100배 빠름)