manopriyonath commited on
Commit
de68b77
·
verified ·
1 Parent(s): 74a1346

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +51 -14
main.py CHANGED
@@ -1,26 +1,63 @@
1
- from fastapi import FastAPI
 
2
  from fastapi.middleware.cors import CORSMiddleware
3
- from pydantic import BaseModel
4
  from transformers import pipeline
5
 
6
- # Load a chat model (you can choose a bigger one)
7
- generator = pipeline('text-generation', model='gpt2') # gpt2 is small; for better, use GPT-Neo
 
 
8
 
9
- app = FastAPI()
10
-
11
- # Allow CORS for your frontend
12
  app.add_middleware(
13
  CORSMiddleware,
14
  allow_origins=["*"],
 
15
  allow_methods=["*"],
16
- allow_headers=["*"]
17
  )
18
 
19
- class Prompt(BaseModel):
20
- prompt: str
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
  @app.get("/chat")
23
- def chat(prompt: str):
24
- # Generate response
25
- response = generator(prompt, max_length=100, do_sample=True)[0]['generated_text']
26
- return {"reply": response}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # main.py
2
+ from fastapi import FastAPI, Query
3
  from fastapi.middleware.cors import CORSMiddleware
 
4
  from transformers import pipeline
5
 
6
+ # ---------------------------
7
+ # Initialize FastAPI app
8
+ # ---------------------------
9
+ app = FastAPI(title="Smart AI Backend with Memory")
10
 
11
+ # Allow all origins (so you can call from any frontend)
 
 
12
  app.add_middleware(
13
  CORSMiddleware,
14
  allow_origins=["*"],
15
+ allow_credentials=True,
16
  allow_methods=["*"],
17
+ allow_headers=["*"],
18
  )
19
 
20
+ # ---------------------------
21
+ # Load small AI model
22
+ # ---------------------------
23
+ generator = pipeline("text-generation", model="distilgpt2")
24
+
25
+ # ---------------------------
26
+ # Conversation memory
27
+ # ---------------------------
28
+ conversation_history = []
29
+
30
+ # ---------------------------
31
+ # API endpoints
32
+ # ---------------------------
33
+
34
+ @app.get("/")
35
+ def home():
36
+ return {"message": "Backend is working"}
37
 
38
  @app.get("/chat")
39
+ def chat(prompt: str = Query(..., description="User input text")):
40
+ """
41
+ Returns AI-generated reply based on user prompt
42
+ and keeps conversation history in memory.
43
+ """
44
+ try:
45
+ # Add user input to conversation history
46
+ conversation_history.append(f"User: {prompt}")
47
+
48
+ # Combine history as input for AI
49
+ input_text = "\n".join(conversation_history)
50
+
51
+ # Generate AI response
52
+ output = generator(input_text, max_length=150, do_sample=True)
53
+ response_text = output[0]['generated_text']
54
+
55
+ # Optionally, remove repeated prompt from generated text
56
+ response_clean = response_text.replace(input_text, "").strip()
57
+
58
+ # Save AI response to conversation history
59
+ conversation_history.append(f"AI: {response_clean}")
60
+
61
+ return {"reply": response_clean}
62
+ except Exception as e:
63
+ return {"error": str(e)}