|
1 | 1 | import {Construct} from "constructs" |
| 2 | +import * as crypto from "crypto" |
2 | 3 | import { |
3 | 4 | BedrockFoundationModel, |
| 5 | + ChatMessage, |
4 | 6 | Prompt, |
5 | 7 | PromptVariant |
6 | 8 | } from "@cdklabs/generative-ai-cdk-constructs/lib/cdk-lib/bedrock" |
7 | 9 | import {BedrockPromptSettings} from "./BedrockPromptSettings" |
| 10 | +import {CfnPrompt} from "aws-cdk-lib/aws-bedrock" |
8 | 11 |
|
9 | 12 | export interface BedrockPromptResourcesProps { |
10 | 13 | readonly stackName: string |
11 | 14 | readonly settings: BedrockPromptSettings |
12 | 15 | } |
13 | 16 |
|
14 | 17 | export class BedrockPromptResources extends Construct { |
15 | | - public readonly queryReformulationPrompt: Prompt |
| 18 | + public readonly reformulationPrompt: Prompt |
16 | 19 | public readonly ragResponsePrompt: Prompt |
17 | | - public readonly ragModelId: string |
18 | | - public readonly queryReformulationModelId: string |
| 20 | + public readonly modelId: string |
19 | 21 |
|
20 | 22 | constructor(scope: Construct, id: string, props: BedrockPromptResourcesProps) { |
21 | 23 | super(scope, id) |
22 | 24 |
|
23 | | - const ragModel = new BedrockFoundationModel("meta.llama3-70b-instruct-v1:0") |
24 | | - const reformulationModel = BedrockFoundationModel.AMAZON_NOVA_LITE_V1 |
| 25 | + const aiModel = new BedrockFoundationModel("meta.llama3-70b-instruct-v1:0") |
25 | 26 |
|
26 | | - const queryReformulationPromptVariant = PromptVariant.text({ |
27 | | - variantName: "default", |
28 | | - model: reformulationModel, |
29 | | - promptVariables: ["topic"], |
30 | | - promptText: props.settings.reformulationPrompt.text |
31 | | - }) |
| 27 | + // Create Prompts |
| 28 | + this.reformulationPrompt = this.createPrompt( |
| 29 | + "ReformulationPrompt", |
| 30 | + `${props.stackName}-reformulation`, |
| 31 | + "Prompt for reformulation queries to improve RAG inference", |
| 32 | + aiModel, |
| 33 | + "", |
| 34 | + [props.settings.reformulationPrompt], |
| 35 | + props.settings.reformulationInferenceConfig |
| 36 | + ) |
32 | 37 |
|
33 | | - const queryReformulationPrompt = new Prompt(this, "QueryReformulationPrompt", { |
34 | | - promptName: `${props.stackName}-queryReformulation`, |
35 | | - description: "Prompt for reformulating user queries to improve RAG retrieval", |
36 | | - defaultVariant: queryReformulationPromptVariant, |
37 | | - variants: [queryReformulationPromptVariant] |
38 | | - }) |
| 38 | + this.ragResponsePrompt = this.createPrompt( |
| 39 | + "RagResponsePrompt", |
| 40 | + `${props.stackName}-ragResponse`, |
| 41 | + "Prompt for generating RAG responses with knowledge base context and system instructions", |
| 42 | + aiModel, |
| 43 | + props.settings.systemPrompt.text, |
| 44 | + [props.settings.userPrompt], |
| 45 | + props.settings.ragInferenceConfig |
| 46 | + ) |
| 47 | + |
| 48 | + this.modelId = aiModel.modelId |
| 49 | + } |
39 | 50 |
|
40 | | - const ragResponsePromptVariant = PromptVariant.chat({ |
| 51 | + private createPrompt( |
| 52 | + id: string, |
| 53 | + promptName: string, |
| 54 | + description: string, |
| 55 | + model: BedrockFoundationModel, |
| 56 | + systemPromptText: string, |
| 57 | + messages: [ChatMessage], |
| 58 | + inferenceConfig: CfnPrompt.PromptModelInferenceConfigurationProperty |
| 59 | + ): Prompt { |
| 60 | + |
| 61 | + const variant = PromptVariant.chat({ |
41 | 62 | variantName: "default", |
42 | | - model: ragModel, |
43 | | - promptVariables: ["query", "search_results"], |
44 | | - system: props.settings.systemPrompt.text, |
45 | | - messages: [props.settings.userPrompt] |
| 63 | + model: model, |
| 64 | + promptVariables: ["prompt", "search_results"], |
| 65 | + system: systemPromptText, |
| 66 | + messages: messages |
46 | 67 | }) |
47 | 68 |
|
48 | | - ragResponsePromptVariant.inferenceConfiguration = { |
49 | | - text: props.settings.inferenceConfig |
| 69 | + variant.inferenceConfiguration = { |
| 70 | + text: inferenceConfig |
50 | 71 | } |
51 | 72 |
|
52 | | - const ragPrompt = new Prompt(this, "ragResponsePrompt", { |
53 | | - promptName: `${props.stackName}-ragResponse`, |
54 | | - description: "Prompt for generating RAG responses with knowledge base context and system instructions", |
55 | | - defaultVariant: ragResponsePromptVariant, |
56 | | - variants: [ragResponsePromptVariant] |
57 | | - }) |
58 | | - |
59 | | - // expose model IDs for use in Lambda environment variables |
60 | | - this.ragModelId = ragModel.modelId |
61 | | - this.queryReformulationModelId = reformulationModel.modelId |
| 73 | + const hash = crypto.createHash("md5") |
| 74 | + .update(JSON.stringify(variant)) |
| 75 | + .digest("hex") |
| 76 | + .substring(0, 6) |
62 | 77 |
|
63 | | - this.queryReformulationPrompt = queryReformulationPrompt |
64 | | - this.ragResponsePrompt = ragPrompt |
| 78 | + return new Prompt(this, id, { |
| 79 | + promptName: `${promptName}-${hash}`, |
| 80 | + description, |
| 81 | + defaultVariant: variant, |
| 82 | + variants: [variant] |
| 83 | + }) |
65 | 84 | } |
66 | 85 | } |
0 commit comments